In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split


In [38]:
def generate_synthetic_dataset(N, D, M, P, seed=0):
    torch.manual_seed(seed)

    # Poids du professeur : 2 blocs (query/key)
    W_q = torch.randn(P, D)
    W_k = torch.randn(P, D)

    x = torch.randn(N, D, M)  # N séquences d’entrée

    y = torch.zeros(N, M, D)
    for i in range(N):
        x_i = x[i]  # [D, M]

        # Projections query/key : [P, M]
        Zq = W_q @ x_i / D**0.5
        Zk = W_k @ x_i / D**0.5

        # Attention : [M, M]
        A = torch.softmax(Zq.T @ Zk, dim=-1)

        # Application : [M, D] = [M, M] @ [M, D]
        y_i = A @ x_i.T
        y[i] = y_i  # [M, D]

    return x, y, (W_q, W_k)


In [None]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, D, P, M, H, c):
        super().__init__()
        self.D = D
        self.P = P
        self.M = M
        self.H = H
        self.c = c

        # Paramètres pour chaque tête (query et key)
        self.W_q = nn.Parameter(torch.randn(H, P, D))
        self.W_k = nn.Parameter(torch.randn(H, P, D))

    def forward(self, x):  # x: [N, D, M]
        N = x.shape[0]
        device = x.device

        attn_sum = 0.0
        for h in range(self.H):
            # Projection pour la tête h : on obtient Zq et Zk de forme [N, P, M]
            Zq = torch.einsum("pd,ndm->npm", self.W_q[h], x) / (self.D**0.5)
            Zk = torch.einsum("pd,ndm->npm", self.W_k[h], x) / (self.D**0.5)
            
            # Transposition pour obtenir Q et K de forme [N, M, P]
            Q = Zq.transpose(1, 2)  # [N, M, P]
            K = Zk.transpose(1, 2)  # [N, M, P]
            
            # Calcul des scores d'attention : "nip, njp -> nij" donne [N, M, M]
            scores = torch.einsum("nip, njp -> nij", Q, K)
            attn = F.softmax(scores, dim=-1)
            attn_sum += attn

        attn_mean = attn_sum / self.H  # [N, M, M]
        # Ajout de la skip connection (même pour chaque batch)
        skip = self.c * torch.eye(self.M, device=device).unsqueeze(0)  # [1, M, M]
        combined = attn_mean + skip  # [N, M, M]
        
        # Appliquer la matrice d'attention aux tokens
        # x est de forme [N, D, M] et on souhaite conserver la dimension D
        output = torch.einsum("ndm, nmk -> ndk", x, combined)
        return output  # [N, D, M]



In [43]:
class DeepSMIMultiHead(nn.Module):
    def __init__(self, D, P_list, M, H_list, c):
        super().__init__()
        self.layers = nn.ModuleList([
            MultiHeadAttentionLayer(D, P, M, H, c)
            for P, H in zip(P_list, H_list)
        ])
        self.output_layer = nn.Linear(D, M)

    def forward(self, x):  # x: [N, D, M]
        for layer in self.layers:
            x = layer(x)  # [N, D, M]
        
        x_transposed = x.transpose(1, 2)  # [N, M, D] — utilisé plus tard dans la multiplication

        # Prédiction de la matrice d'attention
        attention_logits = self.output_layer(x_transposed)  # [N, M, M]
        attention_probs = torch.softmax(attention_logits, dim=-1)  # [N, M, M]

        # Sortie : application de la matrice d’attention sur x
        # attention_probs [N, M, M] × x [N, D, M]ᵀ → [N, M, D]
        y = torch.bmm(attention_probs, x_transposed)  # [N, M, D]

        return y  # chaque ligne = prédiction du token à chaque position


In [44]:
def train_model(model, x_train, y_train, x_test, y_test, n_epochs=100, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    for epoch in range(n_epochs):
        # Mode entraînement
        model.train()
        optimizer.zero_grad()
        y_pred = model(x_train)  # [N_train, M, M]
        loss = loss_fn(y_pred, y_train)
        loss.backward()
        optimizer.step()

        # Accuracy entraînement
        pred_labels = y_pred.argmax(dim=-1)     # [N_train, M]
        true_labels = y_train.argmax(dim=-1)    # [N_train, M]
        train_acc = (pred_labels == true_labels).float().mean()

        # Mode évaluation
        model.eval()
        with torch.no_grad():
            y_test_pred = model(x_test)  # [N_test, M, M]
            test_loss = loss_fn(y_test_pred, y_test)
            pred_test = y_test_pred.argmax(dim=-1)
            true_test = y_test.argmax(dim=-1)
            test_acc = (pred_test == true_test).float().mean()

        # Affichage
        if epoch % 10 == 0 or epoch == n_epochs - 1:
            print(f"Epoch {epoch:03d} | "
                  f"Train Loss: {loss.item():.4f} | Train Acc: {train_acc.item()*100:.2f}% | "
                  f"Test Loss: {test_loss.item():.4f} | Test Acc: {test_acc.item()*100:.2f}%")


In [46]:
# Paramètres
D, M = 64, 10
P = 8
N = 10000

# Dataset
x, y, W_star = generate_synthetic_dataset(N, D, M, P)
x_all, y_all = x.float(), y.float()
# Supposons que x_all, y_all ont la forme torch.Tensor
x_train, x_test, y_train, y_test = train_test_split(
    x_all, y_all, test_size=0.2, random_state=42
)


# Modèle élève
model = DeepSMIMultiHead(D, P_list=[P]*2, M=M, H_list=[4]*2, c=1.0)

# Entraînement
train_model(model, x_train, y_train, x_test, y_test, n_epochs=200, lr=1e-3)


Epoch 000 | Train Loss: 1.4586 | Train Acc: 12.66% | Test Loss: 1.4646 | Test Acc: 12.76%
Epoch 010 | Train Loss: 1.4250 | Train Acc: 12.75% | Test Loss: 1.4344 | Test Acc: 12.82%
Epoch 020 | Train Loss: 1.3983 | Train Acc: 12.85% | Test Loss: 1.4105 | Test Acc: 12.80%
Epoch 030 | Train Loss: 1.3782 | Train Acc: 12.88% | Test Loss: 1.3923 | Test Acc: 12.94%
Epoch 040 | Train Loss: 1.3635 | Train Acc: 12.92% | Test Loss: 1.3791 | Test Acc: 13.08%
Epoch 050 | Train Loss: 1.3529 | Train Acc: 12.97% | Test Loss: 1.3696 | Test Acc: 13.15%
Epoch 060 | Train Loss: 1.3454 | Train Acc: 12.98% | Test Loss: 1.3629 | Test Acc: 13.17%
Epoch 070 | Train Loss: 1.3401 | Train Acc: 12.96% | Test Loss: 1.3582 | Test Acc: 13.18%
Epoch 080 | Train Loss: 1.3362 | Train Acc: 12.97% | Test Loss: 1.3549 | Test Acc: 13.28%
Epoch 090 | Train Loss: 1.3333 | Train Acc: 12.97% | Test Loss: 1.3526 | Test Acc: 13.29%
Epoch 100 | Train Loss: 1.3311 | Train Acc: 12.96% | Test Loss: 1.3510 | Test Acc: 13.30%
Epoch 110 

KeyboardInterrupt: 

In [31]:
print(y_all[0])

tensor([[4.9146e-03, 8.0455e-01, 7.4404e-04, 3.6880e-03, 4.5673e-05, 1.3396e-02,
         9.8182e-04, 3.0481e-04, 2.7697e-02, 1.4368e-01],
        [8.8791e-04, 6.5693e-02, 9.0829e-01, 6.4671e-05, 3.2075e-03, 1.1098e-03,
         3.4199e-03, 3.4873e-03, 6.7115e-03, 7.1247e-03],
        [6.7212e-04, 7.5002e-05, 3.5989e-03, 8.2282e-03, 2.4692e-05, 1.1135e-03,
         1.6930e-05, 2.3824e-05, 2.7449e-01, 7.1176e-01],
        [1.3887e-03, 1.2862e-04, 1.1318e-04, 9.9506e-01, 1.2304e-06, 2.7359e-04,
         2.8722e-03, 8.4135e-05, 7.1825e-05, 2.7311e-06],
        [7.2928e-03, 3.3191e-04, 2.1810e-02, 3.6675e-01, 5.8519e-03, 5.3821e-03,
         6.4585e-03, 5.8602e-01, 6.7511e-05, 3.8685e-05],
        [1.4963e-03, 3.4123e-05, 7.0987e-01, 1.0078e-01, 1.4058e-03, 2.1199e-04,
         1.7471e-03, 1.8445e-01, 4.3007e-06, 1.3284e-06],
        [1.5859e-02, 1.4389e-03, 7.9806e-02, 1.9862e-01, 7.8504e-02, 8.6380e-03,
         2.0008e-01, 8.1958e-02, 3.3271e-01, 2.3911e-03],
        [1.1363e-01, 2.1498