In [4]:
# libraries
from utils import get_cifar10_loaders, train_and_test, plot_pairwise, eval_on_dataloader
import torch
import numpy as np
from resnet import resnet18
from neural_collapse import NC
import copy
from tqdm import tqdm, trange
import pickle
import os
import pandas as pd
import time


max_epoch_warm_up = 350
log_interval = 5
batch_size = 128
lr = 0.001
stop_acc = 0.99
max_epoch_full = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:
def train_step(model: torch.nn.Module, train_loader, criterion, optimizer):
    '''
    : param model: torch.nn.Module
    : param train_loader: torch.utils.data.DataLoader
    : param criterion: torch.nn.Module
    : param optimizer: torch.optim.Optimizer
    : return: float

    Trains the model for one epoch on the training set.
    Returns the average accuracy of the epoch.
    '''

    device = next(model.parameters()).device
    y_preds = torch.tensor([]).to(device)
    y_trues = torch.tensor([]).to(device)
    model.train()
    for i, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        y_preds = torch.cat((y_preds, y_pred), 0)
        y_trues = torch.cat((y_trues, y), 0)
    return (y_preds.argmax(1) == y_trues).float().mean().item()

In [6]:
def train_and_test(model, train_loader, test_loader, criterion, optimizer, max_epochs = 100, stop_acc = 0.99, seed = 42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    converged = False
    model.train()
    start_time = time.time()
    for epoch in range(max_epochs):
        train_acc = train_step(model, train_loader, criterion, optimizer)
        if stop_acc is not None and train_acc > stop_acc:
            converged = True
            break
    if not converged: print('Convergence not reached, increase epochs')
    end_time = time.time()
    test_acc = eval_on_dataloader(model, test_loader)
    return {'test_acc': test_acc, 'train_time': end_time - start_time}

In [7]:
def warm_up (seed : int,
            batch_size : int = 128,
            max_epoch_warm_up : int = 350,
            log_interval : int = 5,
            Optimizer : torch.optim.Optimizer = torch.optim.SGD,
            lr : float = 0.001,
            dir : str = 'results'):
        
    '''
    : param seed: int random seed
    : param batch_size: int
    : param max_epoch_warm_up: int
    : param log_interval: int 
    : param Optimizer: torch.optim.Optimizer
    : param lr: float
    : param dir: directory to save results
    : return: None

    Trains the model on half of the dataset for max_epoch_warm_up epochs.
    Saves the model every log_interval epochs.
    '''

    seed_path = f'{dir}/seed{seed}'
    checkpoints_path = f'{seed_path}/checkpoints'
    # creating directories
    os.makedirs(seed_path, exist_ok = True)
    os.makedirs(checkpoints_path, exist_ok = True)

    # setting seeds
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # getting data loaders
    loaders_half = get_cifar10_loaders(0.5, seed = seed, batch_size= batch_size) # half dataset
    train_loader_half, test_loader_half = loaders_half['train_loader'], loaders_half['test_loader']

    # setting up model training
    model = resnet18(num_classes = 10)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = Optimizer(model.parameters(), lr = lr)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    progress = trange(max_epoch_warm_up, position=0)
    for epoch in progress:
        # training step on half dataset
        progress.set_description(f"Epoch {epoch+1} of {max_epoch_warm_up} (warm up)")
        train_step(model, train_loader_half, criterion, optimizer)
        if (epoch+1) % log_interval == 0:
            # saving the model
            progress.set_description(f"Epoch {epoch+1} of {max_epoch_warm_up} (saving the model)")
            torch.save(model.state_dict(), f'{checkpoints_path}/warm_up_{epoch+1}.pt')

In [8]:
def measure_NC(seed: int, dir: str = 'results'):
    '''
    : param seed: int random seed
    : param dir: directory to save results
    : return: None

    Measures the Neural Collapse of the model trained on half of the dataset.
    Saves the results in a pickle file.
    '''

    seed_path = f'{dir}/seed{seed}'
    checkpoints_path = f'{seed_path}/checkpoints'

    # setting seeds
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # getting data loaders
    loaders_half = get_cifar10_loaders(0.5, seed = seed, batch_size= 128) # half dataset
    train_loader_half, test_loader_half = loaders_half['train_loader'], loaders_half['test_loader']

    results = []

    for epoch in trange(log_interval, max_epoch_warm_up + 1, log_interval):
        results.append({})
        # setting up model
        model = resnet18(num_classes = 10)

        # loading the model
        model.load_state_dict(torch.load(f'{checkpoints_path}/warm_up_{epoch}.pt'))
        model = model.to(device)

        # measuring the neural collapse
        nc = NC(model, train_loader_half)
        results[-1].update(nc)

        # saving results
        with open(f'{seed_path}/results.pkl', 'wb') as f:
            pickle.dump(results, f)

In [12]:
measure_NC(7)

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 70/70 [43:28<00:00, 37.26s/it]


In [13]:
def measure_PL(seed: int,
               batch_size: int = 128,
               max_epoch_full: int = 100,
               stop_acc: float = 0.99,
               dir : str = 'results'):
    '''
    : param seed: int random seed
    : param batch_size: int
    : param max_epoch_full: int maximum number of epochs to train the model on the full dataset
    : param stop_acc: float accuracy on the training set to stop training
    : param dir: directory to save results
    : return: None

    Measures the Performance Loss of the model trained on the full dataset.
    Saves the results in a pickle file.
    '''

    seed_path = f'{dir}/seed{seed}'
    checkpoints_path = f'{seed_path}/checkpoints'

    # load results
    with open(f'{seed_path}/results.pkl', 'rb') as f:
        results = pickle.load(f)

    # setting seeds
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # getting data loaders
    loaders_full = get_cifar10_loaders(seed = seed, batch_size= batch_size) # full dataset
    train_loader_full, test_loader_full = loaders_full['train_loader'], loaders_full['test_loader']

    for epoch in trange(log_interval, max_epoch_warm_up + 1, log_interval):
        pos_res = int((epoch / 5) -1)

        # setting up model
        model = resnet18(num_classes = 10)

        # loading the model
        model.load_state_dict(torch.load(f'{checkpoints_path}/warm_up_{epoch}.pt'))
        model = model.to(device)
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr = lr)
        
        # measuring the performance loss
        results[pos_res].update(train_and_test(model, train_loader_full, test_loader_full, criterion= criterion, optimizer= optimizer, seed = seed))

        if epoch == 5: print(results[pos_res])

        # saving results
        with open(f'{seed_path}/results.pkl', 'wb') as f:
            pickle.dump(results, f)


In [14]:
measure_PL(7)

Files already downloaded and verified
Files already downloaded and verified


  1%|▏         | 1/70 [10:09<11:40:47, 609.38s/it]

{'NC1': 15.429203987121582, 'NC2': 0.37371334433555603, 'NC3': 0.9818062782287598, 'NC4': 0.54756, 'test_acc': 0.607693829113924, 'train_time': 607.2459115982056}


100%|██████████| 70/70 [4:33:52<00:00, 234.75s/it]  


In [15]:
def SP_NC_PL(seed: int,
            batch_size= 128,
            dir: str = 'results',
            shrink: float = 0.6,
            perturb: float = 0.01,
            max_epoch_full=100,
            stop_acc=0.99,
            max_epochs_warm_up = 350):
    '''
    : param seed: int random seed
    : param batch_size: int
    : param dir: directory to save results
    : param shrink: float
    : param perturb: float
    : param max_epoch_full: int maximum number of epochs to train the model on the full dataset
    : param stop_acc: float accuracy on the training set to stop training
    : param max_epochs_warm_up: int maximum number of epochs to train the model on the half dataset
    : return: None

    Measures the Neural Collapse and the Performance Loss of the model trained on the full dataset after Shrink and Perturb.
    Saves the results in a pickle file.
    '''

    seed_path = f'{dir}/seed{seed}'
    checkpoints_path = f'{seed_path}/checkpoints'
    # load results
    with open(f'{seed_path}/results.pkl', 'rb') as f:
        results = pickle.load(f)

     # getting the loaders
    loaders_half = get_cifar10_loaders(0.5, seed = seed, batch_size= batch_size) # half dataset
    train_loader_half, test_loader_half = loaders_half['train_loader'], loaders_half['test_loader']
    loaders_full = get_cifar10_loaders(seed = seed, batch_size= batch_size) # complete dataset
    train_loader_full, test_loader_full = loaders_full['train_loader'], loaders_full['test_loader']

    progress = trange(5, max_epochs_warm_up+1, 5)
    for epoch in progress:
        # load model
        model = resnet18(num_classes = 10)
        model.load_state_dict(torch.load(f'{checkpoints_path}/warm_up_{epoch}.pt'))
        dummy_model = resnet18(num_classes = 10)
        # shrink and perturb the model
        with torch.no_grad():
            for real_parameter, random_parameter in zip(model.parameters(), dummy_model.parameters()):
                real_parameter.mul_(shrink).add_(random_parameter, alpha=perturb)
        Optimizer = torch.optim.SGD
        # compute NC
        model = model.to(device) # you should change this into nc function
        progress.set_description(f"Epoch {epoch} (measuring NC)")
        nc = NC(model, train_loader_half)
        new_key_mapping = {key : f'{key}_SP' for key in nc.keys()}
        SP_NC = {new_key_mapping[old_key]: value for old_key, value in nc.items()}
        # compute PL
        progress.set_description(f"Epoch {epoch} (computing PL)")
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = Optimizer(model.parameters(), lr = lr)
        
        PL = train_and_test(model, train_loader_full, test_loader_full, criterion= criterion, optimizer= optimizer, seed = seed)
        new_key_mapping = {key : f'{key}_SP' for key in PL.keys()}
        SP_PL = {new_key_mapping[old_key]: value for old_key, value in PL.items()}

        pos_res = int((epoch / 5) -1)
        results[pos_res].update(SP_NC)
        results[pos_res].update(SP_PL)

        if epoch == 5: print(results[pos_res])
        
        with open(f'{seed_path}/results.pkl', 'wb') as f:
            pickle.dump(results, f)

In [16]:
SP_NC_PL(7) # to do, approx 8 hours

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Epoch 5 (computing PL):   0%|          | 0/70 [08:30<?, ?it/s]

{'NC1': 15.429203987121582, 'NC2': 0.37371334433555603, 'NC3': 0.9818062782287598, 'NC4': 0.54756, 'test_acc': 0.607693829113924, 'train_time': 607.2459115982056, 'NC1_SP': 15.265457153320312, 'NC2_SP': 0.37433314323425293, 'NC3_SP': 0.9828238487243652, 'NC4_SP': 0.55548, 'test_acc_SP': 0.6729628164556962, 'train_time_SP': 470.26064682006836}





# TODO:

* Plot as usual, include correlation between the two and have a plot of test accuracy vs nc.
* Run the experiments with S&P and measure both NC and PL.
* Run experiments with NC regularizer.
* Euler (ask Iasonas)

In [3]:
with open('results/seed7/results.pkl', 'rb') as f:
    results = pickle.load(f)
    results = pd.DataFrame(results)

results

Unnamed: 0,NC1,NC2,NC3,NC4,test_acc,train_time
0,15.429204,0.373713,0.981806,0.54756,0.607694,607.245912
1,10.460551,0.375583,0.966333,0.62512,0.607397,566.765264
2,7.417085,0.370429,0.940527,0.66608,0.596816,534.982836
3,5.435144,0.360533,0.913593,0.69532,0.595134,505.533059
4,4.094497,0.347900,0.882195,0.72988,0.597211,463.956973
...,...,...,...,...,...,...
65,0.413934,0.202133,0.455983,0.99984,0.571697,168.083418
66,0.411723,0.202033,0.453426,0.99992,0.578026,168.336947
67,0.433576,0.201920,0.456003,0.99964,0.579608,169.117640
68,0.416714,0.201801,0.453658,0.99988,0.578916,168.409798
