# Transformer Encoder-Decoder Demo
#### We are just trying to see if we get an "output" from the model. Remember this is model is untrained, we are seeing if the model can run without errors and if it produce "words", its like seeing if an infant can babble.

In [39]:
import torch

from modules.transformer import TransformerEncoderDecoder

#### Here we are defining a basic vocabulary for source and target languages
##### The most important thing here is the <start>, <end>, and <pad> tokens, because they are used to mark the beginning and end of sentences and for padding. And also its the "trigger" for the model to start and stop generating.

In [29]:
src_vocab = {"<pad>": 0, "<start>": 1, "<end>": 2, "hello": 3, "llms": 4, "are": 5, "great": 6}
tgt_vocab = {"<pad>": 0, "<start>": 1, "<end>": 2, "hi": 3, "there": 4, "llms": 5, "are": 6, "awesome": 7}

#### Here we define a simple tokenize function to convert sentences into token IDs based on the vocabulary

In [30]:
def tokenize(sentence, vocab):
    tokens = sentence.lower().split()
    return [vocab.get(token, vocab["<pad>"]) for token in tokens]

#### Now we will create a source and target sentence and convert them into token IDs

In [31]:
src_sentence = "<start> hello llms <end>"
tgt_sentence = "<start>"

In [32]:
src_ids = torch.tensor([tokenize(src_sentence, src_vocab)], dtype=torch.long)
tgt_ids = torch.tensor([tokenize(tgt_sentence, tgt_vocab)], dtype=torch.long)

print(f"Source IDs: {src_ids}")
print(f"Target IDs: {tgt_ids}")

Source IDs: tensor([[1, 3, 4, 2]])
Target IDs: tensor([[1]])


#### Here we set up the model parameters and instantiate the Transformer model

In [33]:
src_vocab_size = len(src_vocab)
tgt_vocab_size = len(tgt_vocab)
max_seq_length = 10
embedding_dim = 32
num_encoder_layers = 2
num_decoder_layers = 2
num_heads = 4
feed_forward_dim = 64

In [34]:
model = TransformerEncoderDecoder(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    max_seq_length=max_seq_length,
    embedding_dim=embedding_dim,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    num_heads=num_heads,
    feed_forward_dim=feed_forward_dim
)

#### Now we will run the model to generate output logits

In [35]:
# Step 5: Generate output (random at this stage, since the model is untrained)
## Run until we get <end> token or max length
outputs = []
while True:
    output_logits = model(src_ids, tgt_ids)
    outputs.append(output_logits)
    predicted_ids = torch.argmax(output_logits, dim=-1)
    tgt_ids = torch.cat([tgt_ids, predicted_ids[:, -1:]], dim=1)
    if predicted_ids[0, -1].item() == tgt_vocab["<end>"] or tgt_ids.size(1) >= max_seq_length:
        break

In [36]:
def decode_ids(ids, vocab):
    inv_vocab = {v: k for k, v in vocab.items()}
    return [inv_vocab.get(id.item(), "<unk>") for id in ids]

#### Now we will decode the predicted token IDs to see the output by just choosing the highest probability token. This is a simple way to see if the model can produce any output.
##### Research has shown that this isnt the best way to generate text, but for this demo, we just want to see if the model can produce something.

In [37]:
decoded_sentence = " ".join(decode_ids(tgt_ids[0], tgt_vocab))

In [38]:
print("Input Sentence:", src_sentence)
print("Ouput Sentence:", decoded_sentence)

Input Sentence: <start> hello llms <end>
Ouput Sentence: <start> <start> <start> llms llms llms llms awesome <end>


#### We got "an" output, which is a valid, as the words are from our vocabulary. This means the model can run without errors and produce some output. 
##### P.s Did you really expect it to produce something meaningful?