## Assignment 2: 

First task for Assignment Part 2 is to code up the teacher model, and use it to generate some data sets of toy sentences for a few different values of the embedding dimensionality 𝐷(e.g., 𝐷=100 and 𝐷 = 1000).
All the hyperparameter settings are initially to be kept the same as in the Cui et al. paper (the same values were also mentioned in class). But try to write your code in a way where these settings can be easily changed later on, to enable future experiments with the same code.
Data generation part at least should be pretty straightforward. After that, the next step should be to get familiar with how to train simple transformer models in PyTorch, so that you can use that to train the student model on the generated data sets. We will discuss the student model further next week.

- Use bash for running code?

### References:
[Cui et al](https://web.iitd.ac.in/~sumeet/Cui_24.pdf)


In [1]:
import numpy as np
import matplotlib.pyplot as plt 

spinodal = {
    'alpha_cross': np.sort([0.5480511155124588,0.7071599112677189,0.9277860822926854,1.2410177405149914,1.06,1.4541109691229464,1.93929723421423,1.7505718608649372,0.8180015676221762]),
    'omegas':np.sort([0.,.1,.2,.3,.25,.35,.425,.4,.15]),
}

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from math import *
import pandas as pd
from pathlib import Path
import warnings
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colors
warnings.simplefilter("ignore", UserWarning)

# different colors for attention values
c_semantic='deeppink'
c_positional='royalblue'
c_att = 'rebeccapurple'

# different colors for the lienar student
c_lin = 'orange'

# colors for the phase transition borders
c_attlin = 'crimson' 
c_spinodal = 'forestgreen'

# color for changing specific parameters
c_no_col = 'black'


cmap_uninf = LinearSegmentedColormap.from_list('INF-UNINF',
                                                   [mcolors.to_rgba(c_semantic)[:3], (1, 1, 1), mcolors.to_rgba(c_positional)[:3]], N=100)
cmap_attlin = LinearSegmentedColormap.from_list('INF-UNINF',
                                                   [mcolors.to_rgba(c_lin)[:3], (1, 1, 1), mcolors.to_rgba(c_att)[:3]], N=100)
cmap_pos = LinearSegmentedColormap.from_list('INF-UNINF',
                                                   [mcolors.to_rgba('#8DE0A8')[:3],
                                                    mcolors.to_rgba('#93FFE0')[:3],
                                                    mcolors.to_rgba('#85EFFF')[:3],
                                                    mcolors.to_rgba('#6BBFFF')[:3],
                                                    mcolors.to_rgba(c_positional)[:3]], N=100)
cmap_att = LinearSegmentedColormap.from_list('INF-UNINF',
                                                   [mcolors.to_rgba('#FFE0B6')[:3],
                                                    mcolors.to_rgba('#FFB48C')[:3],
                                                    mcolors.to_rgba('#FF8166')[:3],
                                                    mcolors.to_rgba('#FF3E3B')[:3],
                                                    mcolors.to_rgba(c_semantic)[:3]], N=100)

In [3]:
import torch
import numpy as np

def generate_word_embeddings(L, D, seed=None):
    """
    Generate word embeddings x_l ~ N(0, 0.25 I_D)
    """
    if seed is not None:
        torch.manual_seed(seed)
    
    # Generate embeddings with std dev = sqrt(0.25) = 0.5
    embeddings = torch.normal(mean=0.0, std=0.5, size=(L, D))
    return embeddings

def positional_component(l, m):
    return 0.6 if l == m else 0.4

def compute_attention_weights(X, W_Q, omega, DK):
    """
    Compute attention weights of teacher model as a mixture of semantic and positional components
    """
    L, D = X.shape
    
    # Precompute W_Q^T @ W_Q
    WQ_TWQ = W_Q.T @ W_Q  # shape: (D, D)
    
    # Compute semantic attention scores
    semantic_scores = torch.zeros((L, L))
    for l in range(L):
        for m in range(L):
            semantic_scores[l, m] = (X[l] @ WQ_TWQ @ X[m]) / torch.sqrt(torch.tensor(DK))

    # Row-wise softmax
    softmax_semantic = torch.nn.functional.softmax(semantic_scores, dim=1)

    # Positional scores f(l, m)
    positional_scores = torch.tensor([[positional_component(l, m) for m in range(L)] for l in range(L)])

    # Final mixture
    attention_weights = (1 - omega) * softmax_semantic + omega * positional_scores

    return attention_weights

def generate_single_datapoint(D, L, DK, omega, W_Q, seed=None):
    """
    Generate a single datapoint (X, T)
    """
    X = generate_word_embeddings(L, D, seed)        # (L,D)
    A = compute_attention_weights(X, W_Q, omega, DK)# (L,L)
    T = torch.matmul(A, X) #(L,D)
    return X, T

def generate_teacher_dataset(N, D=100, L=2, DK=1, omega=0.3, seed=None):
    """
    Generate a teacher dataset of size N
    """
    if seed is not None:
        torch.manual_seed(seed)
    
    # Generate shared query weight matrix
    W_Q = torch.normal(mean=0.0, std=1.0, size=(DK,D))
    
    X_all = torch.zeros((N, L, D))
    T_all = torch.zeros((N, L, D))
    
    for n in range(N):
        # Use a different seed for each datapoint
        datapoint_seed = seed + n if seed is not None else None
        X, T = generate_single_datapoint(D, L, DK, omega, W_Q, datapoint_seed)
        X_all[n] = X
        T_all[n] = T
    
    return X_all, T_all, W_Q


# parameters
L = 2                # Number of words per sentence
D = 1000             # Dimensionality
N = int(2.2 * D)     # Dataset size
DK = 1               # Dimensionality of key/query vectors
omega = 0.3          # weight for positional component
seed = 42            

# Generate teacher dataset
X, T, W_Q_teacher = generate_teacher_dataset(
    N=N, D=D, L=L, DK=DK, omega=omega, seed=seed
)


print(f"X_teacher shape: {X.shape}") #(N,L,D)
print(f"T_teacher shape: {T.shape}") #(N,L,D)
print(f"W_Q_teacher shape: {W_Q_teacher.shape}") #(D_k,D)
W=  W_Q_teacher.clone()



X_teacher shape: torch.Size([2200, 2, 1000])
T_teacher shape: torch.Size([2200, 2, 1000])
W_Q_teacher shape: torch.Size([1, 1000])


In [6]:
''' Student Model '''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

class StudentAttentionModel(nn.Module):
    def __init__(self, D=100, DK=1, L=2, init_type='random', W_Q_teacher=None):
        super().__init__()
        self.D = D
        self.DK = DK
        self.L = L

        # Positional encoding: r1 = -r2 = 1_D
        r = torch.ones(D)
        self.register_buffer('R', torch.stack([r, -r], dim=0))  # (L, D)

        self.W_Q = nn.Parameter(torch.empty(DK, D))

        # initialization of W
        if init_type == 'semantic':
            self.W_Q = torch.nn.Parameter(W_Q_teacher.reshape(-1, 1))
        elif init_type == 'positional':
            self.W_Q = torch.nn.Parameter(torch.ones(D).reshape(-1, 1))
        else:
            self.W_Q = torch.nn.Parameter(0.1*torch.randn(D, DK))
    def forward(self, X):
        """
        X: (batch_size, L, D)
        Returns:
            Y: (batch_size, L, D)
        """
        # Add positional encodings
        X_pos = X + self.R.unsqueeze(0)  # (B, L, D)
        
        xQ = torch.einsum("imk,kl->iml", X_pos, self.W_Q/np.sqrt(self.D)) # why have they divided by D
        A = torch.nn.Softmax(dim=-1)(torch.einsum("iml,inl->imn", xQ, xQ))
        Y = torch.einsum("imn,inj->imj", A, X_pos)

        return Y


def loss_SSE(Y, T):
    loss = F.mse_loss(Y, T, reduction='sum') / (2 * X.shape[2])  # SSE
    return loss
# change num epochs

def train_student(X_train, T_train, X_test, T_test, lam=1e-2, lr=0.15, epochs=5,init_type='semantic',W_Q_teacher = None,DK=1,batch_size = D):
    '''
    Trains a student attention model with specified initialization. 
    lam : L2 regularisation parameter
    '''

    # Create a DataLoader for mini-batching
    train_dataset = TensorDataset(X_train, T_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    N, L, D = X_train.shape
    init_type= None
    model = StudentAttentionModel(
        D=D, DK=DK, L=L,
        init_type=init_type,
        W_Q_teacher=W_Q_teacher
    )

    optimizer = torch.optim.SGD([{'params': [model.W_Q],"weight_decay":lam }], lr=0.15)

    for epoch in range(epochs):
        for X_batch, T_batch in train_loader:
            optimizer.zero_grad()
            loss = loss_SSE(model(X_batch), T_batch)
            loss.backward()
            optimizer.step()
    gen_error = loss_SSE(model(X_test),T_test).item()/T_test.shape[0]
    W_Q_flat = model.W_Q.flatten()
    train_error = loss.item()+lam/2*float(torch.sum(W_Q_flat**2))
    r1 = torch.ones(D)
    m = np.abs(float(r1@ W_Q_flat /D))
    theta = np.abs(float(W_Q_teacher @ W_Q_flat/D))

    return model,gen_error,train_error,m,theta 


# TODO:
-   Do CV for lambda
-   train for both inits
-   Reproduce Fig. 2: (part a done)
-       alpha (0,2) adjust 
-       look at both training & test error (clear the loss vs MSE confusion)
-       alpha vs delta error
-       alpha vs summary stats (compute them with modular fns)
-       dense linear model
-   Reproduce Fig. 3:
-       summary stat concentration plot
-       color map of delta e wrt omega and alpha change
-       attention model - linear baseline wrt omega and alpha change


In [7]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy


def run_fig_2A(X, T, lam=1e-3, lr=0.15, epochs=500, DK=1,W_Q_teacher = None):
    """
    Runs the generalization experiment comparing semantic vs positional initialization
    across sample complexity α = N_train / D.
    
    Args:
        X, T: torch.Tensor of shape (N, L, D)
        lam: regularization strength
        lr: learning rate
        epochs: number of training epochs
        DK: key/query dimension

    Returns:
        results: list of dicts with alpha, semantic_loss, positional_loss, and delta
    """
    D = X.shape[2]
    L = X.shape[1]
    alphas = np.linspace(0.2, 2.0, 10)
    results = []

    # Fixed test set
    N_test = int(0.2 * D)
    X_test = X[-N_test:]
    T_test = T[-N_test:]

    for alpha in alphas:
        N_train = int(alpha * D)
        X_train = X[:N_train]
        T_train = T[:N_train]

        print(f"\n=== Alpha = {alpha:.2f} | N_train = {N_train} ===")

        # Semantic Init
        print("[Semantic Init]")
        model_sem, e_gen_sem,e_train_sem,m_sem,theta_sem = train_student(X_train, T_train, X_test, T_test, lam, lr, epochs,
                                  init_type='semantic', W_Q_teacher=W_Q_teacher, DK=DK)

        # Positional Init
        print("[Positional Init]")
        model_pos, e_gen_pos,e_train_pos,m_pos,theta_pos = train_student(X_train, T_train, X_test, T_test, lam, lr, epochs,
                                  init_type='positional', DK=DK,W_Q_teacher=W_Q_teacher)

        # Compute losses

        delta_gen = e_gen_pos - e_gen_sem

        delta_train = e_train_pos- e_train_sem

        print(f'sem {e_gen_sem} pos {e_gen_pos}')
        print(f'delta {delta_gen}')

        results.append({
            'alpha': alpha,
            'semantic_loss_gen':e_gen_sem ,
            'positional_loss_gen': e_gen_pos,
            'semantic_loss_train': e_train_sem,
            'positional_loss_train': e_train_pos,
            'delta_gen': delta_gen,
            'delta_train': delta_train,
        })

    return results

def plot_fig_2A(results):
    """
    Plots delta (positional - semantic) loss vs alpha.
    """
    alphas = [r['alpha'] for r in results]
    deltas_gen = [r['delta_gen'] for r in results]
    deltas_train = [r['delta_train'] for r in results]

    fig = plt.figure(figsize=(4, 3.8))
    plt.plot(alphas, deltas_gen, c='grey', label='Generalization')
    #plt.plot(alphas, deltas_train, c='black', linestyle='--', label='Training')
    plt.axhline(0, linestyle='--', color='grey')
    plt.xlabel(r'$\alpha$')
    plt.ylabel(r'$\Delta \epsilon_t$')
    plt.xlim(0.0, 2.0)
    plt.legend()
    plt.tight_layout()
    plt.show()

    plt.show()


def plot_fig_2B(results):
    """
    Plots summary statistics(theta,m) vs alpha.
    """
    alphas = [r['alpha'] for r in results]
    deltas_gen = [r['delta_gen'] for r in results]
    deltas_train = [r['delta_train'] for r in results]

   
    plt.xlabel(r'$\alpha$')
    plt.legend()
    plt.xlim(0.0,2.0)



#plt.plot(df_semantic[df_semantic["alpha"]>=1.24].alpha,df_semantic[df_semantic["alpha"]>=1.24].theta/sigma**2, color=c_semantic)
#plt.plot(df_positional[df_positional["alpha"]>=1.24].alpha,df_positional[df_positional["alpha"]>=1.24].theta/sigma**2, color=c_positional)
#plt.scatter(dfi[dfi["alpha"]>=1.24].alpha,dfi[dfi["alpha"]>=1.24].attention_theta_mean,label=r'$\theta/\sigma^2$', color=c_semantic)
#plt.scatter(dfu[dfu["alpha"]<=1.24].alpha,dfu[dfu["alpha"]<=1.24].attention_theta_mean, color=c_semantic)

#plt.plot(df_semantic[df_semantic["alpha"]<=1.24].alpha,df_semantic[df_semantic["alpha"]<=1.24].m/sigma**2,  color=c_semantic)
#plt.plot(df_positional[df_positional["alpha"]<=1.24].alpha,df_positional[df_positional["alpha"]<=1.24].m/sigma**2,color=c_positional)
#plt.scatter(dfu[dfu["alpha"]<=1.24].alpha,dfu[dfu["alpha"]<=1.24].attention_mag_mean/sigma**2, color=c_positional)
#plt.scatter(dfi[dfi["alpha"]>=1.24].alpha,dfi[dfi["alpha"]>=1.24].attention_mag_mean/sigma**2,label=r'$m/\sigma^2$', color=c_positional)

results = run_fig_2A(X, T,W_Q_teacher=W_Q_teacher)
plot_fig_2A(results)






=== Alpha = 0.20 | N_train = 200 ===
[Semantic Init]
[Positional Init]
sem 1.0815787506103516 pos 1.0815787506103516
delta 0.0

=== Alpha = 0.40 | N_train = 400 ===
[Semantic Init]


KeyboardInterrupt: 