In [20]:
from utils import *
from positional_encoding import PositionalEncoding
from my_embedding import MyEmbedding
from transformer import TransformerBlock

In [13]:
dataset_ = load_dataset("daily_dialog")
texts = dataset_["train"]["dialog"]

def get_corpus(texts):

    corpus = []
    for dialog in texts:
        for sentence in dialog:
            if sentence.strip():
                corpus.append(sentence.strip().lower())

    def clean(text):
        text = text.lower()
        text = re.sub(r"[^\w\s]", "", text)  # supprime ponctuation
        text = re.sub(r"\s+", " ", text).strip()
        return text

    corpus = [clean(s) for s in corpus if len(s.strip()) > 0]


    tokens = set(" ".join(corpus).split())
    vocab = {word: i+1 for i, word in enumerate(tokens)}  # +1 pour réserver 0 = padding
    vocab["<PAD>"] = 0
    inv_vocab = {i: w for w, i in vocab.items()}

    encoded_corpus = []
    for lines in corpus:
        encoded_corpus.append([vocab[word] for word in lines.split()])
    
    return encoded_corpus, vocab, inv_vocab, tokens

encoded_corpus, vocab, inv_vocab, tokens = get_corpus(texts)

In [15]:
class MiniTransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=64, max_len=256):
        super().__init__()
        self.embed = MyEmbedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.blocks = nn.ModuleList([TransformerBlock(d_model) for _ in range(3)])
        self.to_logits = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        x = self.pos_enc(x)
        for block in self.blocks:
            x = block(x)
        logits = self.to_logits(x)
        return logits

In [16]:
class LanguageDataset(Dataset):
    def __init__(self, encoded_corpus):
        self.data = encoded_corpus

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        input_ids = torch.tensor(seq[:-1])   # tous sauf dernier
        target_ids = torch.tensor(seq[1:])   # tous sauf premier
        return input_ids, target_ids


def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
    targets = pad_sequence(targets, batch_first=True, padding_value=0)
    return inputs, targets


In [17]:
dataset = LanguageDataset(encoded_corpus)
loader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

In [None]:
# device = (torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu"))
device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
print(device)

In [21]:
model = MiniTransformerLM(vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore padding
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):

    if epoch == 0:
        print("start training")
        
    model.train()
    total_loss = 0

    for input_ids, target_ids in loader:
        input_ids = input_ids.to(device).long()
        target_ids = target_ids.to(device).long()

        logits = model(input_ids)  # [batch, seq, vocab]
        logits = logits.view(-1, logits.size(-1))     # [batch*seq, vocab]
        targets = target_ids.view(-1)                 # [batch*seq]

        loss = criterion(logits, targets)
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} - Loss: {total_loss:.4f}")

Epoch 3 - Loss: 3639.0905
Epoch 4 - Loss: 3130.1553
Epoch 5 - Loss: 2757.5247
Epoch 6 - Loss: 2499.6595
Epoch 7 - Loss: 2298.1162
Epoch 8 - Loss: 2146.4273
Epoch 9 - Loss: 2013.1704
Epoch 10 - Loss: 1901.5796
