In [47]:
from utils import *
from positional_encoding import PositionalEncoding
from my_embedding import MyEmbedding
from transformer import TransformerBlock

In [48]:
dataset_ = load_dataset("daily_dialog")
texts = dataset_["train"]["dialog"]
vals = dataset_["validation"]["dialog"]

def get_corpus(texts):

    corpus = []
    for dialog in texts:
        for sentence in dialog:
            if sentence.strip():
                corpus.append(sentence.strip().lower())

    def clean(text):
        text = text.lower()
        text = re.sub(r"[^\w\s]", "", text)  # supprime ponctuation
        text = re.sub(r"\s+", " ", text).strip()
        return text

    corpus = [clean(s) for s in corpus if len(s.strip()) > 0]


    tokens = set(" ".join(corpus).split())
    vocab = {word: i+1 for i, word in enumerate(tokens)}  # +1 pour réserver 0 = padding
    vocab["<PAD>"] = 0
    inv_vocab = {i: w for w, i in vocab.items()}

    encoded_corpus = []
    for lines in corpus:
        encoded_corpus.append([vocab[word] for word in lines.split()])
    
    return encoded_corpus, vocab, inv_vocab, tokens

encoded_corpus, vocab, inv_vocab, tokens = get_corpus(texts)

encoded_val, _, _, _ = get_corpus(vals)

In [49]:
encoded_corpus = [line for line in encoded_corpus if len(line) <= 30]
encoded_val = [line for line in encoded_val if len(line) <= 30]
        

In [50]:
class MiniTransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=64, max_len=30):
        super().__init__()
        self.embed = MyEmbedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.blocks = nn.ModuleList([TransformerBlock(d_model) for _ in range(2)])
        self.to_logits = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        x = self.pos_enc(x)
        for block in self.blocks:
            x = block(x)
        logits = self.to_logits(x)
        return logits

In [51]:
class LanguageDataset(Dataset):
    def __init__(self, encoded_corpus):
        self.data = encoded_corpus

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        input_ids = torch.tensor(seq[:-1])   # tous sauf dernier
        target_ids = torch.tensor(seq[1:])   # tous sauf premier
        return input_ids, target_ids


def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
    targets = pad_sequence(targets, batch_first=True, padding_value=0)
    return inputs, targets


In [52]:
dataset = LanguageDataset(encoded_corpus)
loader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

val_dataset = LanguageDataset(encoded_val)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

In [53]:
# device = (torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu"))
device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
print(device)

cuda


In [54]:
model = MiniTransformerLM(vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore padding
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(20):
    # === TRAINING ===
    model.train()
    total_loss = 0
    num_batches = 0

    for input_ids, target_ids in loader:
        input_ids = input_ids.to(device).long()
        target_ids = target_ids.to(device).long()

        logits = model(input_ids)
        logits = logits.view(-1, logits.size(-1))
        targets = target_ids.view(-1)

        loss = criterion(logits, targets)
        total_loss += loss.item()
        num_batches += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_train_loss = total_loss / num_batches

    # === VALIDATION ===
    model.eval()
    val_loss = 0
    val_batches = 0
    with torch.no_grad():
        for input_ids, target_ids in val_loader:
            input_ids = input_ids.to(device).long()
            target_ids = target_ids.to(device).long()

            logits = model(input_ids)
            logits = logits.view(-1, logits.size(-1))
            targets = target_ids.view(-1)

            loss = criterion(logits, targets)
            val_loss += loss.item()
            val_batches += 1

    avg_val_loss = val_loss / val_batches

    print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")



Epoch 1 - Train Loss: 5.5466 | Val Loss: 12.4816
Epoch 2 - Train Loss: 4.6704 | Val Loss: 13.4510
Epoch 3 - Train Loss: 4.3493 | Val Loss: 14.6148
Epoch 4 - Train Loss: 4.1317 | Val Loss: 15.1319
Epoch 5 - Train Loss: 3.9628 | Val Loss: 15.3893
Epoch 6 - Train Loss: 3.8245 | Val Loss: 16.6777


KeyboardInterrupt: 

In [15]:
def generate_text(model, vocab, inv_vocab, start, max_len=30, device="cpu"):
    model.eval()
    tokens = [vocab.get(word, vocab["<PAD>"]) for word in start.split()]
    tokens = torch.tensor(tokens, device=device).unsqueeze(0)
    logits = model(tokens)
    next_token = logits[0, -1].argmax().item()
    tokens = torch.cat([tokens, torch.tensor([[next_token]], device=device)], dim=1)
    return " ".join([inv_vocab[tok.item()] for tok in tokens[0]])

# Exemple d’utilisation :
print(generate_text(model, vocab, inv_vocab, "i", device=device))

i you
