In [1]:
from utils import *
from positional_encoding import PositionalEncoding
from my_embedding import MyEmbedding
from transformer import TransformerBlock

In [2]:
from datasets import load_dataset

dataset = load_dataset("daily_dialog")
train_dialogs = dataset["train"]["dialog"]
val_dialogs   = dataset["validation"]["dialog"]

def concat_dialogs(dialogs, group_size=3):
    corpus = []
    for dialog in dialogs:
        group = []
        for s in dialog:
            if s.strip():
                group.append(s.strip())
                if len(group) == group_size:
                    corpus.append(" ".join(group))
                    group = []
        if group:  # Ajoute le reste
            corpus.append(" ".join(group))
    return corpus

train_corpus = concat_dialogs(train_dialogs, group_size=3)
val_corpus   = concat_dialogs(val_dialogs, group_size=3)

from tokenizers import ByteLevelBPETokenizer

with open("train_corpus.txt", "w", encoding="utf-8") as f:
    for line in train_corpus:
        f.write(line + "\n")

tokenizer = ByteLevelBPETokenizer()
tokenizer.train(files="train_corpus.txt", vocab_size=6000, min_frequency=2, special_tokens=["<PAD>", "<BOS>", "<EOS>", "<UNK>"])
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("<PAD>"), pad_token="<PAD>")
tokenizer.enable_truncation(max_length=64)

def encode_batch(corpus, tokenizer):
    encoded = []
    for line in corpus:
        ids = tokenizer.encode(f"<BOS> {line} <EOS>").ids
        encoded.append(ids)
    return encoded

train_encoded = encode_batch(train_corpus, tokenizer)
val_encoded   = encode_batch(val_corpus, tokenizer)






In [16]:
class MiniTransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=128, max_len=64):
        super().__init__()
        self.embed = MyEmbedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.blocks = nn.ModuleList([TransformerBlock(d_model) for _ in range(2)])
        self.to_logits = nn.Linear(d_model, vocab_size)

    def forward(self, x, pad_mask=None): 
        x = self.embed(x)
        x = self.pos_enc(x)
        for block in self.blocks:
            x = block(x, pad_mask=pad_mask) 
        logits = self.to_logits(x)
        return logits

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class LanguageDataset(Dataset):
    def __init__(self, encoded_corpus):
        self.data = encoded_corpus

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        input_ids = torch.tensor(seq[:-1])
        target_ids = torch.tensor(seq[1:])
        return input_ids, target_ids

def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs  = pad_sequence(inputs, batch_first=True, padding_value=tokenizer.token_to_id("<PAD>"))
    targets = pad_sequence(targets, batch_first=True, padding_value=tokenizer.token_to_id("<PAD>"))
    # pad_mask: True là où c'est du padding
    pad_mask = (inputs == tokenizer.token_to_id("<PAD>"))
    return inputs, targets, pad_mask

train_dataset = LanguageDataset(train_encoded)
val_dataset   = LanguageDataset(val_encoded)
train_loader  = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader    = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)


In [5]:
# device = (torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu"))
device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
print(device)

cuda


In [17]:
vocab_size = tokenizer.get_vocab_size()
model = MiniTransformerLM(vocab_size=vocab_size).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore padding
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

for epoch in range(20):
    # === TRAINING ===
    model.train()
    total_loss = 0
    num_batches = 0

    for input_ids, target_ids, pad_mask in train_loader:
        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)
        pad_mask = pad_mask.to(device)

        logits = model(input_ids, pad_mask=pad_mask)
        logits = logits.view(-1, logits.size(-1))
        targets = target_ids.view(-1)

        loss = criterion(logits, targets)
        total_loss += loss.item()
        num_batches += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_train_loss = total_loss / num_batches

    # === VALIDATION ===
    model.eval()
    val_loss = 0
    val_batches = 0
    with torch.no_grad():
        for input_ids, target_ids, pad_mask in val_loader:
            input_ids = input_ids.to(device).long()
            target_ids = target_ids.to(device).long()

            logits = model(input_ids)
            logits = logits.view(-1, logits.size(-1))
            targets = target_ids.view(-1)

            loss = criterion(logits, targets)
            val_loss += loss.item()
            val_batches += 1

    avg_val_loss = val_loss / val_batches

    print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")



Epoch 1 - Train Loss: 5.2167 | Val Loss: 4.4650
Epoch 2 - Train Loss: 4.3550 | Val Loss: 4.1346
Epoch 3 - Train Loss: 4.0997 | Val Loss: 3.9531
Epoch 4 - Train Loss: 3.9274 | Val Loss: 3.8291
Epoch 5 - Train Loss: 3.7982 | Val Loss: 3.7598
Epoch 6 - Train Loss: 3.6963 | Val Loss: 3.6823
Epoch 7 - Train Loss: 3.6146 | Val Loss: 3.6405
Epoch 8 - Train Loss: 3.5446 | Val Loss: 3.6078
Epoch 9 - Train Loss: 3.4873 | Val Loss: 3.5740
Epoch 10 - Train Loss: 3.4370 | Val Loss: 3.5504
Epoch 11 - Train Loss: 3.3913 | Val Loss: 3.5313
Epoch 12 - Train Loss: 3.3518 | Val Loss: 3.5274
Epoch 13 - Train Loss: 3.3158 | Val Loss: 3.5054
Epoch 14 - Train Loss: 3.2850 | Val Loss: 3.4906
Epoch 15 - Train Loss: 3.2557 | Val Loss: 3.4814
Epoch 16 - Train Loss: 3.2280 | Val Loss: 3.4769
Epoch 17 - Train Loss: 3.2033 | Val Loss: 3.4665
Epoch 18 - Train Loss: 3.1810 | Val Loss: 3.4605
Epoch 19 - Train Loss: 3.1590 | Val Loss: 3.4642
Epoch 20 - Train Loss: 3.1397 | Val Loss: 3.4651


In [18]:
import torch.nn.functional as F
def generate_text(model, tokenizer, start, max_len=30, device="cpu", temperature=1.0, top_k=10):
    model.eval()
    ids = tokenizer.encode(f"<BOS> {start}").ids
    input_ids = torch.tensor([ids], device=device)
    for _ in range(max_len):
        logits = model(input_ids)
        logits = logits[0, -1] / temperature
        # Top-k sampling
        topk = logits.topk(top_k)
        probs = F.softmax(topk.values, dim=-1)
        next_token_id = topk.indices[torch.multinomial(probs, 1).item()].item()
        input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]], device=device)], dim=1)
        if next_token_id == tokenizer.token_to_id("<EOS>"):
            break
    ids = input_ids[0].tolist()
    if ids[0] == tokenizer.token_to_id("<BOS>"):
        ids = ids[1:]
    if tokenizer.token_to_id("<EOS>") in ids:
        ids = ids[:ids.index(tokenizer.token_to_id("<EOS>"))]
    return tokenizer.decode(ids)


print(generate_text(model, tokenizer, "Yesterday, he", device=device))


 Yesterday, he was a big town , and so he was . 
