In [1]:
%load_ext autoreload
%autoreload 2

import requests
import torch
from torch.utils.data import DataLoader
import random
import torch.nn as nn 
from torch import optim
# from sklearn.decomposition import TruncatedSVD as svds
from scipy.sparse.linalg import svds
from sklearn.preprocessing import normalize
from torch.nn.utils.rnn import pad_sequence
from spice import SpiceEmbeddingModel
from gru import GRUEncoder, GRUDecoder
from matplotlib import pyplot as plt
from time import time 
import spice

# Spice encoding 

In [2]:


URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
FILE_PATH = "shakespeare.txt"
EMB_DIM = 64
WINDOW = 5


try:
    with open(FILE_PATH, 'r', encoding='utf-8') as f:
        text = f.read()
except FileNotFoundError:
    response = requests.get(URL)
    text = response.text
    with open(FILE_PATH, 'w', encoding='utf-8') as f:
        f.write(text)


# spice_model = SpiceEmbeddingModel(emb_dim=50, window_size=3)

# dataset = spice_model.get_dataset(text)
# # spice_model.save_model(spice_model.embeddings)

# dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
# for batch in dataloader:
#     sentence_embeddings = batch
#     print("Sentence Embeddings: ", len(sentence_embeddings[0]))



# Modèle definition

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_size = 50 # taille des embeddings
hidden_size = 128
lr = 0.001
num_layer=5
num_epochs = 50
batch_size = 64

# def collate_fn(batch):
#     """ permet de rajouter du padding et renvoyer la taille de la séquence dans un batch"""
#     batch = [item for item in batch if len(item) > 0]  # Filtrer les séquences vides
#     lengths = torch.tensor([len(seq) for seq in batch])  # Longueurs originales
#     padded_batch = pad_sequence(batch, batch_first=True, padding_value=0.0)  # Padding des séquences
#     return padded_batch, lengths


spice_model = SpiceEmbeddingModel()
# spice_model.load_model()
dataset = spice_model.get_dataset(text)
# dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=spice.collate_fn_fillers_roles)



encoder = GRUEncoder(input_size, hidden_size, num_layer).to(device)
decoder = GRUDecoder(hidden_size, input_size, num_layer).to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(list(encoder.parameters()) +list(decoder.parameters()), lr=lr)

losses = []
t0 = time()
for epoch in range(num_epochs):
    total_loss = 0.0

    # for batch, lengths in dataloader:
    #     batch = batch.to(device)
    for fillers, roles, lengths in dataloader:
        fillers = fillers.to(device)  # (batch, seq_len, input_size)
        # roles = roles.to(device)        # (batch, seq_len, role_dim)
        lengths = lengths.to(device)    # (batch,)
        # lengths = lengths.to(device)
        optimizer.zero_grad()

        encoded, perm_idx = encoder(fillers, lengths)
        reconstructed = decoder(encoded, lengths)
        
        # Remettre les séquences dans l'ordre original
        # _, reverse_idx = perm_idx.sort()
        # reconstructed = reconstructed[reverse_idx]
        # fillers = fillers[reverse_idx]

        # Calcul de la perte (ignorer le padding)
        loss = criterion(reconstructed, fillers)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    losses.append(total_loss/ len(dataloader))
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader)}")
t1 = time()
torch.save(encoder, "encoder.pth")
torch.save(decoder, "decoder.pth")
print( f"training time : {t1-t0}")
plt.plot(losses)
plt.title("MSE over epoch")
plt.xlabel("epoch")
plt.ylabel("MSE")
plt.plot()


Epoch 1/50, Loss: 0.011461593141906302
Epoch 2/50, Loss: 0.01136212462337095
Epoch 3/50, Loss: 0.011339516691143879
Epoch 4/50, Loss: 0.01135277928551659
Epoch 5/50, Loss: 0.01133617614819245
Epoch 6/50, Loss: 0.01136466958136721
Epoch 7/50, Loss: 0.01134530363329263
Epoch 8/50, Loss: 0.011413543905787677
Epoch 9/50, Loss: 0.011335558985592797
Epoch 10/50, Loss: 0.011345485265006904


KeyboardInterrupt: 

In [8]:


def decode_sequence(spice_model, sequence):
    """decode les embeddings en mots"""
    decoded_words = []
    for embedding in sequence.cpu():
        if embedding.sum() != 0:  # Ignorer le padding
            decoded_words.append(spice_model.decode_embedding(embedding.detach().numpy(), top_n=1)[0])
    return " ".join(decoded_words)

fillers, roles, lengths = next(iter(dataloader))  
idx = random.randint(0, fillers.size(0) - 1)  #  phrase au hasard
input_seq = fillers[idx].unsqueeze(0).to(device)  # le tenseur avec du padding
mask = (input_seq != 0).any(dim=2)  # vérifie si chaque ligne contient des valeurs non nulles
tensor_clean = input_seq[:, mask[0], :] # enlève les vecteurs nuls du padding
length = torch.tensor([tensor_clean.shape[1]]).to(device)  # récupère la longueur originale sans le padding


encoder.eval()
decoder.eval()
with torch.no_grad():
    encoded, _ = encoder(input_seq, length)
    reconstructed_seq = decoder(encoded, length)


original_text = decode_sequence(spice_model, input_seq.squeeze(0))
reconstructed_text = decode_sequence(spice_model, reconstructed_seq.squeeze(0))


print(f"originale**: {original_text}")
print(f"reconstruite**: {reconstructed_text}")


originale**: my lord you shall oerrule my mind for once
reconstruite**: and to to and and to and and to


# test
