In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.linalg
import scipy
import copy


# This notebook is a simplified version of toy_model.ipynb, only focusing on neural network training.


#################### Part I: Hyperparameters ####################
eta1 = 1e-4 # encoder learning rate (gradient)/ representation learning rate (natural gradient)
eta2 = 1e-3 # decoder learning rate
seed = 4 # random seed
input_dim = 10 # dimension of input random vector
latent_dim = 1 # dimension of representation space
output_dim = 10 # dimension of input random vector
p = 10 # base. i,j are integers in {0,1,2,...,p-1}.
epochs = 10000 # training iterations
log = 100 # logging frequency
dec_w = 200 # decoder width
wd = 0 # decoder weight decay
train_num = 45 # size of training set, no replacement (full dataset size=p(p+1)/2. 55 for p=10)
modulo = False # If true, o=i+j(mod p); else o=i+j.
natural_gradient = True # If true, use natural gradient; else use the common parameter gradient.



################### Part III: trainining neural networks ######################
np.random.seed(seed)
torch.manual_seed(seed)

# dataset
D0_id = [] # D0 is the full dataset, D0_id=[(0,0),(0,1),...,(p-1,p-1)]. D0 contains p*(p-1)/2 samples.
xx_id = [] # xx_id is the list of i in (i,j) in D0_id. xx_id = [0,0,...,p-1]
yy_id = [] # yy_id is the list of j in (i,j) in D0_id. yy_id = [0,1,...,p-1]
for i in range(p):
    for j in range(i,p):
        D0_id.append((i,j))
        xx_id.append(i)
        yy_id.append(j)
        
xx_id = np.array(xx_id)
yy_id = np.array(yy_id)

all_num = int(p*(p+1)/2)
train_id = np.random.choice(all_num,train_num, replace=False) # select training set id
test_id = np.array(list(set(np.arange(all_num)) - set(train_id))) # select testing set id

# parallelogram set
P0 = [] # P0 is the set of all possible parallelograms
P0_id = []

ii = 0
for i in range(all_num):
    for j in range(i+1,all_num):
        if np.sum(D0_id[i]) == np.sum(D0_id[j]):
            P0.append(frozenset({D0_id[i], D0_id[j]}))
            P0_id.append(ii)
            ii += 1

P0_num = len(P0_id)
        
    
# inputs
x_templates = np.random.normal(0,1,size=(p, output_dim)) # input random vectors
if modulo == False:
    y_templates = np.random.normal(0,1,size=(2*p-1, output_dim)) # output random vectors
else:
    y_templates = np.random.normal(0,1,size=(p, output_dim)) # output random vectors
    
x_templates = torch.tensor(x_templates, dtype=torch.float, requires_grad=True)
y_templates = torch.tensor(y_templates, dtype=torch.float, requires_grad=True)

# labels
inputs_id = np.transpose(np.array([xx_id,yy_id]))
if modulo == False:
    out_id = (xx_id + yy_id)
else:
    out_id = (xx_id + yy_id) % p
    
# training set
labels_train = torch.tensor(y_templates[out_id[train_id]], dtype=torch.float, requires_grad=True)
inputs_train = torch.cat([x_templates[xx_id[train_id]],x_templates[yy_id[train_id]]], dim=1)
out_id_train = out_id[train_id]

# testing set
labels_test = torch.tensor(y_templates[out_id[test_id]], dtype=torch.float, requires_grad=True)
inputs_test = torch.cat([x_templates[xx_id[test_id]],x_templates[yy_id[test_id]]], dim=1)
out_id_test = out_id[test_id]

# Define neural networks
class NET(nn.Module): # base MLP model
    def __init__(self, input_dim, output_dim, w=200):
        super(NET, self).__init__()
        self.l1 = nn.Linear(input_dim, w)
        self.l2 = nn.Linear(w, w)
        self.l3 = nn.Linear(w, output_dim)

    def forward(self, x):
        f = torch.nn.Tanh()
        self.x1 = f(self.l1(x))
        self.x2 = f(self.l2(self.x1))
        self.x3 = self.l3(self.x2)
        return self.x3

class DEC(nn.Module): # Decoder
    def __init__(self, input_dim, output_dim, w=200):
        super(DEC, self).__init__()
        self.net = NET(input_dim, output_dim, w=dec_w)

    def forward(self, latent, x_id):
        self.add1 = latent[x_id[:,0]]
        self.add2 = latent[x_id[:,1]]
        self.add = self.add1 + self.add2 # addition in representation space
        self.out = self.net(self.add)
        return self.out


class AE(nn.Module):
    def __init__(self, w=200, input_dim=1, output_dim=1):
        super(AE, self).__init__()
        self.enc = NET(input_dim, latent_dim, w=w)
        self.dec = DEC(latent_dim, output_dim, w=w)

    def forward(self, x, x_id):
        self.latent = self.enc(x)
        self.out = self.dec(self.latent,x_id)

        return self.out


model = AE(input_dim=input_dim, output_dim=output_dim, w=200)


if natural_gradient == True:
    latent = torch.nn.parameter.Parameter(model.enc(x_templates).clone())
    optimizer1 = torch.optim.Adam({latent}, lr=eta1)
else:
    optimizer1 = torch.optim.Adam(model.enc.parameters(), lr=eta1)
    
optimizer2 = torch.optim.AdamW(model.dec.parameters(), lr=eta2, weight_decay=wd)


reach_acc_test = False
reach_acc_train = False
reach_rqi = False

### training ###

test_acc_epochs = []
train_acc_epochs = []

for epoch in range(epochs): 

    optimizer1.zero_grad()
    optimizer2.zero_grad()

    # updata model parameters
    outputs_train = model.dec(latent, inputs_id[train_id])
    outputs_test = model.dec(latent, inputs_id[test_id])
    loss_train = torch.mean((outputs_train-labels_train)**2)
    loss_train.backward()
    optimizer1.step()
    optimizer2.step()


    # calculate accuracy based on nearest neighbor
    pred_train_id = torch.argmin(torch.sum((outputs_train.unsqueeze(dim=1) - y_templates.unsqueeze(dim=0))**2, dim=2), dim=1)
    pred_test_id = torch.argmin(torch.sum((outputs_test.unsqueeze(dim=1) - y_templates.unsqueeze(dim=0))**2, dim=2), dim=1)
    acc_nn_train = np.mean(pred_train_id.detach().numpy() == out_id_train) # training acc
    acc_nn_test = np.mean(pred_test_id.detach().numpy() == out_id_test) # testing acc
    acc_nn = (acc_nn_train*train_id.shape[0] + acc_nn_test*test_id.shape[0])/all_num # whole accuract
    test_acc_epochs.append(acc_nn_test)
    train_acc_epochs.append(acc_nn_train)

    # check if accuracy reaches a threshold (grokking time)
    if not reach_acc_train: # train
        if acc_nn_train >= 0.9:
            reach_acc_train = True
            iter_train = epoch
            
            
    if not reach_acc_test: # test
        if acc_nn_test >= 0.9:
            reach_acc_test = True
            iter_test = epoch

    # Count parallelograms in representation 
    PR = []
    PR_id = []
    if natural_gradient == False:
        latent = model.enc(x_templates).clone()
    latent_scale = latent/torch.std(latent,dim=0).unsqueeze(dim=0)
    
    count = 0
    for ii in range(P0_num):
        i, j = list(P0[ii])[0]
        m, n = list(P0[ii])[1]
        dist = latent_scale[i] + latent_scale[j] - latent_scale[m] - latent_scale[n]
        if (torch.mean(dist**2)<0.01):
            PR_id.append(ii)
            PR.append(P0[ii])

    rqi = len(PR)/P0_num

    # check if RQI reaches a high threshold (time scale of representation learning)
    if not reach_rqi:
        if rqi > 0.99:
            reach_rqi = True
            iter_rqi = epoch
        
    # logging
    if epoch % log == 0:
        print("epoch: %d  | loss: %.8f "%(epoch, loss_train.detach().numpy()))

# if fails within compute budget      
if not reach_acc_test:
    iter_test = epoch

if not reach_acc_train:
    iter_train = epoch

if not reach_rqi:
    iter_rqi = epoch

  labels_train = torch.tensor(y_templates[out_id[train_id]], dtype=torch.float, requires_grad=True)
  labels_test = torch.tensor(y_templates[out_id[test_id]], dtype=torch.float, requires_grad=True)


epoch: 0  | loss: 0.82493395 
epoch: 100  | loss: 0.76302254 
epoch: 200  | loss: 0.73361290 
epoch: 300  | loss: 0.69287986 
epoch: 400  | loss: 0.62817329 
epoch: 500  | loss: 0.52927905 
epoch: 600  | loss: 0.42162097 
epoch: 700  | loss: 0.34450689 
epoch: 800  | loss: 0.28353387 
epoch: 900  | loss: 0.22333899 
epoch: 1000  | loss: 0.17154217 
epoch: 1100  | loss: 0.13550892 
epoch: 1200  | loss: 0.10768575 
epoch: 1300  | loss: 0.08296081 
epoch: 1400  | loss: 0.06331909 
epoch: 1500  | loss: 0.04918868 
epoch: 1600  | loss: 0.03941762 
epoch: 1700  | loss: 0.03237715 
epoch: 1800  | loss: 0.02737385 
epoch: 1900  | loss: 0.02388324 
epoch: 2000  | loss: 0.02099634 
epoch: 2100  | loss: 0.01801854 
epoch: 2200  | loss: 0.01477067 
epoch: 2300  | loss: 0.01096826 
epoch: 2400  | loss: 0.00682222 
epoch: 2500  | loss: 0.00335687 
epoch: 2600  | loss: 0.00138202 
epoch: 2700  | loss: 0.00056325 
epoch: 2800  | loss: 0.00034914 
epoch: 2900  | loss: 0.00020545 
epoch: 3000  | loss: 0