<a href="https://colab.research.google.com/github/luigiantonelli/DeepLearning-Project/blob/main/Deep_Learning_Project_Antonelli_Cuconasu_Gaudenzi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installations and imports

In [1]:
!pip install pytorch-lightning --quiet
!pip install torchmetrics --quiet
!pip install gdown==4.5.4 --no-cache-dir --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m827.8/827.8 KB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 KB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m27.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m158.8/158.8 KB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m264.6/264.6 KB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.2/114.2 KB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.2/199.2 KB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import glob
import math
import pickle
from typing import *
from datetime import datetime

import gdown
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split

import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.progress import TQDMProgressBar

# For reproducibility
seed_everything(10, workers=True)

INFO:lightning_fabric.utilities.seed:Global seed set to 10


10

In [3]:
# url = "https://drive.google.com/drive/folders/1LrGmpT6nVvcWOk-gy656xlFqmH8fIY7k?usp=sharing"
# gdown.download_folder(url=url, quiet=True, use_cookies=False, remaining_ok=True)

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
#dataset_folder_path = "/content/drive/MyDrive/Colab Notebooks/Deep Learning/DeepLearningProject-Shared"
dataset_folder_path = "/content/drive/MyDrive/Deep_Learning_Project"
# dataset_folder_path = "/content/DeepLearning-Shared"
os.chdir(dataset_folder_path)

In [5]:
!ls

datasets		  modules.txt  prova.txt
mathematics_dataset-v1.0  prova2.txt   training


# Vocabulary

In this section we analyzed all the dataset files to retrieve the characters that will compose the vocabulary. Indeed, we wanted to be sure that our vocabulary contains all the files characters regardless the module we are working on.

Moreover, after this pre-processing phase we decided to add the special token `<unk>` (i.e., unknown). Thus, if during inference we are using characters that are not in the vocabulary, we are still able to pre-processes the input, since whathever unknown character is replaced by that special token.  

In [6]:
def read_dataset(text_path: str, lowercase: bool=True) -> Tuple[List[str], List[str]]:
    questions = []
    answers = []
    with open(text_path, 'r') as f:
        for idx, line in enumerate(f):
            row = line.rstrip().lower() if lowercase else line.rstrip()
            # Questions
            if idx % 2 == 0:
                questions.append(row) 
            # Answers
            else: 
                answers.append(row)
    return questions, answers

In [7]:
def get_vocabulary(lists_of_texts: List[List[str]]) -> Set[str]:
    unified_text = []
    
    for l in lists_of_texts:
        unified_text += l

    return Counter(" ".join(unified_text)).keys()

In [8]:
# Get all files
folders = ['extrapolate', 'interpolate', 'train-easy', 'train-medium', 'train-hard']
files = []

for fold in folders:
    files += glob.glob(f"./mathematics_dataset-v1.0/{fold}/*.txt")

In [9]:
files[:5]

['./mathematics_dataset-v1.0/extrapolate/arithmetic__add_sub_multiple_longer.txt',
 './mathematics_dataset-v1.0/extrapolate/algebra__polynomial_roots_big.txt',
 './mathematics_dataset-v1.0/extrapolate/arithmetic__add_or_sub_big.txt',
 './mathematics_dataset-v1.0/extrapolate/arithmetic__div_big.txt',
 './mathematics_dataset-v1.0/extrapolate/arithmetic__mul_div_multiple_longer.txt']

In [10]:
def get_files_vocabulary(files: List[str], save: bool=False) -> List[str]:
    vocabulary = {}
    all_lists = []

    i = 0
    for f in files:
        train, test = read_dataset(f)
        all_lists += train
        all_lists += test
        
        # Set union
        vocabulary |= get_vocabulary(all_lists)
        all_lists = []

        # Save the vocabulary up to now
        if save and i % 10 == 0:
            vocabulary = sorted(list(vocabulary))
            with open('./datasets/pre_vocabulary.pkl', 'wb') as f:
                pickle.dump(vocabulary, f)

    # Save sorted vocabulary
    vocabulary = sorted(list(vocabulary))
    with open('./datasets/pre_vocabulary.pkl', 'wb') as f:
        pickle.dump(vocabulary, f)

    return vocabulary

This operation requires quite a bit of time (~ 25 min), as we are scanning all the files. So, it is commented to avoid executing it.

    vocabulary = get_files_vocabulary(files)

In [11]:
def create_vocabulary_from_set(voc):
    vocabulary = {'<pad>': 0, '<bos>': 1, '<eos>': 2, '<unk>': 3}
    i = 4
    for v in voc:
        vocabulary[v] = i
        i += 1
    return vocabulary

In [12]:
with open('./datasets/vocabulary.pkl', 'rb') as f:
    vocabulary = pickle.load(f)

In [13]:
len(vocabulary)

54

In [14]:
v = create_vocabulary_from_set(vocabulary)

In [15]:
v

{'<pad>': 0,
 '<bos>': 1,
 '<eos>': 2,
 '<unk>': 3,
 ' ': 4,
 '!': 5,
 "'": 6,
 '(': 7,
 ')': 8,
 '*': 9,
 '+': 10,
 ',': 11,
 '-': 12,
 '.': 13,
 '/': 14,
 '0': 15,
 '1': 16,
 '2': 17,
 '3': 18,
 '4': 19,
 '5': 20,
 '6': 21,
 '7': 22,
 '8': 23,
 '9': 24,
 ':': 25,
 '<': 26,
 '=': 27,
 '>': 28,
 '?': 29,
 'a': 30,
 'b': 31,
 'c': 32,
 'd': 33,
 'e': 34,
 'f': 35,
 'g': 36,
 'h': 37,
 'i': 38,
 'j': 39,
 'k': 40,
 'l': 41,
 'm': 42,
 'n': 43,
 'o': 44,
 'p': 45,
 'q': 46,
 'r': 47,
 's': 48,
 't': 49,
 'u': 50,
 'v': 51,
 'w': 52,
 'x': 53,
 'y': 54,
 'z': 55,
 '{': 56,
 '}': 57}

# Dataset

In [16]:
def get_train_module_paths(modules: List[str], difficulty: List[str]) -> List[str]:
    paths = []

    folders = ['train-easy', 'train-medium', 'train-hard']
    
    if difficulty is not None and set(difficulty).issubset(set(folders)):
        folders = difficulty

    for module in modules:
        for fold in folders:
            paths += glob.glob(f"./mathematics_dataset-v1.0/{fold}/{module}.txt")

    return paths

In [17]:
module_files = get_train_module_paths(["algebra__linear_1d", "algebra__linear_2d"], difficulty=['train-easy'])
module_files

['./mathematics_dataset-v1.0/train-easy/algebra__linear_1d.txt',
 './mathematics_dataset-v1.0/train-easy/algebra__linear_2d.txt']

In [18]:
def get_test_module_paths(modules: List[str]) -> List[str]:
    paths = []

    for module in modules:
        paths += glob.glob(f"./mathematics_dataset-v1.0/interpolate/{module}.txt")

    return paths

In [19]:
# algebra_train, algebra_test = read_all_module_files("algebra__linear_1d")

In [20]:
# len(algebra_train)

In [21]:
algebra_path = "./mathematics_dataset-v1.0/train-easy/algebra__linear_1d.txt"
probability_path = "./mathematics_dataset-v1.0/train-easy/probability__swr_p_level_set.txt"
prime_path = "./mathematics_dataset-v1.0/train-easy/numbers__is_prime.txt"

In [22]:
"""
questions_easy_algebra, answers_easy_algebra = read_dataset(algebra_path)
questions_easy_probability, answers_easy_probability = read_dataset(probability_path)
questions_easy_prime, answers_easy_prime = read_dataset(prime_path)
"""

'\nquestions_easy_algebra, answers_easy_algebra = read_dataset(algebra_path)\nquestions_easy_probability, answers_easy_probability = read_dataset(probability_path)\nquestions_easy_prime, answers_easy_prime = read_dataset(prime_path)\n'

There is no substantial difference in time between loading the entire dataset and pre-processing the data in the __getitem__ method:

40 - 60 microsec vs 200 - 300 microsec

In [23]:
# d = Mathematics_Dataset(module_files, v)

In [24]:
# %%time
# d[8]

In [25]:
class Mathematics_Dataset(Dataset):
    def __init__(self, modules_paths: List[str], vocabulary: Dict[str, int], max_len_question: int=160, max_len_answer: int=30):
        super().__init__()
        self.modules_paths = modules_paths
        
        self.questions = []
        self.answers = []
        
        for m in self.modules_paths:
            q_m, a_m = self.read_dataset(m)
            self.questions += q_m
            self.answers += a_m
        
        self.max_len_question = max_len_question
        self.max_len_answer = max_len_answer
        self.vocabulary = vocabulary

    def read_dataset(self, text_path: str, lowercase: bool=True) -> Tuple[List[str], List[str]]:
        questions = []
        answers = []
        with open(text_path, 'r') as f:
            for idx, line in enumerate(f):
                row = line.rstrip().lower() if lowercase else line.rstrip()
                # Questions
                if idx % 2 == 0:
                    questions.append(row) 
                # Answers
                else: 
                    answers.append(row)
        return questions, answers

    def convert_chars_to_ids(self, sentence: str, max_len: int) -> torch.tensor:
        sentence_ids = np.full(max_len + 2, self.vocabulary['<pad>'])

        # Start with <bos>
        sentence_ids[0] = self.vocabulary['<bos>']

        for i, char in enumerate(sentence):
            sentence_ids[i + 1] = self.vocabulary.get(char, self.vocabulary['<unk>'])
            
        # End with <eos>
        sentence_ids[len(sentence) + 1] = self.vocabulary['<eos>']

        return torch.from_numpy(sentence_ids).long()


    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        assert idx < len(self.questions)
        
        q, a = self.questions[idx], self.answers[idx]

        question = self.convert_chars_to_ids(q, self.max_len_question)
        answer = self.convert_chars_to_ids(a, self.max_len_answer)
        
        return question, answer

In [26]:
# d2 = Mathematics_Dataset(module_files, v)

In [27]:
# %%time
# d2[8]

In [28]:
class Mathematics_DataModule(pl.LightningDataModule):
    def __init__(self, modules: List[str], difficulty: List[str]=None, batch_size: int=32):
        super().__init__()
        self.modules = modules
        self.batch_size = batch_size
        self.load_vocabulary()

        self.train_modules_paths = get_train_module_paths(self.modules, difficulty)  
        self.test_modules_paths = get_test_module_paths(self.modules)        

    
    def load_vocabulary(self):
        with open('./datasets/vocabulary.pkl', 'rb') as f:
            v = pickle.load(f)
        self.vocabulary = create_vocabulary_from_set(v)

    def setup(self, stage=None):
        if stage == "fit":
            self.math_train = Mathematics_Dataset(self.train_modules_paths, self.vocabulary)
            self.math_val = Mathematics_Dataset(self.test_modules_paths, self.vocabulary)

        if stage == "test":
            self.math_test = Mathematics_Dataset(self.test_modules_paths, self.vocabulary)
    
    def train_dataloader(self):
        return DataLoader(self.math_train, batch_size=self.batch_size, shuffle=True, num_workers=2, pin_memory=True)

    def val_dataloader(self):                                                              
        return DataLoader(self.math_val, batch_size=self.batch_size, num_workers=2, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.math_test, batch_size=self.batch_size)

    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        pass

In [29]:
# dm = Mathematics_DataModule(['algebra__linear_1d'], batch_size = 64)

# Modules

In [30]:
def scaled_dot_product_attention(query, key, value, sqrt_q, mask, dropout_layer = None):
    t = torch.matmul(query, key.transpose(-2, -1)) / sqrt_q
    """
    t [batch_size, self.num_heads, query.size(-2), key.size(-2)]
    mask [batch_size, self.num_head, 1 or query.size(-2), key.size(-2)]
    """
    t = t.masked_fill(mask == False, -1e10) #-1e10 acts like -infinity, so that the softmax will consider these tokens less important
    t = F.softmax(t, dim = -1)
    if dropout_layer is not None:
        t = dropout_layer(t)
    return torch.matmul(t, value)

In [31]:
class MultiHeadAttention(nn.Module): 
    def __init__(self, embedding_dim, num_heads, dropout = 0.2, tp_attention = False):
        super(MultiHeadAttention, self).__init__()
        assert embedding_dim % num_heads == 0
        self.tp_attention = tp_attention
        self.dim_head = embedding_dim // num_heads #single head dimension
        self.sqrt_q = math.sqrt(self.dim_head)
        self.num_heads = num_heads
        self.W_q = nn.Linear(embedding_dim, embedding_dim, bias = True) #stack of num_heads matrices of dimension (d, dim_head), one for each head
        self.W_k = nn.Linear(embedding_dim, embedding_dim, bias = True)
        self.W_v = nn.Linear(embedding_dim, embedding_dim, bias = True)
        self.W_o = nn.Linear(embedding_dim, embedding_dim, bias = True)
        if self.tp_attention:
            self.W_r = nn.Linear(embedding_dim, embedding_dim, bias = True) #ruolo
        
        self.dropout = nn.Dropout(dropout)
        #self.dropout = nn.Dropout(0.15)

        self._init_weights()

    
    def _init_weights(self):
        nn.init.xavier_uniform_(self.W_q.weight)
        nn.init.xavier_uniform_(self.W_k.weight)
        nn.init.xavier_uniform_(self.W_v.weight)
        nn.init.xavier_uniform_(self.W_o.weight)

        if self.tp_attention:
            nn.init.normal_(self.W_r.weight, mean=0, std=1./self.sqrt_q)


    def forward(self, query, key, value, mask): #query, key, value
        batch_size = query.size(0)

        q = self.W_q(query).view(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
        k = self.W_k(key).view(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
        v = self.W_v(value).view(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)

        """
        in the encoder:
            q,k,v ([batch_size, self.num_heads, max_len_question, self.dim_head])
            mask (src_mask): [batch_size, 1, 1, max_len_question]

        in the decoder (MASKED MULTI-HEAD ATTENTION):
            seq_len = current_len_answer if inference else max_len_answer
                q,k,v ([batch_size, self.num_head, seq_len, self.dim_head])
                mask (trg_mask): [batch_size, 1, seq_len, current_len_answer]
                
        in the decoder (MULTI-HEAD ATTENTION):
            seq_len = current_len_answer if inference else max_len_answer
                q ([batch_size, self.num_head, seq_len, self.dim_head])
                k,v ([batch_size, self.num_head, max_len_question, self.dim_head])
                mask (src_mask): [batch_size, 1, 1, max_len_question]
        """

        attention_value = scaled_dot_product_attention(q, k, v, self.sqrt_q, mask, self.dropout)
            #attention_value ([batch_size, self.num_heads, q.size(-2), v.size(-1)])

        
        if self.tp_attention:
            role = self.W_r(query).view(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
            attention_value *= role  #element-wise product between attention value and role before the final projection
        return self.W_o(attention_value.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads*self.dim_head))
            #output : ([batch_size, q.size(-2)=query.size(-2),embedding_dim)])

In [32]:
class TransformerBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads, hidden_size = None, dropout=0.2, tp_attention = False):
        super(TransformerBlock, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.attention = MultiHeadAttention(embedding_dim, num_heads, dropout, tp_attention)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)
        
        hidden_size = 4*embedding_dim if hidden_size is None else hidden_size
        self.ff = nn.Sequential(nn.Linear(embedding_dim, hidden_size, bias = True), 
                                nn.ReLU(),
                                nn.Dropout(dropout),
                                nn.Linear(hidden_size, embedding_dim, bias = True))
        self.dropout2 = nn.Dropout(dropout)
        self._init_weights()


    def _init_weights(self):
        for p in self.ff:
            if isinstance(p, nn.Linear):
                nn.init.xavier_uniform_(p.weight)
                if p.bias is not None:
                    nn.init.constant_(p.bias, 0)


    def forward(self, query, key, value, mask): #query, key, value
        """
        if this is a TransformerBlock of the encoder:
            query, key, value = x ([batch_size, max_len_question, embedding_dim])

        if this is a TransformerBlock of the decoder:
            seq_len = current_len_answer if inference else max_len_answer
            MASKED MULTI HEAD ATTENTION:
                query, key, value = y ([batch_size, seq_len, embedding_dim])
            MULTI HEAD ATTENTION:
                query: ([batch_size, seq_len, embedding_dim])
                key, value: ([batch_size, max_len_question, embedding_dim])

        """

        x = query + self.dropout1(self.attention(query, key, value, mask)) #query as res conn because the decoder block requires it and it doesn't matter for encoder blocks
        x = self.norm1(x)
        x = x + self.dropout2(self.ff(x))
        x = self.norm2(x)

        return x

In [33]:
class DecoderBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads, hidden_size, dropout = 0.2, tp_attention = False):
        super(DecoderBlock, self).__init__()
        self.masked_attention = MultiHeadAttention(embedding_dim, num_heads, dropout, tp_attention)
        self.norm = nn.LayerNorm(embedding_dim)
        self.dropout = nn.Dropout(dropout)
        self.transformer_block = TransformerBlock(embedding_dim, num_heads, hidden_size, dropout, tp_attention)

    def forward(self, output_encoder, src_mask, y, trg_mask):
        y = y + self.dropout(self.masked_attention(y, y, y, trg_mask)) #masked attention (y = query = key = value) + residual connection
        y = self.norm(y)
        return self.transformer_block(y, output_encoder, output_encoder, src_mask)#query from the masked mha and key and value from the encoder

In [34]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_len=256):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2) * -(math.log(10000.0) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + Variable(self.pe[:, :x.size(1)], requires_grad = False)

In [35]:
class TransformerEncoder(nn.Module):
    def __init__(self, embedding_dim, num_heads, hidden_size, dropout, num_blocks = 6, tp_attention = False):
        super(TransformerEncoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.encoder = nn.ModuleList(
            [TransformerBlock(embedding_dim, num_heads, hidden_size, dropout, tp_attention) for _ in range(num_blocks)]
            )

    def forward(self, x, mask): 
        # x ([batch_size, max_len_question, embedding_dim])
        for block in self.encoder:
            x = block(x, x, x, mask)

        return x

In [36]:
class TransformerDecoder(nn.Module):
    def __init__(self, embedding_dim, num_heads, hidden_size, dropout = 0.2, num_blocks = 6, tp_attention = False):
        super(TransformerDecoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.decoder = nn.ModuleList(
            [DecoderBlock(embedding_dim, num_heads, hidden_size, dropout, tp_attention) for _ in range(num_blocks)]
            )

    def forward(self, output_encoder, src_mask, y, trg_mask): 
        for block in self.decoder:
            y = block(output_encoder, src_mask, y, trg_mask)
        return y

In [67]:
class Transformer(pl.LightningModule):
    def __init__(
        self, 
        special_idxs: Dict[str, int], 
        optimizer_params: dict,
        learning_rate: float=1e-4,
        num_heads: int=4, 
        embedding_dim: int=256, 
        hidden_size: int=512, 
        vocabulary_size: int=58,
        max_len_question: int=162,
        max_len_answer: int=32,
        num_blocks_encoder: int=6, 
        num_blocks_decoder: int=6, 
        dropout: float=0.2, 
        gradient_clip_val: float=0.9,
        tp_attention: bool=False
    ):
        super(Transformer, self).__init__()
        self.save_hyperparameters()

        self.bos_id = special_idxs['<bos>']
        self.eos_id = special_idxs['<eos>']
        self.pad_id = special_idxs['<pad>']
        self.optimizer_params = optimizer_params
        self.learning_rate = learning_rate
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.vocabulary_size = vocabulary_size
        
        self.token_embedding = nn.Embedding(vocabulary_size, embedding_dim, padding_idx=self.pad_id)
        self.positional_embedding = PositionalEncoding(embedding_dim)
        self.encoder = TransformerEncoder(embedding_dim, num_heads, hidden_size, dropout, num_blocks_encoder, tp_attention)
        self.decoder = TransformerDecoder(embedding_dim, num_heads, hidden_size, dropout, num_blocks_decoder, tp_attention)
        self.to_logits = nn.Linear(embedding_dim, vocabulary_size)
        
        self.max_len_question = max_len_question
        self.max_len_answer = max_len_answer

        self.train_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=vocabulary_size, ignore_index=self.pad_id)
        self.train_accuracy2 = torchmetrics.Accuracy(task="multiclass", num_classes=vocabulary_size, ignore_index=self.pad_id)
        self.train_accuracy3 = torchmetrics.Accuracy(task="multiclass", num_classes=vocabulary_size, ignore_index=self.pad_id)
        self.val_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=vocabulary_size, ignore_index=self.pad_id)
        self.test_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=vocabulary_size, ignore_index=self.pad_id)

        self._init_weights()

    def _init_weights(self):
        nn.init.normal_(self.token_embedding.weight, 
                        mean=0, std=1./math.sqrt(self.embedding_dim))
        
        nn.init.normal_(self.to_logits.weight, 
                        mean=0, std=1./math.sqrt(self.vocabulary_size))

        if self.to_logits.bias is not None:
            nn.init.constant_(self.to_logits.bias, 0)



    def create_trg_mask(self, y): #compute a mask so that the prediction of the next token can only depend on the previous tokens
        # #[batch_size, 1, len, len] & [batch_size, 1, 1, len]
        return self.create_causal_mask(y) & self.create_padding_mask(y)


    def create_causal_mask(self, y):
        batch_size, seq_len = y.shape
        mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device = self.device)).expand(
            batch_size, 1, seq_len, seq_len)
        return mask


    def create_padding_mask(self, x):
        batch_size, seq_len = x.shape
        mask = (x != self.pad_id).unsqueeze(-2).unsqueeze(-2)
        return mask

    def greedy_decode(self,x):
        batch_size = x.size(0)
        src_mask = self.create_padding_mask(x)
        # src_mask ([batch_size, 1, 1, self.max_len_question]), 
                #la dimensione [-2] è 1 perché per ogni token della domanda la maschera è la stessa (broadcasting)
        x = self.token_embedding(x)
        x = self.positional_embedding(x)
        #x ([batch_size, self.max_len_question, self.embedding_dim])

        output_encoder = self.encoder(x, src_mask)
        #output_encoder : ([batch_size, self.max_len_question, embedding_dim]) 

        output = torch.ones(batch_size, 1, dtype=torch.int64, device = self.device).fill_(self.bos_id)
        #output: ([batch_size, 1]) 
        done = torch.zeros(batch_size, dtype = torch.uint8, device = self.device)
        for _ in range(self.max_len_answer - 1): 
            trg_mask = self.create_trg_mask(output)
            # tgr_mask ([batch_size, 1, len_current_answer, len_current_answer])

            output_embedding = self.token_embedding(output)
            output_embedding = self.positional_embedding(output_embedding)
            #output_embedding ([batch_size, len_current_answer, self.embedding_dim])

            out = self.decoder(output_encoder, src_mask, output_embedding, trg_mask)
            #out ([batch_size, len_current_answer, self.embedding_dim])
            out = self.to_logits(out)
            #out ([batch_size, len_current_answer, self.vocabulary_size])
            out = torch.argmax(out[:,[-1],:], dim = -1)
            output = torch.cat([output, out], dim = 1) #we concatenate the new token to the output answer
            eos_reached = out.squeeze(1) == self.eos_id
            done |= eos_reached
            if done.sum() == batch_size:
                break
        return output


    def inference(self, x):
        #encode and then generate the output token by token greedily
        self.eval()
        with torch.no_grad():
            return self.greedy_decode(x)

    def forward(self, x, y):
        
        # x ([batch_size, self.max_len_question])
        # y ([batch_size, self.max_len_answer])

        src_mask = self.create_padding_mask(x)

        # src_mask ([batch_size, 1, 1, self.max_len_question]),
                    #la dimensione [-3] è 1 perché successivamente viene effettuato broadcasting per ogni head della MULTI-HEAD ATTENTION 
                    #la dimensione [-2] è 1 perché per ogni token della domanda la maschera è la stessa (broadcasting)

        trg_mask = self.create_trg_mask(y)

        # tgr_mask ([batch_size, 1, self.max_len_answer-1, self.max_len_answer-1]),
                    #la dimensione [-3] è 1 perché successivamente viene effettuato broadcasting per ogni head della MULTI-HEAD ATTENTION 
                    #la dimensione [-2] è self.max_len_answer perché per ogni token della domanda la maschera è diversa (maschera causale)


        x = self.token_embedding(x)

        x = self.positional_embedding(x)

        #x ([batch_size, self.max_len_question, self.embedding_dim])

        y = self.token_embedding(y)
        y = self.positional_embedding(y)

        #y ([batch_size, self.max_len_answer-1, self.embedding_dim])

        output_encoder = self.encoder(x, src_mask)

        #output_encoder : ([batch_size, self.max_len_question, embedding_dim])        

        output_decoder = self.decoder(output_encoder, src_mask, y, trg_mask)

        #output_decoder : ([batch_size, self.max_len_answer-1, embedding_dim]))

        return self.to_logits(output_decoder).transpose(1,2)
    
    def configure_optimizers(self):# learning rate = 1x10^-4; beta1 =0.9; beta2 = 0.995 dal paper
        betas = self.optimizer_params['betas']
        return torch.optim.Adam(self.parameters(), self.learning_rate, betas)


    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x, y[:, :-1])
        loss = F.cross_entropy(y_pred, y[:, 1:], ignore_index = self.pad_id)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        self.train_accuracy.update(y_pred, y[:, 1:])
        self.log('train_accuracy_forward', self.train_accuracy.compute(), on_step=False, on_epoch=True, prog_bar=True, logger=True)

        y_pred2 = self.greedy_decode(x)  #[batch_size, max_eos_found]
        y_pred2 = F.pad(y_pred2, (0, self.max_len_answer - y_pred2.shape[1]), mode='constant', value=self.pad_id) #[batch_size, max_len_answer]
        self.train_accuracy2.update(y_pred2[:, 1:], y[:, 1:])
        self.log('train_accuracy_greedydecode', self.train_accuracy2.compute(), on_step=False, on_epoch=True, prog_bar=True, logger=True)

        y_pred3 = self.inference(x)  #[batch_size, max_eos_found]
        y_pred3 = F.pad(y_pred3, (0, self.max_len_answer - y_pred3.shape[1]), mode='constant', value=self.pad_id) #[batch_size, max_len_answer]
        self.train_accuracy3.update(y_pred3[:, 1:], y[:, 1:])
        self.log('train_accuracy_inference', self.train_accuracy3.compute(), on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return loss, self.train_accuracy.compute(), self.compute_accuracy(p.transpose(1,2),y,0)

    def compute_accuracy(self, logits, targets, pad_value):
        """
        Compute full sequence accuracy of a batch.
        :param logits: the model logits (batch_size, seq_len, out_dim)
        :param targets: the true targets (batch_size, seq_len)
        :param pad_value: PAD value used to fill end of target seqs
        :return: continous accuracy between 0.0 and 1.0
        """
        trg_shifted = targets[:, 1:]              # drop the SOS from targets
        y_hat = torch.argmax(logits, dim=-1)      # get index predictions from logits

        # count matches in batch, masking out pad values in each target
        matches = (torch.eq(trg_shifted,y_hat) | (trg_shifted==pad_value)).all(1).sum().item()
        
        acc_percent = matches / len(logits)
        return acc_percent

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.inference(x)  #[batch_size, max_eos_found]
        y_pred = F.pad(y_pred, (0, self.max_len_answer - y_pred.shape[1]), mode='constant', value=self.pad_id) #[batch_size, max_len_answer]
        self.val_accuracy.update(y_pred[:, 1:], y[:, 1:]) #y_pred, y nel caso volessimo contare <bos> come carattere corretto
        self.log('val_accuracy_step', self.val_accuracy.compute(), on_step=True, on_epoch=False, prog_bar=True, logger=True)


    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.inference(x)  #[batch_size, max_eos_found]
        y_pred = F.pad(y_pred, (0, self.max_len_answer - y_pred.shape[1]), mode='constant', value=self.pad_id) #[batch_size, max_len_answer]
        self.test_accuracy.update(y_pred[:, 1:], y[:, 1:])
        

    def validation_epoch_end(self, outputs):
        self.log('val_accuracy_epoch', self.val_accuracy.compute(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.val_accuracy.reset()
        
        # Also reset the training accuracy
        self.train_accuracy.reset()
        self.train_accuracy2.reset()
        self.train_accuracy3.reset()

    
    def test_epoch_end(self, outputs):
        self.log('test_accuracy_epoch', self.test_accuracy.compute(), on_epoch=True, prog_bar=True, logger=True)
        self.test_accuracy.reset()


In [68]:

vocabulary = v
LEARNING_RATE = 1e-4 # 2.558585886905645e-05
BATCH_SIZE = 128
EMBEDDING_DIM = 16
NUM_HEADS = 4
assert EMBEDDING_DIM % NUM_HEADS == 0
# HIDDEN_SIZE = 2048
HIDDEN_SIZE = 16
DROP_PROB = 0.5
GRADIENT_CLIP_VAL = 0.5
NUM_BLOCKS_ENCODER = 1
NUM_BLOCKS_DECODER = 1
SPECIAL_CHAR_DICT = {'<bos>': vocabulary['<bos>'], '<eos>': vocabulary['<eos>'], '<pad>': vocabulary['<pad>']}
OPTIMIZER_PARAMS = {'betas': (0.9, 0.995)}

tp_transformer_hyperparams = {
    "special_idxs": SPECIAL_CHAR_DICT,
    "optimizer_params": OPTIMIZER_PARAMS,
    "learning_rate": LEARNING_RATE,
    "num_heads": NUM_HEADS,
    "embedding_dim": EMBEDDING_DIM,
    "hidden_size": HIDDEN_SIZE,
    "vocabulary_size": len(vocabulary),
    "num_blocks_encoder": NUM_BLOCKS_ENCODER,
    "num_blocks_decoder": NUM_BLOCKS_DECODER,
    "dropout": DROP_PROB,
    "gradient_clip_val": GRADIENT_CLIP_VAL, # Added just to be saved
    "tp_attention": True
}

tp_transformer = Transformer(**tp_transformer_hyperparams)
modules = ['algebra__linear_1d']
math_dm = Mathematics_DataModule(modules, batch_size=4)


In [39]:
""" #possibile alternativa a greedy_decode
def beam_search(self, x, k = 3):
    batch_size = x.size(0)
    src_mask = self.create_padding_mask(x)
    # src_mask [batch_size, 1, 1, self.max_len_question] 
         
    x = self.token_embedding(x)
    x = self.positional_embedding(x)
    #x [batch_size, self.max_len_question, self.embedding_dim]

    output_encoder = self.encoder(x, src_mask)
    #output_encoder [batch_size, self.max_len_question, embedding_dim]

    # Create initial input for the decoder
    start = torch.ones(batch_size, 1, dtype=torch.int64, device=self.device).fill_(self.bos_id)

    sequences = [(start, 0)] * batch_size

    for _ in range(self.max_len_answer - 1): 
        candidates = [] # List of candidate sequences for each example in the batch

        for sequence, score in sequences:
            # If sequence is already ended, add it to the candidate list and continue with the next sequence
            if sequence.squeeze(1)[-1] == self.eos_id:
                candidates.append((sequence, score))
                continue

            trg_mask = self.create_trg_mask(sequence)
            # tgr_mask [batch_size, 1, len_current_answer, len_current_answer]

            output_embedding = self.token_embedding(sequence)
            output_embedding = self.positional_embedding(output_embedding)
            #output_embedding [batch_size, len_current_answer, self.embedding_dim]

            out = self.decoder(output_encoder, src_mask, output_embedding, trg_mask)
            #out [batch_size, len_current_answer, self.embedding_dim]
            out = self.to_logits(out)
            #out [batch_size, len_current_answer, self.vocabulary_size]

            # Get top-k most likely next tokens and their scores
            scores, indices = torch.topk(out[:, -1, :], k=k)
            for i in range(beam_size):
                token = indices[:, i].unsqueeze(1)
                prob = scores[:, i].unsqueeze(1)
                new_sequence = torch.cat([sequence, token], dim=1)
                new_score = score - torch.log(prob) # use log-probability to prevent underflow
                candidates.append((new_sequence, new_score))

        # Select the k best sequences for each example in the batch
        candidates = sorted(candidates, key=lambda x: x[1], reverse=True)
        sequences = candidates[:k*batch_size]

    # Select the best sequence for each example in the batch
    outputs = []
    for sequence, score in sequences:
        outputs.append(sequence)
    outputs = torch.stack(outputs, dim=0)
    return outputs
"""

' #possibile alternativa a greedy_decode\ndef beam_search(self, x, k = 3):\n    batch_size = x.size(0)\n    src_mask = self.create_padding_mask(x)\n    # src_mask [batch_size, 1, 1, self.max_len_question] \n         \n    x = self.token_embedding(x)\n    x = self.positional_embedding(x)\n    #x [batch_size, self.max_len_question, self.embedding_dim]\n\n    output_encoder = self.encoder(x, src_mask)\n    #output_encoder [batch_size, self.max_len_question, embedding_dim]\n\n    # Create initial input for the decoder\n    start = torch.ones(batch_size, 1, dtype=torch.int64, device=self.device).fill_(self.bos_id)\n\n    sequences = [(start, 0)] * batch_size\n\n    for _ in range(self.max_len_answer - 1): \n        candidates = [] # List of candidate sequences for each example in the batch\n\n        for sequence, score in sequences:\n            # If sequence is already ended, add it to the candidate list and continue with the next sequence\n            if sequence.squeeze(1)[-1] == self.e

In [75]:
acc

[(tensor(0.0833), 0.0),
 (tensor(0.0870), 0.0),
 (tensor(0.0909), 0.0),
 (tensor(0.0714), 0.0),
 (tensor(0.0962), 0.0),
 (tensor(0.0833), 0.0)]

# SOTA

In [None]:
vocabulary = v

In [None]:
name = "Luigi"

root_dir = "./training/checkpoints"
logger_dir = "./training/tensorboard/logs"
checkpoint_dir = "./training/checkpoints/PRO_" + name + "_tp_transformer_checkpoints"
#checkpoint_dir = "./training/checkpoints/Standard_Luigi_tp_transformer_checkpoints"
EPOCHS = 3
LEARNING_RATE = 1e-4 # 2.558585886905645e-05
BATCH_SIZE = 128
EMBEDDING_DIM = 64
HIDDEN_SIZE = 64
NUM_HEADS = 4
assert EMBEDDING_DIM % NUM_HEADS == 0

DROP_PROB = 0.5
GRADIENT_CLIP_VAL = 0.5
NUM_BLOCKS_ENCODER = 3
NUM_BLOCKS_DECODER = 3
SPECIAL_CHAR_DICT = {'<bos>': vocabulary['<bos>'], '<eos>': vocabulary['<eos>'], '<pad>': vocabulary['<pad>']}
OPTIMIZER_PARAMS = {'betas': (0.9, 0.995)}

In [None]:
tp_transformer_hyperparams = {
    "special_idxs": SPECIAL_CHAR_DICT,
    "optimizer_params": OPTIMIZER_PARAMS,
    "learning_rate": LEARNING_RATE,
    "num_heads": NUM_HEADS,
    "embedding_dim": EMBEDDING_DIM,
    "hidden_size": HIDDEN_SIZE,
    "vocabulary_size": len(vocabulary),
    "num_blocks_encoder": NUM_BLOCKS_ENCODER,
    "num_blocks_decoder": NUM_BLOCKS_DECODER,
    "dropout": DROP_PROB,
    "gradient_clip_val": GRADIENT_CLIP_VAL, # Added just to be saved
    "tp_attention": True
}

#now = datetime.now().strftime("%H.%M")

logger = TensorBoardLogger(logger_dir, name="PRO_tp_transformer")
checkpoint_callback = ModelCheckpoint(
    dirpath = checkpoint_dir,
    filename='tp_transformer_{epoch:02d}_{step:06d}_{val_accuracy_epoch:.3f}',
    save_top_k=6,
    monitor='val_accuracy_epoch',
    mode='max',
    verbose=True,
    save_last=True
)
callbacks = [checkpoint_callback, TQDMProgressBar(refresh_rate=20)]
trainer_hyperparams = {
    "default_root_dir": root_dir,
    "accelerator": "auto",
    "devices": 1,
    "precision": 32, #16 if torch.cuda.is_available() else 32, # ADDED
    "log_every_n_steps": 100,
    # "val_check_interval": 0.5, # validation step called 2 times during a training epoch
    "val_check_interval": 1.0, 
    "gradient_clip_val": GRADIENT_CLIP_VAL,
    "max_epochs": EPOCHS,
    "logger": logger,
    "callbacks": callbacks,
    # "deterministic": True,
}

#modules = ['algebra__linear_1d', 'probability__swr_p_level_set', 'numbers__is_prime']
modules = ['algebra__linear_1d']
math_dm = Mathematics_DataModule(modules, batch_size=BATCH_SIZE)

In [None]:
tp_transformer = Transformer(**tp_transformer_hyperparams)
trainer = Trainer(**trainer_hyperparams)

#trainer.fit(tp_transformer, datamodule=math_dm)
math_dm.setup("fit")
trainer.fit(tp_transformer, train_dataloaders=math_dm.train_dataloader(), val_dataloaders=math_dm.val_dataloader())

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                 | Type               | Params
------------------------------------------------------------
0 | token_embedding      | Embedding          | 3.7 K 
1 | positional_embedding | PositionalEncoding | 0     
2 | encoder              | TransformerEncoder | 88.1 K
3 | decoder              | Transforme

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

In [None]:
trained_hyperparams = torch.load(f"{checkpoint_dir}/last-v3.ckpt")
trained_hyperparams['hyper_parameters']

In [None]:
test_model = Transformer(**tp_transformer_hyperparams)
test_trainer = Trainer(**trainer_hyperparams)
"""
# test_trainer.test(test_model, datamodule=math_dm, ckpt_path=f"{checkpoint_dir}/tp_transformer_epoch=00_step=001464_val_accuracy_epoch=0.287.ckpt", verbose=True)
test_trainer.test(test_model, datamodule=math_dm, ckpt_path="best", verbose=True) # ckpt_path="best"
"""
math_dm.setup("test")
test_trainer.test(test_model, dataloaders = math_dm.test_dataloader(), ckpt_path="best", verbose=True)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./training/tensorboard/logs/TP-Transformer #modifica in base al tuo path

In [None]:
checkpoint_dir = "./training/checkpoints/tp_transformer_checkpoints"  
logger = TensorBoardLogger(logger_dir, name="TP-Transformer", log_graph=True)



ckpt_path = checkpoint_dir + "/last.ckpt"  #attenzione che in caso di nuovi last checkpoint il nome è diverso
checkpoint_dir_fineTuning = "./training/checkpoints/tp_transformer_checkpoints_fineTuning"
checkpoint_callback = ModelCheckpoint(
    dirpath = checkpoint_dir_fineTuning,
    filename='tp_transformer_{epoch:02d}_{step:06d}',
    save_top_k=3,
    monitor='accuracy_epoch',
    mode='max',
    save_last=True
)
tp_transformer_ckpt = Transformer.load_from_checkpoint(ckpt_path)


callbacks = [checkpoint_callback, TQDMProgressBar(refresh_rate=20)]

trainer = pl.Trainer(log_every_n_steps=1, default_root_dir=root_dir, accelerator='auto', devices=1, gradient_clip_val = 0.1, max_epochs = EPOCHS + ADDITIONAL_EPOCHS, logger = logger, callbacks = callbacks)
math_dm = Mathematics_DataModule(['algebra__linear_1d'], batch_size = BATCH_SIZE)
trainer.fit(tp_transformer_ckpt, datamodule = math_dm, ckpt_path = ckpt_path)



# NON-SOTA (Transformer)

In [None]:
vocabulary = v

In [None]:
root_dir = "./training/checkpoints"
logger_dir = "./training/tensorboard/logs"
checkpoint_dir = "./training/checkpoints/transformer_vanilla_checkpoints.ckpt"
EPOCHS = 3
BATCH_SIZE = 4
EMBEDDING_DIM = 256
NUM_HEADS = 8
assert EMBEDDING_DIM % NUM_HEADS == 0
HIDDEN_SIZE = 512
DROP_PROB = 0.2
NUM_BLOCKS_ENCODER = 6
NUM_BLOCKS_DECODER = 6
SPECIAL_CHAR_DICT = {'<bos>': vocabulary['<bos>'], '<eos>': vocabulary['<eos>'], '<pad>': vocabulary['<pad>']}


ADDITIONAL_EPOCHS = 5

In [None]:
transformer_vanilla = Transformer(
    SPECIAL_CHAR_DICT, embedding_dim = EMBEDDING_DIM, num_heads = NUM_HEADS, hidden_size = HIDDEN_SIZE, 
    dropout = DROP_PROB, vocabulary_size = len(vocabulary), num_blocks_encoder = NUM_BLOCKS_ENCODER,
    num_blocks_decoder = NUM_BLOCKS_DECODER
    )

logger = TensorBoardLogger(logger_dir, name="Transformer-Vanilla", log_graph=True)

callbacks = [checkpoint_callback, TQDMProgressBar(refresh_rate=20)]

checkpoint_callback = ModelCheckpoint(
    dirpath = checkpoint_dir,
    filename='transformer_vanilla_{epoch:02d}_{step:06d}',
    save_top_k=3,
    monitor='accuracy_epoch',
    mode='max',
    save_last=True
)


trainer = pl.Trainer(default_root_dir=root_dir, accelerator='auto', devices=1, gradient_clip_val = 0.1, max_epochs = EPOCHS+ADDITIONAL_EPOCHS, logger = logger, callbacks = callbacks)
math_dm = Mathematics_DataModule(['algebra__linear_1d'], batch_size = BATCH_SIZE)
trainer.fit(transformer_vanilla, datamodule = math_dm)


In [None]:
trainer.validate(datamodule=dm)
trainer.test(datamodule=dm)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./training/tensorboard/logs/Transformer-Vanilla


da fare:

-  controllare l'architettura GRU (FATTO)

-  controllare teacher forcing (FATTO)



-  utilizzare stage (parametro di setup) per caricare anche un solo dataset se stage = "train" ad esempio 
   (https://colab.research.google.com/drive/1oJrA-Q-neOl1fCQJhIWR_GmxpYaG-cFx?authuser=1#scrollTo=JM57yq7bJS0E)

-  aggiungere predict_step nel pl.LightningModule dove si chiama inference e relativo predict dataloader nel Lightning data module


-  RNN fatte molto bene:
    https://github.com/georgeyiasemis/Recurrent-Neural-Networks-from-scratch-using-PyTorch 
    https://towardsdatascience.com/building-a-lstm-by-hand-on-pytorch-59c02a4ec091
    https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html


In [None]:
vocabulary = v

In [None]:
class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GRUCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size

        self.Wx = nn.Linear(input_size, 3*hidden_size, bias=True)
        self.Wh_gate = nn.Linear(hidden_size, 2*hidden_size, bias=True)

        self.Whh = nn.Linear(hidden_size, hidden_size, bias = True)
        
        #nell'implementazione del git c'è anche una funzione reset_parameters

    def forward(self, x, h):

        x_reset, x_update, x_candidate = torch.tensor_split(self.Wx(x), 3, dim=-1)
        """
        print(f"shape h {h.shape}")
        print(f"shape {x.shape}")
        print(f"hidden_size {self.hidden_size}")
        print(f"input_size {self.input_size}")
        """
        h_reset, h_update = torch.tensor_split(self.Wh_gate(h), 2, dim=-1)

        reset_gate = torch.sigmoid(x_reset + h_reset)

        update_gate = torch.sigmoid(x_update + h_update)

        h_candidate = torch.tanh(x_candidate + self.Whh(reset_gate * h))   

        h_t = update_gate * h + (1-update_gate) * h_candidate

        return h_t




In [None]:
class GRU(nn.Module):
    
    def __init__ (self, input_size, hidden_size, num_cells=2):
        assert num_cells>0

        super(GRU, self).__init__()

        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_cells = num_cells
        
        self.GRU_cells = nn.ModuleList(
            [GRUCell(input_size, hidden_size)]+[GRUCell(hidden_size, hidden_size) for _ in range(1, num_cells)])
        
    def forward(self, x, h=None):
        "x è una sequenza [batch, seq_len, embedding_dim]"

        """
        
        """
        batch_size, seq_len, _ = x.shape

        "h [self.num_cells, batch_size, self.hidden_size]"
        #output_states = torch.stack([torch.zeros(batch_size, self.hidden_size) for _ in range(seq_len)], dim = 0)
        output_states = torch.zeros(batch_size, seq_len, self.hidden_size)

        if(h!=None):
            hidden_states = h
        else:
            hidden_states = torch.zeros(self.num_cells, batch_size, self.hidden_size)
        
        #hidden_states = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_cells)]

        #hidden_states = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_cells)]
        
        for t in range(seq_len):

            x_t = x[:,t,:]

            hidden_states[0] = self.GRU_cells[0](x_t, hidden_states[0])

            for l in range(1, self.num_cells):

                hidden_states[l] = self.GRU_cells[l](hidden_states[l-1], hidden_states[l])


            output_states[:,t,:] = hidden_states[self.num_cells - 1]

        #output_state : [batch_size, seq_len, hidden_size]
        return output_states, hidden_states #(così si prende output_states[:,-1,:] da dare al decoder e tutto output_states per il linear dopo il decoder)
        



In [None]:
""" prove
t = torch.cat([torch.zeros(1, 2, 3) for _ in range(3)], dim = 0)
print(t.shape)
t[1] = torch.ones(2,3)
print(t)
t[-1].shape

#equivalente a 

t = torch.stack([torch.zeros(2, 3) for _ in range(3)], dim = 0)
print(t.shape)
t[1] = torch.ones(2,3)
print(t)
t[-1].shape
"""

' prove\nt = torch.cat([torch.zeros(1, 2, 3) for _ in range(3)], dim = 0)\nprint(t.shape)\nt[1] = torch.ones(2,3)\nprint(t)\nt[-1].shape\n\n#equivalente a \n\nt = torch.stack([torch.zeros(2, 3) for _ in range(3)], dim = 0)\nprint(t.shape)\nt[1] = torch.ones(2,3)\nprint(t)\nt[-1].shape\n'

In [None]:
class GRUEncoderDecoder(pl.LightningModule): #oppure Seq2Seq (informiamoci sui nomi)
    def __init__(self, special_idxs, embedding_dim = 256, hidden_size = 512, dropout = 0.2, vocabulary_size = 58, num_cells = 2):
        super(GRUEncoderDecoder, self).__init__()

        #CONTROLLARE LE DIMENSIONI
        self.save_hyperparameters()

        self.bos_id = special_idxs['<bos>']
        self.eos_id = special_idxs['<eos>']
        self.pad_id = special_idxs['<pad>']


        print(special_idxs)
        self.token_embedding = nn.Embedding(vocabulary_size, embedding_dim, padding_idx = self.pad_id)

        self.GRU_encoder = GRU(embedding_dim, hidden_size, num_cells) 

        self.GRU_decoder = GRU(embedding_dim + hidden_size, hidden_size, num_cells)

        self.to_logits = nn.Sequential(nn.Linear(hidden_size, hidden_size),
                                       nn.ReLU(), 
                                       nn.Dropout(dropout),
                                       nn.Linear(hidden_size, vocabulary_size)) 
                                       

        self.max_len_question = 162
        self.max_len_answer = 32

        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=vocabulary_size, ignore_index = self.pad_id)

        
        #embedding
    #encoder GRU
    #decoder GRU
    #ff per classification
    #decoder dovrebbe poter utilizzare teacher forcing credo -> metodo inference come Transformer

    def inference(self, x):

        #encode and then generate the output token by token greedily

        self.eval()
        with torch.no_grad():
            batch_size = x.size(0)
            x = self.token_embedding(x)

            output_encoder, previous_state = self.GRU_encoder(x)

            last_state_encoder = previous_state[-1].unsqueeze(dim=1)


            output = torch.ones(batch_size, 1, dtype=torch.int64, device = self.device).fill_(self.bos_id)
            done = torch.zeros(batch_size, dtype = torch.uint8, device = self.device)




            for _ in range(self.max_len_answer - 1):
                #last_state_encoder_repeated = last_state_encoder.repeat(current_output.shape[1],1,1).transpose(0,1) 

                    #output.shape[1] è 1, non è un problema visto che nel decoder consideriamo sempre un token

                current_output = output[:,-1].unsqueeze(dim=1)

                current_output_embedding = self.token_embedding(current_output)

                input_decoder = torch.cat((last_state_encoder, current_output_embedding), dim=-1)

                out, previous_state = self.GRU_decoder(input_decoder, previous_state)

                out = self.to_logits(out)

                out = torch.argmax(out[:,[-1],:], dim = -1)

                output = torch.cat([output, out], dim = 1)

                eos_reached = out.squeeze(1) == self.eos_id
                done |= eos_reached
                if done.sum() == batch_size:
                    break

            return output

    def forward(self, x, y):

        x = self.token_embedding(x)

        y = self.token_embedding(y)

        print(f"shape x {x.shape}")
        print(f"shape y {y.shape}")


        output_encoder, state_encoder = self.GRU_encoder(x) 

        print(f"shape output_encoder {output_encoder.shape}")
        print(f"shape state_encoder {state_encoder.shape}")

        last_state_encoder = state_encoder[-1]

        print(f"shape last_state_encoder {last_state_encoder.shape}")
        #state_encoder = self.GRU_encoder(x)[:,-1,:]

        last_state_encoder_repeated = last_state_encoder.repeat(y.shape[1],1,1).transpose(0,1) #CONTROLLA

        print(f"shape last_state_encoder_repeated {last_state_encoder_repeated.shape}")

        input_decoder = torch.cat((last_state_encoder_repeated, y), dim=-1)

        print(f"shape input_decoder {input_decoder.shape}")
        output_decoder, _ = self.GRU_decoder(input_decoder, state_encoder)

        print(f"shape output_decoder {output_decoder.shape}")
        return self.to_logits(output_decoder).transpose(1,2)
        
    
    def configure_optimizers(self):# learning rate = 1x10^-4; beta1 =0.9; beta2 = 0.995 dal paper
        return torch.optim.Adam(self.parameters(), lr=1e-4, betas=(0.9, 0.995))
    
        pass

    def training_step(self, batch, batch_idx):#DOVREBBE RIMANERE COSì
        x, y = batch
        y_pred = self(x, y)
        loss = F.cross_entropy(y_pred, y, ignore_index = self.pad_id)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):#DOVREBBE RIMANERE COSì
        x, y = batch
        y_pred = self.inference(x)  #[batch_size, max_eos_found]
        y_pred = F.pad(y_pred, (0, self.max_len_answer - y_pred.shape[1]), mode='constant', value=self.pad_id) #[batch_size, max_len_answer]
        self.accuracy.update(y_pred, y)

    def test_step(self, batch, batch_idx):#DOVREBBE RIMANERE COSì
        x, y = batch
        y_pred = self.inference(x)  #[batch_size, max_eos_found]
        y_pred = F.pad(y_pred, (0, self.max_len_answer - y_pred.shape[1]), mode='constant', value=self.pad_id) #[batch_size, max_len_answer]
        self.accuracy.update(y_pred, y)

    def validation_epoch_end(self, outputs):#DOVREBBE RIMANERE COSì
        self.log('accuracy_epoch', self.accuracy.compute())
        self.accuracy.reset()


    
    

In [None]:
dm.setup("")
train_loader = dm.train_dataloader()
x = next(iter(train_loader))


In [None]:
x

[tensor([[ 1, 48, 44,  ...,  0,  0,  0],
         [ 1, 48, 44,  ...,  0,  0,  0],
         [ 1, 48, 44,  ...,  0,  0,  0],
         ...,
         [ 1, 48, 44,  ...,  0,  0,  0],
         [ 1, 48, 44,  ...,  0,  0,  0],
         [ 1, 48, 44,  ...,  0,  0,  0]]),
 tensor([[ 1, 20,  2,  ...,  0,  0,  0],
         [ 1, 16, 15,  ...,  0,  0,  0],
         [ 1, 19, 22,  ...,  0,  0,  0],
         ...,
         [ 1, 12, 20,  ...,  0,  0,  0],
         [ 1, 18, 18,  ...,  0,  0,  0],
         [ 1, 12, 19,  ...,  0,  0,  0]])]

In [None]:

root_dir = "./training/GRU/checkpoints"
logger_dir = "./training/GRU/tensorboard/logs"
checkpoint_dir = "./training/GRU/checkpoints/gru_seq2seq_checkpoints"

EPOCHS = 2

BATCH_SIZE = 4
EMBEDDING_DIM = 256
NUM_HEADS = 8
HIDDEN_SIZE = 512
DROP_PROB = 0.2
NUM_CELLS = 2
SPECIAL_CHAR_DICT = {'<bos>': vocabulary['<bos>'], '<eos>': vocabulary['<eos>'], '<pad>': vocabulary['<pad>']}


ADDITIONAL_EPOCHS = 5

In [None]:
gru_seq2seq = GRUEncoderDecoder(
    SPECIAL_CHAR_DICT, embedding_dim = EMBEDDING_DIM, hidden_size = HIDDEN_SIZE, 
    dropout = DROP_PROB, vocabulary_size = len(vocabulary), num_cells = NUM_CELLS
    )


{'<bos>': 1, '<eos>': 2, '<pad>': 0}


In [None]:
x[0]

tensor([[ 1, 48, 44,  ...,  0,  0,  0],
        [ 1, 48, 44,  ...,  0,  0,  0],
        [ 1, 48, 44,  ...,  0,  0,  0],
        ...,
        [ 1, 48, 44,  ...,  0,  0,  0],
        [ 1, 48, 44,  ...,  0,  0,  0],
        [ 1, 48, 44,  ...,  0,  0,  0]])

In [None]:
gru_seq2seq(x[0], x[1]).shape

shape x torch.Size([64, 162, 256])
shape y torch.Size([64, 32, 256])
shape output_encoder torch.Size([64, 162, 512])
shape state_encoder torch.Size([2, 64, 512])
shape last_state_encoder torch.Size([64, 512])
shape last_state_encoder_repeated torch.Size([64, 32, 512])
shape input_decoder torch.Size([64, 32, 768])
shape output_decoder torch.Size([64, 32, 512])


torch.Size([64, 58, 32])

In [None]:
F.cross_entropy(gru_seq2seq(x[0], x[1]), x[1])

shape x torch.Size([64, 162, 256])
shape y torch.Size([64, 32, 256])
shape output_encoder torch.Size([64, 162, 512])
shape state_encoder torch.Size([2, 64, 512])
shape last_state_encoder torch.Size([64, 512])
shape last_state_encoder_repeated torch.Size([64, 32, 512])
shape input_decoder torch.Size([64, 32, 768])
shape output_decoder torch.Size([64, 32, 512])


tensor(4.0678, grad_fn=<NllLoss2DBackward0>)

In [None]:
gru_seq2seq.inference(x[0]).shape

torch.Size([64, 32])

In [None]:
"""
logger = TensorBoardLogger(logger_dir, name="GRU_SEQ2SEQ", log_graph=True)
checkpoint_callback = ModelCheckpoint(
    dirpath = checkpoint_dir,
    filename='gru_seq2seq_{epoch:02d}_{step:06d}',
    save_top_k=3,
    monitor='accuracy_epoch',
    mode='max',
    save_last=True
)
callbacks = [checkpoint_callback, TQDMProgressBar(refresh_rate=20)]
trainer = pl.Trainer(default_root_dir=root_dir, accelerator='cpu', devices=1, gradient_clip_val = 0.1, max_epochs = EPOCHS, logger = logger, callbacks = callbacks)
math_dm = Mathematics_DataModule(['algebra__linear_1d'], batch_size = BATCH_SIZE)
trainer.fit(gru_seq2seq, datamodule = math_dm)
"""

'\nlogger = TensorBoardLogger(logger_dir, name="GRU_SEQ2SEQ", log_graph=True)\ncheckpoint_callback = ModelCheckpoint(\n    dirpath = checkpoint_dir,\n    filename=\'gru_seq2seq_{epoch:02d}_{step:06d}\',\n    save_top_k=3,\n    monitor=\'accuracy_epoch\',\n    mode=\'max\',\n    save_last=True\n)\ncallbacks = [checkpoint_callback, TQDMProgressBar(refresh_rate=20)]\ntrainer = pl.Trainer(default_root_dir=root_dir, accelerator=\'cpu\', devices=1, gradient_clip_val = 0.1, max_epochs = EPOCHS, logger = logger, callbacks = callbacks)\nmath_dm = Mathematics_DataModule([\'algebra__linear_1d\'], batch_size = BATCH_SIZE)\ntrainer.fit(gru_seq2seq, datamodule = math_dm)\n'

In [None]:
"""
da fare:

-  utilizzare stage (parametro di setup) per caricare anche un solo dataset se stage = "train" ad esempio 
   (https://colab.research.google.com/drive/1oJrA-Q-neOl1fCQJhIWR_GmxpYaG-cFx?authuser=1#scrollTo=JM57yq7bJS0E)

-  aggiungere predict_step nel pl.LightningModule dove si chiama inference e relativo predict dataloader nel Lightning data module


-  RNN fatte molto bene:
    https://github.com/georgeyiasemis/Recurrent-Neural-Networks-from-scratch-using-PyTorch 
    https://towardsdatascience.com/building-a-lstm-by-hand-on-pytorch-59c02a4ec091
    https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
"""

'\nda fare:\n\n-  utilizzare stage (parametro di setup) per caricare anche un solo dataset se stage = "train" ad esempio \n   (https://colab.research.google.com/drive/1oJrA-Q-neOl1fCQJhIWR_GmxpYaG-cFx?authuser=1#scrollTo=JM57yq7bJS0E)\n\n-  aggiungere predict_step nel pl.LightningModule dove si chiama inference e relativo predict dataloader nel Lightning data module\n\n\n-  RNN fatte molto bene:\n    https://github.com/georgeyiasemis/Recurrent-Neural-Networks-from-scratch-using-PyTorch \n    https://towardsdatascience.com/building-a-lstm-by-hand-on-pytorch-59c02a4ec091\n    https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html\n'

#GRU Modules


da fare:

-  controllare l'architettura GRU (FATTO)

-  controllare teacher forcing (FATTO)



-  utilizzare stage (parametro di setup) per caricare anche un solo dataset se stage = "train" ad esempio 
   (https://colab.research.google.com/drive/1oJrA-Q-neOl1fCQJhIWR_GmxpYaG-cFx?authuser=1#scrollTo=JM57yq7bJS0E)

-  aggiungere predict_step nel pl.LightningModule dove si chiama inference e relativo predict dataloader nel Lightning data module


-  RNN fatte molto bene:
    https://github.com/georgeyiasemis/Recurrent-Neural-Networks-from-scratch-using-PyTorch 
    https://towardsdatascience.com/building-a-lstm-by-hand-on-pytorch-59c02a4ec091
    https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html


In [None]:
class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GRUCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size

        self.Wx = nn.Linear(input_size, 3*hidden_size, bias=True)
        self.Wh_gate = nn.Linear(hidden_size, 2*hidden_size, bias=True)

        self.Whh = nn.Linear(hidden_size, hidden_size, bias = True)

        
        #nell'implementazione del git c'è anche una funzione reset_parameters

    def forward(self, x, h):

        x_reset, x_update, x_candidate = torch.tensor_split(self.Wx(x), 3, dim=-1)
        """
        print(f"shape h {h.shape}")
        print(f"shape {x.shape}")
        print(f"hidden_size {self.hidden_size}")
        print(f"input_size {self.input_size}")
        """
        #h = h.clone()           #INSERITO

        h_reset, h_update = torch.tensor_split(self.Wh_gate(h), 2, dim=-1)

        reset_gate = torch.sigmoid(x_reset + h_reset)

        update_gate = torch.sigmoid(x_update + h_update)

        h_candidate = torch.tanh(x_candidate + self.Whh(reset_gate * h))   

        h_t = update_gate * h + (1-update_gate) * h_candidate

        return h_t




In [None]:
class GRU(nn.Module):
    
    def __init__ (self, input_size, hidden_size, num_cells=2, device=None):
        assert num_cells>0

        super(GRU, self).__init__()

        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_cells = num_cells
        

        self.GRU_cells = nn.ModuleList(
            [GRUCell(input_size, hidden_size)]+[GRUCell(hidden_size, hidden_size) for _ in range(1, num_cells)])
        
    def forward(self, x, h=None):
        "x è una sequenza [batch, seq_len, embedding_dim]"

        """
        
        """
        batch_size, seq_len, _ = x.shape

        "h [self.num_cells, batch_size, self.hidden_size]"
        #output_states = torch.stack([torch.zeros(batch_size, self.hidden_size) for _ in range(seq_len)], dim = 0)
        output_states = torch.zeros((batch_size, seq_len, self.hidden_size), device=x.device)

        #print(self.device)
        #print(f"output_states device {output_states.device}")

        if(h!=None):
            hidden_states = h
        else:
            #hidden_states = torch.zeros((self.num_cells, batch_size, self.hidden_size), device=x.device)
            hidden_states = [torch.zeros((batch_size, self.hidden_size), device=x.device) for _ in range(self.num_cells)]

        #print(f"hidden_states device {hidden_states.device}")

        
        #hidden_states = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_cells)]

        #hidden_states = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_cells)]
        
        for t in range(seq_len):

            x_t = x[:,t,:]

            #hidden_states = hidden_states.clone()

            hidden_states[0] = self.GRU_cells[0](x_t, hidden_states[0])

            for l in range(1, self.num_cells):
                #hidden_states = hidden_states.clone()

                hidden_states[l] = self.GRU_cells[l](hidden_states[l-1], hidden_states[l])

            #output_states = output_states.clone()

            output_states[:,t,:] = hidden_states[self.num_cells - 1]

        #output_state : [batch_size, seq_len, hidden_size]
        return output_states, hidden_states #(così si prende output_states[:,-1,:] da dare al decoder e tutto output_states per il linear dopo il decoder)
        



In [None]:
""" prove
t = torch.cat([torch.zeros(1, 2, 3) for _ in range(3)], dim = 0)
print(t.shape)
t[1] = torch.ones(2,3)
print(t)
t[-1].shape

#equivalente a 

t = torch.stack([torch.zeros(2, 3) for _ in range(3)], dim = 0)
print(t.shape)
t[1] = torch.ones(2,3)
print(t)
t[-1].shape
"""

' prove\nt = torch.cat([torch.zeros(1, 2, 3) for _ in range(3)], dim = 0)\nprint(t.shape)\nt[1] = torch.ones(2,3)\nprint(t)\nt[-1].shape\n\n#equivalente a \n\nt = torch.stack([torch.zeros(2, 3) for _ in range(3)], dim = 0)\nprint(t.shape)\nt[1] = torch.ones(2,3)\nprint(t)\nt[-1].shape\n'

In [None]:
class GRUEncoderDecoder(pl.LightningModule): #oppure Seq2Seq (informiamoci sui nomi)
    def __init__(self,
        special_idxs: Dict[str, int],
        optimizer_params: dict,
        learning_rate: float=1e-4,
        embedding_dim: float=256,
        hidden_size: int=512,
        vocabulary_size: int=58,
        max_len_question: int=162,
        max_len_answer: int=32,
        num_cells: int=2,
        dropout: float=0.2,
    ):
        super(GRUEncoderDecoder, self).__init__()

        #CONTROLLARE LE DIMENSIONI
        self.save_hyperparameters()

        self.bos_id = special_idxs['<bos>']
        self.eos_id = special_idxs['<eos>']
        self.pad_id = special_idxs['<pad>']

        self.optimizer_params = optimizer_params
        self.learning_rate = learning_rate
        self.embedding_dim = embedding_dim
        self.vocabulary_size = vocabulary_size

        self.token_embedding = nn.Embedding(vocabulary_size, embedding_dim, padding_idx = self.pad_id)

        self.GRU_encoder = GRU(embedding_dim, hidden_size, num_cells) 

        self.GRU_decoder = GRU(embedding_dim + hidden_size, hidden_size, num_cells)

        self.to_logits = nn.Sequential(nn.Linear(hidden_size, hidden_size),
                                       nn.ReLU(), 
                                       nn.Dropout(dropout),
                                       nn.Linear(hidden_size, vocabulary_size)) 
        
        self.max_len_question = max_len_question
        self.max_len_answer = max_len_answer

        self.train_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=vocabulary_size, ignore_index=self.pad_id)
        self.val_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=vocabulary_size, ignore_index=self.pad_id)
        self.test_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=vocabulary_size, ignore_index=self.pad_id)


        
        #embedding
    #encoder GRU
    #decoder GRU
    #ff per classification
    #decoder dovrebbe poter utilizzare teacher forcing credo -> metodo inference come Transformer

    def inference(self, x):

        #encode and then generate the output token by token greedily

        self.eval()
        with torch.no_grad():
            batch_size = x.size(0)
            x = self.token_embedding(x)

            #print(f"x device {x.device}")

            output_encoder, previous_state = self.GRU_encoder(x)

            #print(f"output_encoder device {output_encoder.device}")
            #print(f"previous_state device {output_encoder.device}")


            last_state_encoder = previous_state[-1].unsqueeze(dim=1)


            output = torch.ones(batch_size, 1, dtype=torch.int64, device = self.device).fill_(self.bos_id)
            done = torch.zeros(batch_size, dtype = torch.uint8, device = self.device)




            for _ in range(self.max_len_answer - 1):
                #last_state_encoder_repeated = last_state_encoder.repeat(current_output.shape[1],1,1).transpose(0,1) 

                    #output.shape[1] è 1, non è un problema visto che nel decoder consideriamo sempre un token

                current_output = output[:,-1].unsqueeze(dim=1)

                current_output_embedding = self.token_embedding(current_output)

                input_decoder = torch.cat((last_state_encoder, current_output_embedding), dim=-1)

                out, previous_state = self.GRU_decoder(input_decoder, previous_state)

                out = self.to_logits(out)

                out = torch.argmax(out[:,[-1],:], dim = -1)

                output = torch.cat([output, out], dim = 1)

                eos_reached = out.squeeze(1) == self.eos_id
                done |= eos_reached
                if done.sum() == batch_size:
                    break

            return output

    def forward(self, x, y):

        x = self.token_embedding(x)

        y = self.token_embedding(y)

        #print(f"shape x {x.shape}")
        #print(f"shape y {y.shape}")


        output_encoder, state_encoder = self.GRU_encoder(x) 

        #print(f"shape output_encoder {output_encoder.shape}")
        #print(f"shape state_encoder {state_encoder.shape}")

        last_state_encoder = state_encoder[-1]

        #print(f"shape last_state_encoder {last_state_encoder.shape}")
        #state_encoder = self.GRU_encoder(x)[:,-1,:]

        last_state_encoder_repeated = last_state_encoder.repeat(y.shape[1],1,1).transpose(0,1) #CONTROLLA

        #print(f"shape last_state_encoder_repeated {last_state_encoder_repeated.shape}")

        input_decoder = torch.cat((last_state_encoder_repeated, y), dim=-1)

        #print(f"shape input_decoder {input_decoder.shape}")
        output_decoder, _ = self.GRU_decoder(input_decoder, state_encoder)

        #print(f"shape output_decoder {output_decoder.shape}")
        return self.to_logits(output_decoder).transpose(1,2)
        
    
    def configure_optimizers(self):# learning rate = 1x10^-4; beta1 =0.9; beta2 = 0.995 dal paper
        betas = self.optimizer_params['betas']
        return torch.optim.Adam(self.parameters(), self.learning_rate, betas)


    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x, y)
        loss = F.cross_entropy(y_pred, y, ignore_index = self.pad_id)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        self.train_accuracy.update(y_pred, y)
        self.log('train_accuracy_epoch', self.train_accuracy.compute(), on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.inference(x)  #[batch_size, max_eos_found]
        y_pred = F.pad(y_pred, (0, self.max_len_answer - y_pred.shape[1]), mode='constant', value=self.pad_id) #[batch_size, max_len_answer]
        self.val_accuracy.update(y_pred, y)
        self.log('val_accuracy_step', self.val_accuracy.compute(), on_step=True, on_epoch=False, prog_bar=True, logger=True)


    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.inference(x)  #[batch_size, max_eos_found]
        y_pred = F.pad(y_pred, (0, self.max_len_answer - y_pred.shape[1]), mode='constant', value=self.pad_id) #[batch_size, max_len_answer]
        self.test_accuracy.update(y_pred, y)
        

    def validation_epoch_end(self, outputs):
        self.log('val_accuracy_epoch', self.val_accuracy.compute(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.val_accuracy.reset()
        
        # Also reset the training accuracy
        self.train_accuracy.reset()

    
    def test_epoch_end(self, outputs):
        self.log('test_accuracy_epoch', self.test_accuracy.compute(), on_epoch=True, prog_bar=True, logger=True)
        self.test_accuracy.reset()


    
    

#NON SOTA (GRU)

In [None]:
"""
dm.setup("")
train_loader = dm.train_dataloader()
x = next(iter(train_loader))
"""

'\ndm.setup("")\ntrain_loader = dm.train_dataloader()\nx = next(iter(train_loader))\n'

In [None]:
#x

In [None]:
vocabulary = v

In [None]:
root_dir = "./training/GRU/checkpoints"
logger_dir = "./training/GRU/tensorboard/logs"
checkpoint_dir = "./training/GRU/checkpoints/gru_seq2seq_checkpoints"


EPOCHS = 2
LEARNING_RATE = 1e-4
BATCH_SIZE = 128
EMBEDDING_DIM = 256

# HIDDEN_SIZE = 2048
HIDDEN_SIZE = 512
DROP_PROB = 0.2
GRADIENT_CLIP_VAL = 0.1
NUM_CELLS = 2
SPECIAL_CHAR_DICT = {'<bos>': vocabulary['<bos>'], '<eos>': vocabulary['<eos>'], '<pad>': vocabulary['<pad>']}
OPTIMIZER_PARAMS = {'betas': (0.9, 0.995)}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
GRU_hyperparams = {
    "special_idxs": SPECIAL_CHAR_DICT,
    "optimizer_params": OPTIMIZER_PARAMS,
    "learning_rate": LEARNING_RATE,
    "embedding_dim": EMBEDDING_DIM,
    "hidden_size": HIDDEN_SIZE,
    "vocabulary_size": len(vocabulary),
    "num_cells": NUM_CELLS,
    "dropout": DROP_PROB#,
    #"device" : device
    }

logger = TensorBoardLogger(logger_dir, name="GRU")
checkpoint_callback = ModelCheckpoint(
    dirpath = checkpoint_dir,
    filename='GRU_{epoch:02d}_{step:06d}_{val_accuracy_epoch:.3f}',
    save_top_k=6,
    monitor='val_accuracy_epoch',
    mode='max',
    verbose=True,
    save_last=True
)
callbacks = [checkpoint_callback, TQDMProgressBar(refresh_rate=20)]
trainer_hyperparams = {
    "default_root_dir": root_dir,
    "accelerator": "auto",
    "devices": 1,
    "precision": 16, # ADDED
    "log_every_n_steps": 10,
    # "val_check_interval": 0.5, # validation step called 2 times during a training epoch
    "val_check_interval": 1.0, 
    "gradient_clip_val": GRADIENT_CLIP_VAL,
    "max_epochs": EPOCHS,
    "logger": logger,
    "callbacks": callbacks,
    # "deterministic": True,
}

modules = ['algebra__linear_1d', 'probability__swr_p_level_set', 'numbers__is_prime']
# modules = ['algebra__linear_1d']
math_dm = Mathematics_DataModule(modules, difficulty=['train-easy'], batch_size=BATCH_SIZE)

In [None]:
gru_seq2seq = GRUEncoderDecoder(**GRU_hyperparams)


trainer = Trainer(**trainer_hyperparams)

trainer.fit(gru_seq2seq, datamodule=math_dm)

INFO:pytorch_lightning.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
INFO:pytorch_lightning.callbacks.model_summary:
  | Name            | Type               | Params
-------------------------------------------------------
0 | token_embedding | Embedding          | 14.8 K
1 | GRU_encoder     | GRU                | 2.8 M 
2 | GRU_decoder     | GRU                | 3.5 M 
3 | to_logits       | Sequential         | 292 K 
4 | train_accuracy  | MulticlassAccuracy | 0     
5 | val_accuracy    | MulticlassAccur

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 1: 'val_accuracy_epoch' reached 0.42857 (best 0.42857), saving model to '/content/drive/.shortcut-targets-by-id/1IS7xxoH06-zPLbTk07CAGSmtMTFE85h_/Deep_Learning_Project/training/GRU/checkpoints/gru_seq2seq_checkpoints/GRU_epoch=00_step=000001_val_accuracy_epoch=0.429.ckpt' as top 6


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 2: 'val_accuracy_epoch' reached 0.28571 (best 0.42857), saving model to '/content/drive/.shortcut-targets-by-id/1IS7xxoH06-zPLbTk07CAGSmtMTFE85h_/Deep_Learning_Project/training/GRU/checkpoints/gru_seq2seq_checkpoints/GRU_epoch=01_step=000002_val_accuracy_epoch=0.286.ckpt' as top 6
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.


In [None]:
x[0]

NameError: ignored

In [None]:
gru_seq2seq(x[0], x[1]).shape

In [None]:
F.cross_entropy(gru_seq2seq(x[0], x[1]), x[1])

In [None]:
gru_seq2seq.inference(x[0]).shape

In [None]:
"""
logger = TensorBoardLogger(logger_dir, name="GRU_SEQ2SEQ", log_graph=True)
checkpoint_callback = ModelCheckpoint(
    dirpath = checkpoint_dir,
    filename='gru_seq2seq_{epoch:02d}_{step:06d}',
    save_top_k=3,
    monitor='accuracy_epoch',
    mode='max',
    save_last=True
)
callbacks = [checkpoint_callback, TQDMProgressBar(refresh_rate=20)]
trainer = pl.Trainer(default_root_dir=root_dir, accelerator='cpu', devices=1, gradient_clip_val = 0.1, max_epochs = EPOCHS, logger = logger, callbacks = callbacks)
math_dm = Mathematics_DataModule(['algebra__linear_1d'], batch_size = BATCH_SIZE)
trainer.fit(gru_seq2seq, datamodule = math_dm)
"""

In [None]:
"""
da fare:

-  utilizzare stage (parametro di setup) per caricare anche un solo dataset se stage = "train" ad esempio 
   (https://colab.research.google.com/drive/1oJrA-Q-neOl1fCQJhIWR_GmxpYaG-cFx?authuser=1#scrollTo=JM57yq7bJS0E)

-  aggiungere predict_step nel pl.LightningModule dove si chiama inference e relativo predict dataloader nel Lightning data module


-  RNN fatte molto bene:
    https://github.com/georgeyiasemis/Recurrent-Neural-Networks-from-scratch-using-PyTorch 
    https://towardsdatascience.com/building-a-lstm-by-hand-on-pytorch-59c02a4ec091
    https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
"""