In [3]:
import os, shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.transforms.v2.functional as F_v2
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from Utils.dataset import PreloadedDataset
import matplotlib.pyplot as plt
from tqdm import tqdm

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

from Methods.HEPA.train import train as train_hepa
from Methods.HEPA.model import HEPA
# from Methods.SSMAugPC.model import SSMAugPC
# from Methods.SSMAugPC.train import train as train_ssmaugpc
from Methods.BYOL.train import train as train_byol
from Methods.BYOL.model import BYOL
# from Methods.DINO.train import train as train_dino
# from Methods.DINO.model import DINO
# from Methods.SimCLR.train import train as train_simclr
# from Methods.SimCLR.model import SimCLR
# from Methods.SimSiam.train import train as train_simsiam
# from Methods.SimSiam.model import SimSiam
# from Methods.LAugPC2.train import train as train_laugpc2
# from Methods.LAugPC2.model import LAugPC2
# from Methods.VQVAE.train import train as train_vae
# from Methods.VQVAE.model import VAE
from Methods.AE.train import train as train_ae
from Methods.AE.model import AE
from Methods.MAE.train import train as train_mae
from Methods.MAE.model import MAE
from Methods.GPAViT.train import train as train_gpa
from Methods.GPAViT.model import GPAViT
from Methods.GPAMAE.train import train as train_gpamae
from Methods.GPAMAE.model import GPAMAE
from Methods.VAE.train import train as train_vae
from Methods.VAE.model import VAE
from Methods.Supervised.model import Supervised
from Methods.Supervised.train import train as train_supervised

from Examples.ModelNet10.evals import linear_probing, eval_representations
from Examples.ModelNet10.dataset import ModelNet10
from Utils.functional import get_optimiser

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.backends.cudnn.benchmark = True
device

'cuda'

In [5]:
root = '../Datasets/ModelNet10_pics/'
train_set = ModelNet10(root, 'train', device=device)
train_set, val_set = train_set.split(0.8)

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(15,5))
for i, ax in enumerate(axes[0]):
    (img1, rot1, lab1), _ = train_set[i]
    ax.imshow(img1.squeeze().cpu(), cmap='gray')
    ax.axis('off')
for i, ax in enumerate(axes[1]):
    _, (img2, rot2, lab2) = train_set[i]
    ax.imshow(img2.squeeze().cpu(), cmap='gray')
    ax.axis('off')
plt.show()

# print max and min values
print('Max value:', train_set.transformed_images.max())
print('Min value:', train_set.transformed_images.min())

In [1]:
cfgs = [
    {
        'name': 'AE',
        'model': AE,
        'save': True,
    },
    {
        'name': 'AE',
        'model': AE,
        'save': False,
    },
    {
        'name': 'AE',
        'model': AE,
        'save': False,
    },
    {
        'name': 'AE',
        'model': AE,
        'save': False,
    },
    {
        'name': 'AE',
        'model': AE,
        'save': False,
    },
    {
        'name': 'VAE',
        'model': VAE,
        'save': True,
    },
    {
        'name': 'VAE',
        'model': VAE,
        'save': False,
    },
    {
        'name': 'VAE',
        'model': VAE,
        'save': False,
    },
    {
        'name': 'VAE',
        'model': VAE,
        'save': False,
    },
    {
        'name': 'VAE',
        'model': VAE,
        'save': False,
    },
]

for cfg in cfgs:

    Model = cfg['model']
    backbone = 'modelnet10_cnn'
    experiment_name = cfg['name']
    # experiment = 'mnist_byol'
    experiment='modelnet10'
    # log_dir = None
    log_dir = f'Examples/ModelNet10/out/logs/{experiment}/{experiment_name}/'
    save_dir = None
    if cfg['save']:
        save_dir = f'Examples/ModelNet10/out/models/{experiment}/{experiment_name}.pth'
    if Model == VAE:
        model = Model(1, 256).to(device)
    elif Model == AE or Model == BYOL or Model == MAE:
        model = Model(1).to(device)
    else:
        model = Model(1, 5).to(device)

    optimiser = get_optimiser(
        model, 
        'AdamW', 
        lr=3e-4, 
        wd=0.004, 
        exclude_bias=True,
        exclude_bn=True,
    )

    to_train = True
    if save_dir is not None:
        try:
            sd = torch.load(save_dir)
            # change keys "project" to "transition"
            for key in list(sd.keys()):
                if 'project' in key:
                    sd[key.replace('project', 'transition')] = sd.pop(key)
            model.load_state_dict(sd)
            to_train = False
            print('Model loaded successfully')
        except FileNotFoundError:
            pass
            print('Model not found, training new model')
    if to_train:
        writer = None
        if log_dir is not None:
            # remove reduction if exists
            if os.path.exists(log_dir + 'encoder/reduction.csv'):
                os.remove(log_dir + 'encoder/reduction.csv')
            if os.path.exists(log_dir + 'classifier/reduction.csv'):
                os.remove(log_dir + 'classifier/reduction.csv')

            run_no = 1
            while os.path.exists(log_dir + 'encoder/' + f'run_{run_no}'):
                run_no += 1
            writer = SummaryWriter(log_dir + 'encoder/' + f'run_{run_no}')
        
        if isinstance(model, HEPA):
            train_set.transform = transforms.Compose([
            ])
            train_hepa(
                model,
                optimiser,
                train_set,
                val_set,
                num_epochs=250,
                batch_size=256,
                stop_at=0,
                train_aug_scaler='none',
                val_aug_scaler='none',
                loss_fn='mse',
                writer=writer,
                save_dir=save_dir,
                save_every=5,
            )
        if isinstance(model, GPAViT):
            train_set.transform = transforms.Compose([
            ])
            train_gpa(
                model,
                optimiser,
                train_set,
                val_set,
                num_epochs=250,
                batch_size=256,
                train_aug_scaler='none',
                val_aug_scaler='none',
                writer=writer,
                save_dir=save_dir,
                save_every=5,
            )
        if isinstance(model, GPAMAE):
            train_set.transform = transforms.Compose([
            ])
            train_gpa(
                model,
                optimiser,
                train_set,
                val_set,
                num_epochs=250,
                batch_size=256,
                train_aug_scaler='none',
                val_aug_scaler='none',
                writer=writer,
                save_dir=save_dir,
                save_every=5,
            )
        if isinstance(model, BYOL):
            train_set.transform = transforms.Compose([
            ])
            optimiser = get_optimiser(
                model, 
                'AdamW', 
                lr=3e-5, 
                wd=0.004, 
                exclude_bias=True,
                exclude_bn=True,
            )
            train_byol(
                model,
                optimiser,
                train_set,
                val_set,
                num_epochs=250,
                batch_size=256,
                augmentation=augmentation,
                beta=None,
                writer=writer,
                save_dir=save_dir,
                save_every=5,
            )
        if isinstance(model, AE):
            train_set.transform = transforms.Compose([
                transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),
            ])
            train_ae(
                model,
                optimiser,
                train_set,
                val_set,
                num_epochs=250,
                batch_size=256,
                loss_fn='mse',
                beta=None,
                writer=writer,
                save_dir=save_dir,
                save_every=5,
            )
        if isinstance(model, VAE):
            train_set.transform = transforms.Compose([
                transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),
            ])
            train_vae(
                model,
                optimiser,
                train_set,
                val_set,
                num_epochs=250,
                batch_size=256,
                beta=0.5,
                writer=writer,
                save_dir=save_dir,
                save_every=5,
            )
        if isinstance(model, MAE):
            train_set.transform = transforms.Compose([
                transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),
            ])
            train_mae(
                model,
                optimiser,
                train_set,
                val_set,
                num_epochs=250,
                batch_size=256,
                mask_ratio=0.75,
                writer=writer,
                save_dir=save_dir,
                save_every=5,
            )
        if isinstance(model, Supervised):
            train_set.transform = transforms.Compose([
                transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),
            ])
            train_supervised(
                model,
                optimiser,
                num_epochs=250,
                batch_size=cfg['batch_size'],
                subset_size=cfg['subset_size'],
                learn_on_ss=False,
                writer=writer,
                save_dir=save_dir,
                save_every=5,
            )
        
        print(f'Finished training')
        if save_dir is not None:
            print('Run cell again to load best (val_acc) model.')

        # Evaluate inter-neuron correlations
        rep_metrics = eval_representations(model, flatten=False)
        if writer is not None:
            writer.add_scalar('Encoder/test_feature_corr', rep_metrics['corr'])
            writer.add_scalar('Encoder/test_feature_std', rep_metrics['std'])

    # linear probing
    for n in [1, 10, 100, 1000]:
        dest = f'Examples/MNIST/out/logs/n{n}-{experiment}/{experiment_name}/'
        if log_dir is not None:
            writer = SummaryWriter(dest + f'classifier/run_{run_no}')
        mnist_linear_eval(model, n, writer, flatten=False, test=True)

    # # Semi-supervised learning eval
    # for n in [1, 10, 100, 1000]:
    #     dest = f'Examples/MNIST/out/logs/n{n}-{experiment}/{experiment_name}/'
    #     if log_dir is not None:
    #         writer = SummaryWriter(dest + f'classifier/run_{run_no}')
    #     mnist_linear_eval(model, n, writer, flatten=False, test=True, finetune=True)

NameError: name 'AE' is not defined