# 2024 COMP90042 Project
*Make sure you change the file name with your group id.*

# Readme
*If there is something to be noted for the marker, please mention here.*

*If you are planning to implement a program with Object Oriented Programming style, please put those the bottom of this ipynb file*

# 1.DataSet Processing
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [232]:
!pip3 install torchtext==0.4.0

import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
import spacy
from torch.utils.tensorboard import SummaryWriter
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

import tensorflow as tf
import torch.nn.functional as F
import keras
import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.data import Field, LabelField, Example, Dataset

## NLP preprocessing libraries
import nltk
import spacy
from sklearn.feature_extraction.text import TfidfVectorizer

##others
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import string
import json



In [233]:
## relative file paths

## baseline system - will not be used for any training/evaluation
devClaimsBaselineFile='./data/dev-claims-baseline.json'
## use this for model training
trainClaimsFile='./data/train-claims.json'
## use this set for hyperparameter tuning and evaluation metric 
devClaimsFile='./data/dev-claims.json'
## evidence files need to be downloaded through https://drive.google.com/file/d/1JlUzRufknsHzKzvrEjgw8D3n_IRpjzo6/view?usp=sharing as it is to big to be uploaded to github
evidenceFile='./data/evidence.json'
## test unlabelled dataset
testFile='./data/test-claims-unlabelled.json'

In [234]:
# Load the JSON data
with open(trainClaimsFile, 'r') as file:
    trainClaims=json.load(file)
with open(devClaimsFile, 'r') as file:
    devClaims=json.load(file)
with open(evidenceFile, 'r') as file:
    evidenceData=json.load(file)

## Preprocessing data -- lowercase, tokenize, and stopword removal
stopwords=set(nltk.corpus.stopwords.words('english'))
tokenizer=get_tokenizer('basic_english')
punctuations=string.punctuation

def preprocess(text):
    token=tokenizer(text.lower())
    cleanedTokens=[t for t in token if (t not in stopwords) and (t not in punctuations)]
    return ' '.join(cleanedTokens)

for ids, texts in evidenceData.items():
    evidenceData[ids]=preprocess(texts)

# Function to create DataFrame and merge evidence IDs with Text
def createDF(claims, evidence):
    combinedData=[]
    for claimID, claimText in claims.items():
        # Combine the ID with its corresponding evidences
        evidenceID=claimText['evidences']
        evidenceText=(evidence[i] for i in evidenceID if i in evidence)
        combinedData.append({
            'claim_id': claimID,
            'claim_text': preprocess(claimText['claim_text']),
            'evidence_id': evidenceID,
            'evidence_text': " ".join(evidenceText),
            'claim_label': claimText['claim_label']
        })
    # Create DataFrame
    return pd.DataFrame(combinedData)

# Create CSV Files
trainFullMerged=createDF(trainClaims,evidenceData)
devFullMerged=createDF(devClaims,evidenceData)
trainFullMerged.to_csv("data/trainFullMerged.csv", index=False)
devFullMerged.to_csv("data/devFullMerged.csv", index=False)

# Convert evidence into csv as well
# evidenceFinal=pd.DataFrame(list(evidenceData.items()),columns=['evidence_id','evidence_text'])
# evidenceFinal.to_csv('data/evidencePreprocessed.csv',index=False)

# Convert unlabelled Data into CSV as well
with open(testFile, 'r') as file:
    testData=json.load(file)
for ids, texts in testData.items():
    claim_text = texts['claim_text']
    testData[ids]=preprocess(claim_text)
testFinal=pd.DataFrame(list(testData.items()), columns=['claim_id', 'claim_text'])
testFinal.to_csv('data/testPreprocessed.csv',index=False)

In [235]:
## use this for model training
trainClaimsFile='./data/trainFullMerged.csv'
## use this set for hyperparameter tuning and evaluation metric 
devClaimsFile='./data/devFullMerged.csv'
## evidence files need to be downloaded through googledrive (https://drive.google.com/file/d/1OyihwdAWfqHIOueCB4bLBkYg4hTN_OKm/view?usp=sharing)
evidenceFile='./data/evidencePreprocessed.csv'
## test unlabelled dataset
testFile='./data/testPreprocessed.csv'

trainDataframe=pd.read_csv(trainClaimsFile)
devDataframe=pd.read_csv(devClaimsFile)
evidenceDataframe=pd.read_csv(evidenceFile)
testDataframe=pd.read_csv(testFile)

trainDataframe['claim_text']=trainDataframe['claim_text']
trainDataframe['evidence_id']=trainDataframe['evidence_id'].astype(str).str.strip('[]').str.strip("'").apply(preprocess)
trainDataframe['combined_evidence']=trainDataframe['evidence_id']+" "+trainDataframe['evidence_text']

devDataframe['claim_text']=devDataframe['claim_text']
devDataframe['evidence_id']=devDataframe['evidence_id'].astype(str).str.strip('[]').str.strip("'").apply(preprocess)
devDataframe['combined_evidence']=devDataframe['evidence_id'] +" "+ devDataframe['evidence_text']

evidenceDataframe['combined_evidence']=evidenceDataframe['evidence_id']+" "+evidenceDataframe['evidence_text']
evidenceDataframe['combined_evidence']=evidenceDataframe['combined_evidence'].astype(str).str.strip("'").apply(preprocess)


In [236]:
testDataClaims=testDataframe['claim_text']

In [252]:
## Preprocessing data -- lowercase, tokenize, and stopword removal
stopwords=set(nltk.corpus.stopwords.words('english'))
tokenizer=get_tokenizer('basic_english')
punctuations=set(string.punctuation)

def preprocess(text):
    tokens=tokenizer(text)
    cleanedTokens = [token for token in tokens if token not in stopwords and token not in punctuations]
    return cleanedTokens

def createDatasetEncoderInput(dataframe,textTransform,labelTransform):
    field=[('reviewTextInput',textTransform),('evidenceTextOutput',labelTransform)]
    examples=[]
    for _, row in dataframe.iterrows():
        reviewTextInput=row['claim_text']
        evidenceTextOutput=row['combined_evidence']
        examples.append(Example.fromlist([reviewTextInput,evidenceTextOutput], field))
    return Dataset(examples,fields=field)

TEXT = Field(tokenize=preprocess, lower=True, init_token='<sos>', eos_token='<eos>', batch_first=True)
LABEL = Field(tokenize=preprocess, lower=True, init_token='<sos>', eos_token='<eos>', batch_first=True)

trainTensor = createDatasetEncoderInput(trainDataframe, TEXT, LABEL)
devTensor = createDatasetEncoderInput(devDataframe, TEXT, LABEL)

TEXT.build_vocab(trainTensor)
LABEL.build_vocab(trainTensor)

# Print vocabulary of the TEXT field
#print(TEXT.vocab.stoi)  # Shows the mapping of string to their respective indices

# Print vocabulary of the LABEL field
#print(LABEL.vocab.stoi)

def print_examples(dataset, num_examples=5):
    for i, example in enumerate(dataset):
        if i >= num_examples:
            break
        print(f"Example {i+1}:")
        for field in dataset.fields:
            print(f"{field}:", getattr(example, field))

# Usage
#print_examples(trainTensor)

def extract_and_process_data(dataset, text_field, label_field):
    # Extract tokenized text and label data
    tokenized_texts = [getattr(example, 'reviewTextInput') for example in dataset]
    tokenized_labels = [getattr(example, 'evidenceTextOutput') for example in dataset]
    
    # Process the text and label data to tensors
    text_tensors = text_field.process(tokenized_texts)
    label_tensors = label_field.process(tokenized_labels)
    
    return text_tensors, label_tensors

# Example usage for training data
train_text_tensors, train_label_tensors = extract_and_process_data(trainTensor, TEXT, LABEL)
print("Training Text Tensors:")
print(train_text_tensors)
print("Training Label Tensors:")
print(train_label_tensors)

# Example usage for development data
dev_text_tensors, dev_label_tensors = extract_and_process_data(devTensor, TEXT, LABEL)
print("Development Text Tensors:")
print(dev_text_tensors)
print("Development Label Tensors:")
print(dev_label_tensors)

def convert_to_transformer_input(text_tensor, label_tensor, fixed_length=8):
    # Ensure both tensors are on CPU and converted to numpy for manipulation
    text_numpy = text_tensor.numpy()
    label_numpy = label_tensor.numpy()

    transformer_input = []

    # Iterate through each example in the batch
    for text, label in zip(text_numpy, label_numpy):
        # Truncate or pad the text sequences to the fixed length
        fixed_text = np.pad(text, (0, max(0, fixed_length - len(text))), mode='constant')[:fixed_length]
        fixed_label = np.pad(label, (0, max(0, fixed_length - len(label))), mode='constant')[:fixed_length]

        # Convert integer indices to float
        fixed_text = fixed_text.astype(np.float32)
        fixed_label = fixed_label.astype(np.float32)

        # Add start of sequence (2) and end of sequence (3) tokens
        fixed_text[0] = 2  # Assuming '2' is <sos>
        fixed_text[-1] = 3  # Assuming '3' is <eos>
        fixed_label[0] = 2
        fixed_label[-1] = 3

        # Add to the list as a tuple of arrays
        transformer_input.append([fixed_text, fixed_label])

    return transformer_input

# Example usage with your tensors
transformer_data_train = convert_to_transformer_input(train_text_tensors, train_label_tensors)
transformer_data_dev=convert_to_transformer_input(dev_text_tensors,dev_label_tensors)

print(transformer_data_dev)

Training Text Tensors:
tensor([[   2,  104,   38,  ...,    1,    1,    1],
        [   2,  152,  270,  ...,    1,    1,    1],
        [   2, 1891,  519,  ...,    1,    1,    1],
        ...,
        [   2,   21, 3152,  ...,    1,    1,    1],
        [   2, 1225,   10,  ...,    1,    1,    1],
        [   2, 3630, 3287,  ...,    1,    1,    1]])
Training Label Tensors:
tensor([[   2, 4467, 4383,  ...,    1,    1,    1],
        [   2, 1511, 1150,  ...,    1,    1,    1],
        [   2, 8677, 4635,  ...,    1,    1,    1],
        ...,
        [   2, 8350, 1355,  ...,    1,    1,    1],
        [   2, 9745, 8773,  ...,    1,    1,    1],
        [   2, 9737, 8479,  ...,    1,    1,    1]])
Development Text Tensors:
tensor([[  2,   0,   0,  ...,   1,   1,   1],
        [  2, 468,  54,  ...,   1,   1,   1],
        [  2, 180,  75,  ...,   1,   1,   1],
        ...,
        [  2,  84, 461,  ...,   1,   1,   1],
        [  2,  64, 105,  ...,   1,   1,   1],
        [  2, 947,  70,  ...,   

# 2. Model Implementation
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [246]:
class Transformer(nn.Module):
    def __init__(
        self,
        num_tokens,
        dim_model,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        dropout_p,
    ):
        super().__init__()

        # Layers
        self.transformer = nn.Transformer(
            d_model=dim_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout_p,
        )

    def forward(self):
        pass
  
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Info
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])  
    
class Transformer(nn.Module):
    """
    Model from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    # Constructor
    def __init__(
        self,
        num_tokens,
        dim_model,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        dropout_p,
    ):
        super().__init__()

        # INFO
        self.model_type = "Transformer"
        self.dim_model = dim_model

        # LAYERS
        self.positional_encoder = PositionalEncoding(
            dim_model=dim_model, dropout_p=dropout_p, max_len=5000
        )
        self.embedding = nn.Embedding(num_tokens, dim_model)
        self.transformer = nn.Transformer(
            d_model=dim_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout_p,
        )
        self.out = nn.Linear(dim_model, num_tokens)
        
    def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        # Src size must be (batch_size, src sequence length)
        # Tgt size must be (batch_size, tgt sequence length)

        # Embedding + positional encoding - Out size = (batch_size, sequence length, dim_model)
        src = self.embedding(src) * math.sqrt(self.dim_model)
        tgt = self.embedding(tgt) * math.sqrt(self.dim_model)
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)
        
        # We could use the parameter batch_first=True, but our KDL version doesn't support it yet, so we permute
        # to obtain size (sequence length, batch_size, dim_model),
        src = src.permute(1,0,2)
        tgt = tgt.permute(1,0,2)

        # Transformer blocks - Out size = (sequence length, batch_size, num_tokens)
        transformer_out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask)
        out = self.out(transformer_out)
        
        return out
      
    def get_tgt_mask(self, size) -> torch.tensor:
        # Generates a squeare matrix where the each row allows one word more to be seen
        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
        
        # EX for size=5:
        # [[0., -inf, -inf, -inf, -inf],
        #  [0.,   0., -inf, -inf, -inf],
        #  [0.,   0.,   0., -inf, -inf],
        #  [0.,   0.,   0.,   0., -inf],
        #  [0.,   0.,   0.,   0.,   0.]]
        
        return mask
    
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        return (matrix == pad_token)


# 3.Testing and Evaluation
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [250]:
def generate_random_data(n):
    SOS_token = np.array([2])
    EOS_token = np.array([3])
    length = 8

    data = []

    # 1,1,1,1,1,1 -> 1,1,1,1,1
    for i in range(n // 3):
        X = np.concatenate((SOS_token, np.ones(length), EOS_token))
        y = np.concatenate((SOS_token, np.ones(length), EOS_token))
        data.append([X, y])

    # 0,0,0,0 -> 0,0,0,0
    for i in range(n // 3):
        X = np.concatenate((SOS_token, np.zeros(length), EOS_token))
        y = np.concatenate((SOS_token, np.zeros(length), EOS_token))
        data.append([X, y])

    # 1,0,1,0 -> 1,0,1,0,1
    for i in range(n // 3):
        X = np.zeros(length)
        start = random.randint(0, 1)

        X[start::2] = 1

        y = np.zeros(length)
        if X[-1] == 0:
            y[::2] = 1
        else:
            y[1::2] = 1

        X = np.concatenate((SOS_token, X, EOS_token))
        y = np.concatenate((SOS_token, y, EOS_token))

        data.append([X, y])

    np.random.shuffle(data)

    return data


def batchify_data(data, batch_size=16, padding=False, padding_token=-1):
    batches = []
    for idx in range(0, len(data), batch_size):
        # We make sure we dont get the last bit if its not batch_size size
        if idx + batch_size < len(data):
            # Here you would need to get the max length of the batch,
            # and normalize the length with the PAD token.
            if padding:
                max_batch_length = 0

                # Get longest sentence in batch
                for seq in data[idx : idx + batch_size]:
                    if len(seq) > max_batch_length:
                        max_batch_length = len(seq)

                # Append X padding tokens until it reaches the max length
                for seq_idx in range(batch_size):
                    remaining_length = max_batch_length - len(data[idx + seq_idx])
                    data[idx + seq_idx] += [padding_token] * remaining_length

            batches.append(np.array(data[idx : idx + batch_size]).astype(np.int64))

    print(f"{len(batches)} batches of size {batch_size}")

    return batches


train_data = generate_random_data(9000)
val_data = generate_random_data(100)
print(train_data)
train_dataloader = batchify_data(train_data)
val_dataloader = batchify_data(val_data)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(
    num_tokens=4, dim_model=8, num_heads=2, num_encoder_layers=3, num_decoder_layers=3, dropout_p=0.1
).to(device)
opt = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

def train_loop(model, opt, loss_fn, dataloader):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        X, y = batch[:, 0], batch[:, 1]
        X, y = torch.tensor(X).to(device), torch.tensor(y).to(device)

        # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
        y_input = y[:,:-1]
        y_expected = y[:,1:]
        
        # Get mask to mask out the next words
        sequence_length = y_input.size(1)
        tgt_mask = model.get_tgt_mask(sequence_length).to(device)

        # Standard training except we pass in y_input and tgt_mask
        pred = model(X, y_input, tgt_mask)

        # Permute pred to have batch size first again
        pred = pred.permute(1, 2, 0)      
        loss = loss_fn(pred, y_expected)

        opt.zero_grad()
        loss.backward()
        opt.step()
    
        total_loss += loss.detach().item()
        
    return total_loss / len(dataloader)

def validation_loop(model, loss_fn, dataloader):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch[:, 0], batch[:, 1]
            X, y = torch.tensor(X, dtype=torch.long, device=device), torch.tensor(y, dtype=torch.long, device=device)

            # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
            y_input = y[:,:-1]
            y_expected = y[:,1:]
            
            # Get mask to mask out the next words
            sequence_length = y_input.size(1)
            tgt_mask = model.get_tgt_mask(sequence_length).to(device)

            # Standard training except we pass in y_input and src_mask
            pred = model(X, y_input, tgt_mask)

            # Permute pred to have batch size first again
            pred = pred.permute(1, 2, 0)      
            loss = loss_fn(pred, y_expected)
            total_loss += loss.detach().item()
        
    return total_loss / len(dataloader)

def fit(model, opt, loss_fn, train_dataloader, val_dataloader, epochs):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    
    # Used for plotting later on
    train_loss_list, validation_loss_list = [], []
    
    print("Training and validating model")
    for epoch in range(epochs):
        print("-"*25, f"Epoch {epoch + 1}","-"*25)
        
        train_loss = train_loop(model, opt, loss_fn, train_dataloader)
        train_loss_list += [train_loss]
        
        validation_loss = validation_loop(model, loss_fn, val_dataloader)
        validation_loss_list += [validation_loss]
        
        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {validation_loss:.4f}")
        print()
        
    return train_loss_list, validation_loss_list
    
train_loss_list, validation_loss_list = fit(model, opt, loss_fn, train_dataloader, val_dataloader, 1)

[[array([2., 1., 1., 1., 1., 1., 1., 1., 1., 3.]), array([2., 1., 1., 1., 1., 1., 1., 1., 1., 3.])], [array([2., 0., 0., 0., 0., 0., 0., 0., 0., 3.]), array([2., 0., 0., 0., 0., 0., 0., 0., 0., 3.])], [array([2., 0., 0., 0., 0., 0., 0., 0., 0., 3.]), array([2., 0., 0., 0., 0., 0., 0., 0., 0., 3.])], [array([2., 1., 0., 1., 0., 1., 0., 1., 0., 3.]), array([2., 1., 0., 1., 0., 1., 0., 1., 0., 3.])], [array([2., 0., 0., 0., 0., 0., 0., 0., 0., 3.]), array([2., 0., 0., 0., 0., 0., 0., 0., 0., 3.])], [array([2., 1., 0., 1., 0., 1., 0., 1., 0., 3.]), array([2., 1., 0., 1., 0., 1., 0., 1., 0., 3.])], [array([2., 1., 0., 1., 0., 1., 0., 1., 0., 3.]), array([2., 1., 0., 1., 0., 1., 0., 1., 0., 3.])], [array([2., 0., 0., 0., 0., 0., 0., 0., 0., 3.]), array([2., 0., 0., 0., 0., 0., 0., 0., 0., 3.])], [array([2., 1., 1., 1., 1., 1., 1., 1., 1., 3.]), array([2., 1., 1., 1., 1., 1., 1., 1., 1., 3.])], [array([2., 1., 1., 1., 1., 1., 1., 1., 1., 3.]), array([2., 1., 1., 1., 1., 1., 1., 1., 1., 3.])],

In [240]:
def predict(model, input_sequence, max_length=15, SOS_token=2, EOS_token=3):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    model.eval()
    
    y_input = torch.tensor([[SOS_token]], dtype=torch.long, device=device)

    num_tokens = len(input_sequence[0])

    for _ in range(max_length):
        # Get source mask
        tgt_mask = model.get_tgt_mask(y_input.size(1)).to(device)
        
        pred = model(input_sequence, y_input, tgt_mask)
        
        next_item = pred.topk(1)[1].view(-1)[-1].item() # num with highest probability
        next_item = torch.tensor([[next_item]], device=device)

        # Concatenate previous input with predicted best word
        y_input = torch.cat((y_input, next_item), dim=1)

        # Stop if model predicts end of sentence
        if next_item.view(-1).item() == EOS_token:
            break

    return y_input.view(-1).tolist()
  
test='this is a test'
# Here we test some examples to observe how the model predicts
examples = [
    torch.tensor([[2, 0, 0, 0, 0, 0, 0, 0, 0, 3]], dtype=torch.long, device=device),
    torch.tensor([[2, 1, 1, 1, 1, 1, 1, 1, 1, 3]], dtype=torch.long, device=device),
    torch.tensor([[2, 1, 0, 1, 0, 1, 0, 1, 0, 3]], dtype=torch.long, device=device),
    torch.tensor([[2, 0, 1, 0, 1, 0, 1, 0, 1, 3]], dtype=torch.long, device=device),
    torch.tensor([[2, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3]], dtype=torch.long, device=device),
    torch.tensor([[2, 0, 1, 3]], dtype=torch.long, device=device)
]

for idx, example in enumerate(examples):
    result = predict(model, example)
    print(f"Example {idx}")
    print(f"Input: {example.view(-1).tolist()[1:-1]}")
    print(f"Continuation: {result[1:-1]}")
    print()

Example 0
Input: [0, 0, 0, 0, 0, 0, 0, 0]
Continuation: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Example 1
Input: [1, 1, 1, 1, 1, 1, 1, 1]
Continuation: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

Example 2
Input: [1, 0, 1, 0, 1, 0, 1, 0]
Continuation: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]

Example 3
Input: [0, 1, 0, 1, 0, 1, 0, 1]
Continuation: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]

Example 4
Input: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
Continuation: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]

Example 5
Input: [0, 1]
Continuation: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]



## Object Oriented Programming codes here

*You can use multiple code snippets. Just add more if needed*