In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import functools
import sys

import datasets
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
import tqdm
import os
import torch.nn.functional as F

In [3]:
seed = 0
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(seed)

<torch._C.Generator at 0x7fee04d9e0d0>

In [4]:
import math
sin_wave = np.array([math.sin(x) for x in np.arange(200)])
plt.plot(sin_wave[:50])

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7fedfae162e0>]

In [105]:
X = []
Y = []

seq_len = 50
num_records = len(sin_wave) - seq_len
val_num = 50
for i in range(num_records - val_num):
    X.append(sin_wave[i:i+seq_len])
    Y.append(sin_wave[i+seq_len])
    
X = np.array(X)
X = np.expand_dims(X, axis=2)

Y = np.array(Y)
Y = np.expand_dims(Y, axis=1)

In [106]:
X_val = []
Y_val = []

for i in range(num_records - val_num, num_records):
    X_val.append(sin_wave[i:i+seq_len])
    Y_val.append(sin_wave[i+seq_len])
    
X_val = np.array(X_val)
X_val = np.expand_dims(X_val, axis=2)

Y_val = np.array(Y_val)
Y_val = np.expand_dims(Y_val, axis=1)

In [107]:
class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_seq_len = 80):
        super().__init__()
        self.d_model = d_model
        
        # create constant 'pe' matrix with values dependant on 
        # pos and i
        pe = np.zeros((max_seq_len, d_model))
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] =  math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i] =  pos
                
#         pe = pe.unsqueeze(0)
        self.pe = np.expand_dims(pe, axis=0)
#         self.register_buffer('pe', pe)
#         print(pe)
 
    
    def forward(self, x):
        # make embeddings relatively larger
        x = x * 0.5
        #add constant to embedding
        seq_len = x.shape[1]
    
        x = x + self.pe[:,:seq_len]
        return x
pos_encoder = PositionalEncoder(d_model=1, max_seq_len=50)

In [99]:
pos_encoder.pe.shape

(1, 50, 1)

In [97]:
X.shape

(100, 50, 1)

In [108]:
X_val = pos_encoder(X_val)
X = pos_encoder(X)

In [109]:
X

array([[[ 0.        ],
        [ 1.42073549],
        [ 2.45464871],
        ...,
        [47.06178656],
        [47.61587267],
        [48.52312367]],

       [[ 0.42073549],
        [ 1.45464871],
        [ 2.07056   ],
        ...,
        [46.61587267],
        [47.52312367],
        [48.86881257]],

       [[ 0.45464871],
        [ 1.07056   ],
        [ 1.62159875],
        ...,
        [46.52312367],
        [47.86881257],
        [49.33511459]],

       ...,

       [[ 0.18980387],
        [ 0.71330906],
        [ 1.50039658],
        ...,
        [46.7544892 ],
        [48.23387258],
        [49.49823459]],

       [[-0.28669094],
        [ 0.50039658],
        [ 1.74681718],
        ...,
        [47.23387258],
        [48.49823459],
        [49.30452201]],

       [[-0.49960342],
        [ 0.74681718],
        [ 2.22601289],
        ...,
        [47.49823459],
        [48.30452201],
        [48.8308333 ]]])

In [102]:
from torch.utils.data import TensorDataset, DataLoader
batch_size = 32
X_total = np.concatenate((X, X_val), axis=0)
Y_total = np.concatenate((Y, Y_val), axis=0)
train_dataset = TensorDataset(torch.Tensor(X), torch.Tensor(Y))
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
val_dataset = TensorDataset(torch.Tensor(X_val), torch.Tensor(Y_val))
val_dataloader = DataLoader(val_dataset, batch_size, shuffle=True)
total_dataset = TensorDataset(torch.Tensor(X_total), torch.Tensor(Y_total))
total_dataloader = DataLoader(total_dataset, 1, shuffle=False)

In [63]:
import torch.nn.functional as F
def weights_init_normal(m):
    '''Takes in a module and initializes all linear layers with weight
       values taken from a normal distribution.'''

    classname = m.__class__.__name__
    # for every Linear layer in a model
    if classname.find('Linear') != -1:
        y = m.in_features
        # m.weight.data shoud be taken from a normal distribution
        m.weight.data.normal_(0.0, 1 / np.sqrt(y))
#         m.weight.data.normal_(0.0, 1)
        # m.bias.data should be 0
        m.bias.data.fill_(0)
        
class Transformer2(nn.Module):
    def __init__(self, d, d_in, depth, temp=1, vu=1, vw=1, vv=1, vb=0):
        super().__init__()
        self.d_in = d_in
        self.d = d
        self.depth = depth
        self.temp = temp
        self.vu = vu
        self.vw = vw
        self.vb = vb

        def paramwrap(l):
            return nn.ParameterList([nn.Parameter(p) for p in l])

        self.Us = paramwrap([
            torch.randn(d, d) * np.sqrt(vu)
            for _ in range(depth - 1)])
        self.W1s = paramwrap([torch.randn(d, d) * np.sqrt(vw)
                              for _ in range(depth)])
        self.W2s = paramwrap([torch.randn(d, d) * np.sqrt(vw)
                              for _ in range(depth)])
        self.embedding = nn.Parameter(
            torch.randn(d_in, d) * np.sqrt(vu))
        self.readout = nn.Parameter(
            torch.randn(d) * np.sqrt(vv))
        self.reset_data()

    def forward(self, seq):
        '''
        Input:
            seq: seqlen x tokensize array, for any seqlen and tokensize
        Output:
            out: seqlen x self.d_in array, for the same seqlen as input
        '''
        d = self.d
        d_in = self.d_in
        self.xs.append(seq)
        inseq = seq @ self.embedding / d_in ** 0.5
        inseq.retain_grad()
        self.ks.append(inseq)
        for l in range(self.depth):
            if l > 0:
                inseq = inseq @ self.Us[l - 1] / d ** 0.5
                inseq.retain_grad()
                self.ks.append(inseq)
            # self attn
            gram = inseq @ inseq.T / inseq.shape[1]
            weights = torch.softmax(gram / self.temp, dim=1)
            self.As.append(weights)
            # weights @ inseq gives vectors returned by attention
            # inseq + weights @ inseq is the residual connection
            post_attn = self.layernorm(inseq + weights @ inseq)
            post_attn.retain_grad()
            self.zs.append(post_attn)
            # self.post_attn = post_attn

            # FF
            inseq = post_attn @ self.W1s[l] / d ** 0.5
            inseq.retain_grad()
            self.gs.append(inseq)
            inseq = torch.relu(inseq) @ self.W2s[l] / d ** 0.5
            inseq.retain_grad()
            self.hs.append(inseq)
            inseq = self.layernorm(inseq + post_attn)
            inseq.retain_grad()
            self.xs.append(inseq)
        return (inseq @ self.readout / d ** 0.5).mean()

    def reset_data(self):
        self.xs = []
        self.ks = []
        self.ys = []
        self.zs = []
        self.gs = []
        self.hs = []
        self.As = []

    def layernorm(self, seq):
        '''inplace layernorm
        Input:
            seq: seqlen x tokensize array, for any seqlen and tokensize
        Output:
            out: seqlen x tokensize array
                Means and standard deviation computed over the `tokensize` dimension
        '''
        seq = seq - torch.mean(seq, dim=1, keepdim=True)
        seq = seq / torch.std(seq, dim=1, keepdim=True)
        return seq
class Transformer(nn.Module):
    def __init__(self, d, d_in, depth, temp=1, vu=1, vw=1, vv=1, vb=0):
        super().__init__()
        self.d_in = d_in
        self.d = d
        self.depth = depth
        self.temp = temp
        self.vu = vu
        self.vw = vw
        self.vb = vb
        

        def paramwrap(l):
            return nn.ParameterList([nn.Parameter(p) for p in l])
        
        self.Us = paramwrap([
                    torch.randn(d, d) * np.sqrt(vu).item()
                   for _ in range(depth-1)])

        nn.Parameter(torch.tensor(torch.randn(d, d) * np.sqrt(vw).item()))
        self.W1s = paramwrap([torch.randn(d, d) * np.sqrt(vw).item() 
                   for _ in range(depth)])
        self.W2s = paramwrap([torch.randn(d, d) * np.sqrt(vw).item() 
                   for _ in range(depth)])
        self.embedding = nn.Parameter(
            torch.randn(d_in, d) * np.sqrt(vu).item())
        self.readout = nn.Parameter(
            torch.randn(d) * np.sqrt(vv).item())
        
    def forward(self, seq):
        '''
        Input:
            seq: seqlen x tokensize array, for any seqlen and tokensize
        Output:
            out: seqlen x self.d_in array, for the same seqlen as input
        '''
        d = self.d
        d_in = self.d_in

        inseq = torch.matmul(seq, self.embedding) / d_in**0.5

        for l in range(self.depth):
            if l > 0:

                inseq = torch.matmul(inseq, self.Us[l-1])/ d**0.5

            # self attn
            gram = torch.matmul(inseq, torch.transpose(inseq,1,2)) / inseq.shape[2]
            weights = torch.softmax(gram / self.temp, dim=2)

            # weights @ inseq gives vectors returned by attention
            # inseq + weights @ inseq is the residual connection
            post_attn = self.layernorm(inseq + torch.matmul(weights, inseq))

            
            # FF
            inseq = torch.matmul(post_attn, self.W1s[l]) / d**0.5
            inseq = torch.matmul(F.relu(inseq), self.W2s[l]) / d**0.5
            inseq = self.layernorm(inseq + post_attn)
            
            
            inseq = torch.mean(inseq, dim=1)
            out = torch.matmul(inseq, self.readout)/ d**0.5

        return out
        
    def layernorm(self, seq):
        '''inplace layernorm
        Input:
            seq: seqlen x tokensize array, for any seqlen and tokensize
        Output:
            out: seqlen x tokensize array
                Means and standard deviation computed over the `tokensize` dimension
        '''
        seq = seq - torch.mean(seq, dim=2, keepdim=True)
        seq = seq / torch.std(seq, dim=2, keepdim=True)
        return seq

# class Transformer(nn.Module):
#     # d_model : number of features
#     def __init__(self,feature_size=8,num_layers=1,dropout=0):
#         super(Transformer, self).__init__()
#         self.linear = nn.Linear(1,feature_size)
#         self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=4, dropout=dropout, batch_first=True)
#         self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)        
#         self.decoder = nn.Linear(feature_size,1)
#         self.init_weights()

#     def init_weights(self):
#         initrange = 0.1    
#         self.decoder.bias.data.zero_()
#         self.decoder.weight.data.uniform_(-initrange, initrange)

#     def _generate_square_subsequent_mask(self, sz):
#         mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
#         mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
#         return mask

#     def forward(self, src):
#         device = src.device
#         src = self.linear(src)
#         mask = self._generate_square_subsequent_mask(src.shape[1]).to(device)
#         # mask = self._generate_square_subsequent_mask(len(src)).to(device)
# #         mask = mask.repeat(src.shape[0],1,1)
#         output = self.transformer_encoder(src,mask)
#         output = torch.mean(output, dim=1)
#         output = self.decoder(output)
#         return output    

def get_weights_list(model):
    weights = []
    for name, weight in model.named_parameters():
        weights.append(weight.detach().cpu().numpy().copy())
    return weights

def calculate_weights_diff(weights1, weights2):
    diffs = []
    for weight1, weight2 in zip(weights1, weights2):
        diffs.append(np.linalg.norm(weight2 - weight1)/np.linalg.norm(weight2))
    return diffs

def train(dataloader, model):
    total_loss = 0
    loss_func = nn.MSELoss()
    for data in dataloader:
        x = data[0].to(device)
        y = data[1].to(device)

        output =model(x)
        loss = loss_func(output,y)
        optimizer.zero_grad()
        weights1 = get_weights_list(model)
        loss.backward()
        optimizer.step()
        weights2 = get_weights_list(model)
#         print(calculate_weights_diff(weights1, weights2))
        total_loss += loss.item()
    total_loss /= len(dataloader)
    return total_loss

def validation(dataloader, model):
    total_loss = 0
    loss_func = nn.MSELoss()
    for data in dataloader:
        x = data[0].to(device)
        y = data[1].to(device)

        output =model(x)
        loss = loss_func(output,y)
        
        total_loss += loss.item()
    total_loss /= len(dataloader)
    return total_loss

In [64]:
learning_rate = 0.0001
nepoch = 300              
T = 50                   # length of sequence
hidden_dim = 512
output_dim = 1

model = Transformer(d=hidden_dim, d_in=1, depth=1)
# model = Transformer(feature_size=hidden_dim)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')


model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for name, p in model.named_parameters(): 
    print(name)

The model has 525,312 trainable parameters
embedding
readout
W1s.0
W2s.0


  nn.Parameter(torch.tensor(torch.randn(d, d) * np.sqrt(vw).item()))


In [65]:
train_loss = validation(train_dataloader, model)
val_loss = validation(val_dataloader, model)
print("before train",train_loss, val_loss)

for epoch in range(nepoch):
    # check loss on train
    total_loss = train(train_dataloader, model)
    val_loss = validation(val_dataloader, model)
    print(epoch, total_loss, val_loss)

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


before train 0.6987088173627853 0.6260978877544403
0 0.5480006039142609 0.563891738653183
1 0.5441502928733826 0.49781468510627747
2 0.476975753903389 0.5188007056713104
3 0.5257359966635704 0.4986109435558319
4 0.5846603959798813 0.4873873144388199
5 0.4543536603450775 0.4694352000951767
6 0.4910196512937546 0.5110757052898407
7 0.5414416193962097 0.47709789872169495
8 0.498797282576561 0.485203355550766
9 0.466965489089489 0.5012508630752563
10 0.47775010764598846 0.46833017468452454
11 0.4943982660770416 0.476937860250473
12 0.5267538726329803 0.4945252686738968
13 0.45893697440624237 0.504584014415741
14 0.5022848471999168 0.521232932806015
15 0.5463365465402603 0.4932989776134491
16 0.4810086339712143 0.47526903450489044
17 0.5320923775434494 0.4926516264677048
18 0.48571035265922546 0.4772530496120453
19 0.4579365476965904 0.517876461148262
20 0.47843386232852936 0.49355193972587585
21 0.5057856813073158 0.5164413750171661
22 0.4867498502135277 0.4714328348636627
23 0.48181618005

203 0.4981394112110138 0.4909120947122574
204 0.46992824226617813 0.4975423663854599
205 0.48022623360157013 0.49234993755817413
206 0.4240020737051964 0.4927043914794922
207 0.4672076255083084 0.5247144401073456
208 0.44170161336660385 0.4827924221754074
209 0.5378305837512016 0.49723365902900696
210 0.5533502623438835 0.4952593594789505
211 0.4848663806915283 0.5174336731433868
212 0.46646997332572937 0.48481035232543945
213 0.5073786228895187 0.5217767953872681
214 0.39317941665649414 0.5044716596603394
215 0.4779067263007164 0.501975953578949
216 0.44694188982248306 0.48080530762672424
217 0.5158707424998283 0.4788657873868942
218 0.48081184923648834 0.5398350059986115
219 0.47763747721910477 0.4945495277643204
220 0.4548671841621399 0.5066149830818176
221 0.5662543997168541 0.500445157289505
222 0.49411917477846146 0.46346502006053925
223 0.5198100656270981 0.49259985983371735
224 0.4880713224411011 0.4887019842863083
225 0.5045457780361176 0.48734332621097565
226 0.48676852881908

In [66]:
import gc
import time

def clone_grads(net):
    d = {}
    for name, p in net.named_parameters():
        if p.grad is not None:
            d[name] = p.grad.clone().detach().cpu()
#     d = torch.cat([d[k].reshape(-1) for k in d], dim=0)
    return d


def paramdot(d1, d2):

    ans = sum(
        torch.dot(d1[k].reshape(-1), d2[k].reshape(-1))
        for k in d1)

    return ans


def normalize_matrix(matrix):
    m = np.max(matrix)
    out = matrix / m
    return out

def get_finite_ntk_trained(model, dataloader):
    grads = []
    M = len(dataloader)
    print(M)
    i = 0
    loss_func = nn.MSELoss()
    for data in tqdm.tqdm(dataloader):
        if model.__class__.__name__ == "RNN":
            x = data[0].to(device)
            y = data[1].to(device)
        else:
            x = data[0].squeeze(dim=2).to(device)
            y = data[1].to(device)

        model.train()
        model.zero_grad()
        output =model(x)
        loss = loss_func(output,y)
        scale = 2*(output.item()-y.item())
        loss /= scale
        loss.backward()
        grads.append(clone_grads(model))
    
    finite_ntk = np.zeros((M,M))
    for i in tqdm.tqdm(range(M)):
        for j in range(i+1):
            finite_ntk[i, j] = finite_ntk[j, i] = paramdot(grads[i], grads[j])
    
    return finite_ntk

def get_finite_ntk(model, dataloader):
    grads = []
    M = len(dataloader)
    print(M)
    i = 0

    for data in tqdm.tqdm(dataloader):
        x = data[0].to(device)
        y = data[1].to(device)
        
        model.train()
        model.zero_grad()
        loss =model(x)
        loss.backward()
        grads.append(clone_grads(model))
    
    finite_ntk = np.zeros((M,M))
    for i in tqdm.tqdm(range(M)):
        for j in range(i+1):
            finite_ntk[i, j] = finite_ntk[j, i] = paramdot(grads[i], grads[j])
    
    return finite_ntk

# model = RNN(input_dim=1, hidden_dim = 16, output_dim=1).to(device)
# model = SimpleRNN(indim=1, statedim=4096).to(device)
finite_ntk = get_finite_ntk(model, total_dataloader)

150


100%|██████████████████████████████████████████████████████████████████████████████████| 150/150 [00:00<00:00, 211.38it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 150/150 [00:01<00:00, 100.20it/s]


In [110]:
from kernels import RNTK,TNTK
from kernels.utils import VErf3, VDerErf3
varw = 1
varu = 2
varb = 0.2
varv = 1
avgpool = True
inps = X_total
inpcov = np.einsum('ais,bjs->aibj', inps, inps) / inps.shape[-1]
# inpcov = np.moveaxis(inpcov, 1, 2)
# thcov = RNTK(inpcov, VErf3, VDerErf3, varw, varu, varb, varv, avgpool=avgpool)
inf_ntk = TNTK(inpcov, depth=1)
print(inf_ntk.shape)

# inf_ntk = RNTK(inpcov)

(150, 150)


In [111]:
import numpy as np
import pandas as pd
from kernels import svr_search

train_fold_idx = np.array([[i for i in range(50)]])
test_fold_idx = np.array([[i for i in range(50,150)]])
# gram = finite_ntk
gram = inf_ntk
# gram = thcov
labels = Y_total.squeeze(axis=1)
results = svr_search(gram, labels, train_fold_idx, test_fold_idx)
results

Unnamed: 0,C,normalized,train,test
0,0.0001,False,-0.49936,-0.497499
1,0.000336,False,-0.49936,-0.497499
2,0.001129,False,-0.49936,-0.497499
3,0.003793,False,-0.499359,-0.497499
4,0.012743,False,-0.499359,-0.4975
5,0.042813,False,-0.499356,-0.497504
6,0.143845,False,-0.499347,-0.497518
7,0.483293,False,-0.499318,-0.497563
8,1.623777,False,-0.499224,-0.49772
9,5.455595,False,-0.498956,-0.498304


In [None]:
finite_ntk

In [None]:
error = np.linalg.norm(finite_ntk-inf_ntk) / np.linalg.norm(inf_ntk)
print(error)