In [None]:
!pip install pytorch-lightning

import os
import glob
import math
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Transformer
from torch.utils.data import Dataset, DataLoader
import numpy as np
import sentencepiece as spm
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import pytorch_lightning as pl
from sklearn.utils import shuffle
############################## DATASET

class SaturdayDataset(Dataset):
    def __init__(self, source_sentences, target_sentences, max_len, tokenizer):
        self.source_sentences = source_sentences
        self.target_sentences = target_sentences
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.filter_invalid_pairs()

    def __len__(self):
        return len(self.source_sentences)

    def pad_sequence(self, original_source, original_target):
        # Encode the source and target sequences
        source = original_source[:]
        target = original_target[:]
        # Insert BOS token at the beginning of the target sequence
        target.insert(0, self.tokenizer.bos_id())
        target.append(self.tokenizer.eos_id())  # Add EOS token at the end

        # Calculate lengths
        source_len = len(source)
        target_len = len(target)

        # Padding for target sequence (add padding after EOS)
        remaining_padding_for_target = self.max_len - target_len
        if remaining_padding_for_target > 0:
            target = target + [self.tokenizer.pad_id()] * remaining_padding_for_target

        # Padding for source sequence
        remaining_padding_for_source = self.max_len - source_len
        if remaining_padding_for_source > 0:
            source = source + [self.tokenizer.pad_id()] * remaining_padding_for_source

        return source, target

    def filter_invalid_pairs(self):
        valid_sources = []
        valid_targets = []
        for src, trg in zip(self.source_sentences, self.target_sentences):
            padded_src, padded_trg = self.pad_sequence(src, trg)
            # Verifica che entrambi siano validi
            if len(padded_src) <= self.max_len and len(padded_trg) <= self.max_len:
                valid_sources.append(src)
                valid_targets.append(trg)

        # Aggiorna le liste con i dati validi
        self.source_sentences = valid_sources
        self.target_sentences = valid_targets

    def __getitem__(self, idx):
        src = self.source_sentences[idx]
        trg = self.target_sentences[idx]
        # Converti le frasi in indici e aggiungi il padding
        src_indexes, trg_indexes = self.pad_sequence(src, trg)

        return torch.tensor(src_indexes), torch.tensor(trg_indexes)
################################# MODEL
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class SaturdayTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1,
                 device=None
                 ):
        super(SaturdayTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first=True
                                       )
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)
        self.device = device

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones((sz, sz), device=self.device)) == 1).transpose(0, 1)
        return mask.bool()

    def create_mask(self, src, tgt):
        device = self.device

        src_seq_len = src.shape[1]
        tgt_seq_len = tgt.shape[1]

        tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len)
        src_mask = torch.zeros((src_seq_len, src_seq_len), dtype=torch.bool, device=device)

        src_padding_mask = (src == 0).to(device)
        tgt_padding_mask = (tgt == 0).to(device)

        return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

    def create_mask(self, src, tgt):
        src_seq_len = src.shape[1]  # La seconda dimensione è la lunghezza della sequenza
        tgt_seq_len = tgt.shape[1]  # La stessa cosa per tgt

        # Maschera successiva (tgt_mask) per il decoder
        tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len)

        # Maschera per il source (src_mask)
        src_mask = torch.zeros((src_seq_len, src_seq_len), dtype=torch.bool)

        # Maschere di padding
        src_padding_mask = (src == 0)  # src.shape è [batch_size, seq_len]
        tgt_padding_mask = (tgt == 0)  # tgt.shape è [batch_size, seq_len]

        return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

    def forward(self, src: Tensor, tgt: Tensor):
        # Creiamo le maschere all'interno del forward
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.create_mask(src, tgt)

        # Embedding delle sequenze di input e target
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))

        # Passiamo il tutto attraverso il transformer
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, None)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

################################# LIGHTING
class LightningTransformer(pl.LightningModule):
    def __init__(
            self,
            embed_dim,
            vocab_size,
            num_layers,
            n_heads,
            learning_rate,
            sp_model,
            max_len,
            train_data=None,
            val_data=None,
            batch_size=50
    ):
        super().__init__()
        self.dataset = None
        self.val_dataset = None
        self.model = SaturdayTransformer(emb_size=embed_dim,
                                         src_vocab_size=vocab_size,
                                         tgt_vocab_size=vocab_size,
                                         num_encoder_layers=num_layers,
                                         num_decoder_layers=num_layers,
                                         nhead=n_heads,
                                         device=self.device
                                         )
        self.criterion = nn.CrossEntropyLoss(ignore_index=sp_model.pad_id())
        self.learning_rate = learning_rate
        self.sp_model = sp_model
        self.max_len = max_len
        self.inputs = None
        self.targets = None
        self.val_inputs = None
        self.val_targets = None
        self.train_data = train_data
        self.val_data = val_data
        self.batch_size = batch_size
        self.tokenizer = sp_model
        self.pad_id = sp_model.pad_id()

    def forward(self, src, trg_input):
        return self.model(src, trg_input)

    def training_step(self, batch, batch_idx):
        src, trg = batch
        # src e trg sono ora di forma [batch_size, sequence_length]

        # Move tensors to the same device as the model
        src = src.to(self.device)
        trg = trg.to(self.device)
        tgt_input = trg[:, :-1]  # Prendiamo tutto tranne l'ultimo token (l'input per il decoder)


        # Passiamo i dati nel modello
        logits = self(src, tgt_input)  # Forward pass

        # L'output del target è senza il primo token
        tgt_out = trg[:, 1:]

        # Calcoliamo la loss
        loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        src, trg = batch
        # src e trg sono ora di forma [batch_size, sequence_length]

        # Move tensors to the same device as the model
        src = src.to(self.device)
        trg = trg.to(self.device)
        tgt_input = trg[:, :-1]  # Prendiamo tutto tranne l'ultimo token (l'input per il decoder)


        # Passiamo i dati nel modello
        logits = self(src, tgt_input)  # Forward pass

        # L'output del target è senza il primo token
        tgt_out = trg[:, 1:]

        # Calcoliamo la loss
        loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

        return loss


    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def prepare_data(self):
        if self.train_data is None:
            raise ValueError("Train Data not provided.")
        inputs = [pair[0] for pair in self.train_data]
        targets = [pair[1] for pair in self.train_data]
        # Shuffle data to ensure randomness
        self.inputs, self.targets = shuffle(inputs, targets)
        if self.val_data is None:
            raise ValueError("Validation Data not provided.")
        val_inputs = [pair[0] for pair in self.val_data]
        val_targets = [pair[1] for pair in self.val_data]
        self.val_inputs, self.val_targets = shuffle(val_inputs, val_targets)


    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.dataset = SaturdayDataset(self.inputs, self.targets, self.max_len, self.sp_model)
            self.val_dataset = SaturdayDataset(self.val_inputs, self.val_targets, self.max_len, self.sp_model)

    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=5)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=5)
########################### INIT
tokenizer_path = '/content/drive/MyDrive/models/tokenizer/model.model'
sp_model = spm.SentencePieceProcessor()
sp_model.Load(tokenizer_path)

# Load inputs and targets
train_data = np.load('/content/drive/MyDrive/dataset/transformer/dataset.npy', allow_pickle=True)
validation_data = np.load('/content/drive/MyDrive/dataset/transformer/dataset_validation.npy', allow_pickle=True)

# Parametri del modello
embed_dim = 512
vocab_size = sp_model.vocab_size()
num_layers = 2
n_heads = 8
learning_rate = 0.0001
max_len = 350

model = LightningTransformer(
    embed_dim=embed_dim,
    vocab_size=vocab_size,
    num_layers=num_layers,
    n_heads=n_heads,
    learning_rate=learning_rate,
    sp_model=sp_model,
    max_len=max_len,
    train_data=train_data,
    val_data=validation_data,
    batch_size=40
)

model_dir = os.path.join("/content/drive/MyDrive/models/transformer/", "discussion.ckpt")

if os.path.exists(model_dir):  # Controllo corretto
    try:
        pretrained = torch.load(model_dir, map_location="cuda" if torch.cuda.is_available() else "cpu")
        model.load_state_dict(pretrained["state_dict"], strict=False)
        print("Il modello è stato caricato correttamente.")
    except Exception as e:
        print(f"Errore nel caricare il modello: {e}")
else:
    print("Il file del modello non esiste nella posizione specificata.")

checkpoint_dir = "/content/drive/MyDrive/models/transformer/checkpoints/"

# Configurazione dei callback
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",  # Monitorare la perdita
    dirpath=checkpoint_dir,
    filename="transformer-{epoch:02d}-{val_loss:.2f}",  # Nome file
    save_top_k=10,  # Salva tutti i checkpoint
    every_n_epochs=1  # Salvataggio ad ogni epoca
)

early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    patience=10,
    verbose=True,
    mode="min"
)

trainer = pl.Trainer(
    max_epochs=10,
    accelerator='gpu',
    callbacks=[checkpoint_callback, early_stopping_callback]
)

# Trova tutti i file .ckpt nella cartella
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt"))

if checkpoint_files:
    # Ordina i file per data di modifica (il più recente per ultimo)
    latest_checkpoint = max(checkpoint_files, key=os.path.getmtime)
    print(f"Checkpoint trovato: Riprendo l'allenamento da {latest_checkpoint}")
    trainer.fit(model, ckpt_path=latest_checkpoint)
else:
    print("Checkpoint non trovato: Parto da zero.")
    trainer.fit(model)