# Crash Course: the Double Descent Curve with an OPU

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import RidgeClassifier
import torch
import torch.nn as nn 
import torch.optim as optim
import torchvision 
from tqdm.notebook import  tqdm 

from lightonopu import OPU

### Learning a binary encoding

Here we define the artchitecture for the autoencoder 

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, kernel_depth=24, channel_input=1, kernel_size=6, stride=2, padding=2, beta=1.):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Conv2d(channel_input, kernel_depth, kernel_size, stride, padding)
        self.decoder = nn.ConvTranspose2d(kernel_depth, channel_input, kernel_size, stride, padding)
        self.beta = beta
        nn.init.xavier_uniform_(self.encoder.weight, gain=5 / 3)
        nn.init.xavier_uniform_(self.decoder.weight, gain=5 / 3)

    def forward(self, x):
        h = torch.tanh(self.beta * self.encoder(x)) / self.beta
        y = self.decoder(h)
        return y

    def encode(self, x):
        h = (torch.sign(self.encoder(x)) + 1) / 2
        return h

In [None]:
ae = Autoencoder()

#### Training the autoencoder

Now we download the data. To train the autoencoder we can use the full dataset. To recover the double descent curve 
we will use a subsample of it. 

In [None]:
normalize = torchvision.transforms.Normalize((0.1307,), (0.3081,))
transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms)

# dataloader to train the AE 
mnist_dl_ae_train = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
# dataloader to extract test data
mnist_dl_test = torch.utils.data.DataLoader(testset, batch_size=10000, shuffle=False)

# dataloader to extract a sub sample of the training data
indices = list(range(len(trainset)))
np.random.seed(1234)
np.random.shuffle(indices)
train_idx = indices[:10000]
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
mnist_dl_train = torch.utils.data.DataLoader(trainset, batch_size=10000, sampler=train_sampler)

We train the autoencoder with the mean squared error loss. We double `beta` at each epoch of the training. 

In [None]:
mse = nn.MSELoss()
parameters = ae.parameters()
optimizer = optim.Adam(params=parameters, lr=0.001)
device = 'cuda' # 'cuda' to use the GPU for the training. 
ae = ae.to(device)
n_epochs = 10

for epoch in tqdm(range(n_epochs)):
    
    for i, (x, y) in enumerate(mnist_dl_ae_train):
        optimizer.zero_grad()
        x = x.to(device)
        x_tilde = ae(x)
        loss = mse(x_tilde, x) 

        loss.backward()
        optimizer.step()
        
    ae.beta = ae.beta * 2 

In [None]:
X_train, y_train = iter(mnist_dl_train).next()
X_test, y_test = iter(mnist_dl_test).next()

In [None]:
# Here we finally obtain the binary representation we seeked for our data. 
ae = ae.to('cpu')
X_train_binary = ae.encode(X_train).view(X_train.shape[0],-1).detach()
X_test_binary = ae.encode(X_test).view(X_test.shape[0],-1).detach()
print(X_train_binary.shape, X_test_binary.shape)

### Recover the double descent curve

In [None]:
# We chose the random projections we are going to use and we randomly project our data using the OPU
uniform_sampling = [500 * k for k in range(4, 41)]
interpolation_point_sampling =  [9250, 9750, 10250, 10750]
rps_list = sorted(uniform_sampling + interpolation_point_sampling)
max_rps = max(rps_list)

opu = OPU(n_components=max_rps)
with opu:
    X_train_rf = opu.transform1d(X_train_binary.int())
    X_test_rf = opu.transform1d(X_test_binary.int())

In [None]:
train_accuracy = np.zeros(len(rps_list))
test_accuracy = np.zeros(len(rps_list))

for i, rp in enumerate(tqdm(rps_list)):
    clf = RidgeClassifier()
    clf.fit(X_train_rf[:,:rp], y_train) 
    train_accuracy[i] = clf.score(X_train_rf[:,:rp], y_train)
    test_accuracy[i] = clf.score(X_test_rf[:,:rp], y_test)

In [None]:
# ta da! 
plt.plot(rps_list, 1-train_accuracy, label='OPU training error')
plt.plot(rps_list, 1-test_accuracy, label='OPU test error')
plt.title('MNIST, subset of 10.000 samples')
plt.ylabel('Error')
plt.xlabel('Random features')
plt.legend()
plt.grid()
plt.show()