In [41]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):

    def __init__(self, emb_dim, max_len = 1000):
        super().__init__()
        pe = torch.zeros(max_len, emb_dim)
        position = torch.arange(0,max_len).unsqueeze(1)
        # Scaling term for the positional encoding done with Sine and Cosine
        div_term = torch.exp(torch.arange(0, emb_dim, 2) * (-torch.log(torch.tensor(10000.0))/emb_dim))

        # Add sine and cosine to even and odd positions respectively
        pe[:, 0::2] = torch.sin(position * div_term)
        if emb_dim % 2 == 0:
            pe[:, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])  # Skip last value if mismatch


        self.pe = pe.unsqueeze(0)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return x


# Test it
# PositionalEncoding(20, 100)


In [42]:
class CharTransformer(nn.Module):
    def __init__(self, vocab_size, emb_dim=256, n_heads=4, n_layers = 2, ff_dim = 512, dropout = 0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.pos_enc = PositionalEncoding(emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.fc_out = nn.Linear(emb_dim, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x) # (B,T,D)
        x = self.pos_enc(x) # Add positional encoding into the embedding
        x = x.permute(1,0,2) # Switch into shape (T, B, D) to train the transformer

        seq_len = x.size(0)
        mask = torch.triu(torch.ones(seq_len,seq_len, device=x.device), diagonal=1).bool() # causal mask
        out = self.transformer(x, mask=mask)
        out = out.permute(1,0,2) # Change the shape back to (B,T,D)
        logits = self.fc_out(out) #(B,T,V)
        return logits     

        

In [43]:
import functions_main as fm
from time import time

# Setup
seq_length = 30
batch_size = 200
hidden_size = 64
epochs = 5
learning_rate = 0.003
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layers_LSTM = 1
amount_chars = 10000

# Training block
dataloader, vocab, char2idx, idx2char, text_as_int = fm.get_dataloader(seq_length, batch_size, amount_chars=amount_chars)
model = CharTransformer(vocab_size=len(vocab)).to(device)
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
criterion = nn.CrossEntropyLoss()

start_time = time()

model.train()
for epoch in range(epochs):
    start_epoch = time()
    for x_batch, y_batch in dataloader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()

        logits = model(x_batch)
        loss = criterion(logits.view(-1, logits.size(-1)), y_batch.view(-1))
        loss.backward()
        optimizer.step()
    end_epoch = time()
    epoch_time = end_epoch - start_epoch
    print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}, Total time past {epoch_time:.2f}')
print(f'Model trained in {time() - start_time:.2f}')



Text of len 10000 is being processed.

Epoch 1, Loss: 2.1362, Total time past 8.03
Epoch 2, Loss: 1.6053, Total time past 7.69
Epoch 3, Loss: 1.3401, Total time past 7.57
Epoch 4, Loss: 1.0718, Total time past 7.66
Epoch 5, Loss: 0.9138, Total time past 7.72
Model trained in 38.67


In [44]:
import torch.nn.functional as F
def generate_transformer_text(model, start_string, char2idx, idx2char, length=200, temperature=1.0, device='cpu'):
    model.eval()
    input_ids = torch.tensor([char2idx[c] for c in start_string]).unsqueeze(0)
    generated = input_ids.clone()

    with torch.no_grad():
        for _ in range(length):
            if generated.size(1)>512:
                generated = generated[:,-512:] # truncate large context
            
            logits = model(generated) # shape (1,T, vocab_size)
            next_logits = logits[:, -1, :] / temperature # adjust temperature
            probs = F.softmax(next_logits, dim = -1)
            next_token = torch.multinomial(probs, num_samples=1)

            generated = torch.cat((generated, next_token),dim=1)

    # Convert final tensor to string
    output = ''.join([idx2char[token.item()] for token in generated[0]])

    return output


In [45]:
text = generate_transformer_text(model, start_string="To be, or not to be", char2idx=char2idx, idx2char=idx2char, length=300, temperature=0.8, device=device)
print("\nGenerated Text:\n")
print(text)


Generated Text:

To be, or not to be sloned them?
Ale ale Weale deseseselesesey Con I Con The witheses aking akin, an weakicopleak seakicheak win win don weak.
Fon weak wim weak, wicheakseang wicon deak weak, wineang y y deang wicheakeak, wak weangean on weak, an win win in win gean win win winseangrin d weak--------------------------
