In [1]:
import os, sys, json, numpy as np, pandas as pd, pickle, itertools, logging
from sklearn.preprocessing import MinMaxScaler
from typing import TypeVar, Union, List, Literal, Tuple

In [2]:
from data.make_numpy import preprocessing_pipeline, UNZIPPED_FILE, DATA_DIR
from fundamental_domain_projections.example1 import fundamental_domain_projection
from learn.utils import FundamentalDomainProjectionDataset, generate_dataloaders

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from learn.mlp1 import FDPCNN, FDPNN, train_loop, test_loop

In [3]:
X, y = preprocessing_pipeline(UNZIPPED_FILE, extraction_key='h_1,1')

# Load Dataset

In [12]:
use_precomputed_dirichlet = False
apply_random_permutation = False
transformation = fundamental_domain_projection
perturbation = np.random.rand(4, 26)

In [13]:
perturbation

array([[0.56498019, 0.78951296, 0.15023034, 0.7396962 , 0.65125309,
        0.85071186, 0.94757288, 0.36609931, 0.42502001, 0.54070844,
        0.04448247, 0.42150368, 0.34812036, 0.57289827, 0.95171061,
        0.61174848, 0.24286761, 0.75310309, 0.39754917, 0.01757864,
        0.89722829, 0.65842866, 0.39222438, 0.53650258, 0.76958349,
        0.81699585],
       [0.41770018, 0.30949753, 0.85399435, 0.53040931, 0.66100004,
        0.00934852, 0.32210419, 0.45654338, 0.97181452, 0.20711923,
        0.69753242, 0.69184575, 0.3070336 , 0.51813746, 0.35300409,
        0.39336356, 0.68164905, 0.01508491, 0.314941  , 0.85398818,
        0.63423562, 0.41627382, 0.43258216, 0.88398507, 0.3607944 ,
        0.1000546 ],
       [0.25797729, 0.3551683 , 0.07248914, 0.4393289 , 0.92821462,
        0.42453822, 0.53887573, 0.14983352, 0.28542722, 0.83152432,
        0.18723898, 0.22226257, 0.58073391, 0.90446258, 0.71904151,
        0.47307668, 0.63247218, 0.010486  , 0.39024732, 0.99783077,
      

In [14]:
dataset = FundamentalDomainProjectionDataset(
    apply_random_permutation=apply_random_permutation,
    use_fixed_perturbation=True,
    perturbation=perturbation,
    transformation='combinatorial'
)

In [6]:
if use_precomputed_dirichlet:
    with open(os.path.join(DATA_DIR, 'dircichlet_X.pickle'), 'rb') as f:
        X = pickle.load(f)
    
    # random permutation and dirichlet don't currently work together
    if apply_random_permutation:
        x = []
        for _x in X:
            x.append(np.transpose(np.random.permutation(np.transpose(np.random.permutation(_x)))))
    
    dataset.X = X

In [15]:
train_loader, valid_loader = generate_dataloaders(dataset)

# Fully Connected Linear Network

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [17]:
class FDPNN(nn.Module):
    def __init__(self, data_shape, num_classes):
        super(FDPNN, self).__init__()
        
        self.flatten = nn.Flatten()
        self.linears = nn.Sequential(
            nn.Linear(np.multiply(*data_shape), 100),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU(),
            nn.Linear(50, num_classes),            
        )
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.linears(x)
        logits = self.softmax(x)
        return logits

In [18]:
model = FDPNN(dataset.data_shape, dataset.num_classes)
model.cuda()

FDPNN(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linears): Sequential(
    (0): Linear(in_features=104, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=43, bias=True)
  )
  (softmax): Softmax(dim=1)
)

In [19]:
learning_rate = 0.001

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
epochs = 100
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, model, loss_fn, optimizer)
    test_loop(valid_loader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 3.760965  [    0/54927]
loss: 3.653866  [ 6400/54927]
loss: 3.648849  [12800/54927]
loss: 3.628884  [19200/54927]
loss: 3.658034  [25600/54927]
loss: 3.670732  [32000/54927]
loss: 3.633056  [38400/54927]
loss: 3.656667  [44800/54927]
loss: 3.585806  [51200/54927]
Test Error: 
 Accuracy: 17.0%, Avg loss: 3.630758 

Epoch 2
-------------------------------
loss: 3.603501  [    0/54927]
loss: 3.624758  [ 6400/54927]
loss: 3.594375  [12800/54927]
loss: 3.627121  [19200/54927]
loss: 3.646091  [25600/54927]
loss: 3.635314  [32000/54927]
loss: 3.641171  [38400/54927]
loss: 3.611329  [44800/54927]
loss: 3.569936  [51200/54927]
Test Error: 
 Accuracy: 17.0%, Avg loss: 3.629097 

Epoch 3
-------------------------------
loss: 3.591466  [    0/54927]
loss: 3.616492  [ 6400/54927]
loss: 3.646902  [12800/54927]
loss: 3.554010  [19200/54927]
loss: 3.658225  [25600/54927]
loss: 3.585628  [32000/54927]
loss: 3.656567  [38400/54927]
loss: 3.576707  [44800/549

loss: 3.567011  [38400/54927]
loss: 3.481462  [44800/54927]
loss: 3.578669  [51200/54927]
Test Error: 
 Accuracy: 18.5%, Avg loss: 3.614111 

Epoch 24
-------------------------------
loss: 3.595233  [    0/54927]
loss: 3.531355  [ 6400/54927]
loss: 3.522427  [12800/54927]
loss: 3.536384  [19200/54927]
loss: 3.614417  [25600/54927]
loss: 3.589731  [32000/54927]
loss: 3.518151  [38400/54927]
loss: 3.601320  [44800/54927]
loss: 3.462253  [51200/54927]
Test Error: 
 Accuracy: 18.7%, Avg loss: 3.613029 

Epoch 25
-------------------------------
loss: 3.555867  [    0/54927]
loss: 3.590589  [ 6400/54927]
loss: 3.557575  [12800/54927]
loss: 3.485072  [19200/54927]
loss: 3.492744  [25600/54927]
loss: 3.478560  [32000/54927]
loss: 3.517209  [38400/54927]
loss: 3.512943  [44800/54927]
loss: 3.614679  [51200/54927]
Test Error: 
 Accuracy: 18.3%, Avg loss: 3.615366 

Epoch 26
-------------------------------
loss: 3.518381  [    0/54927]
loss: 3.515082  [ 6400/54927]
loss: 3.446522  [12800/54927]
l

loss: 3.536309  [ 6400/54927]
loss: 3.593083  [12800/54927]
loss: 3.532802  [19200/54927]
loss: 3.561984  [25600/54927]
loss: 3.519983  [32000/54927]
loss: 3.445631  [38400/54927]
loss: 3.518586  [44800/54927]
loss: 3.484207  [51200/54927]
Test Error: 
 Accuracy: 18.7%, Avg loss: 3.612903 

Epoch 47
-------------------------------
loss: 3.537087  [    0/54927]
loss: 3.572087  [ 6400/54927]
loss: 3.518876  [12800/54927]
loss: 3.556049  [19200/54927]
loss: 3.458176  [25600/54927]
loss: 3.577660  [32000/54927]
loss: 3.516002  [38400/54927]
loss: 3.603724  [44800/54927]
loss: 3.469369  [51200/54927]
Test Error: 
 Accuracy: 18.4%, Avg loss: 3.615147 

Epoch 48
-------------------------------
loss: 3.462733  [    0/54927]
loss: 3.480727  [ 6400/54927]
loss: 3.535088  [12800/54927]
loss: 3.415194  [19200/54927]
loss: 3.551581  [25600/54927]
loss: 3.527924  [32000/54927]
loss: 3.603667  [38400/54927]
loss: 3.581209  [44800/54927]
loss: 3.508894  [51200/54927]
Test Error: 
 Accuracy: 18.7%, Avg

loss: 3.432130  [51200/54927]
Test Error: 
 Accuracy: 18.4%, Avg loss: 3.615089 

Epoch 69
-------------------------------
loss: 3.473228  [    0/54927]
loss: 3.517974  [ 6400/54927]
loss: 3.537590  [12800/54927]
loss: 3.272710  [19200/54927]
loss: 3.583933  [25600/54927]
loss: 3.538459  [32000/54927]
loss: 3.432802  [38400/54927]
loss: 3.443125  [44800/54927]
loss: 3.508122  [51200/54927]
Test Error: 
 Accuracy: 18.6%, Avg loss: 3.613587 

Epoch 70
-------------------------------
loss: 3.572545  [    0/54927]
loss: 3.479675  [ 6400/54927]
loss: 3.451295  [12800/54927]
loss: 3.565490  [19200/54927]
loss: 3.488529  [25600/54927]
loss: 3.409922  [32000/54927]
loss: 3.509307  [38400/54927]
loss: 3.503063  [44800/54927]
loss: 3.508873  [51200/54927]
Test Error: 
 Accuracy: 18.7%, Avg loss: 3.613662 

Epoch 71
-------------------------------
loss: 3.494461  [    0/54927]
loss: 3.539883  [ 6400/54927]
loss: 3.535284  [12800/54927]
loss: 3.549049  [19200/54927]
loss: 3.457091  [25600/54927]
l

loss: 3.567782  [19200/54927]
loss: 3.466147  [25600/54927]
loss: 3.491699  [32000/54927]
loss: 3.458293  [38400/54927]
loss: 3.519391  [44800/54927]
loss: 3.411931  [51200/54927]
Test Error: 
 Accuracy: 18.6%, Avg loss: 3.614726 

Epoch 92
-------------------------------
loss: 3.443698  [    0/54927]
loss: 3.534945  [ 6400/54927]
loss: 3.518998  [12800/54927]
loss: 3.530058  [19200/54927]
loss: 3.393504  [25600/54927]
loss: 3.532301  [32000/54927]
loss: 3.549994  [38400/54927]
loss: 3.480695  [44800/54927]
loss: 3.489657  [51200/54927]
Test Error: 
 Accuracy: 18.6%, Avg loss: 3.613618 

Epoch 93
-------------------------------
loss: 3.445326  [    0/54927]
loss: 3.371163  [ 6400/54927]
loss: 3.455224  [12800/54927]
loss: 3.520553  [19200/54927]
loss: 3.565429  [25600/54927]
loss: 3.437530  [32000/54927]
loss: 3.450952  [38400/54927]
loss: 3.441002  [44800/54927]
loss: 3.451412  [51200/54927]
Test Error: 
 Accuracy: 18.4%, Avg loss: 3.615615 

Epoch 94
-------------------------------
l

# CNN Network

In [12]:
model = FDPCNN(dataset.data_shape, dataset.num_classes)
model.cuda()

FDPCNN(
  (linears): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=192, out_features=43, bias=True)
  )
  (softmax): Softmax(dim=1)
)

In [39]:
learning_rate = 0.001

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [40]:
epochs = 100
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, model, loss_fn, optimizer, add_channel_dim=True)
    test_loop(valid_loader, model, loss_fn, add_channel_dim=True)
print("Done!")

Epoch 1
-------------------------------
loss: 3.759778  [    0/54927]
loss: 3.743992  [ 6400/54927]
loss: 3.688791  [12800/54927]
loss: 3.661777  [19200/54927]
loss: 3.658791  [25600/54927]
loss: 3.685580  [32000/54927]
loss: 3.689259  [38400/54927]
loss: 3.653208  [44800/54927]
loss: 3.735170  [51200/54927]
Test Error: 
 Accuracy: 11.7%, Avg loss: 3.674383 

Epoch 2
-------------------------------
loss: 3.646764  [    0/54927]
loss: 3.693906  [ 6400/54927]
loss: 3.709287  [12800/54927]
loss: 3.688607  [19200/54927]
loss: 3.673328  [25600/54927]
loss: 3.629618  [32000/54927]
loss: 3.676247  [38400/54927]
loss: 3.698577  [44800/54927]
loss: 3.678238  [51200/54927]
Test Error: 
 Accuracy: 14.1%, Avg loss: 3.660840 

Epoch 3
-------------------------------
loss: 3.720447  [    0/54927]
loss: 3.644556  [ 6400/54927]
loss: 3.618338  [12800/54927]
loss: 3.572728  [19200/54927]
loss: 3.614452  [25600/54927]
loss: 3.693798  [32000/54927]
loss: 3.618841  [38400/54927]
loss: 3.668601  [44800/549

loss: 3.650540  [38400/54927]
loss: 3.684364  [44800/54927]
loss: 3.630692  [51200/54927]
Test Error: 
 Accuracy: 17.6%, Avg loss: 3.625009 

Epoch 24
-------------------------------
loss: 3.690070  [    0/54927]
loss: 3.691133  [ 6400/54927]
loss: 3.619085  [12800/54927]
loss: 3.643157  [19200/54927]
loss: 3.586827  [25600/54927]
loss: 3.625211  [32000/54927]
loss: 3.588610  [38400/54927]
loss: 3.583448  [44800/54927]
loss: 3.557784  [51200/54927]
Test Error: 
 Accuracy: 17.9%, Avg loss: 3.621614 

Epoch 25
-------------------------------
loss: 3.614865  [    0/54927]
loss: 3.611125  [ 6400/54927]
loss: 3.551447  [12800/54927]
loss: 3.584301  [19200/54927]
loss: 3.604898  [25600/54927]
loss: 3.592354  [32000/54927]
loss: 3.663563  [38400/54927]
loss: 3.611835  [44800/54927]
loss: 3.651559  [51200/54927]
Test Error: 
 Accuracy: 18.1%, Avg loss: 3.619879 

Epoch 26
-------------------------------
loss: 3.595016  [    0/54927]
loss: 3.541649  [ 6400/54927]
loss: 3.486603  [12800/54927]
l

loss: 3.586523  [ 6400/54927]
loss: 3.514801  [12800/54927]
loss: 3.534796  [19200/54927]
loss: 3.589982  [25600/54927]
loss: 3.634238  [32000/54927]
loss: 3.609514  [38400/54927]
loss: 3.465649  [44800/54927]
loss: 3.603566  [51200/54927]
Test Error: 
 Accuracy: 16.9%, Avg loss: 3.631691 

Epoch 47
-------------------------------
loss: 3.613777  [    0/54927]
loss: 3.637755  [ 6400/54927]
loss: 3.613050  [12800/54927]
loss: 3.617853  [19200/54927]
loss: 3.523091  [25600/54927]
loss: 3.580540  [32000/54927]
loss: 3.579091  [38400/54927]
loss: 3.601344  [44800/54927]
loss: 3.628162  [51200/54927]
Test Error: 
 Accuracy: 16.1%, Avg loss: 3.638427 

Epoch 48
-------------------------------
loss: 3.576707  [    0/54927]
loss: 3.570631  [ 6400/54927]
loss: 3.553907  [12800/54927]
loss: 3.589515  [19200/54927]
loss: 3.576469  [25600/54927]
loss: 3.601410  [32000/54927]
loss: 3.537340  [38400/54927]
loss: 3.604764  [44800/54927]
loss: 3.588950  [51200/54927]
Test Error: 
 Accuracy: 17.0%, Avg

loss: 3.567839  [51200/54927]
Test Error: 
 Accuracy: 17.8%, Avg loss: 3.621897 

Epoch 69
-------------------------------
loss: 3.514839  [    0/54927]
loss: 3.600299  [ 6400/54927]
loss: 3.538693  [12800/54927]
loss: 3.578027  [19200/54927]
loss: 3.594599  [25600/54927]
loss: 3.580712  [32000/54927]
loss: 3.614744  [38400/54927]
loss: 3.619502  [44800/54927]
loss: 3.654986  [51200/54927]
Test Error: 
 Accuracy: 17.7%, Avg loss: 3.622615 

Epoch 70
-------------------------------
loss: 3.587069  [    0/54927]
loss: 3.624381  [ 6400/54927]
loss: 3.585296  [12800/54927]
loss: 3.575721  [19200/54927]
loss: 3.588619  [25600/54927]
loss: 3.574337  [32000/54927]
loss: 3.463785  [38400/54927]
loss: 3.587032  [44800/54927]
loss: 3.609793  [51200/54927]
Test Error: 
 Accuracy: 17.4%, Avg loss: 3.625998 

Epoch 71
-------------------------------
loss: 3.673259  [    0/54927]
loss: 3.595327  [ 6400/54927]
loss: 3.636677  [12800/54927]
loss: 3.598176  [19200/54927]
loss: 3.625344  [25600/54927]
l

loss: 3.593695  [19200/54927]
loss: 3.653789  [25600/54927]
loss: 3.576796  [32000/54927]
loss: 3.607773  [38400/54927]
loss: 3.548815  [44800/54927]
loss: 3.535245  [51200/54927]
Test Error: 
 Accuracy: 18.2%, Avg loss: 3.618533 

Epoch 92
-------------------------------
loss: 3.706573  [    0/54927]
loss: 3.458714  [ 6400/54927]
loss: 3.639597  [12800/54927]
loss: 3.619517  [19200/54927]
loss: 3.547512  [25600/54927]
loss: 3.692246  [32000/54927]
loss: 3.586010  [38400/54927]
loss: 3.632626  [44800/54927]
loss: 3.577582  [51200/54927]
Test Error: 
 Accuracy: 18.2%, Avg loss: 3.618475 

Epoch 93
-------------------------------
loss: 3.563393  [    0/54927]
loss: 3.443159  [ 6400/54927]
loss: 3.535588  [12800/54927]
loss: 3.630744  [19200/54927]
loss: 3.548685  [25600/54927]
loss: 3.529392  [32000/54927]
loss: 3.503311  [38400/54927]
loss: 3.581547  [44800/54927]
loss: 3.410629  [51200/54927]
Test Error: 
 Accuracy: 18.3%, Avg loss: 3.618079 

Epoch 94
-------------------------------
l