In [None]:
device = "mps"

# Dataset Building

In [None]:
dataset = [
    "dont forget to like share and subscribe",
    "dont forget machine learning is fun",
    "machine learning is fun and awesome",
    "if you like machine learning i like you",
    "i like you more than machine learning"
]

vocab = set()
special_tokens = ["<pad>", "<start>", "<end>"]
for sentence in dataset:
    vocab.update(sentence.split())
vocab = special_tokens + list(vocab)

vocab_to_index = {word: index for index, word in enumerate(vocab)}

# Tokenizing Dataset

In [None]:
def encode(sentence: str):
    return [vocab_to_index[word] for word in sentence.split()]

def encode_batch(sentences: list[str]):
    encoded_sentences = [
        [vocab_to_index["<start>"]] + encode(sentence) + [vocab_to_index["<end>"]]
        for sentence in sentences
    ]
    max_length = max([len(encoded_sentence) for encoded_sentence in encoded_sentences])
    encoded_sentences = [
        encoded_sentence + [vocab_to_index["<pad>"]] * (max_length - len(encoded_sentence))
        for encoded_sentence in encoded_sentences
    ]
    return encoded_sentences

def decode(tokens: list[int]):
    return " ".join([vocab[token] for token in tokens])

tokenized_dataset = encode_batch(dataset)
tokenized_dataset = torch.tensor(tokenized_dataset)

input_tokens = tokenized_dataset[:, :-1]
input_tokens = input_tokens.to(device)

target_tokens = tokenized_dataset[:, 1:]
target_tokens = target_tokens.to(device)

# Modelling

In [None]:
class SinusoidalPositionEncoding(nn.Module):
    def __init__(self, embed_size: int, max_seq_length: int):
        super().__init__()
        position = torch.arange(max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2) * (-math.log(10000.0) / embed_size))
        pe = torch.zeros(20, embed_size)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('positional_embedding', pe)

    def forward(self, x: torch.Tensor):
        return x + self.positional_embedding[:x.size(1), :]

class CausalLanguageModel(nn.Module):
    def __init__(self, embed_size: int, vocab_size: int, num_layers: int):
        super().__init__()
        self.embedding_layer = nn.Parameter(torch.randn(vocab_size, embed_size))
        transformer_decoder = nn.TransformerDecoderLayer(d_model=embed_size, nhead=12)
        transformer_decoder_stack = nn.TransformerDecoder(
                                                          decoder_layer=transformer_decoder,
                                                          num_layers=num_layers
                                                          )
        self.positional_encoding = SinusoidalPositionEncoding(embed_size, max_seq_length=20)

    def forward(self, x: torch.Tensor, return_attention_scores: bool = False):
        x = torch.nn.functional.embedding(x, self.embedding_layer)
        x = self.positional_encoding(x)
        x, attention_scores = self.transformer(x)
        logits = torch.matmul(x, self.embedding_layer.T)
        if return_attention_scores:
            return logits, attention_scores
        return logits

In [None]:
causal_language_model = CausalLanguageModel(embed_size=6, vocab_size=len(vocab), num_layers=2).to(device)

# Training

In [None]:
optimizer = torch.optim.Adam(causal_language_model.parameters(), lr=2e-3)

for i in range(600):
    logits = causal_language_model(input_tokens)
    loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), target_tokens.view(-1)) #since it is a classification prob we are using cross entropy loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if i % 10 == 0:
        print(f"Epoch {i}, Loss: {loss.item()}")
        pred = logits.argmax(dim=-1)

# Inference

In [None]:
def generation(prefix, max_length=10):
    input_tokens = torch.tensor([vocab_to_index["<start>"]] + encode(prefix)).to(device).unsqueeze(0)
    for _ in range(max_length):
        with torch.no_grad():
            logits = causal_language_model(input_tokens)
            last_token_logits = logits[0, -1].argmax(dim=-1, keepdim=True)
            print(decode(last_token_logits.tolist())[0][0])
            input_tokens = torch.cat([input_tokens, last_token_logits], dim=1)
            if input_tokens[0][-1] == vocab_to_index["<end>"]:
                break
    return decode(input_tokens[0].tolist())

In [None]:
generation("i like")