In [None]:
## Add imports here
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import Counter
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 50

def initialise_projections(in_dim, out_dim):
    """
    create projections for Q, K, V.
    """
    return nn.Linear(in_dim, out_dim, bias=False)

def pairwise_similarities(Q, K):
    """
    Compute dot product attention.
    """
    return torch.matmul(Q, K.transpose(-2, -1))

def attention_scaled(att_scores, d_k):
    """
    Scale the raw attention scores.
    """
    inv_sqrt = float(d_k) ** -0.5
    return att_scores * inv_sqrt

def attention_softmax(scaled_att_scores):
    """
    Normalize the scaled raw attention scores with softmax.
    """
    return F.softmax(scaled_att_scores, dim=-1)

def compute_outputs(att_probs, V):
    """
    Get outputs as a weighted sum of values by attention scores.
    """
    return torch.matmul(att_probs, V)

def make_causal_mask(size):
    """
    Create a mask matrix that masks future context for the attention.
    """
    return torch.tril(torch.ones(size, size)).to(DEVICE)

def apply_causal_mask(att_scores, mask):
    """
    Apply mask to attention.
    """
    att_scores = att_scores.masked_fill(mask == 0, float('-inf'))
    return att_scores

def split_heads(x, num_heads):
    """
    Splitting the input across multiple heads.
    """
    batch_size, seq_len, embed_dim = x.size()
    head_dim = embed_dim // num_heads
    x = x.view(batch_size, seq_len, num_heads, head_dim)
    return x.transpose(1, 2)

def merge_heads(x):
    """
    Reversing splitting action of function split_heads().
    """
    batch_size, num_heads, seq_len, head_dim = x.size()
    x = x.transpose(1, 2).contiguous()
    return x.view(batch_size, seq_len, num_heads * head_dim)

def self_attention(x, projection_q, projection_k, projection_v, num_heads):
    """
    Self-attention block.
    """
    Q = projection_q(x)
    K = projection_k(x)
    V = projection_v(x)
    Q = split_heads(Q, num_heads)
    K = split_heads(K, num_heads)
    V = split_heads(V, num_heads)
    att_scores = pairwise_similarities(Q, K)
    d_k = K.size(-1)
    att_scores = attention_scaled(att_scores, d_k)
    seq_len = x.size(1)
    mask = make_causal_mask(seq_len)
    mask = mask.unsqueeze(0).unsqueeze(1)
    att_scores = apply_causal_mask(att_scores, mask)
    att_probs = attention_softmax(att_scores)
    att_output = compute_outputs(att_probs, V)
    return merge_heads(att_output)

def split_heads_qkv(Q, K, V, num_heads):
    """
    Split Q, K, V across multiple heads.
    """
    Q = split_heads(Q, num_heads)
    K = split_heads(K, num_heads)
    V = split_heads(V, num_heads)
    return Q, K, V

def load_and_preprocess_data():
    # Read raw text splits
    with open("/content/shakespear_train.txt", "r") as f:
        lines_train = f.read().splitlines()
    with open("/content/shakespear_dev.txt", "r") as f:
        lines_dev = f.read().splitlines()
    #with open("/content/shakespear_test.txt", "r") as f:
        #lines_test = f.readlines()


    # Split each line into words for counting
    tokens_train = [ln.split() for ln in lines_train]

    # Utility function to flatten tokens
    def flat(tokens):
        flattened = []
        for group in tokens:
            flattened.extend(group)
        return flattened

    # Count frequency of each token in the training set
    token_counts = Counter(flat(tokens_train))

    # Create a tokenizer for a mapping and inverse mapping
    vocab = ["<PAD>", "<START>", "<STOP>"] + sorted(token_counts.keys())
    tokenizer = {token: idx for idx, token in enumerate(vocab)}
    tokenizer_inv = {idx: token for token, idx in tokenizer.items()}

    # Prepare datasets by converting each non-empty line into ID sequences
    def tokenizer_line(line):
        words = line.strip().split()
        if not words:
            return []
        seq = [tokenizer["<START>"]]
        for w in words:
            seq.append(tokenizer.get(w, tokenizer["<PAD>"]))
        seq.append(tokenizer["<STOP>"])
        return seq

    data_train = [tokenizer_line(ln) for ln in lines_train if ln]
    data_val   = [tokenizer_line(ln) for ln in lines_dev   if ln]

    return data_train, data_val, tokenizer, tokenizer_inv


def pad_to_length(tokens, max_len, tokenizer):
    """
    Pad tokens to a fixed length.
    """
    pad_token = tokenizer["<PAD>"]
    curr = len(tokens)
    if curr < max_len:
        num_pad = max_len - curr
        tokens = tokens + [pad_token] * num_pad
    elif curr > max_len:
        tokens = tokens[:max_len]
    return tokens

def tokenize(sentence, pad_to_len=None, tokenizer=None, include_stop=True):
    """
    Tokenize a sentence
    """
    if isinstance(sentence, list):
        tokens = sentence  # already token IDs
    else:
        words = sentence.strip().split()
        tokens = [tokenizer.get(w, tokenizer["<PAD>"]) for w in words]

    if pad_to_len:
        tokens = pad_to_length(tokens, pad_to_len, tokenizer)

    # ensure we never exceed max
    assert len(tokens) <= pad_to_len, "tokenized length > pad_to_len"
    return tokens

def decode(tokens, tokenizer_inv, end_at_stop=True, omit_pad=True):
    """
    Decode tokens to text.
    """
    result = []
    for t in tokens:
        w = tokenizer_inv.get(t, "")
        if omit_pad and w == "<PAD>":
            continue
        result.append(w)
        if end_at_stop and w == "<STOP>":
            break
    else:
        pass
    return " ".join(result)

@torch.no_grad()
def evaluate_losses(data, model, tokenizer, bs=32, progress=True, pad_to_len=MAX_LEN):
    it = range(0, len(data), bs)
    if progress:
        it = tqdm(it)

    out = []
    for b_start in it:
        batch = slice(b_start, b_start + bs)
        tokens = torch.tensor(
            [tokenize(t, pad_to_len=pad_to_len, tokenizer=tokenizer) for t in data[batch]], dtype=torch.long
        ).to(DEVICE)
        X_tokens, y_tokens = tokens[:, :-1].contiguous(), tokens[:, 1:].contiguous()

        model.eval()
        logits, _ = model(X_tokens)
        log_probs = F.log_softmax(logits, dim=-1)
        y_log_probs = torch.gather(log_probs, 2, y_tokens[..., None])[..., 0]

        for i in range(y_tokens.shape[0]):
            not_pad = y_tokens[i] != tokenizer["<PAD>"]
            loss = -y_log_probs[i, not_pad].mean()
            out.append(loss.item())

    return out
def generate_text(model, tokenizer, tokenizer_inv, context="<START>", gen_tokens=10, temperature=0.6):
    """
    Generate a fixed number of tokens using the trained model.
    """
    ## Tokenize the context
    init_ids = tokenize(context, pad_to_len=MAX_LEN, tokenizer=tokenizer)
    tokens = torch.tensor([init_ids], dtype=torch.long, device=DEVICE)

    model.eval()
    with torch.no_grad():
        for _ in range(gen_tokens):
            ## Get predictions
            seq_in = tokens if tokens.size(1) <= MAX_LEN else tokens[:, -MAX_LEN:]
            logits, _ = model(seq_in)

            ## Focus on the last token's predictions
            last_logits = logits[0, -1, :]

            ## Apply the softmax to get a probabilities
            scaled_logits = last_logits / temperature
            probs = F.softmax(scaled_logits, dim=0)

            ## Sample from the distribution
            next_tok = torch.multinomial(probs, num_samples=1)

            ## Append to the context
            tokens = torch.cat([tokens, next_tok.unsqueeze(0)], dim=1)

            ## Stop if we generated a STOP token
            if next_tok.item() == tokenizer["<STOP>"]:
                break

    ## Convert back to text
    generated_ids = tokens.squeeze(0).tolist()
    return decode(generated_ids, tokenizer_inv)




Using device : cuda

In [None]:
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_len=MAX_LEN, dropout=0.1, ff_dim=128):
        super(TransformerLM, self).__init__()

        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_len, embed_dim))
        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, dropout, ff_dim)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        self._init_weights()

    def _init_weights(self):
       nn.init.xavier_uniform_(self.token_embedding.weight)
       nn.init.xavier_uniform_(self.fc_out.weight)
       for layer in self.layers:
            layer._init_weights()

    def forward(self, x):
        emb = self.token_embedding(x)
        pos = self.positional_encoding[:, :x.size(1), :]
        hidden = self.dropout(emb + pos)
        for block in self.layers:
            hidden = block(hidden)

        logits = self.fc_out(hidden)
        return logits, hidden


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1, ff_dim=128):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout2 = nn.Dropout(dropout)

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)

    def forward(self, x):
        att = self.attention(x)
        att = self.dropout1(att)
        x = self.norm1(x + att)
        ff_out = self.ff(x)
        ff_out = self.dropout2(ff_out)
        x = self.norm2(x + ff_out)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.proj_q = nn.Linear(embed_dim, embed_dim)
        self.proj_k = nn.Linear(embed_dim, embed_dim)
        self.proj_v = nn.Linear(embed_dim, embed_dim)
        self.fc_out = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        Q = self.proj_q(x)
        K = self.proj_k(x)
        V = self.proj_v(x)
        attn_out = self_attention(x, self.proj_q, self.proj_k, self.proj_v, self.num_heads)
        attn_out = self.dropout(attn_out)
        out = self.fc_out(attn_out)
        return out


def train_model(model, train_dataset, val_dataset, tokenizer, tokenizer_inv):
    # Using AdamW optimizer with weight decay and learning rate scheduler

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer["<PAD>"])

    train_losses = []
    val_losses = []
    val_perplexities = []

    for epoch in range(10):
        model.train()
        running_train_loss = 0.0

        for sample in train_dataset:
            token_ids = tokenize(sample, pad_to_len=MAX_LEN, tokenizer=tokenizer)
            input_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(DEVICE)

            input_seq, target_seq = input_tensor[:, :-1], input_tensor[:, 1:]

            optimizer.zero_grad()
            output_logits, _ = model(input_seq)

            loss = loss_fn(output_logits.view(-1, model.fc_out.out_features), target_seq.view(-1))
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item()

        avg_train_loss = running_train_loss / len(train_dataset)
        train_losses.append(avg_train_loss)

        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for sample in val_dataset:
                token_ids = tokenize(sample, pad_to_len=MAX_LEN, tokenizer=tokenizer)
                input_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(DEVICE)

                input_seq, target_seq = input_tensor[:, :-1], input_tensor[:, 1:]
                output_logits, _ = model(input_seq)

                val_loss = loss_fn(output_logits.view(-1, model.fc_out.out_features), target_seq.view(-1))
                total_val_loss += val_loss.item()

        avg_val_loss = total_val_loss / len(val_dataset)
        val_losses.append(avg_val_loss)

        # Compute and store perplexity for validation set
        val_perplexity = np.exp(avg_val_loss)
        val_perplexities.append(val_perplexity)

        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}, Val Perplexity = {val_perplexity:.2f}")

        # Generate a sample text
        generated = generate_text(model, tokenizer, tokenizer_inv, context="<START>", gen_tokens=10)
        print(f"Sample text: {generated}")

        # Update learning rate according to cosine schedule
        scheduler.step()

    return model, train_losses, val_losses, val_perplexities, tokenizer, tokenizer_inv



def main():
    ## Load and preprocess the training and validation datasets along with the tokenizers
    train_dataset, val_dataset, tokenizer, tokenizer_inv = load_and_preprocess_data()

    ## Define model hyperparameters
    vocab_size = len(tokenizer)
    embed_dim = 128
    num_heads = 4
    num_layers = 4
    dropout = 0.1
    ff_dim = 256


    model = TransformerLM(
        vocab_size,
        embed_dim=embed_dim,
        num_heads=num_heads,
        num_layers=num_layers,
        dropout=dropout,
        ff_dim=ff_dim
    ).to(DEVICE)

    ## Display the model structure
    print(model)

    ## Train the model
    model, train_losses, val_losses, val_perplexities, tokenizer, tokenizer_inv = train_model(
        model, train_dataset, val_dataset, tokenizer, tokenizer_inv
    )

    ## Plot training and validation losses over epochs
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.legend()
    plt.show()

    ## Save the trained model parameters for later use
    torch.save(model.state_dict(), "transformer_model.pth")

    ## Compute and print the overall validation perplexity
    avg_val_loss = np.mean(val_losses)
    overall_val_perplexity = np.exp(avg_val_loss)
    print(f"\nDev Perplexity: {overall_val_perplexity}")

    # Code for evaluating test data is commented out below:
    # with open("/content/shakespear_test.txt", "r") as f:
    #     lines_test = f.readlines()
    # print(f"\nTest perplexity: {}")


if __name__ == "__main__":
    main()

In [None]:
def inference(model_path, test_file, tokenizer, tokenizer_inv, gen_tokens=10, temperature=0.6):
    ## Load the saved model
    model = TransformerLM(
        vocab_size=len(tokenizer),
        embed_dim=128,
        num_heads=4,
        num_layers=4,
        dropout=0.1,
        ff_dim=256
    ).to(DEVICE)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    ## Read and process the input from test.txt
    with open(test_file, 'r') as f:
        test_lines = [line.strip() for line in f if line.strip()]

    processed_data = []
    for line in test_lines:
        tokens = [tokenizer["<START>"]] + \
                 [tokenizer.get(word, tokenizer["<PAD>"]) for word in line.split()] + \
                 [tokenizer["<STOP>"]]
        padded = pad_to_length(tokens, MAX_LEN, tokenizer)
        processed_data.append(padded)

    ## Generate text and calculate perplexity
    test_losses = evaluate_losses(processed_data, model, tokenizer)
    perplexity = np.exp(np.mean(test_losses))

    generated_texts = []
    for line in test_lines:
        context = ' '.join(line.split()[:10])
        generated = generate_text(
            model, tokenizer, tokenizer_inv,
            context=context,
            gen_tokens=gen_tokens,
            temperature=temperature
        )
        generated_texts.append(generated)

    return generated_texts, perplexity


# Example usage
model_path = "transformer_model.pth"
test_file = "shakespear_test.txt"
_, _, tokenizer, tokenizer_inv = load_and_preprocess_data()
generated_texts, ppl = inference(model_path, test_file, tokenizer, tokenizer_inv)

## Print the generated text and perplexity
print(f"\nTest Perplexity: {ppl:.2f}")
print("\nGenerated Examples:")
for i, text in enumerate(generated_texts[:10]):
    print(f"Sample text: {i+1}: {text}")
