In [1]:
import torch as t
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
import numpy as np
import seaborn as sns
from train import train
from models import Transformer, Low_rank
from utils import generate_data, entropy, power_unif_law, generate_each, last_position_law, gen_d_law, almost_rank_d
from interp import back_track, attention_map, by_attention, every_attention, new_computation_basis

In [None]:
#training transformer

#model params
N = 100
d = 10
nb_layers = 1
width = 0
depth = 0
para = 5
nb_head = 1
n_gram = 3
context_window = n_gram

#distribution params
alphas = [1, 1, 1]
nb_tokens=[40, 40, 1]
t.manual_seed(2222)
pi = power_unif_law(alphas, nb_tokens, N)

#learning params
batch_size=2**10
num_batch=5000
lr=1e-3

low_batch_size=2**10
low_num_batch=4000
low_lr=1e-3

device='cpu'
cosim = t.nn.CosineSimilarity(dim=-1)

Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)

model = Transformer(d, N, nb_layers, width, depth, para, nb_head, context_window, pi)
dict1 = train(model, Data, lr=lr, next_token=False)

plt.plot(dict1['Loss'], label=f'Full divergence')
print(sum(dict1['Loss'][-100:-1])/100)
print(sum(dict1['Acc'][-100:-1])/100)

model = Transformer(d, N, nb_layers, width, depth, para, nb_head, context_window, pi)
dict2 = train(model, Data, lr=lr, next_token=True)

plt.plot(dict2['Loss'], label=f'Next token')
print(sum(dict2['Loss'][-100:-1])/100)
print(sum(dict2['Acc'][-100:-1])/100)


#upper bound
ent=entropy(pi)
plt.plot([np.log(N)-ent for _ in dict2['Loss']], label='Random baseline', color='red')

#lower bound
model_low = Low_rank(d, N, context_window, pi)
Data = generate_data(low_batch_size, low_num_batch, pi, context_window)
dict_low = train(model_low, Data, lr=low_lr)
best_loss = sum(dict_low['Loss'][-100:-1])/100
plt.plot([best_loss for _ in dict2['Loss']], label='Optimal baseline', color='green')
plt.legend()
plt.xlabel("Batch number")
plt.ylabel("Divergence")
plt.title("Transformer's learning dynamics")
plt.show()

plt.plot(dict2['Acc'], label=f'Next token')
plt.plot(dict1['Acc'], label=f'Full divergence')
plt.plot([1/N for _ in dict2['Acc']], color='black', label='Random baseline')
plt.legend()
plt.xlabel("Batch number")
plt.ylabel("Accuracy")
plt.ylim(top=1+0.1, bottom=0-0.1)
plt.title("Transformer's learning dynamics")
plt.show()


In [None]:
"""Freezing different components in the training process."""

#Model parameters
N = 10
d = 5
nb_layers = 1
width = 0
depth = 0
para = 20
nb_head = 1
n_gram = 3
context_window = n_gram

#Distribution parameters
alphas = [1, 1, 1]
nb_tokens=[N, N, 1]
t.manual_seed(666)
pi = power_unif_law(alphas, nb_tokens, N)

#Learning parameters
batch_size=2**10
num_batch=5000
lr=1e-3

Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)

Freezer = [
    (True, True, False, True), 
    (False, False, True, False),
]

for freeze_E, freeze_QKV, freeze_O, freeze_U in Freezer:
    freezer = {
        'freeze_E': freeze_E,
        'freeze_QKV': freeze_QKV,
        'freeze_O': freeze_O,
        'freeze_U': freeze_U,
    }
    model = Transformer(d, N, nb_layers, width, depth, para, nb_head, context_window, pi)
    model.freeze(freezer)
    dict = train(model, Data, lr=lr)
    print(freezer)
    print(sum(dict['Loss'][-100:-1])/100)

In [None]:
"""Measuring different scaling laws."""

loss=[]
alpha_1_list=[]
alpha_2_list=[]
N_list=[]
d_list=[]
width_list=[]
nb_layers_list=[]
nb_head_list=[]
para_list=[]
unif_loss=[]
best_loss=[]

context_window = 3
batch_size = 2**9
num_batch = 3000
seed=2222

count = 0
max_count = 99//8 #Counting the number of iteration to keep track of expected time.

for N in [100]:
    for alpha_1, alpha_2 in product([0.9],[0.6]):
        t.manual_seed(seed)
        alphas = [alpha_1, alpha_1, alpha_2]
        nb_tokens = [N, N, N]
        pi = power_unif_law(alphas, nb_tokens, N)
        Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)
        ent = entropy(pi).item()
        for d in [10]:
            model_low = Low_rank(N, d, n_gram, context_window, pi)
            dict_low = train(model_low, Data, lr=1e-3)
            best = sum(dict_low['Loss'][-100:-1])/100
            for width in [0]:
                for depth in [1]:
                    for nb_head in [1]:
                        for nb_layers in [1]:
                            for parallel_heads in [5]:
                                N_list.append(N)
                                d_list.append(d)
                                width_list.append(width)
                                nb_head_list.append(nb_head)
                                nb_layers_list.append(nb_layers)
                                para_list.append(parallel_heads)
                                alpha_1_list.append(alpha_1)
                                alpha_2_list.append(alpha_2)
                                unif_loss.append(-ent+np.log(N))
                                best_loss.append(best)

                                model = Transformer(d, N, nb_layers, width, depth, para, nb_head, context_window, pi)
                                dict = train(model, Data, lr=1e-3)
                                loss.append(sum(dict['Loss'][-100:-1])/100)

                                count+=1
                                print(count/max_count)

dict={
    'alpha_1': alpha_1_list,
    'alpha_2': alpha_2_list,
    'N': N_list,
    'd': d_list,
    'width': width_list,
    'nb_layers': nb_layers_list,
    'nb_head': nb_head_list,
    'para_head': para_list,
    'loss': loss,
    'unif_loss': unif_loss,
    'best_loss': best_loss,
}

data = pd.DataFrame(dict)
data.to_csv('scaling_csv_v3/.csv', index=False)

In [None]:
"""Measuring the effect of different architectural modifications."""

#Model parameters
N = 100
d = 10
h = 0
depth = 1
nb_layers = 1
nb_head = 1
para = 10
n_gram = 3
max_seq_len = n_gram
assert max_seq_len >= n_gram
assert n_gram == 3

#Distribution parameters
alphas = [1, 1, 1]
nb_tokens=[40, 40, 1]
t.manual_seed(666)
pi = power_unif_law(alphas, nb_tokens, N)

#Learning parameters
batch_size=2**10
num_batch=5000
lr=1e-3
seed=333

Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)

for skip_res_connection in [False, True]:
    for skip_pos_QK in [False]:
        for skip_emb_QK in [False]:
            for skip_pos_OV in [False]:
                for skip_emb_OV in [False]:
                    skips = {
                                'skip_res_connection': skip_res_connection,
                                'skip_pos_QK': skip_pos_QK,
                                'skip_emb_QK': skip_emb_QK,
                                'skip_pos_OV': skip_pos_OV,
                                'skip_emb_OV': skip_emb_OV,
                            }
                    t.manual_seed(seed)
                    model = Transformer(d, N, nb_layers, width, depth, para, nb_head, context_window, pi)
                    if not((skip_pos_OV and skip_emb_OV) or (skip_pos_QK and skip_emb_QK)):
                        dict = train(model, Data, lr=lr)
                        print(skips)
                        print(sum(dict['Loss'][-100:-1])/100)

In [None]:
"""Looking at the shape of the learned attention pattern."""

#Model parameters
N = 100
d = 10
nb_layers = 1
width = 0
depth = 1
para = 3
nb_head = 1
n_gram = 3
context_window = n_gram

#Learning parameters
batch_size=2**10
num_batch=5000
lr=1e-3

#Distribution parameters
alphas = [1, 1, 1]
nb_tokens=[10, 10, 1]
t.manual_seed(666)
pi = power_unif_law(alphas, nb_tokens, N)

skips = {
    'skip_res_connection': True, #We skip the residual connection.
    'skip_pos_QK': False,
    'skip_emb_QK': False,
    'skip_pos_OV': False,
    'skip_emb_OV': False,
}

model = Transformer(d, N, nb_layers, width, depth, para, nb_head, context_window, pi, skips=skips)
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)
dict = train(model, Data, lr=lr)
print(sum(dict['Loss'][-100:-1])/100)

for i in range(para):
    map = attention_map(model, 0, i).detach()
    sns.heatmap(map, vmin=0, vmax=1)
    plt.xlabel('Position 1')
    plt.ylabel('Position 2')
    plt.title(f"Probability of head {i} given to the first token for an input pair.")
    plt.show()

    S = t.linalg.svdvals(map)
    print(S)

In [None]:
"""Measuring each head's contribution to the next token prediction.""" #TODO: clean

#Model parameters
N = 10
d = 5
nb_layers = 1
width = 0
depth = 0
para = 2
nb_head = 1
n_gram = 3
context_window = n_gram

#Learning parameters
batch_size=2**9
num_batch=4000
lr=1e-3

#Distribution parameters
alphas = [1, 1, 1]
nb_tokens=[10, 5, 1]
seed = 2222
t.manual_seed(seed)
pi = power_unif_law(alphas, nb_tokens, N)

skips = {
    'skip_res_connection': True, #We skip the residual connection.
    'skip_pos_QK': False,
    'skip_emb_QK': False,
    'skip_pos_OV': False,
    'skip_emb_OV': False,
}

model = Transformer(d, N, nb_layers, width, depth, para, nb_head, context_window, pi, skips=skips)
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)
dict = train(model, Data, lr=lr, seed=seed)
print(sum(dict['Loss'][-10:-1])/10)

examples = generate_each(pi)
contribution = back_track(model, examples)
new_contribution = new_computation_basis(model, examples)

ex = 10
print(new_contribution[f'para_{0}_layer_{0}'][ex, 2], t.norm(new_contribution[f'para_{0}_layer_{0}'][ex, 2]))
print(new_contribution[f'para_{1}_layer_{0}'][ex, 2], t.norm(new_contribution[f'para_{1}_layer_{0}'][ex, 2]))
print(new_contribution[f'para_{1}_layer_{0}'][ex, 2]+new_contribution[f'para_{0}_layer_{0}'][ex, 2])

W_U = model.unemb.weight.detach()
map = t.einsum('Nd, nd -> Nn', W_U, W_U)
sns.heatmap(map, center=0, cmap='bwr')
plt.title(r'$W_U^TW_U$')
plt.show()

map = map*(new_contribution[f'para_{1}_layer_{0}'][ex, 2]+new_contribution[f'para_{0}_layer_{0}'][ex, 2]).unsqueeze(0)
sns.heatmap(map, center=0, cmap='bwr')
plt.title(r'$W_U^TW_UX$')
plt.show()

every_attention(contribution, examples[ex].unsqueeze(0), 0)

_, computations = model.forward(examples, out_computation=True)
for i in range(N):
    #print(computations['logits'][examples[:, 2] == 0].mean(0)[2] - (computations['logits'][examples[:, 2] == 0].mean(0)[2]).mean())
    A = W_U@computations[f'para_{0}_layer_{0}'][examples[:, 2] == i].mean(0)[2] - (W_U@computations[f'para_{0}_layer_{0}'][examples[:, 2] == i].mean(0)[2]).mean()
    B = W_U@computations[f'para_{1}_layer_{0}'][examples[:, 2] == i].mean(0)[2] - (W_U@computations[f'para_{1}_layer_{0}'][examples[:, 2] == i].mean(0)[2]).mean()
    #print(A)
    #print(B)
    print(t.logical_and(t.logical_and(A*B > 0, A > 0), B > 0))

for i in range(para):
    by_attention(contribution, examples, 0, i, sort=False, method='solo')

by_attention(contribution, examples, 0, i, sort=False, method="group")

map = attention_map(model, 0, 0).detach()
sns.heatmap(map, vmin=0, vmax=1)
plt.xlabel('Position 1')
plt.ylabel('Position 2')
plt.title(f"Probability of head {0} given to the first token for an input pair.")
plt.show()

S = t.linalg.svdvals(map)
print(S)

map = attention_map(model, 0, 1).detach()
sns.heatmap(map, vmin=0, vmax=1)
plt.xlabel('Position 1')
plt.ylabel('Position 2')
plt.title(f"Probability of head {1} given to the first token for an input pair.")
plt.show()

S = t.linalg.svdvals(map)
print(S)

In [None]:
"""Show the Superposition phenomenon one low entropy distribution.""" #TODO: clean

#Model parameters
N = 10
d = 5
context_window = 3

#Learning parameters
batch_size=2**9
num_batch=5000
lr=1e-3

t.manual_seed(2108) #2108 and 1
nb_tokens = [N, N]
pi = last_position_law([N, N], N, 0, 0.95)

model = Low_rank(d, N, context_window, pi)
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)
dict = train(model, Data, lr=lr)
print(sum(dict['Loss'][-100:-1])/100)


W_U = model.unemb.weight.detach()
W_U = W_U#/t.norm(W_U, dim=-1, keepdim=True)
map = t.einsum('Nd, nd -> Nn', W_U, W_U)
sns.heatmap(map, center=0, cmap='bwr')
plt.title(r'$W_U^TW_U$ for quadratic importance')
plt.show()

W_E = model.word_emb.weight.detach()
w_e = []
examples = generate_each(pi, eps=0.1/(N**2))
indices = []
for ex in examples:
    indices.append(ex[0]+ex[1]*N)
indices = t.Tensor(indices).to(t.int)
W_E = W_E[indices]#/t.norm(W_E[indices], dim=-1, keepdim=True)
map = W_E@W_E.mH
sns.heatmap(map, center=0, cmap='bwr')
plt.title(r'$W_E^TW_E$')
plt.show()

#Prints the cosimilarity between the w_e(z) and the expected right w_u(g(z))
print(((W_U/t.norm(W_U, dim=-1, keepdim=True))@(W_E/t.norm(W_E, dim=-1, keepdim=True)).mH)[examples[t.arange(N**2), 2], t.arange(N**2)])

model = Transformer(d, N, nb_layers=1, width=0, depth=0, parallel_heads=5, nb_head=1, context_window=context_window, pi=pi)
#model.unemb.weight = t.nn.Parameter(W_U, requires_grad=False)
dict = train(model, Data, lr=lr)
print(sum(dict['Loss'][-100:-1])/100)

W_U = model.unemb.weight.detach()
W_U = W_U#/t.norm(W_U, dim=-1, keepdim=True)
map = t.einsum('Nd, nd -> Nn', W_U, W_U)
sns.heatmap(map, center=0, cmap='bwr')
plt.title(r'$W_U^TW_U$ for quadratic importance')
plt.show()

with t.no_grad():
    _, computations = model.forward(examples, out_computation=True)
W_E = computations[f'res_after_mlp_layer_{0}'].detach()[:, 2, :]
W_E = W_E[indices]#/t.norm(W_E[indices], dim=-1, keepdim=True)
map = W_E@W_E.mH
sns.heatmap(map, center=0, cmap='bwr')
plt.title(r'$A^TA$')
plt.show()

#Prints the cosimilarity between the w_e(z) and the expected right w_u(g(z))
print(((W_U/t.norm(W_U, dim=-1, keepdim=True))@(W_E/t.norm(W_E, dim=-1, keepdim=True)).mH)[examples[t.arange(N**2), 2], t.arange(N**2)])


#Prints the cosimilarity between all pairs W_E + POS
W_E = model.word_emb.weight.detach()
POS = model.pos_emb.weight.detach()[:2]
W_E = (W_E.unsqueeze(0) + POS.unsqueeze(1)).flatten(0, 1)
map = W_E@W_E.mH
sns.heatmap(map, center=0, cmap='bwr')
plt.title(r'$(W_E+POS)^T(W_E+POS)$')
plt.show()

In [34]:
"""Show that near """ #TODO: clean

N = 10
d = 5
context_window = 3

#learning params
batch_size=2**9
num_batch=5000
lr=1e-3

#distribution params
axis_aligned = []
non_axis_aligned = []
random = []
nb_tokens=[N, N]
num_rep = 3
t.manual_seed(666)
for dim in range(d, N+1):
    mean = 0.
    for _ in range(num_rep):
        pi, W = gen_d_law(nb_tokens, N, dim, axis_aligned=True)
        U, S, V = t.linalg.svd(W)
        W = (U[:, :d]*S[:d])@V[:d, :d]

        model = Low_rank(d, N, context_window, pi)
        Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)
        dict = train(model, Data, lr=lr, seed=(_+1)*(dim+1))

        W_U = model.unemb.weight.detach()
        mean += t.norm(W@t.linalg.inv(W.mH@W)@W.mH-W_U@t.linalg.inv(W_U.mH@W_U)@W_U.mH).item()
    axis_aligned.append(mean/num_rep)

    mean = 0.
    for _ in range(num_rep):
        pi, W = gen_d_law(nb_tokens, N, dim, axis_aligned=False)
        U, S, V = t.linalg.svd(W)
        W = (U[:, :d]*S[:d])@V[:d, :d]

        model = Low_rank(d, N, context_window, pi)
        Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)
        dict = train(model, Data, lr=lr, seed=(_+1)*(dim+1))

        W_U = model.unemb.weight.detach()
        mean += t.norm(W@t.linalg.inv(W.mH@W)@W.mH-W_U@t.linalg.inv(W_U.mH@W_U)@W_U.mH).item()
    non_axis_aligned.append(mean/num_rep)

    mean = 0.
    for _ in range(30):
        Q = t.randn_like(W)
        mean += t.norm(Q@t.linalg.inv(Q.mH@Q)@Q.mH-W@t.linalg.inv(W.mH@W)@W.mH).item()
    random.append(mean/30)

X = [dim for dim in range(d, N+1)]
plt.plot(X, axis_aligned, label='axis aligned')
plt.plot(X, non_axis_aligned, label='axis unaligned')
plt.plot(X, random, label='random baseline')
plt.legend()
plt.xlabel('Embedding dimension')
plt.ylabel('L2 matrix distance')
plt.title('Distance between W_U and low rank log-prob')
plt.show()

100%|██████████| 5000/5000 [00:06<00:00, 724.16it/s]
100%|██████████| 5000/5000 [00:07<00:00, 704.32it/s]
100%|██████████| 5000/5000 [00:06<00:00, 714.74it/s]
 33%|███▎      | 1659/5000 [00:02<00:04, 693.06it/s]


KeyboardInterrupt: 

In [5]:
#d direction but nice #TODO: clean
N = 10
d = 5
context_window = 3

#learning params
batch_size=2**10
num_batch=5000
lr=5e-4

#distribution params
axis_aligned = []
non_axis_aligned = []
random = []
nb_tokens=[N, N]
t.manual_seed(66)

pi = almost_rank_d(nb_tokens, N, d+1, axis_aligned=False)

model = Low_rank(d, N, context_window, pi)
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)
dict = train(model, Data, lr=lr)
print(sum(dict['Loss'][-100:-1])/100)

W_E = model.word_emb.weight.detach()
W_U = model.unemb.weight.detach()

f1 = W_E@W_U.mH #real
L = t.log(pi[2].flatten(0, 1))-t.log(pi[2].flatten(0,1)).mean(-1, keepdim=True)
PI = pi[2].flatten(0, 1)
Z = f1 - L - ((f1-L)*PI).sum(-1, keepdim=True)
print(t.log((t.exp(Z)*PI).sum(-1)).mean())

EL = (L*PI).sum(-1, keepdim=True)
COV = W_U.mH@((L*PI-EL*PI).mH)
VAR = t.einsum('Nd, ND, MN -> MdD', W_U, W_U, PI) - t.einsum('Nd, nD, MN, Mn -> MdD', W_U, W_U, PI, PI)
w_e = t.einsum('MdD, dM -> MD', t.linalg.inv(VAR), COV)
cosim = t.nn.CosineSimilarity(dim=-1)
#print(cosim(w_e, W_E))

f2 = w_e@W_U.mH #th
Z = f2 - L - ((f2-L)*PI).sum(-1, keepdim=True)
print(t.log((t.exp(Z)*PI).sum(-1)).mean())

U, S, V = t.linalg.svd(t.log(pi[2].flatten(0,1))-t.log(pi[2].flatten(0,1)).mean(-1, keepdim=True))
f = t.einsum('Md, d, dN -> MN', U[:, :d], S[:d], V[:d]) #least-square
Z = f - L - ((f-L)*PI).sum(-1, keepdim=True)
print(t.log((t.exp(Z)*PI).sum(-1)).mean())

100%|██████████| 5000/5000 [00:11<00:00, 421.38it/s]

0.026685121059417723
tensor([ 1.0415,  1.9363,  3.1598,  5.0069,  1.9707,  1.5509, 12.1780,  2.3745,
         5.9765,  5.2128,  3.2085,  1.0149,  1.8449, 11.2013, 13.2905,  3.2395,
         7.7813,  3.9956,  1.3935, 13.2777,  2.3417,  1.5386,  1.0061,  3.3038,
         7.0918,  4.2550,  5.5623,  4.0351,  1.6059,  3.8768,  3.6817, 14.5698,
         4.6992,  1.0136,  3.2237,  1.2154,  3.1014,  2.1661,  1.7571, 17.6890,
         2.9061, 17.2756,  3.9689,  3.5767,  1.0110,  1.8904,  5.1125,  4.2290,
         2.6369,  3.5960,  1.4123,  2.9827,  7.4260,  1.2031,  1.8818,  1.0039,
         1.2683,  6.4286,  6.0821,  2.9564, 18.0613,  6.1706,  3.1082,  3.6319,
         8.0174,  1.2654,  1.0123,  2.1825,  2.7111,  4.1964,  2.7695,  4.0028,
         6.4231,  2.2761,  3.9487,  8.7244,  2.2082,  1.0183,  4.2089,  3.5555,
        17.7348,  1.4154,  1.5374,  2.2965,  4.1503,  2.8886,  2.9644,  3.0730,
         1.1668,  1.5287,  6.6125, 10.9166,  5.9564, 10.5734,  2.3835,  2.6210,
         3.6969,  4


