# Project for Deep Learning course: Question Answering on Mathematics Dataset

# Task
> Mathematical reasoning – a
core ability within human intelligence – presents some unique challenges as a domain:
we do not come to understand and solve mathematical problems primarily on the back
of experience and evidence, but on the basis of inferring, learning, and exploiting laws,
axioms, and symbol manipulation rules. The task consists of many different types
of mathematics problems, with the motivation that it should be harder for a model
to do well across a range of problem types without possessing at least some part of
these abilities that allow for algebraic generalization.



# Import and installation

Here we install and import all the libraries that we will utilize


In [1]:
!pip install pytorch_lightning transformers datasets wget gdown

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.1.0-py3-none-any.whl (774 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m774.6/774.6 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers
  Downloading transformers-4.34.0-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting wget
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m44.7 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-ut

In [2]:
import json
import pytorch_lightning as pl
import numpy as np
import os
import wget
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer
from transformers import BertTokenizer
from torch.utils.data import DataLoader #random_split
from pytorch_lightning import Trainer, LightningDataModule, LightningModule
from datasets import load_dataset

# Parameters

Here we define all the parameters needed. The most important to choose is the dataset ("num") which allows to change the portion of the math-dataset used, and so to train the model to solve different operations. [num can vary from 0 to 4]

In [3]:
#define parameters
SEED = 42 #random seed
BATCH_SIZE = 32 #batch size
train_samples = 1000 #number of samples for training and validation set (respectively 90 and 10 %)
test_samples = 100 #number of samples for test set
lista_dataset = ["algebra__linear_1d", "arithmetic__add_or_sub", "arithmetic__div", "arithmetic__mul", "arithmetic_simplify_surd"] #list of possible dataset
device = "cuda" if torch.cuda.is_available() else "cpu" #device used

In [4]:
#choose the dataset
num = 0 #from 0 to 4

# Utils

Here we define some functions that will be utilized by each model presented. More in specific we define: a class to convert dataset into Tensors and an Accuracy to evaluate the results during and after the training

In [5]:
class Dataset_class:
    def __init__(self, q, att_q, a, att_a):
        self.input_ids = q
        self.attention_mask = att_q
        self.answer_ids = a
        self.answer_attention_mask = att_a

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, item):
        return [
            torch.tensor(self.input_ids[item], dtype=torch.long).to(device),
            torch.tensor(self.attention_mask[item], dtype=torch.long).to(device),
            torch.tensor(self.answer_ids[item], dtype=torch.long).to(device),
            torch.tensor(self.answer_attention_mask[item], dtype=torch.long).to(device)
        ]


In [6]:
#accuracy written to evaluate the model
def evaluate_accuracy(predicted, target):
    predicted = torch.argmax(predicted, dim=2)

    score_per_char = 0
    total_score = 0
    count = 0
    size_0 = predicted.size(0)
    size_1 = predicted.size(1)

    for i in range(0, size_0):
        score = 0
        for j in range(0, size_1):
            if predicted[i][j].item() == target[i][j].item():
                score += 1
            count += 1
            if target[i][j].item() == 1:
                break
        if score == count:
            total_score += 1
        score_per_char += score

    total_accuracy = total_score / size_0
    total_acc_char = score_per_char / count

    return total_accuracy, total_acc_char


# Load Dataset and tokenization

Here we define some functions to create the dictionaries, codify the dataset with the dictionaries created and create the Dataloaders. At the end we load the partiotion of the dataset chose and randomly seeelect the data needed between the data present in the partition loaded (based on the number of samples required).

In [7]:
def build_dict(data):
    char_to_idx = {}
    idx_to_char = {}
    all_characters = set()
    question_max_length = 0
    answer_max_length = 0

    for split_name in ["train", "test"]:
        split_dataset = data[split_name]

        for row in split_dataset:
            question = row['question']
            answer = row['answer']
            question_length = len(row['question'])
            answer_length = len(row['answer'])

            all_characters.update(question)
            all_characters.update(answer)

            if question_length > question_max_length:
                question_max_length = question_length

            if answer_length > answer_max_length:
                answer_max_length = answer_length

    char_idx = 0  # Initialize a variable to keep track of character indices

    # Add special tokens
    special_tokens = ['<start>', '<end>', 'PAD', 'UNK']

    for token in special_tokens:
        char_to_idx[token] = char_idx
        idx_to_char[char_idx] = token
        char_idx += 1  # Increment the index for the next special token

    # Assign each character a unique index

    for char in all_characters:
        char_to_idx[char] = char_idx
        idx_to_char[char_idx] = char
        char_idx += 1  # Increment the index for the next character

    return char_to_idx, idx_to_char, question_max_length, answer_max_length


In [8]:
def Tokenize(data, char_to_idx, max_length_q, max_length_a):
    max_length_q -= 3
    padded_question = []
    attention_mask_q = []
    for string in data["question"]:
        numerical_indices_q = [char_to_idx['<start>']] + [char_to_idx[char] if char in char_to_idx else char_to_idx['UNK'] for char in string][2:len(string) - 3] + [char_to_idx['<end>']]
        if len(numerical_indices_q) < max_length_q:
            numerical_indices_q += [char_to_idx['PAD']] * (max_length_q - len(numerical_indices_q))
        padded_question.append(numerical_indices_q)
        att = [1 for _ in range(0, len(string) - 3)] + [0 for _ in range(0, max_length_q - (len(string) - 3))]
        attention_mask_q.append(att)

    max_length_a -= 3
    padded_answer = []
    attention_mask_a = []
    for string in data["answer"]:
        numerical_indices_a = [char_to_idx['<start>']] + [char_to_idx[char] if char in char_to_idx else char_to_idx['UNK'] for char in string][2:len(string) - 3] + [char_to_idx['<end>']]
        if len(numerical_indices_a) < max_length_a:
            numerical_indices_a += [char_to_idx['PAD']] * (max_length_a - len(numerical_indices_a))
        padded_answer.append(numerical_indices_a)
        att = [1 for _ in range(0, len(string) - 3)] + [0 for _ in range(0, max_length_a - (len(string) - 3))]
        attention_mask_a.append(att)

    return Dataset_class(padded_question, attention_mask_q, padded_answer, attention_mask_a)


In [9]:
def create_dataset(data, batch_size):
    char_to_idx, idx_to_char, max_length_q, max_length_a = build_dict(data)
    traindataset = Tokenize(data['train'][:int(len(data['train']) * 0.9)], char_to_idx, max_length_q, max_length_a)
    vaaliddataset = Tokenize(data['train'][int(len(data['train']) * 0.9):], char_to_idx, max_length_q, max_length_a)
    testdataset = Tokenize(data['test'], char_to_idx, max_length_q, max_length_a)

    train_dataloaders = torch.utils.data.DataLoader(traindataset, shuffle=True, batch_size=batch_size)
    valid_dataloaders = torch.utils.data.DataLoader(vaaliddataset, shuffle=False, batch_size=batch_size)
    test_dataloaders = torch.utils.data.DataLoader(testdataset, shuffle=False, batch_size=batch_size)

    return train_dataloaders, valid_dataloaders, test_dataloaders, [char_to_idx, idx_to_char, max_length_q - 3, max_length_a - 3]


In [10]:
from datasets import DatasetDict
# Load the dataset from HuggingFace
dataset = load_dataset('math_dataset', lista_dataset[num])

# Resize the dataset on the portion needed
dataset = DatasetDict({
    "train": dataset["train"].shuffle(seed=SEED).select(range(train_samples)),
    "test": dataset["test"].shuffle(seed=SEED).select(range(test_samples))
})

# Create the dataloaders and save the parameters needed for the model
train_dataloaders, valid_dataloaders, test_dataloaders, params = create_dataset(dataset, BATCH_SIZE)


Downloading builder script:   0%|          | 0.00/8.40k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/102k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/24.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1999998 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [11]:
import pickle
with open("/paramas.pkl", 'wb') as f:
    pickle.dump(params, f)

# Basile model: LSTM

This is a baseline model to evaluate better the performance of the proposed model (Our solution). This is a very basic solution that employes an lstm to predict the answer

## Model

In [12]:
class LSTMmodel(LightningModule):
    def __init__(self, char_to_idx, idx_to_char, question_max_length=68, answer_max_length=10, batch_size=128,
                 embedding_size=512, hidden_units=2048, num_encoder_layers=6, padding_idx=2, p=0.):
        super().__init__()

        self.pad_token = "PAD"
        self.dict_size = len(char_to_idx)
        self.question_max_length = question_max_length
        self.answer_max_length = answer_max_length
        self.padding_idx = padding_idx
        self.idx_to_char = idx_to_char
        self.dropout_probability = p
        self.batch_size = batch_size
        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.num_encoder_layers = num_encoder_layers

        self.embedding = nn.Embedding(self.dict_size, self.embedding_size, padding_idx=self.padding_idx)
        self.lstm = nn.LSTM(input_size=self.embedding_size, hidden_size=self.hidden_units,
                            num_layers=self.num_encoder_layers, dropout=self.dropout_probability, batch_first=True)
        self.output_layer = nn.Linear(self.hidden_units, self.dict_size)
        self.list_loss = []
        self.list_accuracy = []
    def forward(self, batch):
        input_tensor = batch[0]
        embedded = self.embedding(input_tensor)
        h0 = torch.zeros(self.num_encoder_layers, embedded.size(0), self.hidden_units).to(embedded.device)
        c0 = torch.zeros(self.num_encoder_layers, embedded.size(0), self.hidden_units).to(embedded.device)
        lstm_out, _ = self.lstm(embedded, (h0, c0))
        sequence_length = batch[2].size(1)
        lstm_out = lstm_out[:, :sequence_length, :]
        logits = self.output_layer(lstm_out)

        return logits

    def training_step(self, batch, _):
        batch_answers = batch[2][:, 1:].flatten(0, 1)
        batch[2] = batch[2][:, :-1]
        pred = self(batch).flatten(0, 1)

        loss = F.cross_entropy(pred, batch_answers, ignore_index=self.padding_idx)
        print(loss.item())
        self.list_loss.append(loss.item())
        return loss

    def validation_step(self, batch, _):
        batch_answers = batch[2]
        pred = self.predict(batch)

        accuracy = evaluate_accuracy(pred, batch_answers)
        print("val_tot_accuracy", accuracy)
        self.list_accuracy.append(accuracy)
        return accuracy

    def test_step(self, batch, _):
        batch_answers = batch[2]
        pred = self.predict(batch.copy())

        accuracy = evaluate_accuracy(pred, batch_answers)
        print("test_tot_accuracy", accuracy)
        return accuracy

    def predict(self, batch):
        batch[2] = torch.tensor([[1] for j in range(len(batch[2]))], device=device).long()

        for i in range(self.answer_max_length-1):
            transformer_result = self(batch)
            batch[2] = torch.tensor([], device=device).long()

            for j in range(len(transformer_result)):
                predicted_chars = torch.argmax(transformer_result[j], 1)

                if len(predicted_chars.shape) == 1:
                    predicted_chars = predicted_chars.unsqueeze(0)

                start_line_char = torch.tensor([1], device=device).unsqueeze(0).long()
                predicted_chars = torch.cat((start_line_char, predicted_chars), 1)
                batch[2] = torch.cat((batch[2], predicted_chars), 0)

        return F.one_hot(batch[2], num_classes=self.dict_size)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4, betas=(0.9, 0.995))
    def get_loss(self):
        # Questo metodo restituirà il valore di self.data
        return self.list_loss

In [13]:
#Define the model
model_LSTM = LSTMmodel(params[0], params[1], params[2], params[3], BATCH_SIZE)

## Training


In [None]:
#Train the model
trainer = pl.Trainer(max_epochs=30)
#Train the model using the train_dataloaders
trainer.fit(model_LSTM, train_dataloaders, valid_dataloaders)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name         | Type      | Params
-------------------------------------------
0 | embedding    | Embedding | 25.1 K
1 | lstm         | LSTM      | 188 M 
2 | output_layer | Linear    | 100 K 
-------------------------------------------
188 M     Trainable params
0         Non-trainable params
188 M     Total params
755.870   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


val_tot_accuracy (0.0, 0.0)
val_tot_accuracy (0.0, 0.0)


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (29) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

3.8969576358795166
3.8879988193511963
3.8735616207122803
3.86171817779541
3.8468563556671143
3.827470302581787
3.803588628768921
3.7659928798675537
3.70888352394104
3.5773258209228516
3.335123300552368
3.1225593090057373
2.9140169620513916
2.862685441970825
2.9516003131866455
2.9510395526885986
2.9019458293914795
2.9311227798461914
2.7999625205993652
2.746664047241211
2.640627145767212
2.606916666030884
2.6422128677368164
2.562087297439575
2.558417558670044
2.5037543773651123
2.4494071006774902
2.362496852874756
2.356485605239868


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.2644628099173554)
val_tot_accuracy (0.0, 0.23529411764705882)
val_tot_accuracy (0.0, 0.25396825396825395)
val_tot_accuracy (0.0, 0.26666666666666666)
2.2752394676208496
2.2071900367736816
2.1892879009246826
2.240363121032715
2.223520278930664
2.169602155685425
2.0998027324676514
2.1081554889678955
2.0350356101989746
1.9462220668792725
1.9674466848373413
1.9876042604446411
1.9634841680526733
1.8905251026153564
1.8880850076675415
1.938786506652832
1.9244765043258667
1.829635500907898
1.7823525667190552
1.8324154615402222
1.7534596920013428
1.7209539413452148
1.692206621170044
1.6848784685134888
1.8077044486999512
1.6555531024932861
1.8435536623001099
1.846274733543396
1.3032797574996948


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.816026210784912
1.713876724243164
1.6584751605987549
1.804100751876831
1.7820377349853516
1.7609074115753174
1.6934468746185303
1.729027271270752
1.7194445133209229
1.7666676044464111
1.8371943235397339
1.7529979944229126
1.7536664009094238
1.7123386859893799
1.7541567087173462
1.636962890625
1.7370193004608154
1.798683524131775
1.7799930572509766
1.6509528160095215
1.8809757232666016
1.683017611503601
1.6406686305999756
1.690755009651184
1.8660529851913452
1.7582881450653076
1.7956258058547974
1.9281078577041626
1.9284149408340454


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.6981475353240967
1.6665372848510742
1.802316665649414
1.7932217121124268
1.7631683349609375
1.6691852807998657
1.7512894868850708
1.7820581197738647
1.767864465713501
1.705924153327942
1.7822779417037964
1.6804797649383545
1.7858428955078125
1.760151982307434
1.6877132654190063
1.8427448272705078
1.858712911605835
1.6973365545272827
1.7898973226547241
1.8328181505203247
1.7464303970336914
1.7276111841201782
1.7086050510406494
1.6885210275650024
1.5466127395629883
1.7708295583724976
1.7576528787612915
1.818656086921692
1.8207186460494995


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.5980507135391235
1.6981536149978638
1.7586427927017212
1.5042617321014404
1.7115172147750854
1.7720893621444702
1.6469244956970215
1.9171489477157593
1.7058264017105103
1.790114164352417
1.8101434707641602
1.611451268196106
1.6151447296142578
1.740618109703064
1.7782381772994995
1.7801284790039062
1.759332537651062
1.6698271036148071
1.7305426597595215
1.7616959810256958
1.7413649559020996
1.7523493766784668
1.8131983280181885
1.9023692607879639
1.7248932123184204
1.7706485986709595
1.7982659339904785
1.7502145767211914
1.6843948364257812


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7644753456115723
1.63496732711792
1.7157279253005981
1.7023341655731201
1.7041198015213013
1.7760061025619507
1.6625219583511353
1.7658472061157227
1.677363634109497
1.6867812871932983
1.7397347688674927
1.7331267595291138
1.795647382736206
1.6901752948760986
1.7984665632247925
1.7895842790603638
1.5960440635681152
1.8085695505142212
1.6724185943603516
1.6473498344421387
1.7471060752868652
1.6982412338256836
1.7267580032348633
1.8042982816696167
1.9183769226074219
1.6586185693740845
1.764817237854004
1.9511479139328003
1.6108300685882568


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7095664739608765
1.7961777448654175
1.7604016065597534
1.656821608543396
1.7754727602005005
1.720383882522583
1.6987650394439697
1.7597479820251465
1.5531319379806519
1.6455297470092773
1.794670820236206
1.723116397857666
1.8007630109786987
1.8408842086791992
1.6847667694091797
1.8113845586776733
1.85149347782135
1.744337797164917
1.660198450088501
1.811233401298523
1.8016704320907593
1.6968637704849243
1.7357118129730225
1.7725234031677246
1.6669400930404663
1.8675764799118042
1.7164479494094849
1.682442545890808
2.120819330215454


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7035863399505615
1.7654740810394287
1.6510976552963257
1.7332595586776733
1.7538751363754272
1.8646214008331299
1.599143147468567
1.7844266891479492
1.7352988719940186
1.715777039527893
1.6291335821151733
1.6890172958374023
1.764297366142273
1.7509845495224
1.7419639825820923
1.8383989334106445
1.7881519794464111
1.757515788078308
1.6151297092437744
1.6731959581375122
1.6668431758880615
1.6862307786941528
1.7377156019210815
1.6672872304916382
1.8019684553146362
1.671613335609436
1.8665130138397217
1.7479876279830933
1.8853702545166016


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7214486598968506
1.6250519752502441
1.6387214660644531
1.7332489490509033
1.8552089929580688
1.7196943759918213
1.7988340854644775
1.6890820264816284
1.5731278657913208
1.776436686515808
1.64327871799469
1.804667353630066
1.8368358612060547
1.762991189956665
1.6111091375350952
1.7446582317352295
1.7762939929962158
1.7139253616333008
1.717331051826477
1.5511821508407593
1.770087718963623
1.9019474983215332
1.6919745206832886
1.7008204460144043
1.6902345418930054
1.7775920629501343
1.706153154373169
1.6850076913833618
1.5618537664413452


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7937010526657104
1.6680155992507935
1.6933205127716064
1.688450574874878
1.638924479484558
1.70357084274292
1.7075822353363037
1.7995338439941406
1.7820874452590942
1.690475583076477
1.8194572925567627
1.7645008563995361
1.6901417970657349
1.6435487270355225
1.7246041297912598
1.646053671836853
1.7173480987548828
1.6340700387954712
1.7054810523986816
1.6825188398361206
1.814186453819275
1.765315055847168
1.7696127891540527
1.787921667098999
1.7546273469924927
1.7361587285995483
1.8363381624221802
1.623977541923523
1.7556740045547485


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7832649946212769
1.757527470588684
1.7919957637786865
1.6513400077819824
1.8422236442565918
1.7528489828109741
1.6870038509368896
1.812006950378418
1.7539199590682983
1.6930075883865356
1.761580467224121
1.705617904663086
1.6535652875900269
1.7757785320281982
1.7696034908294678
1.6129764318466187
1.7199691534042358
1.6885242462158203
1.7471102476119995
1.652410626411438
1.6236602067947388
1.688844919204712
1.7017470598220825
1.7035984992980957
1.724880337715149
1.7374976873397827
1.6997666358947754
1.7597756385803223
1.695858120918274


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7324879169464111
1.6824756860733032
1.6837081909179688
1.7140575647354126
1.7118005752563477
1.7069859504699707
1.7990206480026245
1.7089548110961914
1.716727614402771
1.854631781578064
1.843241572380066
1.6917924880981445
1.735866904258728
1.7640506029129028
1.6711628437042236
1.7232149839401245
1.6765739917755127
1.8127832412719727
1.8473879098892212
1.616601824760437
1.842806100845337
1.7517191171646118
1.652291178703308
1.7580770254135132
1.6520975828170776
1.6727104187011719
1.6878494024276733
1.6396872997283936
1.4787062406539917


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.684814453125
1.745964765548706
1.700502634048462
1.767661690711975
1.7751531600952148
1.7633028030395508
1.8080453872680664
1.7049434185028076
1.7199822664260864
1.798871397972107
1.6820796728134155
1.7401645183563232
1.734605312347412
1.7376865148544312
1.684941053390503
1.6079808473587036
1.670154094696045
1.6739732027053833
1.785433053970337
1.668479084968567
1.7896645069122314
1.7849947214126587
1.6079864501953125
1.7278627157211304
1.704791784286499
1.7913391590118408
1.6786750555038452
1.6849009990692139
1.600299596786499


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7956268787384033
1.7284451723098755
1.7710789442062378
1.65638267993927
1.8400826454162598
1.738875150680542
1.6691783666610718
1.6994720697402954
1.6974196434020996
1.7085676193237305
1.7785365581512451
1.6268343925476074
1.7569382190704346
1.70641028881073
1.8135825395584106
1.7063450813293457
1.8325610160827637
1.8756020069122314
1.7168267965316772
1.7169419527053833
1.657342791557312
1.5830645561218262
1.6899889707565308
1.749876618385315
1.7842185497283936
1.6648179292678833
1.7632156610488892
1.672512173652649
1.4063602685928345


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.6448553800582886
1.7115141153335571
1.7322925329208374
1.741705298423767
1.6726089715957642
1.7327383756637573
1.6816524267196655
1.740859866142273
1.7110331058502197
1.8625845909118652
1.7832305431365967
1.6899056434631348
1.8168861865997314
1.7822935581207275
1.8249765634536743
1.7668415307998657
1.5585018396377563
1.831036925315857
1.6338777542114258
1.713582992553711
1.733622670173645
1.645047903060913
1.727766990661621
1.6474381685256958
1.74905264377594
1.7312296628952026
1.6793617010116577
1.6484402418136597
1.9893810749053955


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.6936407089233398
1.7135288715362549
1.6493974924087524
1.5893456935882568
1.742882490158081
1.7044224739074707
1.5993648767471313
1.6787528991699219
1.8249084949493408
1.6364208459854126
1.781744122505188
1.6496657133102417
1.7809665203094482
1.6325817108154297
1.83611261844635
1.6600911617279053
1.748839020729065
1.7674607038497925
1.8445175886154175
1.7008095979690552
1.7242528200149536
1.811877965927124
1.7106564044952393
1.6631088256835938
1.6630303859710693
1.7480261325836182
1.832345962524414
1.7087846994400024
1.6116794347763062


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.722086787223816
1.7794557809829712
1.8141316175460815
1.6888772249221802
1.6342570781707764
1.682554006576538
1.8356986045837402
1.6640785932540894
1.5922623872756958
1.6897002458572388
1.8092960119247437
1.6391842365264893
1.6947107315063477
1.7444981336593628
1.7610716819763184
1.8837637901306152
1.6654607057571411
1.7269995212554932
1.5955369472503662
1.5948312282562256
1.7040940523147583
1.7531806230545044
1.7108509540557861
1.749360203742981
1.784529209136963
1.8709635734558105
1.671657919883728
1.7013245820999146
1.504860520362854


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.6730055809020996
1.6784071922302246
1.7064207792282104
1.7125589847564697
1.9360191822052002
1.5220837593078613
1.7918411493301392
1.6533757448196411
1.77715265750885
1.6971641778945923
1.7172050476074219
1.7044204473495483
1.6669135093688965
1.7170568704605103
1.7318724393844604
1.7355303764343262
1.7622320652008057
1.7632073163986206
1.7933104038238525
1.8148393630981445
1.7374045848846436
1.672385811805725
1.6815584897994995
1.631148099899292
1.7796804904937744
1.6306605339050293
1.7385469675064087
1.810896396636963
1.873198390007019


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7255597114562988
1.84595787525177
1.7176564931869507
1.6448246240615845
1.7164368629455566
1.6504825353622437
1.7236759662628174
1.5920628309249878
1.7618893384933472
1.8096840381622314
1.744449257850647
1.6507569551467896
1.7233322858810425
1.676469087600708
1.7264469861984253
1.6919140815734863
1.7766612768173218
1.6798588037490845
1.7669823169708252
1.7388370037078857
1.695871114730835
1.6502796411514282
1.7045092582702637
1.643044352531433
1.7185993194580078
1.8118336200714111
1.9075976610183716
1.6766510009765625
1.975657343864441


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.6416034698486328
1.6518796682357788
1.663169503211975
1.6473898887634277
1.7319377660751343
1.8063926696777344
1.7873589992523193
1.6555354595184326
1.6126492023468018
1.635686993598938
1.5731934309005737
1.7671257257461548
1.7537082433700562
1.7850672006607056
1.6731082201004028
1.6739088296890259
1.625775933265686
1.7242417335510254
1.7079179286956787
1.6139373779296875
1.74942946434021
1.9004886150360107
1.82137930393219
1.7797294855117798
1.7005162239074707
1.743730902671814
1.804053783416748
1.7705399990081787
1.9843666553497314


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.675279140472412
1.6914591789245605
1.6253629922866821
1.6743327379226685
1.6224745512008667
1.7083429098129272
1.7097759246826172
1.6295546293258667
1.7961598634719849
1.7524112462997437
1.6807066202163696
1.7614867687225342
1.6370580196380615
1.646901249885559
1.710123062133789
1.703669548034668
1.7200545072555542
1.770705223083496
1.789122223854065
1.7568854093551636
1.7313134670257568
1.8587892055511475
1.8595298528671265
1.8028725385665894
1.6249319314956665
1.7622984647750854
1.7744224071502686
1.677443504333496
1.6649476289749146


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7431706190109253
1.7214711904525757
1.6626023054122925
1.8110042810440063
1.6791917085647583
1.6280975341796875
1.733545184135437
1.7040784358978271
1.7994952201843262
1.8318551778793335
1.757673740386963
1.6971540451049805
1.6832044124603271
1.7258026599884033
1.841215968132019
1.7135423421859741
1.6432496309280396
1.6574769020080566
1.713495135307312
1.672706961631775
1.678814172744751
1.7252757549285889
1.7102781534194946
1.7182142734527588
1.6189414262771606
1.7278379201889038
1.818221926689148
1.691207766532898
1.809334397315979


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7170003652572632
1.7888225317001343
1.629233717918396
1.6806212663650513
1.7764456272125244
1.6587048768997192
1.7062894105911255
1.7424572706222534
1.7173793315887451
1.7227349281311035
1.708547592163086
1.7277486324310303
1.7065887451171875
1.7881050109863281
1.8796056509017944
1.5716125965118408
1.7598716020584106
1.7420856952667236
1.6667711734771729
1.6299574375152588
1.7299758195877075
1.701576590538025
1.7357709407806396
1.7229949235916138
1.7711145877838135
1.7458950281143188
1.7104617357254028
1.6945308446884155
1.5564794540405273


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.5693717002868652
1.7612370252609253
1.7041431665420532
1.6221822500228882
1.8274106979370117
1.6498843431472778
1.6189124584197998
1.693996548652649
1.815364956855774
1.649567723274231
1.7277377843856812
1.6772600412368774
1.7458381652832031
1.815657615661621
1.6946136951446533
1.7031210660934448
1.6911953687667847
1.8013453483581543
1.7020524740219116
1.743507981300354
1.6704459190368652
1.7166008949279785
1.7462691068649292
1.7137919664382935
1.6219350099563599
1.726859450340271
1.8100053071975708
1.7747962474822998
1.9831737279891968


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7129138708114624
1.7814046144485474
1.622707486152649
1.6935865879058838
1.7394627332687378
1.764604091644287
1.6993536949157715
1.68552565574646
1.7190444469451904
1.716978907585144
1.7124601602554321
1.769168734550476
1.7118995189666748
1.7836438417434692
1.7243620157241821
1.7776129245758057
1.8170101642608643
1.7245906591415405
1.6782305240631104
1.6573421955108643
1.7036864757537842
1.6694871187210083
1.7197527885437012
1.7019593715667725
1.7582285404205322
1.776970624923706
1.6108416318893433
1.8064465522766113
1.4454364776611328


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.744168758392334
1.7048094272613525
1.7058656215667725
1.6242725849151611
1.7620083093643188
1.672094702720642
1.6423417329788208
1.7086776494979858
1.6816186904907227
1.6246416568756104
1.7376461029052734
1.7001523971557617
1.7086344957351685
1.7529786825180054
1.7297306060791016
1.6417397260665894
1.7201941013336182
1.7662851810455322
1.6399933099746704
1.7788612842559814
1.704483151435852
1.7616479396820068
1.7507202625274658
1.7579072713851929
1.7327344417572021
1.7243480682373047
1.7296088933944702
1.7649658918380737
1.7686551809310913


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.6219964027404785
1.6345981359481812
1.8314083814620972
1.718103051185608
1.7703561782836914
1.643235683441162
1.7122725248336792
1.7847105264663696
1.7854580879211426
1.742811679840088
1.73508882522583
1.7281023263931274
1.7156153917312622
1.7082535028457642
1.6794511079788208
1.749292254447937
1.7897958755493164
1.8316245079040527
1.7426806688308716
1.7092543840408325
1.7203545570373535
1.744147777557373
1.6206759214401245
1.8181883096694946
1.579842448234558
1.6886515617370605
1.696829915046692
1.6646168231964111
1.993947148323059


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.7126041650772095
1.7722980976104736
1.6553688049316406
1.7312949895858765
1.663194179534912
1.6890183687210083
1.663509726524353
1.6624895334243774
1.7787718772888184
1.8799036741256714
1.6997743844985962
1.6512444019317627
1.6758073568344116
1.8331094980239868
1.7834193706512451
1.6014641523361206
1.6592665910720825
1.6335698366165161
1.6603124141693115
1.7737518548965454
1.7231712341308594
1.6617710590362549
1.6798114776611328
1.79102623462677
1.6996790170669556
1.8334022760391235
1.6699769496917725
1.742565393447876
2.0712246894836426


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.6986850500106812
1.7207006216049194
1.5762450695037842
1.6122294664382935
1.739147424697876
1.6467193365097046
1.69852614402771
1.7099339962005615
1.630467414855957
1.778730034828186
1.788041353225708
1.7643738985061646
1.7597448825836182
1.7059258222579956
1.7320225238800049
1.7414520978927612
1.7720154523849487
1.7344906330108643
1.8013993501663208
1.6616111993789673
1.7041575908660889
1.6381571292877197
1.7344719171524048
1.7884294986724854
1.6936160326004028
1.5367827415466309
1.842029094696045
1.8302178382873535
1.5101709365844727


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)
1.6663438081741333
1.7911977767944336
1.7343666553497314
1.709606409072876
1.6431143283843994
1.7283122539520264
1.77446711063385
1.7981071472167969
1.6423132419586182
1.687936544418335
1.6739435195922852
1.6592034101486206
1.724266529083252
1.7953318357467651
1.7309556007385254
1.7600703239440918
1.6485469341278076
1.7507280111312866
1.719417691230774
1.8026560544967651
1.618649959564209
1.7704342603683472
1.6720045804977417
1.7633750438690186
1.697740077972412
1.728188157081604
1.7393819093704224
1.6290329694747925
1.817699909210205


Validation: |          | 0/? [00:00<?, ?it/s]

val_tot_accuracy (0.0, 0.36363636363636365)
val_tot_accuracy (0.0, 0.38235294117647056)
val_tot_accuracy (0.0, 0.35714285714285715)
val_tot_accuracy (0.0, 0.3333333333333333)


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


In [None]:
#Save the model
torch.save(model_LSTM.state_dict(), "/content/drive/MyDrive/NewCheckpoints/model_checkpoint_LSTM10.pth")

RuntimeError: ignored

# State of the art: TP Transformar

## model


This is the state of the art model for this task right now. It is made with a Transformer architecture and based on SelfAttention mechanism

In [None]:
class TotalEmbeddings(LightningModule):
    def __init__(self, dict_size, sizes, embedding_size, padding_idx=2, p=0.1):
        super().__init__()

        self.dict_size = dict_size
        self.max_length = max(sizes[0], sizes[1])
        self.embedding_size = embedding_size
        self.padding_idx = padding_idx

        self.dropout_probability = p
        self.scale = math.sqrt(self.embedding_size)
        self.embedding_layer = nn.Embedding(self.dict_size, self.embedding_size, padding_idx=self.padding_idx)
        self.dropout = nn.Dropout(p=self.dropout_probability)

    def forward(self, input):
        input = input.to(self.embedding_layer.weight.device)
        char_embedding = self.embedding_layer(input) * self.scale

        # Sinusoidal Positional Embeddings Computation
        positional_embedding = self._initialize_pos_embedding(char_embedding.size(1)).to(char_embedding.device)

        return self.dropout(char_embedding + positional_embedding)

    def _initialize_pos_embedding(self, max_len):
        positional_embeddings = torch.zeros(max_len, self.embedding_size, device=self.embedding_layer.weight.device)

        for pos in range(max_len):
            for i in range(0, self.embedding_size, 2):
                positional_embeddings[pos, i] = math.sin(pos / (10000 ** (i / self.embedding_size)))
                positional_embeddings[pos, i + 1] = math.cos(pos / (10000 ** (i / self.embedding_size)))

        return positional_embeddings


In [None]:
class PositionwiseFFN(LightningModule):
    def __init__(self, embedding_size, hidden_units):
        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_units = hidden_units

        self.linear_1 = nn.Linear(self.embedding_size, self.hidden_units)
        self.linear_2 = nn.Linear(self.hidden_units, self.embedding_size)
        self.relu = nn.ReLU()

    def forward(self, input):
        output = self.linear_1(input)
        output = self.relu(output)
        output = self.linear_2(output)
        return output


In [None]:
class Decoder(LightningModule):

    def __init__(self, embedding_size, hidden_units, num_layers, p_AT, p=0.1):
        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.num_layers = num_layers
        self.dropout_probability = p
        self.p_attention = p_AT

        self.layers = nn.ModuleList([DecoderLayer(self.embedding_size,  self.hidden_units, p_AT= self.p_attention, p=self.dropout_probability) for _ in range(self.num_layers)])

    def forward(self, encoder_batch, outputs_batch, src_padding_mask, tgt_padding_mask):
        decoder_layer_output = outputs_batch

        for decoder_layer in self.layers:
            decoder_layer_output = decoder_layer(encoder_batch, decoder_layer_output, src_padding_mask, tgt_padding_mask)

        return decoder_layer_output


In [None]:
class DecoderLayer(LightningModule):

    def __init__(self, embedding_size, hidden_units, p_AT, p=0.1):
        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.dropout_probability = p
        self.p_attention = p_AT

        self.ffn = PositionwiseFFN(self.embedding_size, self.hidden_units)
        self.selfAttn = SelfAttention(self.p_attention)
        self.layer_norm_1 = nn.LayerNorm(self.embedding_size, device=device)
        self.layer_norm_2 = nn.LayerNorm(self.embedding_size, device=device)
        self.layer_norm_3 = nn.LayerNorm(self.embedding_size, device=device)
        self.layer_norm_4 = nn.LayerNorm(self.embedding_size, device=device)
        self.densefilter = PositionwiseFeedforward(hid_dim=self.p_attention.d_x,
                                                   pf_dim=self.p_attention.d_r,
                                                   dropout=self.p_attention.dropout)
        self.dropout = nn.Dropout(p=self.dropout_probability)

    def forward(self, src, trg, src_mask, trg_mask):
        z = self.layer_norm_2(trg)
        z = self.selfAttn(z, z, z, trg_mask)
        z = self.dropout(z)
        trg = trg + z

        # Encoder attention
        z = self.layer_norm_2(trg)
        z = self.selfAttn(z, src, src, src_mask)
        z = self.dropout(z)
        trg = trg + z

        # Dense filter
        z = self.layer_norm_3(trg)
        z = self.densefilter(z)
        z = self.dropout(z)
        trg = trg + z

        return self.layer_norm_4(trg)


In [None]:
class Encoder(LightningModule):

    def __init__(self, embedding_size, hidden_units, num_layers, p_AT, p=0.1):
        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.num_layers = num_layers
        self.dropout_probability = p
        self.p_attention = p_AT

        self.layers = nn.ModuleList([EncoderLayer(self.embedding_size, self.hidden_units, self.p_attention, p=self.dropout_probability) for _ in range(self.num_layers)])

    def forward(self, batch, src_padding_mask):

        encoder_layer_output = batch

        for encoder_layer in self.layers:
            encoder_layer_output = encoder_layer(encoder_layer_output, src_padding_mask)

        return encoder_layer_output


In [None]:
class EncoderLayer(LightningModule):

    def __init__(self, embedding_size, hidden_units, p_AT, p=0.1):
        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.dropout_probability = p
        self.p_attention = p_AT

        self.selfAttn = SelfAttention(self.p_attention)
        self.layer_norm_1 = nn.LayerNorm(self.embedding_size)
        self.dense_filter = PositionwiseFeedforward(self.embedding_size, self.hidden_units, self.dropout_probability)
        self.layer_norm_2 = nn.LayerNorm(self.embedding_size)
        self.dropout = nn.Dropout(p=self.dropout_probability)
        self.layer_norm_3 = nn.LayerNorm(self.embedding_size)

    def forward(self, src, src_mask):
        # Sublayer 1
        z = self.layer_norm_1(src)
        z = self.selfAttn(z, z, z, src_mask)
        z = self.dropout(z)
        src = src + z

        # Sublayer 2
        z = self.layer_norm_2(src)
        z = self.dense_filter(z)
        z = self.dropout(z)
        src = src + z

        return self.layer_norm_3(src)


In [None]:
class PositionwiseFeedforward(LightningModule):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()

        self.hid_dim = hid_dim
        self.pf_dim = pf_dim

        self.linear1 = nn.Linear(hid_dim, pf_dim)
        self.linear2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)

        self.reset_parameters()

    def forward(self, x):
        x = self.linear1(x)
        x = self.dropout(F.relu(x))
        x = self.linear2(x)

        return x

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.linear1.weight)
        nn.init.xavier_uniform_(self.linear2.weight)


In [None]:
class SelfAttention(pl.LightningModule):

    def __init__(self, p):
        super().__init__()

        self.p = p
        self.d_h = p.d_x
        self.n_I = p.n_I

        self.W_q = nn.Linear(self.d_h, p.d_q * p.n_I)
        self.W_k = nn.Linear(self.d_h, p.d_k * p.n_I)
        self.W_v = nn.Linear(self.d_h, p.d_v * p.n_I)
        self.W_r = nn.Linear(self.d_h, p.d_r * p.n_I)

        self.W_o = nn.Linear(p.d_v * p.n_I, p.d_x)

        self.dropout = nn.Dropout(p.dropout)
        self.dot_scale = torch.FloatTensor([math.sqrt(p.d_k)])
        self.mul_scale = torch.FloatTensor([1.0 / math.sqrt(math.sqrt(2) - 1)])

    def forward(self, query, key, value, mask=None):
        bsz = query.shape[0]

        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        R = self.W_r(query)

        Q = Q.view(bsz, -1, self.n_I, self.p.d_q).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_I, self.p.d_k).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_I, self.p.d_v).permute(0, 2, 1, 3)
        R = R.view(bsz, -1, self.n_I, self.p.d_r).permute(0, 2, 1, 3)
        dot = torch.einsum('bhid,bhjd->bhij', Q, K) / self.dot_scale.to(key.device)

        attention = self.dropout(F.softmax(dot, dim=-1))

        v_bar = torch.einsum('bhjd,bhij->bhid', V, attention)

        new_v = v_bar * R
        new_v = new_v.permute(0, 2, 1, 3).contiguous()

        new_v = new_v.view(bsz, -1, self.n_I * self.p.d_v)

        x = self.W_o(new_v)

        return x

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W_q.weight)
        nn.init.xavier_uniform_(self.W_k.weight)
        nn.init.xavier_uniform_(self.W_v.weight)
        nn.init.xavier_uniform_(self.W_o.weight)
        nn.init.normal_(self.W_r.weight, mean=0, std=1.0 / math.sqrt(self.p.d_r))


In [None]:
class SelfAttentionConfig:
    def __init__(self, d_x, d_q, d_k, d_v, d_r, dropout, n_I):
        self.d_x = d_x
        self.d_q = d_q
        self.d_k = d_k
        self.d_v = d_v
        self.d_r = d_r
        self.dropout = dropout
        self.n_I = n_I


In [None]:
class TP_Transformer_SelfAttention(LightningModule):
    def __init__(self, char_to_idx, idx_to_char, question_max_length=68, answer_max_length=10, batch_size=128,
                 embedding_size=512, hidden_units=2048, num_heads=8, num_encoder_layers=6, num_decoder_layers=6,
                 padding_idx=2, p=0.):
        super().__init__()

        self.pad_token = "PAD"
        self.dict_size = len(char_to_idx)
        self.question_max_length = question_max_length
        self.answer_max_length = answer_max_length
        self.padding_idx = padding_idx
        self.idx_to_char = idx_to_char
        self.dropout_probability = p
        self.batch_size = batch_size
        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.num_heads = num_heads
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.p_attention = SelfAttentionConfig(d_x=512, d_q=64, d_k=64, d_v=64, d_r=64, dropout=self.dropout_probability, n_I=8)

        self.total_embedding = TotalEmbeddings(self.dict_size, [self.question_max_length, self.answer_max_length],
                                               self.embedding_size, self.padding_idx, p=self.dropout_probability)
        self.encoder = Encoder(self.embedding_size, hidden_units=self.hidden_units, num_layers=self.num_encoder_layers,
                               p=self.dropout_probability, p_AT=self.p_attention)
        self.decoder = Decoder(self.embedding_size, hidden_units=self.hidden_units, num_layers=self.num_decoder_layers,
                               p=self.dropout_probability, p_AT=self.p_attention)

    def forward(self, batch):
        questions = (batch[0]).to(device)
        answers = (batch[2]).to(device)

        source_pad_mask = (batch[1]).to(device)
        target_pad_mask = (batch[3]).to(device)

        embedded_questions = self.total_embedding(questions)
        embedded_answers = self.total_embedding(answers)
        encoder_output = self.encoder(embedded_questions, src_padding_mask=source_pad_mask)
        decoder_output = self.decoder(encoder_output, embedded_answers, src_padding_mask=source_pad_mask,
                                       tgt_padding_mask=target_pad_mask)
        return torch.matmul(decoder_output, torch.transpose(self.total_embedding.embedding_layer.weight, 0, 1))

    def training_step(self, batch, _):
        batch_answers = batch[2][:, 1:].flatten(0, 1)
        batch[2] = batch[2][:, :-1]
        pred = self(batch).flatten(0, 1)

        loss = F.cross_entropy(pred, batch_answers, ignore_index=self.padding_idx)
        return loss

    def validation_step(self, batch, _):
        batch_answers = batch[2]
        pred = self.predict(batch)

        accuracy = evaluate_accuracy(pred, batch_answers)
        print("Validation accuracy: ", accuracy)
        return accuracy

    def test_step(self, batch, _):
        batch_answers = batch[2]
        pred = self.predict(batch.copy())

        accuracy = evaluate_accuracy(pred, batch_answers)
        print("Test accuracy: ", accuracy)
        return accuracy

    def predict(self, batch):
        batch[2] = torch.tensor([[1] for j in range(len(batch[2]))], device=device).long()

        for i in range(self.answer_max_length - 1):
            transformer_result = self(batch)
            batch[2] = torch.tensor([], device=device).long()

            for j in range(len(transformer_result)):
                predicted_chars = torch.argmax(transformer_result[j], 1)

                if len(predicted_chars.shape) == 1:
                    predicted_chars = predicted_chars.unsqueeze(0)

                start_line_char = torch.tensor([1], device=device).unsqueeze(0).long()
                predicted_chars = torch.cat((start_line_char, predicted_chars), 1)

                batch[2] = torch.cat((batch[2], predicted_chars), 0)

        return F.one_hot(batch[2], num_classes=self.dict_size)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4, betas=(0.9, 0.995))



In [None]:
#Define the model
model_TP = TP_Transformer_SelfAttention(params[0], params[1], params[2], params[3], BATCH_SIZE)

## Training


In [None]:
#Train the model
trainer = pl.Trainer(max_epochs=30)
#Train the model using the train_dataloaders
trainer.fit(model_TP, train_dataloaders, valid_dataloaders)

In [None]:
#Save the model
torch.save(model_TP.state_dict(), "/content/drive/MyDrive/NewCheckpoints/model_checkpoint_TPnew4.pth")

# Our solution: Transformer with Hybrid attention

This is the state of the art model for this task right now. It is made with a Transformer architecture and based on SelfAttention mechanism

## Model

In [None]:
class TotalEmbeddings(LightningModule):
    def __init__(self, dict_size, sizes, embedding_size, padding_idx=2, p=0.1):
        super().__init__()

        self.dict_size = dict_size
        self.max_length = max(sizes[0], sizes[1])
        self.embedding_size = embedding_size
        self.padding_idx = padding_idx
        self.dropout_probability = p

        self.scale = math.sqrt(self.embedding_size)
        self.embedding_layer = nn.Embedding(self.dict_size, self.embedding_size, padding_idx=self.padding_idx)
        self.dropout = nn.Dropout(p=self.dropout_probability)

    def forward(self, input):
        input = input.to(self.embedding_layer.weight.device)
        char_embedding = self.embedding_layer(input) * self.scale

        # Sinusoidal Positional Embeddings Computation
        positional_embedding = self._initialize_pos_embedding(char_embedding.size(1)).to(char_embedding.device)

        return self.dropout(char_embedding + positional_embedding)

    def _initialize_pos_embedding(self, max_len):
        positional_embeddings = torch.zeros(max_len, self.embedding_size, device=self.embedding_layer.weight.device)

        for pos in range(max_len):

            for i in range(0, self.embedding_size, 2):
                positional_embeddings[pos, i] = math.sin(pos / (10000 ** (i / self.embedding_size)))
                positional_embeddings[pos, i + 1] = math.cos(pos / (10000 ** (i / self.embedding_size)))

        return positional_embeddings


In [None]:
class PositionwiseFFN(LightningModule):
    def __init__(self, embedding_size, hidden_units):
        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_units = hidden_units

        self.linear_1 = nn.Linear(self.embedding_size, self.hidden_units)
        self.linear_2 = nn.Linear(self.hidden_units, self.embedding_size)
        self.relu = nn.ReLU()

    def forward(self, input):
        output = self.linear_1(input)
        output = self.relu(output)
        output = self.linear_2(output)
        return output


In [None]:
class Decoder(LightningModule):
    def __init__(self, embedding_size, hidden_units, num_layers, p_AT, p=0.1):
        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.num_layers = num_layers
        self.dropout_probability = p
        self.p_attention = p_AT

        self.layers = nn.ModuleList([DecoderLayer(self.embedding_size, self.hidden_units, p_AT=self.p_attention, p=self.dropout_probability) for _ in range(self.num_layers)])

    def forward(self, encoder_batch, outputs_batch, src_padding_mask, tgt_padding_mask):
        decoder_layer_output = outputs_batch

        for decoder_layer in self.layers:
            decoder_layer_output = decoder_layer(encoder_batch, decoder_layer_output, src_padding_mask, tgt_padding_mask)

        return decoder_layer_output


In [None]:
class DecoderLayer(LightningModule):
    def __init__(self, embedding_size, hidden_units, p_AT, p=0.1):
        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.dropout_probability = p
        self.p_attention = p_AT

        self.ffn = PositionwiseFFN(self.embedding_size, self.hidden_units)
        self.selfAttn = SelfAttention(self.p_attention)
        self.layer_norm_1 = nn.LayerNorm(self.embedding_size, device=device)
        self.layer_norm_2 = nn.LayerNorm(self.embedding_size, device=device)
        self.layer_norm_3 = nn.LayerNorm(self.embedding_size, device=device)
        self.layer_norm_4 = nn.LayerNorm(self.embedding_size, device=device)
        self.densefilter = PositionwiseFeedforward(hid_dim=self.p_attention.d_x,
                                                   pf_dim=self.p_attention.d_r,
                                                   dropout=self.p_attention.dropout)
        self.dropout = nn.Dropout(p=self.dropout_probability)

    def forward(self, src, trg, src_mask, trg_mask):
        z = self.layer_norm_2(trg)
        z = self.selfAttn(z, z, z, trg_mask)
        z = self.dropout(z)
        trg = trg + z
        z = self.layer_norm_2(trg)
        z = self.selfAttn(z, src, src, src_mask)
        z = self.dropout(z)
        trg = trg + z
        z = self.layer_norm_3(trg)
        z = self.densefilter(z)
        z = self.dropout(z)
        trg = trg + z

        return self.layer_norm_4(trg)


In [None]:
class Encoder(LightningModule):
    def __init__(self, embedding_size, hidden_units, num_layers, p_AT, p=0.1):
        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.num_layers = num_layers
        self.dropout_probability = p
        self.p_attention = p_AT

        self.layers = nn.ModuleList([EncoderLayer(self.embedding_size, self.hidden_units, self.p_attention, p=self.dropout_probability) for _ in range(self.num_layers)])

    def forward(self, batch, src_padding_mask):
        encoder_layer_output = batch

        for encoder_layer in self.layers:
            encoder_layer_output = encoder_layer(encoder_layer_output, src_padding_mask)

        return encoder_layer_output


In [None]:
class EncoderLayer(LightningModule):
    def __init__(self, embedding_size, hidden_units, p_AT, p=0.1):
        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.dropout_probability = p
        self.p_attention = p_AT

        self.selfAttn = HybridAttention(self.p_attention, self.embedding_size)
        self.layer_norm_1 = nn.LayerNorm(self.embedding_size)
        self.dense_filter = PositionwiseFeedforward(self.embedding_size, self.hidden_units, self.dropout_probability)
        self.layer_norm_2 = nn.LayerNorm(self.embedding_size)
        self.dropout = nn.Dropout(p=self.dropout_probability)
        self.layer_norm_3 = nn.LayerNorm(self.embedding_size)

    def forward(self, src, src_mask):

        z = self.layer_norm_1(src)
        z = self.selfAttn(z, z, z, src_mask)
        z = self.dropout(z)
        src = src + z
        z = self.layer_norm_2(src)
        z = self.dense_filter(z)
        z = self.dropout(z)
        src = src + z

        return self.layer_norm_3(src)


In [None]:
class PositionwiseFeedforward(LightningModule):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()

        self.hid_dim = hid_dim
        self.pf_dim = pf_dim

        self.linear1 = nn.Linear(hid_dim, pf_dim)
        self.linear2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)

        self.reset_parameters()

    def forward(self, x):
        x = self.linear1(x)
        x = self.dropout(F.relu(x))
        x = self.linear2(x)

        return x

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.linear1.weight)
        nn.init.xavier_uniform_(self.linear2.weight)


In [None]:
class HybridAttention(pl.LightningModule):
    def __init__(self, attention, embedding_size, num_heads=8, masked=False):
        super().__init__()

        self.embedding_size = embedding_size
        self.num_heads = num_heads
        attention.d_x
        self.key_embedding_size = attention.d_x
        self.masked = masked
        self.dk = self.embedding_size // self.num_heads

        # Initialize the MultiHeadAttention and the new attention mechanism
        self.multi_head_attention = MultiHeadAttention(self.embedding_size, self.num_heads, self.key_embedding_size, self.masked)
        self.new_attention_mechanism = AdditiveAttention(self.embedding_size, self.key_embedding_size)

        # Learnable parameter for weighted combination
        self.alpha = nn.Parameter(torch.tensor(0.5))

    def forward(self, batch_keys, batch_queries, batch_values, trg_mask):
        # Calculate attention from the MultiHeadAttention
        multi_head_attention_result = self.multi_head_attention(batch_keys, batch_queries, batch_values, trg_mask)

        # Calculate attention from the new attention mechanism
        new_attention_result = self.new_attention_mechanism(batch_keys, batch_queries, batch_values, trg_mask)

        # Combine the results with a learnable parameter to weight them
        hybrid_attention_result = self.alpha * multi_head_attention_result + (1 - self.alpha) * new_attention_result.sum(dim=1, keepdim=True)

        return hybrid_attention_result


class MultiHeadAttention(pl.LightningModule):
    def __init__(self, embedding_size, num_heads, key_embedding_size, masked=False):
        super().__init__()

        self.embedding_size = embedding_size
        self.num_heads = num_heads
        self.key_embedding_size = key_embedding_size
        self.masked = masked
        self.dk = self.embedding_size // self.num_heads

        self.attention_heads = nn.ModuleList([AttentionHead(self.dk, self.embedding_size, self.key_embedding_size, self.masked) for _ in range(self.num_heads)])

    def forward(self, batch_keys, batch_queries, batch_values, padding_mask):
        attention_heads_results = []
        for head in self.attention_heads:
            attention_heads_results.append(head(batch_keys, batch_queries, batch_values, padding_mask))

        # Sum the results from different attention heads
        return torch.sum(torch.stack(attention_heads_results), dim=0)



class AttentionHead(pl.LightningModule):
    def __init__(self, dk, embedding_size, key_embedding_size, masked):
        super().__init__()

        self.dk = dk
        self.masked = masked
        self.embedding_size = embedding_size
        self.key_embedding_size = key_embedding_size

        self.Wk = nn.Linear(key_embedding_size, dk, bias=False)
        self.Wq = nn.Linear(embedding_size, dk, bias=False)
        self.Wv = nn.Linear(embedding_size, dk, bias=False)
        self.Wr = nn.Linear(embedding_size, dk, bias=False)
        self.Wo = nn.Linear(dk, embedding_size, bias=False)

    def forward(self, batch_keys, batch_queries, batch_values, padding_mask):
        # Project keys, queries, and values
        keys = self.Wk(batch_keys)
        queries = self.Wq(batch_queries)
        values = self.Wv(batch_values)
        roles = self.Wr(batch_queries)
        # Calculate unnormalized attention scores
        unnormalized_attention_score = torch.matmul(queries, keys.transpose(1, 2)) / math.sqrt(self.dk)

        if self.masked:
            input_len_1 = batch_queries.shape[1]
            input_len_2 = batch_keys.shape[1]
            target_mask = torch.triu(torch.full((input_len_1, input_len_2), float("-inf"), device=device), diagonal=1)
            unnormalized_attention_score = unnormalized_attention_score + target_mask

        # Apply padding mask
        padding_mask = padding_mask.bool()
        padding_mask = padding_mask.unsqueeze(1).repeat(1, batch_queries.shape[1], 1)
        unnormalized_attention_score = unnormalized_attention_score.masked_fill_(padding_mask, float("-inf"))
        attention_score = F.softmax(unnormalized_attention_score, 2)
        filler = torch.matmul(attention_score, values)

        # Apply the output linear layer with role interaction
        return self.Wo(filler * roles)


class AdditiveAttention(pl.LightningModule):
    def __init__(self, query_embedding_size, key_embedding_size):
        super().__init__()

        self.query_embedding_size = query_embedding_size
        self.key_embedding_size = key_embedding_size

        self.Wq = nn.Linear(query_embedding_size, query_embedding_size)
        self.Wk = nn.Linear(key_embedding_size, key_embedding_size)
        self.V = nn.Linear(query_embedding_size, 1)

    def forward(self, batch_keys, batch_queries, batch_values, padding_mask):
        # Project queries and keys to the same embedding size
        queries = self.Wq(batch_queries)
        keys = self.Wk(batch_keys)
        additive_scores = self.V(torch.tanh(queries + keys)).squeeze(dim=-1)
        if padding_mask is not None:
            padding_mask = padding_mask.bool()
            additive_scores = additive_scores.masked_fill(padding_mask, float("-inf"))

        # Compute attention scores and the weighted sum of values
        attention_scores = F.softmax(additive_scores / math.sqrt(self.query_embedding_size), dim=-1)
        weighted_sum = torch.matmul(attention_scores, batch_values)

        return weighted_sum


In [None]:
class SelfAttention(pl.LightningModule):
    def __init__(self, p):
        super().__init__()

        self.p = p
        self.d_h = p.d_x
        self.n_I = p.n_I

        self.W_q = nn.Linear(self.d_h, p.d_q * p.n_I)
        self.W_k = nn.Linear(self.d_h, p.d_k * p.n_I)
        self.W_v = nn.Linear(self.d_h, p.d_v * p.n_I)
        self.W_r = nn.Linear(self.d_h, p.d_r * p.n_I)

        self.W_o = nn.Linear(p.d_v * p.n_I, p.d_x)

        self.dropout = nn.Dropout(p.dropout)
        self.dot_scale = torch.FloatTensor([math.sqrt(p.d_k)])
        self.mul_scale = torch.FloatTensor([1.0 / math.sqrt(math.sqrt(2) - 1)])

    def forward(self, query, key, value, mask=None):
        bsz = query.shape[0]

        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        R = self.W_r(query)

        Q = Q.view(bsz, -1, self.n_I, self.p.d_q).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_I, self.p.d_k).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_I, self.p.d_v).permute(0, 2, 1, 3)
        R = R.view(bsz, -1, self.n_I, self.p.d_r).permute(0, 2, 1, 3)
        dot = torch.einsum('bhid,bhjd->bhij', Q, K) / self.dot_scale.to(key.device)
        attention = self.dropout(F.softmax(dot, dim=-1))
        v_bar = torch.einsum('bhjd,bhij->bhid', V, attention)
        new_v = v_bar * R
        new_v = new_v.permute(0, 2, 1, 3).contiguous()

        new_v = new_v.view(bsz, -1, self.n_I * self.p.d_v)

        x = self.W_o(new_v)

        return x

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W_q.weight)
        nn.init.xavier_uniform_(self.W_k.weight)
        nn.init.xavier_uniform_(self.W_v.weight)
        nn.init.xavier_uniform_(self.W_o.weight)
        nn.init.normal_(self.W_r.weight, mean=0, std=1.0 / math.sqrt(self.p.d_r))


In [None]:
class SelfAttentionConfig:
    def __init__(self, d_x, d_q, d_k, d_v, d_r, dropout, n_I):
        self.d_x = d_x
        self.d_q = d_q
        self.d_k = d_k
        self.d_v = d_v
        self.d_r = d_r
        self.dropout = dropout
        self.n_I = n_I


In [None]:
class Transformer(LightningModule):
    def __init__(self, char_to_idx, idx_to_char, question_max_length=68, answer_max_length=10, batch_size=128,
                 embedding_size=512, hidden_units=2048, num_heads=8, num_encoder_layers=6, num_decoder_layers=6,
                 padding_idx=2, p=0.):
        super().__init__()

        self.pad_token = "PAD"
        self.dict_size = len(char_to_idx)
        self.question_max_length = question_max_length
        self.answer_max_length = answer_max_length
        self.padding_idx = padding_idx
        self.idx_to_char = idx_to_char
        self.dropout_probability = p
        self.batch_size = batch_size
        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.num_heads = num_heads
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.p_attention = SelfAttentionConfig(d_x=512, d_q=64, d_k=64, d_v=64, d_r=64, dropout=self.dropout_probability, n_I=8)

        self.total_embedding = TotalEmbeddings(self.dict_size, [self.question_max_length, self.answer_max_length],
                                               self.embedding_size, self.padding_idx, p=self.dropout_probability)
        self.encoder = Encoder(self.embedding_size, hidden_units=self.hidden_units, num_layers=self.num_encoder_layers,
                               p=self.dropout_probability, p_AT=self.p_attention)
        self.decoder = Decoder(self.embedding_size, hidden_units=self.hidden_units, num_layers=self.num_decoder_layers,
                               p=self.dropout_probability, p_AT=self.p_attention)

    def forward(self, batch):
        questions = (batch[0]).to(device)
        answers = (batch[2]).to(device)

        source_pad_mask = (batch[1]).to(device)
        target_pad_mask = (batch[3]).to(device)

        embedded_questions = self.total_embedding(questions)

        embedded_answers = self.total_embedding(answers)

        encoder_output = self.encoder(embedded_questions, src_padding_mask=source_pad_mask)

        decoder_output = self.decoder(encoder_output, embedded_answers, src_padding_mask=source_pad_mask,
                                       tgt_padding_mask=target_pad_mask)

        return torch.matmul(decoder_output, torch.transpose(self.total_embedding.embedding_layer.weight, 0, 1))

    def training_step(self, batch, _):
        batch_answers = batch[2][:, 1:].flatten(0, 1)
        batch[2] = batch[2][:, :-1]
        pred = self(batch).flatten(0, 1)

        loss = F.cross_entropy(pred, batch_answers, ignore_index=self.padding_idx)
        return loss

    def validation_step(self, batch, _):
        batch_answers = batch[2]

        pred = self.predict(batch)
        print(pred)
        accuracy = evaluate_accuracy(pred, batch_answers)
        print("Validation accuracy: ", accuracy)
        return accuracy

    def test_step(self, batch, _):
        batch_answers = batch[2]
        pred = self.predict(batch.copy())

        accuracy = evaluate_accuracy(pred, batch_answers)
        print("Test accuracy: ", accuracy)
        return accuracy

    def predict(self, batch):
        batch[2] = torch.tensor([[1] for j in range(len(batch[2]))], device=device).long()

        for i in range(self.answer_max_length - 1):
            transformer_result = self(batch)
            batch[2] = torch.tensor([], device=device).long()

            for j in range(len(transformer_result)):
                predicted_chars = torch.argmax(transformer_result[j], 1)

                if len(predicted_chars.shape) == 1:
                    predicted_chars = predicted_chars.unsqueeze(0)

                start_line_char = torch.tensor([1], device=device).unsqueeze(0).long()
                predicted_chars = torch.cat((start_line_char, predicted_chars), 1)

                batch[2] = torch.cat((batch[2], predicted_chars), 0)

        return F.one_hot(batch[2], num_classes=self.dict_size)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4, betas=(0.9, 0.995))



In [None]:
#Define the model
model = Transformer(params[0], params[1], params[2], params[3], BATCH_SIZE)

## Training

In [None]:
#Train the model
trainer = pl.Trainer(max_epochs=30)
#Train the model using the train_dataloaders
trainer.fit(model, train_dataloaders, valid_dataloaders)

In [None]:
#Save the model
torch.save(model.state_dict(), "/content/drive/MyDrive/NewCheckpoints/model_checkpoint_TP2.pth")

# Download pretrained weights

In [None]:
!gdown 1-3N-VGHo4sjJxqFXwDGMWe_mSsF2hHk_ -O /content/lstm_baseline.ckpt11
!gdown 1-GknyRsaGTSFgUvxjIJ2qMP1jNcY6ta9 -O /content/lstm_baseline_params.pkl1
#!gdown 1-JMz-uwBTGg7obOkJiW8m9hGH9TzfnPX -O /content/lstm_baseline.ckpt12

#!gdown 1-1ZTuDRNO4XW_ZCHZpcu5Qf1zgrPxRtn -O /content/lstm_baseline_params.pkl3
#!gdown 1-04MIfMdRN56lbOMy4B-J-OeD50gR1-S -O /content/lstm_baseline.ckpt32

In [None]:
#!gdown 1-VobZkuYbLgju0zs0ehJ6o9apXb7X6BW -O /content/transformer_our_model.ckpt
#!gdown 1-wjp76oP-YQMBowpMlC5G2c4fFLGl6rC -O /content/transformer_our_model_params.pkl1

!gdown 1-SgDaJEY_vl3UnSk5GM1-UE5-NTVln17 -O /content/transformer_our_model_params.pkl2
!gdown 1-SL8qHMdrCXHtAmS9L3dIw5JxY_1N6u0 -O /content/transformer_our_model.ckpt2

#!gdown 1-7C16v8WQggSvZEIYO-n-mOxDHiAli5S -O /content/transformer_our_model_params.pkl3
#!gdown 1-4A-x2guysHk8Jq3Rg1qq4wXZLJbVOzk -O /content/transformer_our_model.ckpt3

In [None]:
#!gdown 1-4A-x2guysHk8Jq3Rg1qq4wXZLJbVOzk -O /content/tp_transformer.ckpt1
#!gdown 1-7C16v8WQggSvZEIYO-n-mOxDHiAli5S -O /content/tp_transformer_params.pkl1
#!gdown 1-LIma3I3RgwEIZyjevez90fC5MJquWo- -O /content/tp_transformer_params.pkl2f
#!gdown 1-AkKDtd25uPV4JVODz2AwGVan4t9f6wq -O /content/tp_transformer.ckpt21
#!gdown 1-IOVAF8xa8qy6H-KDvzAm1-ligamv5Ae -O /content/tp_transformer.ckpt22

#!gdown 1-jvpEDLcdt8ul8Euk3hFmt0AiabSp_-b -O /content/tp_transformer.ckpt31
#!gdown 1-608w9B9BZIMiP9PGsHd-8BxNbvWFon6 -O /content/tp_transformer_params.pkl3
#!gdown 1-DX3uG5LJd-LipedxGO4R-JD9HE-re_ls -O /content/tp_transformer.ckpt32


!gdown 1-LIma3I3RgwEIZyjevez90fC5MJquWo- -O /content/tp_transformer_params.pkl12


!gdown 1-AkKDtd25uPV4JVODz2AwGVan4t9f6wq -O /content/tp_transformer_checkpoint.pkl13
#!gdown 1-IOVAF8xa8qy6H-KDvzAm1-ligamv5Ae -O /content/tp_transformer_checkpoint.pkl14
#!gdown 1-SL8qHMdrCXHtAmS9L3dIw5JxY_1N6u0 -O /content/tp_transformer_checkpoint.pkl15

# Testing


In [None]:
# Choose your model
choice = 2
trainer = pl.Trainer(max_epochs=10)
import pickle
params = None
if choice == 1:
    with open("/content/lstm_baseline_params.pkl1", 'rb') as file:
        params = pickle.load(file)
    model_LSTM.load_state_dict(torch.load("/content/lstm_baseline.ckpt11"), strict=False)
    model_LSTM.eval()
    trainer.test(model_LSTM, test_dataloaders)
elif choice == 2:
    with open("/content/tp_transformer_params.pkl12", 'rb') as file:
        params = pickle.load(file)
    model_TP.load_state_dict(torch.load("/content/tp_transformer_checkpoint.pkl13"), strict=False)
    model_TP.eval()
    trainer.test(model_TP, test_dataloaders)
elif choice == 3:
    with open("/content/transformer_our_model_params.pkl3", 'rb') as file:
        params = pickle.load(file)
    model.load_state_dict(torch.load("/content/transformer_our_model.ckpt3"), strict=False)
    model.eval()
    trainer.test(model, test_dataloaders)
else:
    print("No model selected")




# Prediction

Before running the code it is necessary to previously run the section where the selected model is defined

In [None]:
choice = 2
trainer = pl.Trainer(max_epochs=10)
import pickle
params = None
if choice == 1:
    with open("/content/lstm_baseline_params.pkl1", 'rb') as file:
        params = pickle.load(file)
    model_LSTM.load_state_dict(torch.load("/content/lstm_baseline.ckpt11"), strict=False)
    model_LSTM.eval()
elif choice == 2:
    with open("/content/tp_transformer_params.pkl12", 'rb') as file:
        params = pickle.load(file)
    model_TP.load_state_dict(torch.load("/content/tp_transformer_checkpoint.pkl13"), strict=False)
    model_TP.eval()
elif choice == 3:
    with open("/content/transformer_our_model_params.pkl3", 'rb') as file:
        params = pickle.load(file)
    model.load_state_dict(torch.load("/content/transformer_our_model.ckpt3"), strict=False)
    model.eval()
else:
    print("No model selected")

### Prediction functions


In [None]:
def tokenize( input,char_to_idx,question_max_length):
  question = []
  question.append(1)  # adding start of line char

  for char in input:
    if char not in char_to_idx.keys():
      question.append(1)
    else:
      question.append(char_to_idx[char])  # turning character into index

  question.append(2)  # adding end of line char

  for i in range(question_max_length-len(input)-2):
      question.append(0)  # adding padding after

  question = torch.tensor(question, device=device).unsqueeze(0)

  return {"question": question}

In [None]:
import numpy as np

def create_attention_mask(input_tensor):
    """
    Create a fully connected attention mask for a given tensor.

    Args:
        input_tensor (numpy.ndarray): The input tensor.

    Returns:
        numpy.ndarray: The attention mask with 1s indicating positions to attend to.
    """
    N, M = input_tensor.shape
    attention_mask = np.ones((N, M), dtype=int)
    return attention_mask


In [None]:
def encode_question(inputs1,inputs2,inputs3,inputs4):

    tesst = Dataset_class3(inputs1,inputs2,inputs3,inputs4)
    dataloaders = torch.utils.data.DataLoader(tesst, shuffle=False, batch_size=16)

    for batch in dataloaders:
      pred = model_LSTM.predict(batch.copy())
    predicted = torch.argmax(pred, dim=2)

    return predicted


In [None]:
def tensor_to_string(tensor, idx_to_char):
    # Extract the elements from the tensor and map them to characters
    char_list = [idx_to_char[idx.item()] for idx in tensor.squeeze()]
    # Join the characters to form the string
    return ''.join(char_list)




In [None]:
class Dataset_class3:
  def __init__(self, inputs1,inputs2,inputs3,inputs4):
    self.input_ids = inputs1
    self.attention_mask = inputs2
    self.answer_ids = inputs3
    self.answer_attention_mask = inputs4

  def __len__(self):
    return len(self.input_ids)

  def __getitem__(self, item):
    return [torch.tensor(self.input_ids[item], dtype=torch.long),
            torch.tensor(self.attention_mask[item], dtype=torch.long),
            torch.tensor(self.answer_ids[item], dtype=torch.long),
            torch.tensor(self.answer_attention_mask[item], dtype=torch.long)]

In [None]:
def answerquestion(question):


  string_question = tokenize( question,params[0],params[2])
  string_answer = tokenize( "b'-19\n'",params[0],params[3])
  TensorPredictedAnswer = encode_question(string_question["question"],create_attention_mask(string_question["question"]),string_answer["question"],create_attention_mask(string_answer["question"]))
  decoded_answer = tensor_to_string(TensorPredictedAnswer, params[1])
  return decoded_answer

In [None]:
Example_inthe_dataset = 22
test = dataset["test"][Example_inthe_dataset]

In [None]:
string_question = tokenize( test["question"],params[0],params[2])
string_answer = tokenize( test["answer"],params[0],params[3])


In [None]:
TensorPredictedAnswer = encode_question(string_question["question"],create_attention_mask(string_question["question"]),string_answer["question"],create_attention_mask(string_answer["question"]))

In [None]:
# Translate the tensors to strings
decoded_answer = tensor_to_string(TensorPredictedAnswer, params[1])

### Preditction

In [None]:
print("Question:",dataset["test"]["question"][Example_inthe_dataset])
print("Real answer:",dataset["test"]["answer"][Example_inthe_dataset])
print("Predicted answer:",decoded_answer)

In [None]:
user_question = input("Type equation: ")
variable = input("For which varible?: ")
user_question_modifed = "b'Solve "+user_question+" for "+variable+ ".\n'"

answer = answerquestion(user_question_modifed)
print("The answer to your question is",answer)