In [None]:
## import stuff

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils.prng_data as prngs_data
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import colors as mcolors
from mpl_toolkits.axes_grid1 import make_axes_locatable
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from matplotlib.colors import LinearSegmentedColormap

from utils.gpt2 import GPT, GPTConfig, MLP, CausalSelfAttention
from utils.prng_data import lcg_vectorized, find_as, find_coprimes
from utils.datasets import PRNGsDataset

# usual imports
import pickle as pl
import pandas as pd
import argparse
import os
import random
import copy

from tqdm.auto import tqdm

plt.rcParams.update({"font.size": 20})
sns.set_theme(style="whitegrid")
dpi = 300
cmap = 'coolwarm'

# set the internal precision of float32 matrix multiplications: 
# https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
# “highest”, float32 matrix multiplications use the float32 datatype (24 mantissa bits with 23 bits explicitly stored) for internal computations.
# “high”, float32 matrix multiplications either use the TensorFloat32 datatype (10 mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers 
torch.set_float32_matmul_precision('high')
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32  # fp16 needs a further change of the code.

## Toggle to true if you want to use GPU
USE_GPU = False
if USE_GPU and torch.cuda.is_available():
    device = 'cuda'
elif USE_GPU and torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description = 'Hyperparameters')
    parser.add_argument('--main_seed', type = int, default = 1) # main seed for the experiments
    ### Dataset hyperparams
    parser.add_argument('--p_eval', type = int, default = 2048) # p for mod p
    parser.add_argument('--num_as', type = int, default = 16) # number of as
    parser.add_argument('--num_cs', type = int, default = 16) # number of cs
    parser.add_argument('--num_examples_per_prng', type = int, default = 1) # number of examples
    parser.add_argument('--total_examples', type = int, default = 1000_000) # number of examples
    parser.add_argument('--context_len', type = int, default = 256) # number of examples
    parser.add_argument('--chunk_size', type = int, default = 32) # number of examples
    parser.add_argument('--period_min', type = int, default = 0) # min period of training
    parser.add_argument('--period_max', type = int, default = 512) # max period of training
    ### Model hyperparams
    parser.add_argument('--n_layer', type = int, default = 1) # number of layers
    parser.add_argument('--n_head', type = int, default = 1) # number of heads
    parser.add_argument('--n_embd', type = int, default = 768)  # embedding dimension
    parser.add_argument('--head_dim', type = int, default = 768) # number of heads
    parser.add_argument('--act_name', type = str, default = 'relu') # activation
    ### Optimization hyperparams
    # parser.add_argument('--step', type = int, default = 2000) # number of training steps
    parser.add_argument('--num_steps', type = int, default = 100_000) # number of training steps
    parser.add_argument('--warmup_steps', type = int, default = 2048) # number of warmup steps
    parser.add_argument('--lr_trgt', type = float, default = 3e-4) # the target learning rate
    parser.add_argument('--lr_init', type = float, default = 1e-6) # initial learning rate
    parser.add_argument('--lr_min', type = float, default = 1e-6) # final learning rate
    parser.add_argument('--batch_size', type = int, default = 256) # batch size
    # adamw hyperparams
    parser.add_argument('--weight_decay', type = float, default = 1.0) # weight decay
    parser.add_argument('--beta1', type = float, default = 0.9) # beta1 
    parser.add_argument('--beta2', type = float, default = 0.99) # beta2
    ### Evaluation hyperparams
    parser.add_argument('--results_dir', type = str, default = './results')
    parser.add_argument('--plots_dir', type = str, default = './plots')
    # Other
    parser.add_argument('--shifts', type = int, default = 0) # position of 1 to p_eval numbers in the sequence
    
    return parser.parse_args(["--act_name=relu", "--context_len=256", "--batch_size=256", "--n_layer=1", "--p_eval=2048", "--total_examples=1000000", "--n_embd=768", \
        "--n_head=1", "--head_dim=768", "--warmup_steps=2048", "--num_steps=100000", "--num_examples_per_prng=1", "--lr_trgt=3e-04", "--weight_decay=1.0"])

config = parse_args()

config.vocab_size = config.p_eval

# # if I am not wrong, this seed only takes care of torch and not numpy
# np.random.seed(config.main_seed)
# torch.manual_seed(config.main_seed)
# torch.cuda.manual_seed(config.main_seed)

# Color
N = (config.p_eval // 6) + 1  # number of colors to extract from each of the base_cmaps below
base_cmaps = ['Greys', 'Purples', 'Reds', 'Oranges', 'Blues', 'Greens']

n_base = len(base_cmaps)
# we go from 0.2 to 0.8 below to avoid having several whites and blacks in the resulting cmaps
colors = np.concatenate([plt.get_cmap(name)(np.linspace(0.2, 0.8, N)) for name in base_cmaps])
custom_cmap = mcolors.ListedColormap(colors)

In [None]:
## Useful functions

def create_model(config):
    gptconfig = GPTConfig(block_size=config.context_len, n_embd=config.n_embd, n_head=config.n_head, vocab_size=config.vocab_size, n_layer=config.n_layer, act_name=config.act_name)
    model = GPT(gptconfig)
    model.to(device)
    return model


def find_prng_parameters_test(config):
    """ For a given p, find the possible values of a's and c's according to the Hull–Dobell Theorem: https://en.wikipedia.org/wiki/Linear_congruential_generator """

    a_list = prngs_data.find_as(config.p_eval)
    c_list = prngs_data.find_coprimes(config.p_eval)

    val_as = np.random.choice(a_list, min(64, len(a_list)))
    val_cs = np.random.choice(c_list, min(64, len(c_list)))

    ## use arbitary (a, c) as long as its not in val_a and val_c
    # train_as = [i for i in range(1, config.p_eval) if i not in val_as]
    # train_cs = [i for i in range(1, config.p_eval) if i not in val_cs]

    return val_as, val_cs


@torch.inference_mode()
def lcg_vectorized_with_fixed_seed(p: int = 512, length: int = 8, a_list: list = [45], c_list: list = [123]) -> torch.Tensor:
    """
    Vectorized version of lcg function with fixed (0th) seed.
    It supports multiple 'a' and 'c' values.
    """
    ## Create mesh grid and flatten
    a_mesh, c_mesh = np.meshgrid(a_list, c_list)
    a_flat = torch.tensor(a_mesh.flatten(), dtype=torch.int64)
    c_flat = torch.tensor(c_mesh.flatten(), dtype=torch.int64)

    ## Generate initial seed(s)
    initial_seeds = torch.arange(p, dtype=torch.int64)[:1]

    def single_lcg(a, c, seed):
        @torch.compile
        def next_value(prev):
            return (a*prev + c) % p
        
        sequence = [seed]
        for _ in range(length - 1):
            sequence.append(next_value(sequence[-1]))
        
        return torch.stack(sequence)

    ## Vectorize over a, c, and initial seeds
    results = torch.vmap(torch.vmap(single_lcg, in_dims=(None, None, 0)), in_dims=(0, 0, None), chunk_size=16)(a_flat, c_flat, initial_seeds)
    
    ## Reshape to combine all sequences
    return results.reshape(a_flat.size(0), -1, length)

"""PCA"""
@torch.inference_mode()
def pca_1d(embedding: torch.Tensor, components: tuple) -> torch.Tensor:
    embedding = embedding.T  # (n_embd, vocab)
    mean = embedding.mean(dim=1, keepdim=True)  # (n_embd, 1)
    centered_data = embedding - mean

    U, S, Vt = torch.linalg.svd(centered_data.T, full_matrices=False)
    
    # Select the specified components
    selected_components = Vt[list(components), :]  # (len(components), n_embd)
    
    U, Vt = svd_flip(U, Vt)
    
    # Calculate and print the ratio of explained variance for the selected components
    total_variance = torch.sum(S)
    selected_variance = S[list(components)]
    variance_ratio = (selected_variance / total_variance)
    
    print(f'Selected components {components} explain {variance_ratio} of the total variance')
    
    return (selected_components + mean.T)


@torch.inference_mode()
def pca_2d(embedding: torch.Tensor, components: tuple = (0, 1)) -> torch.Tensor:
    embedding = embedding.T  # (n_embd, vocab)
    mean = embedding.mean(dim=1, keepdim=True)  # (n_embd, 1)
    centered_data = embedding - mean

    U, S, Vt = torch.linalg.svd(centered_data.T, full_matrices=False)
    
    # Select the specified components
    selected_components = Vt[list(components), :]  # (len(components), n_embd)
    
    U, Vt = svd_flip(U, Vt)
    
    # Calculate and print the ratio of explained variance for the selected components
    total_variance = torch.sum(S)
    selected_variance = torch.sum(S[list(components)])
    variance_ratio = (selected_variance / total_variance).item()
    
    print(f'Selected components {components} explain {variance_ratio:.4f} of the total variance')
    
    return (selected_components + mean.T)


@torch.inference_mode()
def svd_flip(u, v):
    # columns of u, rows of v
    max_abs_cols = torch.argmax(torch.abs(u), 0)
    idx = torch.arange(u.shape[1]).to(u.device)
    signs = torch.sign(u[max_abs_cols, idx])
    u *= signs
    v *= signs.view(-1, 1)
    return u, v

In [None]:
## create model and load checkpoints
model = create_model(config)

ckpt_path = f'{config.results_dir}/chkpt_p{config.p_eval}_Tn{config.context_len}_N{config.total_examples}_ne{config.num_examples_per_prng}_n{config.n_embd}_h{config.n_head}_d{config.n_layer}_I{config.main_seed}_lr{config.lr_trgt:0.6f}_Tw{config.warmup_steps}_T{config.num_steps}_B{config.batch_size}_wd{config.weight_decay}.pth'
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

# plot_path = f'{config.plots_dir}/chkpt_p{config.p_eval}_Tn{config.context_len}_N{config.total_examples}_ne{config.num_examples_per_prng}_n{config.n_embd}_h{config.n_head}_d{config.n_layer}_I{config.main_seed}_lr{config.lr_trgt:0.6f}_Tw{config.warmup_steps}_T{config.num_steps}_B{config.batch_size}_wd{config.weight_decay}.pth'

In [None]:
## Fix random seeds
np.random.seed(config.main_seed)
torch.manual_seed(config.main_seed)
torch.cuda.manual_seed(config.main_seed)

## Generate test (a,c) using the Hull-Doebell theorem
a_list, c_list = find_prng_parameters_test(config)
a, c = a_list[1], c_list[1]

## Generate the test dataset
seq_collections = lcg_vectorized_with_fixed_seed(p = config.p_eval, length = config.context_len + 1, a_list = [a], c_list = [c])
test_dataset = seq_collections[0, :, :]

In [None]:
## 1d pca

## Tuple of PCA components to be computed
components = tuple([0])


## Extract the embedding matrix, comupte 1d PCA, and project the embedding matrix onto the PCA
wte = model.transformer.wte.weight.data.cpu()  # (vocab, n_embd)
wte_pca = pca_1d(wte, components=components)
results = wte @ wte_pca.t()


## Plot PCA and its Fourier transform
nrows = len(components)
ncols = 2
n_upto = 20  # We only plot a part of the PCA vector, for aesthetic reasons.
        
fig, axs = plt.subplots(figsize=(4*ncols, nrows*2), tight_layout=True)
# plt.suptitle("PCA of Token Embeddings")

for i_c in range(len(components)):

    plt.subplot(nrows, ncols, i_c * ncols + 1)
    plt.grid(False)
    plt.tick_params(axis='x', which='both', bottom=True, labelbottom=True)
    plt.tick_params(axis='y', which='both', left=True, labelleft=True)
    # plt.plot(results[:20, 0], '-.',label=f'{components[i] + 1}-th component')
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)
    plt.plot(results[:20, i_c], '.', color='purple')
    plt.plot(results[:20, i_c], '-', alpha=0.2, color='purple')
    plt.xlabel('number', fontsize=15)
    plt.ylabel('value', fontsize=15)
    # plt.legend(loc='upper right')

    plt.subplot(nrows, ncols, i_c * ncols + 2)
    ft = np.abs( np.fft.fft(results[:, :], axis=0) )
    plt.grid(False)
    plt.tick_params(axis='x', which='both', bottom=True, labelbottom=True)
    plt.tick_params(axis='y', which='both', left=True, labelleft=True)
    plt.xticks(fontsize=15, ticks=config.p_eval * np.array([0, 1/4, 1/2, 3/4, 1]), labels=["$0$", "$\\frac{\\pi}{2}$", "$\\pi$", "$\\frac{3\\pi}{2}$", "$2\\pi$"])
    plt.yticks(fontsize=15)
    # plt.plot(ft, '-.',label=f'{components[0] + 1}-th component')
    plt.plot(ft[:, i_c], '-', alpha=0.4, color='purple')
    plt.scatter(ft[:, i_c].argmax(), ft[:, i_c].max(), s=20, color='purple')
    # plt.annotate(f'{ft.argmax()}', (ft.argmax(), ft.max()), color='purple', fontsize=10)
    # plt.xlabel(f'frequency ($\\times \\frac{{2\\pi}}{{ {config.p_eval} }}$)')
    plt.xlabel('frequency', fontsize=15)
    plt.ylabel('Fourier\n amplitude', fontsize=15)
    # plt.legend(loc='upper right')


## Save the figure. Replace the filename and directory with your own.
# plt.savefig(f"{config.plots_dir}/pca_embedding_1d_p{config.p_eval}_d{config.n_layer}_h{config.n_head}_n{config.n_embd}.pdf", dpi=400, format='pdf', bbox_inches="tight")
        
plt.show()

In [None]:
## 2d PCA

## The two PCA components to be computed
components = (1,2)


## Extract the embedding matrix, and comupte 1d PCA
wte = model.transformer.wte.weight.data.cpu()  # (vocab, n_embd)
wte_pca = pca_2d(wte, components=components)
    
    
## Make a Scatter plot of 2d PCA
colors_2d = ['red', 'blue', 'green', 'purple']
fig = plt.figure(figsize=(3.5, 3), constrained_layout=True)

for number, embd_v in enumerate(wte[:]):
    results = (embd_v @ wte_pca[0], embd_v @ wte_pca[1])
    plt.scatter(results[0], results[1], alpha=0.2, color=colors_2d[number%4], s=10)
    # plt.annotate(f'{number}', results, color=custom_cmap.colors[number], fontsize=5)

plt.grid(False)
plt.tick_params(axis='x', which='both', bottom=True, labelbottom=True, labelsize=15)
plt.tick_params(axis='y', which='both', left=True, labelleft=True, labelsize=15)
plt.xlabel("2nd PCA component", fontsize=15)
plt.ylabel("3rd PCA component", fontsize=15)
# plt.xlabel(f'PCA Component {components[0] + 1}')
# plt.ylabel(f'PCA Component {components[1] + 1}')
# plt.title("PCA of Token Embeddings")


## Save the figure. Replace the filename and directory with your own.
# plt.savefig(f"{config.plots_dir}/pca_embedding_2d_p{config.p_eval}_d{config.n_layer}_h{config.n_head}_n{config.n_embd}.pdf", dpi=400, format='pdf', bbox_inches='tight')

plt.show()

In [None]:
## Autocorrelation embedding matrix

wte = model.transformer.wte.weight.data.cpu()  # (vocab, n_embd)

xlim = 700
ylim = xlim
xticks = np.arange(0, xlim, 256)
yticks = xticks
plt.figure(figsize=(5,4), tight_layout=True)
plt.imshow( ((wte @ wte.T) / wte.norm(dim=-1)**2)[:xlim, :xlim], origin='lower', vmin=0, vmax=0.2)
plt.xticks(ticks=xticks)
plt.yticks(ticks=yticks)
plt.tick_params(axis='x', which='both', bottom=True, labelbottom=True, labelsize=15)
plt.tick_params(axis='y', which='both', left=True, labelleft=True, labelsize=15)
plt.grid(False)
cbar = plt.colorbar(fraction=0.05, aspect=18)
cbar.ax.tick_params(labelsize=15)
cbar.set_ticks([0.0, 0.05, 0.1, 0.15, 0.2])
# cbar.set_ticks([0.0, 0.1, 0.2, 0.3, 0.4])
# cbar.set_ticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])


## Save the figure. Replace the filename and directory with your own.
# plt.savefig(f"{config.plots_dir}/{config.act_name}/embedding_autocorrelation_p{config.p_eval}_d{config.n_layer}_h{config.n_head}_n{config.n_embd}.pdf", dpi=400, format='pdf', bbox_inches='tight')

plt.show()