In [1]:
import torch
import numpy as np
from torch import nn
from torch import asin
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import math

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
"""
Download data from FashionMNIST
"""
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [30]:
class Linear_Arcsine(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Linear_Arcsine, self).__init__(in_features, out_features, bias, device, dtype)
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        if bias:
            self.bias = nn.Parameter(torch.empty((out_features,1), **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)


    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if self.bias is not None:
            W = torch.concat([self.weight, self.bias], dim = 1)
            input = torch.concat([input, torch.ones((input.shape[0],1),device=device)], dim=1)
        else:
            W = self.weight
        return torch.asin(nn.functional.normalize(input, dim = 1) @ nn.functional.normalize(W, dim = 1).T)

In [44]:
"""
arcsin
nn with arcsin activation
"""
class arcsinNN(nn.Module):
    def __init__(self):
        super(arcsinNN, self).__init__()
        self.Flatten = nn.Flatten() # flatten the input
        self.Layers = nn.Sequential(
            Linear_Arcsine(28*28, 1024),
            Linear_Arcsine(1024, 1024),
            Linear_Arcsine(1024,512),
            nn.Linear(512,10)
        )
        
    def forward(self, x):
       
        x = self.Flatten(x)
        logits = self.Layers(x)
        
        return logits

In [None]:
"""
arcsin
nn with arcsin activation
Layer 1: d_in:784, d_out:512 with arcsin activation
Layer 2: d_in:512, d_out:10 with no activation
"""
from torch import arcsin


class arcsinNN(nn.Module):
    def __init__(self):
        super(arcsinNN, self).__init__()
        self.Flatten = nn.Flatten() # flatten the input
        self.Linear1 = nn.Linear(28*28,512, bias = False)
        self.Linear2 = nn.Linear(512,10)
        
    def forward(self, x):
       
        x_flat = self.Flatten(x)

        x1 = self.Linear1(x_flat)

        x2 = torch.tensor(x_flat.T) # [d, n]

        W = self.Linear1.weight.T # [d, D]

        
        x = torch.arcsin(nn.functional.normalize(W, dim = 0).T @ nn.functional.normalize(x2, dim = 0) )
        
        # second layer
        logits = self.Linear2(x.T)
        
        return logits

In [45]:
arcsin_model = arcsinNN().to(device)
print(arcsin_model)

arcsinNN(
  (Flatten): Flatten(start_dim=1, end_dim=-1)
  (Layers): Sequential(
    (0): Linear_Arcsine(in_features=784, out_features=1024, bias=True)
    (1): Linear_Arcsine(in_features=1024, out_features=1024, bias=True)
    (2): Linear_Arcsine(in_features=1024, out_features=512, bias=True)
    (3): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [51]:
learning_rate = 5*1e-3
batch_size = 64
epochs = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=arcsin_model.parameters(),lr=learning_rate)

In [47]:
train_loader = DataLoader(training_data,batch_size=batch_size,collate_fn=lambda x:tuple(x_.to(device) for x_ in default_collate(x)))
test_loader = DataLoader(test_data,batch_size=batch_size,collate_fn=lambda x:tuple(x_.to(device) for x_ in default_collate(x)))

In [48]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X.float())
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X.float())
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [52]:
"""
Train the arcsin NN.
"""
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, arcsin_model, loss_fn, optimizer)
    test_loop(test_loader, arcsin_model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 0.788042  [    0/60000]
loss: 0.792586  [ 6400/60000]
loss: 0.594242  [12800/60000]
loss: 0.744619  [19200/60000]
loss: 0.606815  [25600/60000]
loss: 0.507904  [32000/60000]
loss: 0.541784  [38400/60000]
loss: 0.656298  [44800/60000]
loss: 0.511460  [51200/60000]
loss: 0.621085  [57600/60000]
Test Error: 
 Accuracy: 79.9%, Avg loss: 0.560171 

Epoch 2
-------------------------------
loss: 0.455348  [    0/60000]
loss: 0.522918  [ 6400/60000]
loss: 0.435473  [12800/60000]
loss: 0.526622  [19200/60000]
loss: 0.601740  [25600/60000]
loss: 0.446505  [32000/60000]
loss: 0.421663  [38400/60000]
loss: 0.581674  [44800/60000]
loss: 0.526054  [51200/60000]
loss: 0.506725  [57600/60000]
Test Error: 
 Accuracy: 79.8%, Avg loss: 0.552932 

Epoch 3
-------------------------------
loss: 0.430264  [    0/60000]
loss: 0.486433  [ 6400/60000]
loss: 0.350681  [12800/60000]
loss: 0.464652  [19200/60000]
loss: 0.564004  [25600/60000]
loss: 0.403191  [32000/600

In [55]:
indicator = lambda x : (x > 0).float()

In [80]:
"""
random feature map for arcsin
"""
class RandomFeatureMap(nn.Module):
    def __init__(self, size_in):
        """
        Build a feature map with output dimension equals input dimension

        size_in: input dimension d

        """
        super().__init__()
        self.size_in = size_in
        """
        weights is d by d Gaussian matrix 
        """
        weights = [nn.Parameter(torch.tensor(np.random.normal(loc=0.0, scale=1.0, size=(self.size_in,self.size_in)),device=device), requires_grad=False)] # for each N, generate N random Rademacher vectors
        self.weights = nn.ParameterList(weights)
    
    
    
    def forward(self, x):
        """
        return feature map for arcsin: for each Z in weights matrix have the feature sign(Z dot x) = 2*indicator(Z dot x) -1
        with these features, phix^T phiy estimates d*(2/pi)arcsin(x^T y / ||x||||y||) 
        """

        return torch.stack([2*indicator(x @ weight)-1 for weight in self.weights]).squeeze(0)




In [61]:
"""
Sanity check
"""
F = RandomFeatureMap(784).float()
x = torch.rand((20,784))
print(F(x).shape)

torch.Size([20, 784])


In [None]:
D = arcsin_model.get_parameter(target='Linear1.weight').shape[0]
D

In [68]:
for i,p in enumerate(arcsin_model.parameters()):
    print(p.shape,i)

torch.Size([1024, 784]) 0
torch.Size([1024, 1]) 1
torch.Size([1024, 1024]) 2
torch.Size([1024, 1]) 3
torch.Size([512, 1024]) 4
torch.Size([512, 1]) 5
torch.Size([10, 512]) 6
torch.Size([10]) 7


In [131]:
"""
approximate the arcsin nn using the feature map
"""
class ApproxNN(nn.Module):
    def __init__(self, model: nn.Module = None):
        super(ApproxNN, self).__init__()
        self.Flatten = nn.Flatten() # flatten the input
        
        # # Initialize two linear layers with weights from the trained arcsin neural network.
        # self.Linear1 = nn.Linear(28*28,1024)
        # self.Linear1.weight = nn.Parameter(torch.clone(arcsin_model.get_parameter(target='Linear1.weight')))
        # self.Linear2 = nn.Linear(512,10)
        # self.Linear2.weight = nn.Parameter(torch.clone(arcsin_model.get_parameter(target='Linear2.weight')))
        # self.Linear2.bias = nn.Parameter(torch.clone(arcsin_model.get_parameter(target='Linear2.bias')))
        self.Linear_Arcsines = nn.ModuleList([nn.Linear(28*28,1024), nn.Linear(1024,1024), nn.Linear(1024, 512)])
        self.Linear = nn.Linear(512,10)
        self.RandomFeatureMaps = [RandomFeatureMap(28*28 + 1).float(), RandomFeatureMap(1024 + 1).float()] # Initialize the random feature map
        if model is not None:
            params = [x for x in model.parameters()]
            self.Linear_Arcsines[0].weight, self.Linear_Arcsines[0].bias = params[0], params[1]
            self.Linear_Arcsines[1].weight, self.Linear_Arcsines[1].bias = params[2], params[3]
            self.Linear_Arcsines[2].weight, self.Linear_Arcsines[2].bias = params[4], params[5]
            self.Linear.weight, self.Linear.bias = params[6], params[7]


    
        
    def forward(self, x):

        x = self.Flatten(x) # [n, d]
        D = x.shape[1]
        x = torch.concat([x, torch.ones((x.shape[0],1),device=device)], dim=1) # [n, d+1]
        phi_x1 = self.RandomFeatureMaps[0](x)  # [n, d+1]
        W1 = torch.concat([self.Linear_Arcsines[0].weight, self.Linear_Arcsines[0].bias.reshape(-1,1)], dim = 1) # [D, d+1]
        phi_W1 = self.RandomFeatureMaps[0](W1)  # [D, d+1]
        x = (np.pi/2)*((phi_x1 @ phi_W1.T)/D) # [n, D]

        x = torch.concat([x, torch.ones((x.shape[0],1),device=device)], dim=1) # [n, D+1]
        D = x.shape[1]
        phi_x2 = self.RandomFeatureMaps[1](x) # [n, D+1]
        W2 = torch.concat([self.Linear_Arcsines[1].weight, self.Linear_Arcsines[1].bias.reshape(-1,1)], dim = 1) # [D, D+1]
        phi_W2 = self.RandomFeatureMaps[1](W2)
        x = (np.pi/2)*((phi_x2 @ phi_W2.T)/D) # [n, D]

        x = torch.concat([x, torch.ones((x.shape[0],1),device=device)], dim=1) # [n, D+1]
        D = x.shape[1]
        phi_x3= self.RandomFeatureMaps[1](x) # [n, D+1]
        W3 = torch.concat([self.Linear_Arcsines[2].weight, self.Linear_Arcsines[2].bias.reshape(-1,1)], dim = 1) # [D, D+1]
        phi_W3 = self.RandomFeatureMaps[1](W3)
        x = (np.pi/2)*((phi_x3 @ phi_W3.T)/D) # [n, D]


        # output layer
        logits = self.Linear(x)
        
        return logits

In [100]:
approx_model = ApproxNN(arcsin_model).to(device)
approx_model

ApproxNN(
  (Flatten): Flatten(start_dim=1, end_dim=-1)
  (Linear_Arcsines): ModuleList(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=1024, bias=True)
    (2): Linear(in_features=1024, out_features=512, bias=True)
  )
  (Linear): Linear(in_features=512, out_features=10, bias=True)
)

In [102]:
test_loop(test_loader,approx_model,loss_fn)

Test Error: 
 Accuracy: 36.8%, Avg loss: 3.193804 



In [103]:
learning_rate = 5*1e-3
batch_size = 64
epochs = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(approx_model.parameters(), lr=learning_rate)

In [104]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, approx_model, loss_fn, optimizer)
    test_loop(test_loader, approx_model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.805723  [    0/60000]
loss: 0.902240  [ 6400/60000]
loss: 0.787267  [12800/60000]
loss: 1.077304  [19200/60000]
loss: 0.799184  [25600/60000]
loss: 0.807200  [32000/60000]
loss: 0.851136  [38400/60000]
loss: 0.780177  [44800/60000]
loss: 0.707347  [51200/60000]
loss: 0.846230  [57600/60000]
Test Error: 
 Accuracy: 72.0%, Avg loss: 0.795850 

Epoch 2
-------------------------------
loss: 0.618774  [    0/60000]
loss: 0.739327  [ 6400/60000]
loss: 0.665057  [12800/60000]
loss: 0.999990  [19200/60000]
loss: 0.723274  [25600/60000]
loss: 0.777493  [32000/60000]
loss: 0.825820  [38400/60000]
loss: 0.762647  [44800/60000]
loss: 0.677188  [51200/60000]
loss: 0.795998  [57600/60000]
Test Error: 
 Accuracy: 72.6%, Avg loss: 0.770942 

Epoch 3
-------------------------------
loss: 0.603661  [    0/60000]
loss: 0.709153  [ 6400/60000]
loss: 0.632986  [12800/60000]
loss: 0.970057  [19200/60000]
loss: 0.689899  [25600/60000]
loss: 0.764959  [32000/600

In [132]:
approx_model_2 = ApproxNN().to(device)
approx_model_2

ApproxNN(
  (Flatten): Flatten(start_dim=1, end_dim=-1)
  (Linear_Arcsines): ModuleList(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=1024, bias=True)
    (2): Linear(in_features=1024, out_features=512, bias=True)
  )
  (Linear): Linear(in_features=512, out_features=10, bias=True)
)

In [133]:
learning_rate = 5*1e-3
batch_size = 64
epochs = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(approx_model_2.parameters(), lr=learning_rate)

In [134]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, approx_model_2, loss_fn, optimizer)
    test_loop(test_loader, approx_model_2, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.309504  [    0/60000]
loss: 1.278397  [ 6400/60000]
loss: 0.853149  [12800/60000]
loss: 0.977041  [19200/60000]
loss: 0.830342  [25600/60000]
loss: 0.846256  [32000/60000]
loss: 0.857329  [38400/60000]
loss: 0.803575  [44800/60000]
loss: 0.769239  [51200/60000]
loss: 0.762467  [57600/60000]
Test Error: 
 Accuracy: 73.6%, Avg loss: 0.735464 

Epoch 2
-------------------------------
loss: 0.638553  [    0/60000]
loss: 0.810593  [ 6400/60000]
loss: 0.507958  [12800/60000]
loss: 0.758793  [19200/60000]
loss: 0.715028  [25600/60000]
loss: 0.726271  [32000/60000]
loss: 0.771497  [38400/60000]
loss: 0.778427  [44800/60000]
loss: 0.736700  [51200/60000]
loss: 0.701074  [57600/60000]
Test Error: 
 Accuracy: 74.6%, Avg loss: 0.690089 

Epoch 3
-------------------------------
loss: 0.577960  [    0/60000]
loss: 0.756343  [ 6400/60000]
loss: 0.461069  [12800/60000]
loss: 0.705386  [19200/60000]
loss: 0.672133  [25600/60000]
loss: 0.674270  [32000/600

In [None]:
"""
arcsin
nn with arcsin activation
Layer 1: d_in:784, d_out:512 with arcsin activation
Layer 2: d_in:512, d_out:10 with no activation
"""
class arcsinNN(nn.Module):
    def __init__(self):
        super(arcsinNN, self).__init__()
        self.Flatten = nn.Flatten() # flatten the input

        # Initialize two linear layers with weights from the trained neural network.
        self.Linear1 = nn.Linear(28*28,512, bias = False)
        self.Linear1.weight = nn.Parameter(torch.clone(approx_model.get_parameter(target='Linear1.weight')))
        self.Linear2 = nn.Linear(512,10)
        self.Linear2.weight = nn.Parameter(torch.clone(approx_model.get_parameter(target='Linear2.weight')))
        self.Linear2.bias = nn.Parameter(torch.clone(approx_model.get_parameter(target='Linear2.bias')))
        
    def forward(self, x):
       
        
        
        x_flat = self.Flatten(x)

        x1 = self.Linear1(x_flat)

        x2 = torch.tensor(x_flat.T) # [d, n]

        W = self.Linear1.weight.T # [d, D]

        
        x = torch.arcsin(nn.functional.normalize(W, dim = 0).T @ nn.functional.normalize(x2, dim = 0) )
        
        # second layer
        logits = self.Linear2(x.T)
        
        return logits

In [None]:
arcsin_model = arcsinNN().to(device)
print(arcsinmodel)

In [None]:
test_loop(test_loader,arcsin_model,loss_fn)