In [1]:
import pickle
import os
import numpy as np
import matplotlib.pyplot as plt

from utils.funs import count_outliers
from utils.symmetry_scores import get_scores

from transformers import AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dir = '../../../_data/fig-symmetry-language-models/TinyGPT-query-key.pkl'

if os.path.isfile(dir):
    with open(dir, 'rb') as file:
        models = pickle.load(file)
else: models = {}

In [4]:
path = ["transformer.h[", "].attn.attention.q_proj.weight", "].attn.attention.k_proj.weight"]

'TinyGPT 1m (l = 8, d = 64, h = 1 ; 1M parameters)'
dh = 64
l = 8
d = 64
h = d // dh

model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
score_List= get_scores(d, l, h, dh, model, path)
models['TinyGPT-1m'] = [l, d, h, dh, score_List]

'TinyGPT 3m (l = 8, d = 64, h = 2 ; 3M parameters)'
dh = 64
l = 8
d = 128
h = d // dh

model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-3M")
score_List= get_scores(d, l, h, dh, model, path)
models['TinyGPT-3m'] = [l, d, h, dh, score_List]

'TinyGPT 8m (l = 8, d = 64, h = 4 ; 8M parameters)'
dh = 64
l = 8
d = 256
h = d // dh

model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-8M")
score_List= get_scores(d, l, h, dh, model, path)
models['TinyGPT-8m'] = [l, d, h, dh, score_List]

'TinyGPT 1layer 21M (l = 1, d = 1024 h = 16 ; 21M parameters)'
dh = 64
l = 1
d = 1024
h = d // dh

model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1Layer-21M")
score_List= get_scores(d, l, h, dh, model, path)
models['TinyGPT-21m'] = [l, d, h, dh, score_List]

'TinyGPT 28M (l = 8, d = 512 h = 8 ; 28M parameters)'
dh = 64
l = 8
d = 512
h = d // dh

model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-28M")
score_List= get_scores(d, l, h, dh, model, path)
models['TinyGPT-28m'] = [l, d, h, dh, score_List]

In [5]:
'save'
with open(dir, 'wb') as file:
    pickle.dump(models, file)

In [None]:
from utils.visualization import symmetry_score_boxplot, symmetry_score_scatter, symmetry_score_outliers

symmetry_score_boxplot(models)
symmetry_score_scatter(models)
symmetry_score_outliers(models)