In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

In [8]:
def generate_synthetic_dataset(N, D, M, P, seed=0):
    torch.manual_seed(seed)

    # W* du prof : 2 couches concaténées
    W_star = torch.randn(2 * P, D)

    # Entrées x ∈ ℝ^{N, D, M}
    x = torch.randn(N, D, M)

    # Projections z ∈ ℝ^{N, 2P, M}
    z = torch.einsum("pd,ndm->npm", W_star, x) / D**0.5

    y = torch.zeros(N, M, M)
    for i in range(N):
        z1 = z[i, :P].T  # [M, P]
        z2 = z[i, P:].T  # [M, P]
        A = z1 @ z2.T    # [M, M]
        y[i] = torch.softmax(A, dim=1)

    return x, y, W_star



In [16]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, D, P, M, H, c):
        super().__init__()
        self.D = D
        self.P = P
        self.M = M
        self.H = H
        self.c = c

        # Paramètres pour chaque tête (query et key)
        self.W_q = nn.Parameter(torch.randn(H, P, D))
        self.W_k = nn.Parameter(torch.randn(H, P, D))

    def forward(self, x):  # x: [N, D, M]
        N = x.shape[0]
        device = x.device

        attn_sum = 0.0
        for h in range(self.H):
            # Projection de x via W_q et W_k pour la tête h.
            # Zq, Zk: [N, P, M]
            Zq = torch.einsum("pd,ndm->npm", self.W_q[h], x) / (self.D**0.5)
            Zk = torch.einsum("pd,ndm->npm", self.W_k[h], x) / (self.D**0.5)
            
            # On transpose pour avoir Q et K de forme [N, M, P]
            Q = Zq.transpose(1, 2)  # [N, M, P]
            K = Zk.transpose(1, 2)  # [N, M, P]
            
            # Calcul des scores d'attention : somme sur la dimension P
            # Equation : "nip, njp -> nij" donne un tenseur de forme [N, M, M]
            scores = torch.einsum("nip, njp -> nij", Q, K)
            attn = F.softmax(scores, dim=-1)  # [N, M, M]
            attn_sum += attn

        # Moyenne sur les têtes
        attn_mean = attn_sum / self.H  # [N, M, M]
        # Ajout de la skip connection : on ajoute c * I pour chaque exemple
        skip = self.c * torch.eye(self.M, device=device).unsqueeze(0)  # [1, M, M]
        combined = attn_mean + skip  # [N, M, M]
        
        # Appliquer l'attention à x : 
        # x est de forme [N, D, M] et on veut obtenir un résultat de forme [N, D, M].
        # On utilise l'einsum "ndm, nmk -> ndk" (ici k correspond à l'indice de token, de taille M)
        output = torch.einsum("ndm, nmk -> ndk", x, combined)
        return output  # [N, D, M]

In [17]:
class DeepSMIMultiHead(nn.Module):
    def __init__(self, D, P_list, M, H_list, c):
        super().__init__()
        self.layers = nn.ModuleList([
            MultiHeadAttentionLayer(D, P, M, H, c)
            for P, H in zip(P_list, H_list)
        ])
        self.output_layer = nn.Linear(D, M)  # prédit les scores d’attention

    def forward(self, x):  # x: [N, D, M]
        for layer in self.layers:
            x = layer(x)  # [N, D, M]

        # Transposer pour avoir [N, M, D]
        x_t = x.transpose(1, 2)

        # Prédire la matrice d’attention : [N, M, M]
        logits = self.output_layer(x_t)  # [N, M, M]
        attention = torch.softmax(logits, dim=-1)  # sur les lignes

        return attention  # [N, M, M]


In [19]:
def cross_entropy_attention(pred, target, eps=1e-8):
    # pred, target ∈ [N, M, M]
    pred = torch.clamp(pred, eps, 1. - eps)  # éviter log(0)
    loss = - (target * torch.log(pred)).sum(dim=-1)  # [N, M]
    return loss.mean()


In [28]:
import torch.nn.functional as F

def cosine_attention_similarity(pred, target):
    # pred, target ∈ [N, M, M]
    pred = F.normalize(pred, p=2, dim=-1)   # normalisation ligne par ligne
    target = F.normalize(target, p=2, dim=-1)
    sim = (pred * target).sum(dim=-1)       # [N, M]
    return sim.mean()                       # scalaire


In [29]:
def train_model(model, x_train, y_train, x_test, y_test, n_epochs=100, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(n_epochs):
        # Entraînement
        model.train()
        optimizer.zero_grad()
        y_pred = model(x_train)  # [N, M, M]
        train_loss = cross_entropy_attention(y_pred, y_train)
        train_loss.backward()
        optimizer.step()

        # Cosine similarity d'entraînement
        train_cos_sim = cosine_attention_similarity(y_pred, y_train)

        # Évaluation
        model.eval()
        with torch.no_grad():
            y_test_pred = model(x_test)
            test_loss = cross_entropy_attention(y_test_pred, y_test)
            test_cos_sim = cosine_attention_similarity(y_test_pred, y_test)

        # Affichage
        if epoch % 10 == 0 or epoch == n_epochs - 1:
            print(f"Epoch {epoch:03d} | "
                  f"Train CE: {train_loss.item():.4f} | Train CosSim: {train_cos_sim.item():.4f} | "
                  f"Test CE: {test_loss.item():.4f} | Test CosSim: {test_cos_sim.item():.4f}")


In [27]:
# Paramètres
D, M = 64, 10
P = 8
N = 1000

# Génération du dataset synthétique

x_all, y_all, W_star = generate_synthetic_dataset(N, D, M, P)
x_all, y_all = x_all.float(), y_all.float()

# Split train / test
x_train, x_test, y_train, y_test = train_test_split(x_all, y_all, test_size=0.2, random_state=42)

print("x_train:", x_train.shape)
print("y_train:", y_train.shape)


# Création du modèle élève : par exemple 2 couches avec 4 têtes chacune
model = DeepSMIMultiHead(D, P_list=[P]*2, M=M, H_list=[4]*2, c=1.0)

# Entraînement
train_model(model, x_train, y_train, x_test, y_test, n_epochs=200, lr=1e-3)


x_train: torch.Size([800, 64, 10])
y_train: torch.Size([800, 10, 10])


NameError: name 'cosine_attention_similarity' is not defined