In [1]:
import torch
from matplotlib import pyplot as plt
import sys
import time
import os
import numpy as np
import math


##############################################################################################################
# Trains a linear Transformer with 1,2,3,4 layers
# Plots the test loss of trained Transformer against 1,2,3,4 steps of gradient descent (with and without preconditioning)
##############################################################################################################

#use cuda if available, else use cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(0)
# import the model and some useful functions
from linear_transformer import Transformer_F, attention, in_context_loss, generate_data_inplace, Transformer_C

# set up some print options
np.set_printoptions(precision = 2, suppress = True)
torch.set_printoptions(precision=2)

#begin logging
cur_dir = 'log' 
os.makedirs(cur_dir, exist_ok=True)
#f = open(cur_dir + '/rotation.log', "a", 1)
#sys.stdout = f

# Set up common problem parameters

lr = 0.002
clip_r = lr
alg = 'adam'
mode = 'sphere'

d = 5        # dimension


n_head = 1  # 1-headed attention
B = 30000  # 1000 minibatch size
var = 0.0001  # initializations scale of transformer parameter
shape_k = 0.1  # shape_k: parameter for Gamma distributed covariates
max_iters = 2100  # Number of Iterations to run
print('setting max_iters to 2100 for quick run, need about 40100 for 7 layers')
hist_stride = 1  # stride for saved model paramters in `train.ipynb'
stride = 400

# a convenience function for taking a step and clipping
def clip_and_step(allparam, optimizer, clip_r = None):
    norm_p=None
    grad_all = allparam.grad
    for i in range(0,allparam.shape[0]):
        norm_p = grad_all[i,:,:,:,:].norm().item()
        if norm_p > (clip_r/(8-i+1)**2):
            grad_all[i,:,:,:,:].mul_((clip_r/(8-i+1)**2)/norm_p)
            fraction = (clip_r/(8-i+1)**2)/norm_p
        else:
            fraction = 1.0
    optimizer.step()
    return fraction

#format for saving run data
filename_format = '/run_{}_{}_{}_{}_{}.pth'
n_layers = [3]  # number of layers of transformer

In [2]:
def run_exp(data_activation, kernel_activation, N, keys):
    print(f"lr: {lr}")
    for key in keys:
        sd = key[0]
        n_layer = key[1]
        filename = cur_dir + filename_format.format(data_activation, kernel_activation, N, sd, n_layer)
        print(key)

        prob_seed = sd
        opt_seed = sd

        hist = []

        #set seed and initialize model
        torch.manual_seed(opt_seed)
        model = Transformer_F(n_layer, 1, d, 0.1, N=N)
        model.to(device)
        #initialize algorithm. Important: set beta = 0.9 for adam, 0.999 is very slow
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.99, 0.99), weight_decay=0)
        residual_betas = (0.9, 0.9)

        # set seed
        # sample random rotation matrix
        # initialize initial training batch
        np.random.seed(prob_seed)
        torch.manual_seed(prob_seed)
        gaus = torch.FloatTensor(d,d).uniform_(-1,1).cuda()
        U = torch.linalg.svd (gaus)[0].cuda()
        #D = torch.diag(torch.FloatTensor([1,1,1/2,1/1.5,1])).cuda()
        D = torch.eye(d).cuda()
        Z = torch.zeros([B,N+1,d+1])
        y = Z[:,-1,-1]
        Z = Z.to(device)
        y = y.to(device)

        reference_lr = lr
        residual_lr = reference_lr

        Z,y = generate_data_inplace(Z,data_activation, U=U, D=D )
        eff_t = 0
        for t in range(max_iters):
            optimizer.param_groups[0]['lr'] = residual_lr
            optimizer.param_groups[0]['betas'] = residual_betas
            if t%4000==0 and t>20000:
                reference_lr = reference_lr * 0.5
            start = time.time()
            # save model parameters
            if t%stride ==0:
                hist.append(model.allparam.clone().detach())

            loss = in_context_loss(model, Z, y, kernel_activation)
            # compute gradient, take step
            loss.backward()
            if t>8000:
                fraction_elapsed_time = clip_and_step(model.allparam, optimizer, clip_r=0.0000001)
            else:
                fraction_elapsed_time = clip_and_step(model.allparam, optimizer, clip_r=0.00000001)
            model.zero_row_col()
            norms = model.allparam.grad.norm().item()
            optimizer.zero_grad()

            if t%10==0:
                Z,y = generate_data_inplace(Z,data_activation, U=U, D=D)

            end=time.time()
            if t%100 ==0 or t<5:
                print('iter {} | Loss: {:.3}  time: {:.2}  gradnorm: {:.2}, fraction:{:.2}, efft:{}'\
                      .format(t,loss.item(), end-start, norms, fraction_elapsed_time, eff_t))
        torch.save({'hist':hist, 'U':U, 'D':D}, filename)
        
def run_comb_exp(data_activation, N, keys):
    print(f"lr: {lr}")
    kernel_activation = 'combact3'
    for key in keys:
        sd = key[0]
        n_layer = key[1]
        filename = cur_dir + filename_format.format(data_activation, kernel_activation, N, sd, n_layer)
        print(key)

        prob_seed = sd
        opt_seed = sd

        hist = []

        #set seed and initialize model
        torch.manual_seed(opt_seed)
        model = Transformer_C(n_layer, 2, d, 0.1, N=N)
        model.to(device)
        #initialize algorithm. Important: set beta = 0.9 for adam, 0.999 is very slow
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.99, 0.99), weight_decay=0)
        residual_betas = (0.9, 0.9)

        # set seed
        # sample random rotation matrix
        # initialize initial training batch
        np.random.seed(prob_seed)
        torch.manual_seed(prob_seed)
        gaus = torch.FloatTensor(d,d).uniform_(-1,1).cuda()
        U = torch.linalg.svd (gaus)[0].cuda()
        #D = torch.diag(torch.FloatTensor([1,1,1/2,1/1.5,1])).cuda()
        D = torch.eye(d).cuda()
        Z = torch.zeros([B,N+1,d+1])
        y = Z[:,-1,-1]
        Z = Z.to(device)
        y = y.to(device)

        reference_lr = lr
        residual_lr = reference_lr

        Z,y = generate_data_inplace(Z,data_activation, U=U, D=D )
        eff_t = 0
        for t in range(max_iters):
            optimizer.param_groups[0]['lr'] = residual_lr
            optimizer.param_groups[0]['betas'] = residual_betas
            if t%4000==0 and t>20000:
                reference_lr = reference_lr * 0.5
            start = time.time()
            # save model parameters
            if t%stride ==0:
                hist.append(model.allparam.clone().detach())

            loss = in_context_loss(model, Z, y, kernel_activation)
            # compute gradient, take step
            loss.backward()
            if t>8000:
                fraction_elapsed_time = clip_and_step(model.allparam, optimizer, clip_r=0.0000001)
            else:
                fraction_elapsed_time = clip_and_step(model.allparam, optimizer, clip_r=0.00000001)
            model.zero_row_col()
            norms = model.allparam.grad.norm().item()
            optimizer.zero_grad()

            if t%10==0:
                Z,y = generate_data_inplace(Z,data_activation, U=U, D=D)

            end=time.time()
            if t%100 ==0 or t<5:
                #print(residual_lr)
                #print(eff_t)
                print('iter {} | Loss: {:.3}  time: {:.2}  gradnorm: {:.2}, fraction:{:.2}, efft:{}'\
                      .format(t,loss.item(), end-start, norms, fraction_elapsed_time, eff_t))
        torch.save({'hist':hist, 'U':U, 'D':D}, filename)

In [5]:
####################################
# run experiments
####################################
Ns = [8]    # context length
seeds=[0,1,2]
n_layers = [1,3,5,7]
keys = []
for s in seeds:
    for n_layer in n_layers:
        keys.append((s,n_layer,))
        
        
lr = 0.1
for N in Ns:
    run_comb_exp('euclidean',N,keys)
    #run_comb_exp('exp',N,keys)
    #run_comb_exp('comb',N,keys)

    run_exp('euclidean', 'exp', N, keys)
    #run_exp('exp', 'exp', N, keys)
    #run_exp('comb', 'exp', N, keys)
    
    run_exp('euclidean', 'linear', N, keys)
    #run_exp('exp', 'linear', N, keys)
    #run_exp('comb', 'linear', N, keys)
        

lr: 0.1
(0, 1)
iter 0 | Loss: 0.989  time: 0.1  gradnorm: 1.2e-10, fraction:8.9e-09, efft:0
iter 1 | Loss: 0.996  time: 0.073  gradnorm: 1.2e-10, fraction:1.6e-08, efft:0
iter 2 | Loss: 0.996  time: 0.03  gradnorm: 1.2e-10, fraction:1.6e-08, efft:0
iter 3 | Loss: 0.996  time: 0.011  gradnorm: 1.2e-10, fraction:1.6e-08, efft:0
iter 4 | Loss: 0.996  time: 0.011  gradnorm: 1.2e-10, fraction:1.6e-08, efft:0
iter 100 | Loss: 0.988  time: 0.013  gradnorm: 1.2e-10, fraction:2.7e-08, efft:0
iter 200 | Loss: 1.01  time: 0.013  gradnorm: 1.2e-10, fraction:1.3e-08, efft:0
iter 300 | Loss: 1.0  time: 0.013  gradnorm: 1.2e-10, fraction:2.7e-08, efft:0
iter 400 | Loss: 0.987  time: 0.014  gradnorm: 1.2e-10, fraction:4.7e-08, efft:0
iter 500 | Loss: 1.0  time: 0.013  gradnorm: 1.2e-10, fraction:1.6e-08, efft:0
iter 600 | Loss: 1.0  time: 0.014  gradnorm: 1.2e-10, fraction:2.5e-08, efft:0
iter 700 | Loss: 0.996  time: 0.013  gradnorm: 1.2e-10, fraction:2.3e-08, efft:0
iter 800 | Loss: 1.0  time: 0.014