In [1]:
import torch
import numpy as np
from typing import List
from collections import OrderedDict
from torch import nn
from torch import asin
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import math

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
"""
Download data from FashionMNIST
"""
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [4]:
from layers import LinearArcsine, RandomFeatureMap

In [5]:
"""
arcsin
nn with arcsine activation
"""
class ArcsinNN(nn.Module):
    def __init__(self, in_features: int, out_features: int, hidden_features: List[int] = None):
        """
        Initialize an ArcsinNN
        hidden_features: a list contains the dimension of hidden layers
        """
        super(ArcsinNN, self).__init__()
        self.num_hidden_layers = len(hidden_features) if hidden_features else 0
        self.Flatten = nn.Flatten() # flatten the input

        if self.num_hidden_layers:
            Layers = []
            dims = [in_features] + hidden_features
            for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
                Layers.append(('LinearArcsine'+f'{i}', LinearArcsine(in_features = in_dim, out_features = out_dim)))
            Layers.append(('Output', nn.Linear(in_features = hidden_features[-1], out_features= out_features)))
            self.Layers = nn.Sequential(OrderedDict(Layers))
            
        else:
            raise ValueError("Missing hidden_feautres!")
        
    def forward(self, x):
       
        x = self.Flatten(x)
        logits = self.Layers(x)
        
        return logits

In [6]:
model = ArcsinNN(in_features=28*28, out_features=10, hidden_features=[1024,1024,512]).to(device)
print(model)

ArcsinNN(
  (Flatten): Flatten(start_dim=1, end_dim=-1)
  (Layers): Sequential(
    (LinearArcsine0): LinearArcsine(in_features=784, out_features=1024, bias=True)
    (LinearArcsine1): LinearArcsine(in_features=1024, out_features=1024, bias=True)
    (LinearArcsine2): LinearArcsine(in_features=1024, out_features=512, bias=True)
    (Output): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [30]:
class ApproxArcsineNN(nn.Module):
    """
    Given a valid ArcsinNN model, approximate the ArcsinNN using RandomFeatureMap for each LinearArcsine layers
    """
    def __init__(self, model: ArcsinNN = None):
        super(ApproxArcsineNN, self).__init__()
        if model is None or type(model) is not ArcsinNN:
            raise ValueError("Missing input ArcsinNN model!")

        self.num_hidden_layers = model.num_hidden_layers

        self.Flatten = nn.Flatten() # flatten the input
        
        # create random feature maps
        dims = set(model.Layers[i].in_features for i in range(self.num_hidden_layers))
        self.RandomFeatureMaps = {d: RandomFeatureMap(d+1, device=model.Layers[0].weight.device).float() for d in dims}

        # copy and paste weights
        self.Linears = nn.ModuleList([nn.Linear(in_features=model.Layers[i].in_features, out_features=model.Layers[i].out_features) for i in range(self.num_hidden_layers)])
        for i in range(self.num_hidden_layers):
            self.Linears[i].weight = nn.Parameter(model.Layers[i].weight.clone().detach())
            self.Linears[i].bias = nn.Parameter(model.Layers[i].bias.clone().detach())
        self.Output = nn.Linear(in_features=model.Layers[-1].in_features, out_features=model.Layers[-1].out_features)
        self.Output.weight = nn.Parameter(model.Layers[-1].weight.clone().detach())
        self.Output.bias = nn.Parameter(model.Layers[-1].bias.clone().detach())

    def forward(self, x):

        x = self.Flatten(x) # [n, D_in]

        for i in range(self.num_hidden_layers):
            n, D = x.shape[0], x.shape[1]
            x = torch.concat([x, torch.ones((n, 1), device=x.device)], dim = 1) # [n, D_in + 1]
            W = torch.concat([self.Linears[i].weight, self.Linears[i].bias], dim = 0) # [D_in + 1, D_out]
            phi_x = self.RandomFeatureMaps[D](x) # [n, D_in + 1]
            phi_W = self.RandomFeatureMaps[D](W.T) # [D_out, D_in + 1]
            x = (np.pi/2)*((phi_x @ phi_W.T)/(D+1))

        # output layer
        logits = self.Output(x)
        
        return logits

In [75]:
class RepresentArcsineNN(nn.Module):
    """
    Given a valid ArcsinNN model, approximate the ArcsinNN using RandomFeatureMap for each LinearArcsine layers, and represent using composition of feature maps.
    """
    def __init__(self, model: ArcsinNN = None):
        super(RepresentArcsineNN, self).__init__()
        if model is None or type(model) is not ArcsinNN:
            raise ValueError("Missing input ArcsinNN model!")

        self.num_hidden_layers = model.num_hidden_layers

        self.Flatten = nn.Flatten() # flatten the input
        
        # create random feature maps
        self.input_dim = model.Layers[0].in_features
        self.RandomFeatureMaps = {i: RandomFeatureMap(self.input_dim + i + 1, device = model.Layers[0].weight.device).float() for i in range(self.num_hidden_layers)}

        # copy and paste weights
        self.Linears = nn.ModuleList([nn.Linear(in_features=model.Layers[i].in_features, out_features=model.Layers[i].out_features) for i in range(self.num_hidden_layers)])
        for i in range(self.num_hidden_layers):
            self.Linears[i].weight = nn.Parameter(model.Layers[i].weight.clone().detach())
            self.Linears[i].bias = nn.Parameter(model.Layers[i].bias.clone().detach())
        self.Output = nn.Linear(in_features=model.Layers[-1].in_features, out_features=model.Layers[-1].out_features)
        self.Output.weight = nn.Parameter(model.Layers[-1].weight.clone().detach())
        self.Output.bias = nn.Parameter(model.Layers[-1].bias.clone().detach())
    
    def forward(self, x):
        x = self.Flatten(x)
        
        n, D = x.shape[0], x.shape[1]        
        W = torch.eye(self.Linears[0].weight.shape[0], device=self.Linears[0].weight.device)
        for i in range(self.num_hidden_layers):
            # compute phi(phi(...phi(x))) 
            x = torch.concat([x, torch.ones((n, 1), device=x.device)], dim = 1)   
            x = self.RandomFeatureMaps[i](x)

            # compute phi(phi(...phi(W)))
            W = torch.matmul(W, self.Linears[i].weight)
            W = torch.concat([W, self.Linears[i].bias], dim = 0)
            W = (torch.pi/2/W.shape[0]) * self.RandomFeatureMaps[i](W.T).T

        x = torch.matmul(x, W)

        logits = self.Output(x)
        return logits

In [9]:
learning_rate = 5*1e-3
batch_size = 64
epochs = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(),lr=learning_rate)

In [10]:
train_loader = DataLoader(training_data,batch_size=batch_size,collate_fn=lambda x:tuple(x_.to(device) for x_ in default_collate(x)))
test_loader = DataLoader(test_data,batch_size=batch_size,collate_fn=lambda x:tuple(x_.to(device) for x_ in default_collate(x)))

In [11]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X.float())
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X.float())
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [12]:
"""
Train the arcsin NN.
"""
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, model, loss_fn, optimizer)
    test_loop(test_loader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.298501  [    0/60000]
loss: 0.608058  [ 6400/60000]
loss: 0.428516  [12800/60000]
loss: 0.697981  [19200/60000]
loss: 0.560684  [25600/60000]
loss: 0.475368  [32000/60000]
loss: 0.480378  [38400/60000]
loss: 0.529320  [44800/60000]
loss: 0.520500  [51200/60000]
loss: 0.524951  [57600/60000]
Test Error: 
 Accuracy: 80.7%, Avg loss: 0.553309 

Epoch 2
-------------------------------
loss: 0.429452  [    0/60000]
loss: 0.458487  [ 6400/60000]
loss: 0.385785  [12800/60000]
loss: 0.557985  [19200/60000]
loss: 0.591877  [25600/60000]
loss: 0.452306  [32000/60000]
loss: 0.435526  [38400/60000]
loss: 0.494277  [44800/60000]
loss: 0.525467  [51200/60000]
loss: 0.521806  [57600/60000]
Test Error: 
 Accuracy: 81.6%, Avg loss: 0.520884 

Epoch 3
-------------------------------
loss: 0.371641  [    0/60000]
loss: 0.450556  [ 6400/60000]
loss: 0.336251  [12800/60000]
loss: 0.529340  [19200/60000]
loss: 0.557300  [25600/60000]
loss: 0.435475  [32000/600

In [81]:
approx_model = ApproxArcsineNN(model).to(device)
approx_model

ApproxArcsineNN(
  (Flatten): Flatten(start_dim=1, end_dim=-1)
  (Linears): ModuleList(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=1024, bias=True)
    (2): Linear(in_features=1024, out_features=512, bias=True)
  )
  (Output): Linear(in_features=512, out_features=10, bias=True)
)

In [82]:
test_loop(test_loader,approx_model,loss_fn)

Test Error: 
 Accuracy: 36.0%, Avg loss: 3.267939 



In [83]:
learning_rate = 5*1e-3
batch_size = 64
epochs = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(approx_model.parameters(), lr=learning_rate)

In [84]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, approx_model, loss_fn, optimizer)
    test_loop(test_loader, approx_model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 3.196406  [    0/60000]
loss: 1.034924  [ 6400/60000]
loss: 0.702310  [12800/60000]
loss: 0.918660  [19200/60000]
loss: 0.943056  [25600/60000]
loss: 0.704924  [32000/60000]
loss: 0.729974  [38400/60000]
loss: 0.718146  [44800/60000]
loss: 0.934986  [51200/60000]
loss: 0.795974  [57600/60000]
Test Error: 
 Accuracy: 69.5%, Avg loss: 0.855636 

Epoch 2
-------------------------------
loss: 0.646713  [    0/60000]
loss: 0.827655  [ 6400/60000]
loss: 0.590898  [12800/60000]
loss: 0.832639  [19200/60000]
loss: 0.890666  [25600/60000]
loss: 0.716206  [32000/60000]
loss: 0.722073  [38400/60000]
loss: 0.720469  [44800/60000]
loss: 0.879605  [51200/60000]
loss: 0.781933  [57600/60000]
Test Error: 
 Accuracy: 70.3%, Avg loss: 0.826669 

Epoch 3
-------------------------------
loss: 0.639981  [    0/60000]
loss: 0.791992  [ 6400/60000]
loss: 0.569536  [12800/60000]
loss: 0.805711  [19200/60000]
loss: 0.858490  [25600/60000]
loss: 0.718397  [32000/600

In [76]:
composite_model = RepresentArcsineNN(model).to(device)
composite_model

RepresentArcsineNN(
  (Flatten): Flatten(start_dim=1, end_dim=-1)
  (Linears): ModuleList(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=1024, bias=True)
    (2): Linear(in_features=1024, out_features=512, bias=True)
  )
  (Output): Linear(in_features=512, out_features=10, bias=True)
)

In [77]:
test_loop(test_loader,composite_model,loss_fn)

Test Error: 
 Accuracy: 21.7%, Avg loss: 2.477477 



In [78]:
learning_rate = 5*1e-3
batch_size = 64
epochs = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(composite_model.parameters(), lr=learning_rate)

In [80]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, composite_model, loss_fn, optimizer)
    test_loop(test_loader, composite_model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.464188  [    0/60000]
loss: 1.463477  [ 6400/60000]
loss: 0.953338  [12800/60000]
loss: 1.205656  [19200/60000]
loss: 0.873847  [25600/60000]
loss: 1.017071  [32000/60000]
loss: 0.904895  [38400/60000]
loss: 0.802798  [44800/60000]
loss: 0.957256  [51200/60000]
loss: 0.843616  [57600/60000]
Test Error: 
 Accuracy: 71.0%, Avg loss: 0.816899 

Epoch 2
-------------------------------
loss: 0.656620  [    0/60000]
loss: 1.014019  [ 6400/60000]
loss: 0.550010  [12800/60000]
loss: 0.983556  [19200/60000]
loss: 0.716611  [25600/60000]
loss: 0.833702  [32000/60000]
loss: 0.779663  [38400/60000]
loss: 0.751554  [44800/60000]
loss: 0.873554  [51200/60000]
loss: 0.751323  [57600/60000]
Test Error: 
 Accuracy: 73.2%, Avg loss: 0.743759 

Epoch 3
-------------------------------
loss: 0.575610  [    0/60000]
loss: 0.926225  [ 6400/60000]
loss: 0.481549  [12800/60000]
loss: 0.923675  [19200/60000]
loss: 0.660986  [25600/60000]
loss: 0.773763  [32000/600