In [None]:
import numpy as np
from typing import List

import torch
import math
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F

import argparse
from tqdm import tqdm

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

from statistics import mean, stdev

from source import *
from MNIST_source import *

In [None]:
# Number of original images
n = 3
# Number of noisy samples for each original sample
m = 10
# Widths of both neural networks
widths = [28*28, 28*28 + 1, 28*28 + 2, n]

num_epochs = 100

num_trials = 2


In [None]:
radnet_final_losses = []
radnet_final_accuracies = []

relunet_final_losses = []
relunet_final_accuracies = []

for trial in tqdm(range(num_trials)):
        rad_los, rad_acc, relu_los, relu_acc = train_both(
            num_samples = n,
            m_copies = m,
            dim_vector= widths,
            verbose=False,
            num_epochs=num_epochs)
        radnet_final_losses.append(round(rad_los[-1].item(),5))
        radnet_final_accuracies.append(rad_acc[-1].item())
        relunet_final_losses.append(round(relu_los[-1].item(),5))
        relunet_final_accuracies.append(relu_acc[-1].item())

print("Train both networks for %d epochs" % num_epochs)
print( "Step ReLU radial network losses:",  radnet_final_losses)
print( "Step ReLU radial network accuracies:", radnet_final_accuracies)
print("")
print( "ReLU MLP losses:", relunet_final_losses)
print( "ReLU MLP accuracies:", relunet_final_accuracies)

In [None]:

print("Over %d trials:" % num_trials)

print("Radnet Loss = {1:.3e} +/- {2:.3e}".
    format(radnet_final_losses, mean(radnet_final_losses), stdev(radnet_final_losses))
)

print("Radnet Accuracy = {1:.3e} +/- {2:.3e}".
    format(radnet_final_accuracies, mean(radnet_final_accuracies), stdev(radnet_final_accuracies))
)


print("ReLU MLP Loss = {1:.3e} +/- {2:.3e}".
    format(relunet_final_losses, mean(relunet_final_losses), stdev(relunet_final_losses))
)

print(
    "ReLU MLP Accuracy = {1:.3e} +/- {2:.3e}".
    format(relunet_final_accuracies, mean(relunet_final_accuracies), stdev(relunet_final_accuracies))
)

In [None]:
if False:
    radnet_final_losses.append(round(rad_los[-1].item(),5))
    radnet_final_accuracies.append(rad_acc[-1].item())
    relunet_final_losses.append(round(relu_los[-1].item(),5))
    relunet_final_accuracies.append(relu_acc[-1].item())

In [None]:
if False:
    radnet_final_losses.append(rad_los[-1])
    radnet_final_accuracies.append(rad_acc[-1])
    relunet_final_losses.append(relu_los[-1])
    relunet_final_accuracies.append(relu_acc[-1]

# Experiments to do

In [None]:
ns = [3]
ms = [1000]
d= 28*28
dim_vec = [d, d+1, d+2, d+3, 3]

# 10 runs for each, use seed for reproducibility

# Desired metric: mean square error or accuracy (search pytorch metric accuracy)
# Maybe use cross entropy instead of MSE

In [None]:
n = 3
m=1000
# test a bunch of different rates for each of radnet, relu MLP
# the "optimal" learning rate for the radnet and relu could be different

# Hyperparameter search for the learning rate

# If the accuracy of both is the same, can look at the rate of convergence (eyeballing the plot) with optimized learning rate

In [None]:
# Generalization

n = 3
m=1000 + 1000

# split training and test set, equal numbers of each

# Use best learning rate for each radnet, relu MLP
# Compare the test set loss and accuracy for each, 


In [None]:
ns = [3,4,5]
ms = [100,500,1000,10000]
d= 28*28
dim_vecs = [
    [d, d+1, d+2, d+3, 1],
    [d, d+1, d+2, d+3, d+4, 1],
    [d, d+1, d+2, d+3, d+4, d+5, 1]]

# 12 combinations, and do 10 runs for each

# Desired metric: mean square error or accuracy
# Maybe use cross entropy instead of MSE

Change data set:

* overlap circles
* choose one sample from each MNIST label (0-9), or just 0,1,2,3



Change model:

* if the models are tied, can make the problem harder by reducing the number of parameters (or doing the data set changes)


# Network for learning all of MNIST

In [None]:
train_features_flat = train_features.flatten(1)
train_features_flat.shape

In [None]:
train_labels.shape
train_labels_onehot = F.one_hot(train_labels, num_classes=10)

In [None]:
radnet = RadNet(eta=torch.sigmoid, dims=[28*28,28*28, 28 , 28,10], has_bias=False)

In [None]:
model_trained, model_losses = training_loop(
    n_epochs = 3000, 
    learning_rate = 0.05,
    model = radnet,
    params = list(radnet.parameters()),
    x_train = train_features_flat,
    y_train = train_labels_onehot,
    verbose=True)

In [None]:
relu_net = torch.nn.Sequential(
    torch.nn.Linear(28*28, 28*28),
    torch.nn.ReLU(),
    torch.nn.Linear(28*28, 28),
    torch.nn.ReLU(),
    torch.nn.Linear(28, 28),
    torch.nn.ReLU(),
    torch.nn.Linear(28, 10)
    )

In [None]:
relu_model_trained, relu_model_losses = training_loop(
    n_epochs = 3000, 
    learning_rate = 0.05,
    model = relu_net,
    params = list(relu_net.parameters()),
    x_train = train_features_flat,
    y_train = train_labels_onehot,
    verbose=True)

# Get noisy sample

In [None]:
num_samples = 3
m_copies = 100

noisy_threes, noisy_labels = add_noise(label=3, n=int(num_samples), m=int(m_copies), verbose =False)
print(noisy_threes.shape, noisy_labels.shape)

In [None]:
noisy_threes_flat = noisy_threes.flatten(1)

# Train radnet with the noise

In [None]:
d=28*28
dim_vector = [d, d+1, d+2, d+3,num_samples]

radnet = RadNet(eta=stepReLU_eta, dims=dim_vector, has_bias=False)

In [None]:
model_trained, model_losses , model_accuracies = ce_training_loop(
    n_epochs = 1000, 
    learning_rate = 0.05,
    model = radnet,
    params = list(radnet.parameters()),
    x_train = noisy_threes_flat,
    y_train = noisy_labels,
    verbose=True)

In [None]:
plt.plot(torch.tensor(model_losses).detach()[:20])
plt.show()
plt.plot(torch.tensor(model_accuracies).detach()[:20])
plt.show()

In [None]:
if False:
    radnet = RadNet(eta=stepReLU_eta, dims=dim_vector, has_bias=False).to(device) 
    model_trained, model_losses, model_accuracies = ce_training_loop(
        n_epochs = 2000, 
        learning_rate = 0.05,
        model = radnet,
        params = list(radnet.parameters()),
        x_train = noisy_threes_flat.to(device),
        y_train = noisy_labels.to(device),
        verbose=True)


# Train ReLU net with noise

In [None]:
relu_net = torch.nn.Sequential(
    torch.nn.Linear(28*28, dim_vector[1]),
    torch.nn.ReLU(),
    torch.nn.Linear(dim_vector[1], dim_vector[2]),
    torch.nn.ReLU(),
    torch.nn.Linear(dim_vector[2], dim_vector[3]),
    torch.nn.ReLU(),
    torch.nn.Linear(dim_vector[3],num_samples)
    )

In [None]:
relu_model_trained, relu_model_losses, relu_model_accuracies = ce_training_loop(
    n_epochs = 1000, 
    learning_rate = 0.05,
    model = relu_net,
    params = list(relu_net.parameters()),
    x_train = noisy_threes_flat,
    y_train = noisy_labels,
    verbose=True)

In [None]:
plt.plot(torch.tensor(relu_model_losses).detach()[:20])
plt.show()
plt.plot(torch.tensor(relu_model_accuracies).detach()[:20])
plt.show()

# Train both nets

In [None]:
train_both(
    num_samples = 3,
    m_copies = 10,
    dim_vector= [28*28, 2, 2, 3])

In [None]:
train_both(
    num_samples = 3,
    m_copies = 10,
    dim_vector= [28*28, 28*28 + 1, 28*28 + 2, 3])

In [None]:



if False:
    for i in range(len(n)):
        for m in ms:
            train_both(
                num_samples = n[i],
                m_copies = m,
                dim_vector= dims[i])
            


# Scraps

In [None]:
# Calculate distances
radius = float('inf')
for i in range(n):
    for j in range(i+1,n):
        if torch.linalg.norm(threes[i] - threes[j]).item() < radius:
            radius = torch.linalg.norm(threes[i] - threes[j]).item()
radius = radius/2.5
radius

In [None]:
noisy_threes = torch.Tensor(torch.Size([int(n*m), 1, 28, 28]))
noisy_labels = torch.Tensor(torch.Size([n*m, n]))
for i in range(n):
    for j in range(m):
        noisy_threes[i*n + j]= threes[i] + noise[j]   
        noisy_labels[i*m + j]=  torch.eye(n)[i]
        


if False:
    print(noisy_threes.shape, noisy_labels.shape)
if False:
    plt.imshow(threes[0][0], cmap="gray")
    plt.show()
    plt.imshow(noisy_threes[0][0], cmap="gray")
    plt.show()
    plt.imshow(noisy_threes[1][0], cmap="gray")