In [10]:
import torch

# 1. Charger le fichier texte
with open('exemple.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    
text = text.strip().lower()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
block_size = 64

def get_batch(batch_size=4):
    ix = torch.randint(0, len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])
    return x, y

print("Texte encodé. Taille vocabulaire :", vocab_size)

Texte encodé. Taille vocabulaire : 36


In [13]:
import torch.nn as nn
import torch.nn.functional as F

class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(0.1),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embed, n_head):
        super().__init__()
        self.sa = nn.MultiheadAttention(embed_dim=n_embed, num_heads=n_head, batch_first=True)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ffwd = FeedForward(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        attn_output, _ = self.sa(self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False)
        x = x + attn_output
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTMini(nn.Module):
    def __init__(self, vocab_size, block_size, n_embed=128, n_head=4, n_layer=2):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, n_embed)
        self.pos_embed = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        token_emb = self.token_embed(idx)             # (B, T, n_embed)
        pos = torch.arange(T, device=idx.device)
        pos_emb = self.pos_embed(pos).unsqueeze(0)    # (1, T, n_embed)
        x = token_emb + pos_emb                       # (B, T, n_embed)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)                         # (B, T, vocab_size)
        return logits

In [15]:
# Init
model = GPTMini(vocab_size, block_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# Entraînement rapide
for step in range(300):
    x, y = get_batch()
    logits = model(x)
    B, T, C = logits.shape
    loss = loss_fn(logits.view(B*T, C), y.view(B*T))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(f"Étape {step} – Perte : {loss.item():.4f}")

print("✅ Entraînement terminé.")

Étape 0 – Perte : 3.6951
Étape 100 – Perte : 2.0235
Étape 200 – Perte : 1.5482
✅ Entraînement terminé.


In [17]:
@torch.no_grad()
def generate(model, prompt, max_new_tokens=100, temperature=1.0):
    model.eval()
    idx = torch.tensor([encode(prompt)], dtype=torch.long)
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]
        logits = model(idx_cond)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)
    return decode(idx[0].tolist())

# Génération
prompt = "un matin,"
print("\n🔮 Texte généré :\n")
print(generate(model, prompt, max_new_tokens=200, temperature=0.8))


🔮 Texte généré :

un matin,  n nmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm
