In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import pairwise_distances
from functools import reduce
import operator

import pandas as pd
import sys
import os

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

seed=26

In [4]:
def make_loaders(seed, dataset, batch_size=32):
    dataset_dict = {'MNIST': ['mnist_data/', datasets.MNIST],
                    'FMNIST': ['fashion_mnist_data/', datasets.FashionMNIST],
                    'CIFAR10': ['cifar10_data/', datasets.CIFAR10],
                    'SVHN': ['svhn_data/', datasets.SVHN],
                    'CIFAR100': ['cifar100_data/', datasets.CIFAR100]}

    torch.manual_seed(seed)
    path, function = dataset_dict[dataset]
    
    if dataset != 'SVHN':
        train_loader = torch.utils.data.DataLoader(function(path, download=True, train=True,
                                                   transform=transforms.Compose([transforms.ToTensor()])),
                                                   batch_size=batch_size, shuffle=True)

        test_loader = torch.utils.data.DataLoader(function(path, download=True, train=False,
                                                   transform=transforms.Compose([transforms.ToTensor()])), 
                                                   batch_size=batch_size, shuffle=True)
    else:
        train_loader = torch.utils.data.DataLoader(function(path, download=True, split='train',
                                                   transform=transforms.Compose([transforms.ToTensor()])),
                                                   batch_size=batch_size, shuffle=True)

        test_loader = torch.utils.data.DataLoader(function(path, download=True, split='test',
                                                   transform=transforms.Compose([transforms.ToTensor()])), 
                                                   batch_size=batch_size, shuffle=True)
    
    return train_loader, test_loader



def train(net, train_loader, tracker, optimizer, epoch, device):
    net.train()
    criterion = nn.CrossEntropyLoss()

    train_loss = 0
    for batch_id, (data, label) in tqdm(enumerate(train_loader), total=len(train_loader), leave=False, desc =f'Epoch: {epoch}'):
        optimizer.zero_grad()
        
        pred = net(data.to(device))
        
        Preds = torch.cat(pred.split(10, dim = 1))
        Labels = torch.cat([label for _ in range(net.N)])
        
        loss = criterion(Preds.to(device), Labels.to(device)) * net.N
        loss.backward()
        
        tracker(net, epoch)
        
        optimizer.step()
#         scheduler.step()
        
        train_loss += loss.item()
    
    epoch_loss = train_loss/len(train_loader)
    tracker.loss_tracker[epoch] = epoch_loss

    return epoch_loss

def evaluate_single(net, train_loader, device):
    criterion = nn.CrossEntropyLoss()

    train_loss = 0
    for batch_id, (data, label) in enumerate(train_loader):        
        pred = net(data.to(device))
        loss = criterion(pred, label.to(device))
        train_loss += loss.item()

    return train_loss/len(train_loader)


def evaluate(net, train_loader, device):
    criterion = nn.CrossEntropyLoss()
    net.eval()
    
    
    train_loss = 0
    acc = []

    for batch_id, (data, label) in enumerate(train_loader):        
        pred = net(data.to(device))
        
        Preds = torch.cat(pred.split(10, dim = 1))
        Labels = torch.cat([label for _ in range(net.N)])

        acc.append(((Preds.cpu().argmax(dim=1).reshape(-1) == Labels.cpu().reshape(-1))).float().mean().item())
        
        loss = criterion(Preds.to(device), Labels.to(device)) * net.N
        
        train_loss += loss.item()
    
    net.train()
    
    return train_loss/len(train_loader), np.mean(acc)

# Split Networks

In [7]:
def factorize_net(net):
    n_models = net.N
    n_layers = net.n_layers
    networks = [[] for _ in range(n_models)]
    
    for index, (name, par) in enumerate(net.named_parameters()):
        
        params = par.cpu().detach().numpy().reshape(n_models, -1)
        
        for index in range(n_models):
            networks[index].append(params[index].reshape(-1, 1))
            
    networks = [np.vstack(net) for net in networks]
        
    return np.hstack(networks)

In [6]:
def train_scheduler(net, train_loader, optimizer, scheduler, epoch, device, dataset='CIFAR100'):
    net.train()
        
    train_loss = 0
    for batch_id, (data, label) in tqdm(enumerate(train_loader), total=len(train_loader), leave=False, desc=f'Epoch: {epoch}'):
        optimizer.zero_grad()
        
        pred = net(data.to(device))
        
        if dataset == 'CIFAR100':
            n_classes=100
        else:
            n_classes=10
        
        Preds = torch.cat(pred.split(n_classes, dim = 1))
        Labels = torch.cat([label for _ in range(net.N)])
        
        loss = criterion(Preds.to(device), Labels.to(device)) * net.N
        loss.backward()
        
        optimizer.step()
        optimizer.param_groups[0]['lr'] = scheduler.triangle_scheduler(batch_id, epoch)
        
        train_loss += loss.item()
    
    epoch_loss = train_loss/len(train_loader)

    return epoch_loss


def evaluate(net, train_loader, device, dataset='CIFAR100'):
    criterion = nn.CrossEntropyLoss()
    net.eval()
    
    
    train_loss = 0
    acc = []

    for batch_id, (data, label) in enumerate(train_loader):        
        pred = net(data.to(device))
        
        if dataset == 'CIFAR100':
            n_classes=100
        else:
            n_classes=10
        
        Preds = torch.cat(pred.split(n_classes, dim = 1))
        Labels = torch.cat([label for _ in range(net.N)])

        acc.append(((Preds.cpu().argmax(dim=1).reshape(-1) == Labels.cpu().reshape(-1))).float().mean().item())
        
        loss = criterion(Preds.to(device), Labels.to(device)) * net.N
        
        train_loss += loss.item()
    
    net.train()
    
    return train_loss/len(train_loader), np.mean(acc)




In [None]:
os.makedirs()

In [1]:
import os

In [None]:
os.mkdir('Case1/')

N = 8
dataset = 'SVHN'
device = 'cuda:2'
batch_size = 256

train_loader, test_loader = make_loaders(26, batch_size=batch_size, dataset=dataset)

batches_in_epoch = len(train_loader)

epochs = 2001


for lr in [1e-4]:
    
    knots= [0, 1, 2]
    vals= [lr, lr, lr]
                   
    
    best_loss_train = 100
    
    
    net = MultipleShallowNets(
                              [ResNet(dataset, mode='resnet9', seed=26+index)
                               for index in range(N)],
                               kernel_size=3).to(device)

    logs = 'Case1/'
    
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(net.parameters(), lr=0, momentum=0.9, weight_decay=0)
    
    
    tlr = TriangleLR(optimizer, batches_in_epoch, knots, vals)

    df = []

    for epoch in range(epochs):

        train_loss = train_scheduler(net, train_loader, optimizer, tlr, epoch, device, dataset=dataset)

        if train_loss < best_loss_train:
            best_loss_train =train_loss

            path = logs + f'ResNet9_{lr}_reg0.torch'
            torch.save(net.state_dict(), path)

        if epoch % 5 == 0:
            test_loss, test_acc = evaluate(net, test_loader, device, dataset=dataset)

            print('Epoch', epoch + 1, 'Train loss', train_loss/N, 'Test loss', test_loss/N, 'Test acc', test_acc)

            df.append([epoch, train_loss/N, test_loss/N, test_acc])

            df_ = pd.DataFrame(df)
            df_.columns = ['epoch', 'train_loss', 'test_loss', 'test_acc']

            df_.to_csv(f'{logs}ResNet9_{lr}_reg0.csv', index=False)
            
            if epoch in [200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200]:
                path = logs + f'ResNet9_{lr}_reg0.torch'
                torch.save(net.state_dict(), path)

    print('!!!!!!!!!!!!NET CONVERGED', 'Seed:', seed, 'Best loss:', best_loss_train/N)
    
device = 'cuda:1'
N = 8
folder = 'Case1'

models_path = f'{folder}/Models/'
if not os.path.isdir(models_path):
    os.mkdir(models_path)

net = MultipleShallowNets([ResNet('SVHN', mode='resnet9', seed=26+index) for index in range(N)], kernel_size=3).to(device)
net.load_state_dict(torch.load(f'{folder}/ResNet9_0.0001_reg0.torch', map_location=device))

W = factorize_net(net)
for i in range(N):
    model = ResNet('SVHN', mode='resnet9', seed=26)
    model.from_vector(W[:, i])
    torch.save(model.state_dict(), models_path + f'model{i}.torch')

HBox(children=(FloatProgress(value=0.0, description='Epoch: 0', max=287.0, style=ProgressStyle(description_wid…

Epoch 1 Train loss 2.331176830085728 Test loss 2.2580858235265695 Test acc 0.17669209910958422


HBox(children=(FloatProgress(value=0.0, description='Epoch: 1', max=287.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 2', max=287.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 3', max=287.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 4', max=287.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 5', max=287.0, style=ProgressStyle(description_wid…

Epoch 6 Train loss 1.8938416379669403 Test loss 1.8274339893284965 Test acc 0.425085819235035


HBox(children=(FloatProgress(value=0.0, description='Epoch: 6', max=287.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 7', max=287.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 8', max=287.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 9', max=287.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 10', max=287.0, style=ProgressStyle(description_wi…

Epoch 11 Train loss 1.4959249812136128 Test loss 1.4429262502520692 Test acc 0.5589057797310399


HBox(children=(FloatProgress(value=0.0, description='Epoch: 11', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 12', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 13', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 14', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 15', max=287.0, style=ProgressStyle(description_wi…

Epoch 16 Train loss 1.192174022621394 Test loss 1.1522153896443985 Test acc 0.6621098103476506


HBox(children=(FloatProgress(value=0.0, description='Epoch: 16', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 17', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 18', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 19', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 20', max=287.0, style=ProgressStyle(description_wi…

Epoch 21 Train loss 0.9790710908610646 Test loss 0.9786285600241493 Test acc 0.7164304466808543


HBox(children=(FloatProgress(value=0.0, description='Epoch: 21', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 22', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 23', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 24', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 25', max=287.0, style=ProgressStyle(description_wi…

Epoch 26 Train loss 0.8323155298880999 Test loss 0.837547280624801 Test acc 0.7626326452283299


HBox(children=(FloatProgress(value=0.0, description='Epoch: 26', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 27', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 28', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 29', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 30', max=287.0, style=ProgressStyle(description_wi…

Epoch 31 Train loss 0.7316996197667271 Test loss 0.7493292011466681 Test acc 0.7840600107230392


HBox(children=(FloatProgress(value=0.0, description='Epoch: 31', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 32', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 33', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 34', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 35', max=287.0, style=ProgressStyle(description_wi…

Epoch 36 Train loss 0.6606868522092441 Test loss 0.6872992930459041 Test acc 0.8031351835119958


HBox(children=(FloatProgress(value=0.0, description='Epoch: 36', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 37', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 38', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 39', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 40', max=287.0, style=ProgressStyle(description_wi…

Epoch 41 Train loss 0.6079459907908888 Test loss 0.6322909225435818 Test acc 0.8158348477354237


HBox(children=(FloatProgress(value=0.0, description='Epoch: 41', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 42', max=287.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 43', max=287.0, style=ProgressStyle(description_wi…

# Factorize