In [1]:
import torch
from torch import nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from tqdm import tqdm
import numpy as np
import os

In [2]:
from torch.utils.data import DataLoader

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])
dataset = MNIST('./datasets', train=False, download=True, transform=transform)

valset = torch.utils.data.Subset(dataset, np.arange(10000))
mnistloader = torch.utils.data.DataLoader(valset, batch_size=256, shuffle=True)

In [3]:
from mnist_cnn import MNIST_CNN

model = MNIST_CNN().to('cuda:0')

In [4]:
weights = os.listdir('datasets/PretrainedWeights/raw')
num_models = len([w for w in weights if w.endswith('.pt')])

In [5]:
from train_val import mnist_validation

In [9]:
accs = []

In [10]:
pbar = tqdm(range(num_models))
for i in pbar:
    weights = torch.load(f'datasets/PretrainedWeights/raw/model{i}.pt', map_location='cuda')
    model.load_state_dict(weights)
    model.eval()
    acc = mnist_validation(mnistloader, model)
    pbar.set_postfix({'acc': acc})
    accs.append(acc)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [07:49<00:00,  1.07it/s, acc=0.947]


In [None]:
shape_list = [144, 16, 2304, 16, 25600, 64, 640, 10]

In [None]:
class Dict2Vec(nn.Module):
    def __init__(self, shape_list, hidden=256, reduce='add'):
        assert reduce in ['add', 'cat']
        self.reduce = reduce
        super().__init__()
        self.list = torch.nn.ModuleList([])
        
        for s in shape_list:
            self.list.append(nn.Linear(s, hidden))
        
        self.fc1 = nn.Linear(hidden, hidden) if reduce == 'add' else nn.Linear(hidden * len(shape_list), hidden)
        self.fc2 = nn.Linear(hidden, 1)

    def forward(self, weights):
        x = 0. if self.reduce == 'add' else []
        for i, v in enumerate(weights):
            v = F.relu(self.list[i](v))
            if self.reduce == 'cat':
                x.append(v)
            else:
                x += v
                
        if self.reduce == 'add':
            x /= len(weights)
        else:
            x = torch.cat(x, dim=-1)
            
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x.squeeze(-1)

In [None]:
mlp = Dict2Vec(shape_list, 256).cuda()

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)

In [None]:
from weight_datasets import PretrainedWeights, AdditiveNoise, SignFlip

In [None]:
trainset = PretrainedWeights('datasets/PretrainedWeights/', transform=AdditiveNoise(0.5))

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=256, shuffle=False)

In [None]:
@torch.no_grad()
def validation(dataloader, model):
    corrects = 0
    counts = 0
    for i, data in enumerate(dataloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = [ins.to('cuda:0') for ins in inputs]
        labels = labels.to('cuda:0')
        # forward + backward + optimize
        outputs = model(inputs)
        preds = (outputs > 0.).detach().to(torch.float)
        corrects += (preds == labels).sum()
        counts += inputs[0].shape[0]
    
    return corrects / counts

In [None]:
pbar = tqdm(range(100))
for epoch in pbar:  # loop over the dataset multiple times
    losses = 0.
    counts = 0
    corrects = 0
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = [ins.to('cuda:0') for ins in inputs]
        labels = labels.to('cuda:0')

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = mlp(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        losses += loss.item() * inputs[0].shape[0]
        counts += inputs[0].shape[0]
        preds = (outputs > 0.).detach().to(torch.float)
        corrects += (preds == labels).sum()

    losses /= counts
    train_acc = corrects / counts
    
    val_acc = validation(valloader, mlp)

    pbar.set_postfix({'loss': losses, 'train_acc': train_acc, 'val_acc': val_acc})