In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import time
import collections
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor, tensor
import pandas as pd
import openml

#from aeon.regression.sklearn import RotationForestRegressor
from sklearn.metrics import root_mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split


np.set_printoptions(precision=3, threshold=5) # Print options
device = "cuda" # torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST

In [2]:
from torchvision import datasets, transforms


def normalize_mean_std_traindata(X_train: Tensor, X_test: Tensor) -> Tuple[Tensor, Tensor]:
    mean = X_train.mean(dim=0)
    std = X_train.std(dim=0)
    X_train = (X_train - mean) / std
    X_test = (X_test - mean) / std

    X_train = torch.clip(X_train, -5, 5)
    X_test = torch.clip(X_test, -5, 5)
    return X_train, X_test


# Define a transform to normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Download and load the training data
mnist_path = "/home/nikita/hdd/MNIST"
trainset = datasets.MNIST(mnist_path, download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset), shuffle=False)

# Download and load the test data
testset = datasets.MNIST(mnist_path, download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=False)

# Flatten the data
X_train, y_train_cat = next(iter(trainloader))
X_train = X_train.view(len(trainset), -1).to(device)
X_test, y_test_cat = next(iter(testloader))
X_test = X_test.view(len(testset), -1).to(device)

# Convert train and test labels to one-hot encoding
y_train = nn.functional.one_hot(y_train_cat, num_classes=10).float().to(device)
y_test = nn.functional.one_hot(y_test_cat, num_classes=10).float().to(device)
y_train_cat = y_train_cat.to(device)
y_test_cat = y_test_cat.to(device)

# Normalize by mean and std
X_train, X_test = normalize_mean_std_traindata(X_train, X_test)
print(f"Train data shape: {X_train.shape}")
print(f"Train labels shape: {y_train.shape}")
print(f"Test data shape: {X_test.shape}")
print(f"Test labels shape: {y_test.shape}")

Train data shape: torch.Size([60000, 784])
Train labels shape: torch.Size([60000, 10])
Test data shape: torch.Size([10000, 784])
Test labels shape: torch.Size([10000, 10])


# Logistic Regression

In [3]:
from models.base import LogisticRegression

model = LogisticRegression(
        n_classes = 10,
        l2_lambda = 0.001,
        max_iter = 300,
    )
X_train_pred = model.fit_transform(X_train, y_train)
X_test_pred = model(X_test)

print("X_test_pred", X_test_pred)

train_accuracy = (torch.argmax(X_train_pred, dim=1) == y_train_cat).float().mean().item()
test_accuracy = (torch.argmax(X_test_pred, dim=1) == y_test_cat).float().mean().item()

print(f"Train accuracy: {train_accuracy}")
print(f"Test accuracy: {test_accuracy}")

X_test_pred tensor([[ -0.2475, -10.2454,   0.3929,  ...,  11.1283,   0.0446,   3.3311],
        [  5.8921,   1.3840,  13.1938,  ..., -18.7759,   4.5716, -11.9226],
        [ -5.7093,   6.3725,   1.9279,  ...,   0.8355,   0.3640,  -1.5679],
        ...,
        [ -7.5508,  -7.3104,  -2.6633,  ...,   2.3283,   4.0663,   4.8250],
        [ -2.7671,  -1.8676,  -3.1231,  ...,  -4.0745,   6.4971,  -3.2343],
        [  2.6899, -10.5171,   4.8326,  ...,  -7.0530,  -0.4949,  -4.2397]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Train accuracy: 0.9335166811943054
Test accuracy: 0.9265999794006348


# GradientRFBoost

In [7]:
from models.random_feature_representation_boosting import GradientRFRBoostClassifier

model = GradientRFRBoostClassifier(
    in_dim = 784,
    hidden_dim = 512,
    n_classes = 10,
    randfeat_xt_dim = 512,
    randfeat_x0_dim = 512,
    n_layers = 3,
    l2_cls =  0.0001,
    l2_ghat = 0.0001,
    feature_type="SWIM",
    upscale_type = "SWIM",
    lbfgs_max_iter = 300,
    boost_lr = 0.1,
    use_batchnorm=True,
    SWIM_scale=1.0,
    )
X_train_pred = model.fit_transform(X_train, y_train)
X_test_pred = model(X_test)

train_accuracy = (torch.argmax(X_train_pred, dim=1) == y_train_cat).float().mean().item()
test_accuracy = (torch.argmax(X_test_pred, dim=1) == y_test_cat).float().mean().item()

print(f"Train accuracy: {train_accuracy}")
print(f"Test accuracy: {test_accuracy}")

#TODO NEXT: add xtx0 to the classification case

logits tensor([[  0.57382404804229736328,  -3.73551511764526367188,
           0.82371711730957031250,  ...,
           1.09647154808044433594,  -0.22811865806579589844,
          -0.30708503723144531250],
        [ 11.18465900421142578125, -11.50838088989257812500,
           1.40315723419189453125,  ...,
          -0.33516967296600341797,   0.50863718986511230469,
           0.28789615631103515625],
        [ -2.97072553634643554688,  -5.19392490386962890625,
           0.83330798149108886719,  ...,
           1.90902304649353027344,  -1.86906218528747558594,
           2.21786260604858398438],
        ...,
        [ -0.85270190238952636719,  -2.78479862213134765625,
          -4.78819513320922851562,  ...,
          -1.44127058982849121094,   2.85936808586120605469,
           1.85477983951568603516],
        [  2.45700097084045410156,  -6.29847288131713867188,
           1.08437561988830566406,  ...,
          -2.30308389663696289062,  -2.91267466545104980469,
          -0.77890306

In [5]:
def see_results_for_every_layer(X_train, y_train, X_test, y_test, model, loss_fn):
    with torch.no_grad():
        X0_train = X_train
        X0_test = X_test

        X_train = model.upscale(X0_train)
        X_test = model.upscale(X0_test)

        pred_train = model.top_level_modules[0](X_train)
        pred_test = model.top_level_modules[0](X_test)

        ce = loss_fn(pred_train, y_train)
        ce_test = loss_fn(pred_test, y_test)
        acc = (pred_train.argmax(1) == y_train.argmax(1)).float().mean()
        acc_test = (pred_test.argmax(1) == y_test.argmax(1)).float().mean()
        print(f"Train ce at layer 0: {ce}")
        print(f"Test ce at layer 0: {ce_test}")
        print(f"Train acc at layer 0: {acc}")
        print(f"Test acc at layer 0: {acc_test}")
        print()
        
        for t, (feat_layer, ghat_layer, classifier, batchnorm) in enumerate(zip(model.random_feature_layers, 
                                                                     model.ghat_boosting_layers, 
                                                                     model.top_level_modules[1:],
                                                                     model.batchnorms)):
            features_train = feat_layer(X_train, X0_train)
            features_test = feat_layer(X_test, X0_test)
            X_train += model.boost_lr * ghat_layer(features_train)
            X_train = batchnorm(X_train)
            X_test  += model.boost_lr * ghat_layer(features_test)
            X_test = batchnorm(X_test)
            
            pred_train = classifier(X_train)
            pred_test = classifier(X_test)

            ce = loss_fn(pred_train, y_train)
            ce_test = loss_fn(pred_test, y_test)
            acc = (pred_train.argmax(1) == y_train.argmax(1)).float().mean()
            acc_test = (pred_test.argmax(1) == y_test.argmax(1)).float().mean()
            print(f"Train ce at layer {t+1}: {ce}")
            print(f"Test ce at layer {t+1}: {ce_test}")
            print(f"Train acc at layer {t+1}: {acc}")
            print(f"Test acc at layer {t+1}: {acc_test}")
            print()


see_results_for_every_layer(X_train, y_train, X_test, y_test, model, nn.functional.cross_entropy)

Train ce at layer 0: 0.27760976552963257
Test ce at layer 0: 0.27482572197914124
Train acc at layer 0: 0.9209666848182678
Test acc at layer 0: 0.9220999479293823

Train ce at layer 1: 0.09838391095399857
Test ce at layer 1: 0.13978564739227295
Train acc at layer 1: 0.9709666967391968
Test acc at layer 1: 0.9599999785423279

Train ce at layer 2: 0.06271514296531677
Test ce at layer 2: 0.12089692056179047
Train acc at layer 2: 0.9815833568572998
Test acc at layer 2: 0.9641000032424927

Train ce at layer 3: 0.045359883457422256
Test ce at layer 3: 0.11288446933031082
Train acc at layer 3: 0.9872333407402039
Test acc at layer 3: 0.9679999947547913



# End2End

In [6]:
from models.end2end import End2EndMLPResNet

model = End2EndMLPResNet(
    in_dim = X_train.shape[1],
    hidden_dim = 128,
    bottleneck_dim = 32,
    out_dim = 10,
    n_blocks = 4,
    lr = 0.01,
    end_lr_factor = 0.01,
    n_epochs = 20,
    weight_decay = 0.001,
    batch_size = 512
    )
X_train_pred = model.fit_transform(X_train, y_train)
X_test_pred = model(X_test)

print("X_test_pred", X_test_pred)

train_accuracy = (torch.argmax(X_train_pred, dim=1) == y_train_cat).float().mean().item()
test_accuracy = (torch.argmax(X_test_pred, dim=1) == y_test_cat).float().mean().item()

print(f"Train accuracy: {train_accuracy}")
print(f"Test accuracy: {test_accuracy}")

100%|██████████| 20/20 [00:18<00:00,  1.07it/s]

X_test_pred tensor([[-1.97553634643554687500e-03, -5.60329854488372802734e-03,
         -3.84935736656188964844e-03,  ...,
          1.01216506958007812500e+00,  4.82936203479766845703e-03,
          2.54796445369720458984e-03],
        [-1.76844596862792968750e-02,  1.61109864711761474609e-02,
          1.00505971908569335938e+00,  ...,
         -8.94658267498016357422e-04,  3.86652350425720214844e-03,
         -1.08496844768524169922e-03],
        [ 1.16206407546997070312e-02,  9.91596341133117675781e-01,
          1.04889273643493652344e-04,  ...,
          5.59684634208679199219e-03, -7.64999538660049438477e-03,
         -5.87991625070571899414e-03],
        ...,
        [ 3.46951186656951904297e-04, -3.10655683279037475586e-03,
         -2.05048918724060058594e-03,  ...,
         -3.10361385345458984375e-04,  1.76037102937698364258e-03,
         -5.23433834314346313477e-03],
        [-8.93914699554443359375e-03, -6.42771273851394653320e-03,
         -1.36475265026092529297e-03,  .


