In [None]:
!pip install pytorch-lightning

import os
import glob
import math
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Transformer
from torch.utils.data import Dataset, DataLoader
import numpy as np
import sentencepiece as spm
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import pytorch_lightning as pl
from sklearn.utils import shuffle
############################## DATASET

class SaturdayDataset(Dataset):
    def __init__(self, source_sentences, target_sentences, max_len, tokenizer):
        self.source_sentences = source_sentences
        self.target_sentences = target_sentences
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.filter_invalid_pairs()

    def __len__(self):
        return len(self.source_sentences)

    def pad_sequence(self, original_source, original_target):
        # Encode the source and target sequences
        source = original_source[:]
        target = original_target[:]
        # Insert BOS token at the beginning of the target sequence
        target.insert(0, self.tokenizer.bos_id())
        target.append(self.tokenizer.eos_id())  # Add EOS token at the end

        # Calculate lengths
        source_len = len(source)
        target_len = len(target)

        # Padding for target sequence (add padding after EOS)
        remaining_padding_for_target = self.max_len - target_len
        if remaining_padding_for_target > 0:
            target = target + [self.tokenizer.pad_id()] * remaining_padding_for_target

        # Padding for source sequence
        remaining_padding_for_source = self.max_len - source_len
        if remaining_padding_for_source > 0:
            source = source + [self.tokenizer.pad_id()] * remaining_padding_for_source

        return source, target

    def filter_invalid_pairs(self):
        valid_sources = []
        valid_targets = []
        for src, trg in zip(self.source_sentences, self.target_sentences):
            padded_src, padded_trg = self.pad_sequence(src, trg)
            # Verifica che entrambi siano validi
            if len(padded_src) <= self.max_len and len(padded_trg) <= self.max_len:
                valid_sources.append(src)
                valid_targets.append(trg)

        # Aggiorna le liste con i dati validi
        self.source_sentences = valid_sources
        self.target_sentences = valid_targets

    def __getitem__(self, idx):
        src = self.source_sentences[idx]
        trg = self.target_sentences[idx]
        # Converti le frasi in indici e aggiungi il padding
        src_indexes, trg_indexes = self.pad_sequence(src, trg)

        return torch.tensor(src_indexes), torch.tensor(trg_indexes)
################################# MODEL
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads) -> None:
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = self.d_model // self.num_heads

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)

        self.W_out = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        # q,k,v are of shape (batch_size,n_heads,seq_len,single_head_dim)
        d_k = q.size(-1)

        # score is of shape (batch_size,n_heads,seq_len,seq_len)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

        # Ensure mask is on the same device as q, k, v
        if mask is not None:
            mask = mask.to(q.device)  # Move mask to the same device as the input tensors (q)
            scores = scores.masked_fill(mask == 0, -1e9)

        scores = F.softmax(scores, dim=-1)
        attention_weights = torch.matmul(scores, v)
        # attention_weights is of shape (batch_size,n_heads,seq_len,single_head_dim)
        return attention_weights

    def forward(self, q, k, v, mask=None):
        """
            Q |
            K | -> scaled_dot_product_attention -> concat -> linear
            V |

        """
        # q,k,v are of shape (batch_size,seq_length,d_model)
        batch_size = q.size(0)
        # calculating linear projections
        # reshaping to (batch_size,seq_len,n_heads,single_head_dim) -> transpose to (batch_size,n_heads,seq_len,single_head_dim)
        q = self.W_Q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.W_K(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.W_V(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # attention
        attention = self.scaled_dot_product_attention(q, k, v, mask)
        batch_size, _, seq_length, d_k = attention.size()
        output = self.W_out(attention.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model))
        return output


class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1) -> None:
        super(PositionWiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))


class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1) -> None:
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output = self.self_attn(x, x, x, mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        ff_output = self.feed_forward(x)
        x = x + self.dropout2(ff_output)
        x = self.norm2(x)
        return x


class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, input_vocab_size, max_len=512, dropout=0.1):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_vocab_size, d_model)
        self.positional_encoding = nn.Parameter(self._get_positional_encoding(max_len, d_model), requires_grad=False)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)

    def _get_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)

    def forward(self, x, mask=None):
        x = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = x + self.dropout2(attn_output)
        x = self.norm2(x)
        ff_output = self.feed_forward(x)
        x = x + self.dropout3(ff_output)
        x = self.norm3(x)
        return x


class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, output_vocab_size, max_len=512, dropout=0.1):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_vocab_size, d_model)
        self.positional_encoding = nn.Parameter(self._get_positional_encoding(max_len, d_model), requires_grad=False)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        self.fc_out = nn.Linear(d_model, output_vocab_size)

    def _get_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        x = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        x = self.fc_out(x)
        return x


class FridayTransformer(nn.Module):
    def __init__(self, embed_dim, src_vocab_size, target_vocab_size, num_layers=6, d_ff=2048, n_heads=8) -> None:
        super(FridayTransformer, self).__init__()
        self.target_vocab_size = target_vocab_size
        self.encoder = Encoder(embed_dim, n_heads, d_ff, num_layers, src_vocab_size)
        self.decoder = Decoder(embed_dim, n_heads, d_ff, num_layers, target_vocab_size)
        self.num_heads = n_heads

    def generate_mask(self, src, trg):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        batch_size, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(batch_size, 1, trg_len, trg_len)
        return src_mask, trg_mask

    def forward(self, src, trg):
        src_mask, trg_mask = self.generate_mask(src, trg)
        enc_out = self.encoder(src, src_mask)

        outputs = self.decoder(trg, enc_out, src_mask, trg_mask)
        return outputs

################################# LIGHTING
class LightningTransformer(pl.LightningModule):
    def __init__(
            self,
            embed_dim,
            vocab_size,
            num_layers,
            n_heads,
            learning_rate,
            sp_model,
            max_len,
            train_data=None,
            val_data=None,
            batch_size=50
    ):
        super().__init__()
        self.dataset = None
        self.val_dataset = None
        self.model = FridayTransformer(embed_dim=embed_dim,
                                         src_vocab_size=vocab_size,
                                         target_vocab_size=vocab_size,
                                         num_layers=num_layers,
                                         n_heads=n_heads
                                         )
        self.criterion = nn.CrossEntropyLoss(ignore_index=sp_model.pad_id())
        self.learning_rate = learning_rate
        self.sp_model = sp_model
        self.max_len = max_len
        self.inputs = None
        self.targets = None
        self.val_inputs = None
        self.val_targets = None
        self.train_data = train_data
        self.val_data = val_data
        self.batch_size = batch_size
        self.tokenizer = sp_model
        self.pad_id = sp_model.pad_id()

    def forward(self, src, trg_input):
        return self.model(src, trg_input)

    def training_step(self, batch, batch_idx):
        src, trg = batch
        # src e trg sono ora di forma [batch_size, sequence_length]

        # Move tensors to the same device as the model
        src = src.to(self.device)
        trg = trg.to(self.device)
        tgt_input = trg[:, :-1]  # Prendiamo tutto tranne l'ultimo token (l'input per il decoder)


        # Passiamo i dati nel modello
        logits = self(src, tgt_input)  # Forward pass

        # L'output del target è senza il primo token
        tgt_out = trg[:, 1:]

        # Calcoliamo la loss
        loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            src, trg = batch
            src = src.to(self.device)
            trg = trg.to(self.device)
            tgt_input = trg[:, :-1]

            logits = self(src, tgt_input)
            tgt_out = trg[:, 1:]

            loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

            return loss



    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def prepare_data(self):
        if self.train_data is None:
            raise ValueError("Train Data not provided.")
        inputs = [pair[0] for pair in self.train_data]
        targets = [pair[1] for pair in self.train_data]
        # Shuffle data to ensure randomness
        self.inputs, self.targets = shuffle(inputs, targets)
        if self.val_data is None:
            raise ValueError("Validation Data not provided.")
        val_inputs = [pair[0] for pair in self.val_data]
        val_targets = [pair[1] for pair in self.val_data]
        self.val_inputs, self.val_targets = shuffle(val_inputs, val_targets)


    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.dataset = SaturdayDataset(self.inputs, self.targets, self.max_len, self.sp_model)
            self.val_dataset = SaturdayDataset(self.val_inputs, self.val_targets, self.max_len, self.sp_model)

    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=5)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=5)
########################### INIT
tokenizer_path = '/content/drive/MyDrive/models/tokenizer/model.model'
sp_model = spm.SentencePieceProcessor()
sp_model.Load(tokenizer_path)

# Load inputs and targets
train_data = np.load('/content/drive/MyDrive/dataset/transformer/multitask/dataset.npy', allow_pickle=True)
validation_data = np.load('/content/drive/MyDrive/dataset/transformer/multitask/dataset_validation.npy', allow_pickle=True)

# Parametri del modello
embed_dim = 512
vocab_size = sp_model.vocab_size()
num_layers = 3
n_heads = 8
learning_rate = 0.0001
max_len = 350

model = LightningTransformer(
    embed_dim=embed_dim,
    vocab_size=vocab_size,
    num_layers=num_layers,
    n_heads=n_heads,
    learning_rate=learning_rate,
    sp_model=sp_model,
    max_len=max_len,
    train_data=train_data,
    val_data=validation_data,
    batch_size=25
)

model_dir = os.path.join("/content/drive/MyDrive/models/transformer/multitask/", "multitask.ckpt")

if os.path.exists(model_dir):  # Controllo corretto
    try:
        pretrained = torch.load(model_dir, map_location="cuda" if torch.cuda.is_available() else "cpu")
        model.load_state_dict(pretrained["state_dict"], strict=False)
        print("Il modello è stato caricato correttamente.")
    except Exception as e:
        print(f"Errore nel caricare il modello: {e}")
else:
    print("Il file del modello non esiste nella posizione specificata.")

checkpoint_dir = "/content/drive/MyDrive/models/transformer/multitask/checkpoints/"

# Configurazione dei callback
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",  # Monitorare la perdita
    dirpath=checkpoint_dir,
    filename="transformer-{epoch:02d}-{val_loss:.2f}",  # Nome file
    save_top_k=10,  # Salva tutti i checkpoint
    every_n_epochs=1  # Salvataggio ad ogni epoca
)

early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    patience=10,
    verbose=True,
    mode="min"
)

trainer = pl.Trainer(
    max_epochs=20,
    accelerator='gpu',
    callbacks=[checkpoint_callback, early_stopping_callback]
)

# Trova tutti i file .ckpt nella cartella
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt"))

if checkpoint_files:
    # Ordina i file per data di modifica (il più recente per ultimo)
    latest_checkpoint = max(checkpoint_files, key=os.path.getmtime)
    print(f"Checkpoint trovato: Riprendo l'allenamento da {latest_checkpoint}")
    trainer.fit(model, ckpt_path=latest_checkpoint)
else:
    print("Checkpoint non trovato: Parto da zero.")
    trainer.fit(model)