In [1]:
import torch
import pandas as pd
from collections import defaultdict
import random

In [2]:
from utils.data.dataloaders import make_mnist_dataloaders
from utils.models.supernet import SuperNet 
from utils.trainer import Trainer

In [3]:
data_dir = './data/'
models_dir = './models/'
results_dir = './results/'
n_epoch = 30
batch_size = 512
num_workers = 0
lr = 1e-2
weight_decay = 0.0
strategies = ['random', 'mean', 'sum', 'dropout_0.1', 'dropout_0.3', 'dropout_0.5', 'dropout_0.7']
seed = 0

In [4]:
random.seed(seed)

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

In [6]:
train_dataloader, val_dataloader, test_dataloader = make_mnist_dataloaders(data_dir=data_dir, random_state=seed, batch_size=batch_size, num_workers=num_workers)

In [7]:
optimizers = { 'sgd_m': lambda parameters, weight_decay, lr: torch.optim.SGD(parameters, lr=lr, weight_decay=weight_decay, momentum=0.9),
               'sgd': torch.optim.SGD,
               'adam': torch.optim.Adam}

default_val_params = {'test_dataloader': val_dataloader, 'show_results': False, 'verbose': False, 'log': True}
default_test_params = {'test_dataloader': test_dataloader, 'show_results': False, 'verbose': False, 'log': True}

In [8]:
accuracies_valid = defaultdict(list)
accuracies_test = defaultdict(list)
for optimizer in optimizers:
    for strategy in strategies:    
        default_train_params = {'train_dataloader': train_dataloader, 'val_dataloader': val_dataloader, 'n_epoch': n_epoch, 
                        'optim': optimizers[optimizer], 'weight_decay': weight_decay, 'schedul': None, 'loss': torch.nn.CrossEntropyLoss, 
                        'weighted': True, 'lr': lr, 'show_results': False, 'saved_models_dir': models_dir, 'verbose': False, 
                        'early_stopping': True, 'max_gap': 2, 'gamma': None}
        
        #обучение суперсети для каждой стратегии обучения
        model = SuperNet(strategy=strategy)
        trainer = Trainer(model, device=device, label_names=list(range(10)), model_name=f'supernet_{strategy}_{optimizer}')
        trainer.train(**default_train_params)
    
        #сэмплирование
        model1 = model.sampler(1)
        model2 = model.sampler(2)
    
        #вычисление метрик сэмплов суперсети
        trainer1 = Trainer(model1, device=device, label_names=list(range(10)), model_name=f'sample1_{strategy}_{optimizer}')
        trainer2 = Trainer(model2, device=device, label_names=list(range(10)), model_name=f'sample2_{strategy}_{optimizer}')
    
        _, _, history = trainer1.test(**default_val_params)
        accuracies_valid[f'{strategy} {optimizer}'].append(history['accuracy'])
        _, _, history = trainer1.test(**default_test_params)
        accuracies_test[f'{strategy} {optimizer}'].append(history['accuracy'])
  
        _, _, history = trainer2.test(**default_val_params)
        accuracies_valid[f'{strategy} {optimizer}'].append(history['accuracy'])
        _, _, history = trainer2.test(**default_test_params)
        accuracies_test[f'{strategy} {optimizer}'].append(history['accuracy'])
    
        #случайная инициализация сэмплов суперсети
        model1.random_init()
        model2.random_init()  
    
        #обучение сэмплов суперсети с нуля
        trainer1.train(**default_train_params)
        trainer2.train(**default_train_params)
    
        #вычисление метрик сэмплов суперсети
        _, _, history = trainer1.test(**default_val_params)
        accuracies_valid[f'{strategy} {optimizer}'].append(history['accuracy'])
        _, _, history = trainer1.test(**default_test_params)
        accuracies_test[f'{strategy} {optimizer}'].append(history['accuracy'])
    
        _, _, history = trainer2.test(**default_val_params)
        accuracies_valid[f'{strategy} {optimizer}'].append(history['accuracy'])
        _, _, history = trainer2.test(**default_test_params)
        accuracies_test[f'{strategy} {optimizer}'].append(history['accuracy'])

In [10]:
model_names = ['sample_1 before', 'sample_2 before', 'sample_1 after', 'sample_2 after']
results_valid = pd.DataFrame(model_names, columns=['model'])
results_test = pd.DataFrame(model_names, columns=['model'])
for optimizer in optimizers:
    for strategy in strategies:
        results_valid[f'{strategy} {optimizer}'] = accuracies_valid[f'{strategy} {optimizer}']
        results_test[f'{strategy} {optimizer}'] = accuracies_test[f'{strategy} {optimizer}']

In [11]:
results_valid.to_excel(f'{results_dir}/тестовая_выборка.xls', index=False)

In [12]:
results_test.to_excel(f'{results_dir}/отложенная_выборка.xls', index=False)

In [13]:
results_valid

Unnamed: 0,model,random sgd_m,mean sgd_m,sum sgd_m,dropout_0.1 sgd_m,dropout_0.3 sgd_m,dropout_0.5 sgd_m,dropout_0.7 sgd_m,random sgd,mean sgd,...,dropout_0.3 sgd,dropout_0.5 sgd,dropout_0.7 sgd,random adam,mean adam,sum adam,dropout_0.1 adam,dropout_0.3 adam,dropout_0.5 adam,dropout_0.7 adam
0,sample_1 before,0.9391,0.8179,0.625,0.1963,0.8007,0.917,0.903,0.2821,0.2403,...,0.6508,0.7226,0.4183,0.9559,0.9072,0.5837,0.8434,0.8893,0.6234,0.4256
1,sample_2 before,0.9447,0.7341,0.6601,0.5647,0.808,0.9469,0.8992,0.2637,0.308,...,0.8097,0.8252,0.6922,0.9684,0.8299,0.5271,0.8868,0.8963,0.8903,0.5382
2,sample_1 after,0.9587,0.9641,0.9639,0.9567,0.9571,0.9549,0.9553,0.8552,0.835,...,0.8685,0.9004,0.8851,0.9707,0.9732,0.9753,0.9758,0.9749,0.9682,0.9774
3,sample_2 after,0.9633,0.9616,0.9625,0.9639,0.9617,0.9614,0.9599,0.863,0.8924,...,0.8817,0.8878,0.888,0.9765,0.9753,0.9718,0.9735,0.9786,0.975,0.9778


In [14]:
results_test

Unnamed: 0,model,random sgd_m,mean sgd_m,sum sgd_m,dropout_0.1 sgd_m,dropout_0.3 sgd_m,dropout_0.5 sgd_m,dropout_0.7 sgd_m,random sgd,mean sgd,...,dropout_0.3 sgd,dropout_0.5 sgd,dropout_0.7 sgd,random adam,mean adam,sum adam,dropout_0.1 adam,dropout_0.3 adam,dropout_0.5 adam,dropout_0.7 adam
0,sample_1 before,0.9406,0.8176,0.6244,0.193,0.8113,0.917,0.9076,0.2748,0.2406,...,0.6576,0.7291,0.4203,0.959,0.9072,0.598,0.8396,0.8874,0.632,0.4353
1,sample_2 before,0.9461,0.74,0.6569,0.5609,0.8163,0.9508,0.8997,0.2597,0.3114,...,0.8199,0.8336,0.6948,0.9696,0.8299,0.5279,0.8846,0.9001,0.8893,0.5568
2,sample_1 after,0.9627,0.9673,0.9656,0.9581,0.9587,0.9564,0.9542,0.8627,0.8394,...,0.8717,0.902,0.89,0.9722,0.9724,0.9771,0.9751,0.9794,0.9688,0.9774
3,sample_2 after,0.9649,0.9634,0.9652,0.9646,0.9629,0.9651,0.9587,0.8691,0.8919,...,0.8936,0.893,0.8962,0.9773,0.9791,0.9758,0.9757,0.9816,0.974,0.979
