In [1]:
import torch
from helpers.data import get_dataloaders
from helpers.train import TrainingManager
from helpers.loss_accuracy import accuracy
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import numpy as np
from tqdm import tqdm

In [2]:
_Conv2d = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1, bias=False)
_BN2d = nn.BatchNorm2d
_act = partial(nn.ReLU, inplace=True)

In [3]:
trainloader, testloader = get_dataloaders('cifar10_resnet110_output', 32)

In [4]:
device = torch.device('cuda')
is_trial = True
load = False

In [5]:
class BasicBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.seq = nn.Sequential(_Conv2d(in_planes, planes, stride=stride), _BN2d(planes), _act(), _Conv2d(planes, planes), _BN2d(planes))
        self.shortcut = \
            nn.Sequential(_Conv2d(in_planes, planes, kernel_size=1, stride=stride), _BN2d(planes)) if stride != 1 or in_planes != planes else nn.Sequential() 
    def forward(self, x):
        return F.relu(self.seq(x) + self.shortcut(x))

class WeightChangerBlock(nn.Module):
    def __init__(self, in_c, w_size):
        super(WeightChangerBlock, self).__init__()
        self.w_size = w_size
        self.tanh = nn.Tanh()
        self.avgpool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten())
        self.maxpool = nn.Sequential(nn.AdaptiveMaxPool2d(1), nn.Flatten())
        self.w_creator = nn.Linear(in_c*2, np.prod(w_size))
    def forward(self, last_output):
        features = self.tanh(torch.cat((self.avgpool(last_output), self.maxpool(last_output)), 1))
        return self.w_creator(features).reshape((-1, *self.w_size))
        
class BasicBlockFB(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlockFB, self).__init__()
        assert(in_planes == planes)
        self.stride = stride
        self.weight_changer_block = WeightChangerBlock(planes, (in_planes, planes, 3, 3))
        self.seq = nn.Sequential(_BN2d(planes), _act(), _Conv2d(planes, planes), _BN2d(planes))
        self.shortcut = \
            nn.Sequential(_Conv2d(in_planes, planes, kernel_size=1, stride=stride), _BN2d(planes)) if stride != 1 or in_planes != planes else nn.Sequential() 

    def metaconv(self, x, w):
        '''
        Forward pass of a meta convolution layer.
        Note that we do not conv all batch with the same set of conv weights.
        The trick is to use group convolutions for convolving each input with its own set of conv weights.
        '''
        holdx, holdw = x, w
        w = w.reshape(w.shape[0] * w.shape[1], *w.shape[2:])
        x = x.reshape(1, x.shape[0] * x.shape[1], *x.shape[2:])
        out = F.conv2d(x, w, None, stride=self.stride, groups=holdx.shape[0], padding=1)
        return out.reshape(holdx.shape[0], holdw.shape[1], holdx.shape[2], holdx.shape[3])
    
    def forward(self, x, last_output):
        w = self.weight_changer_block(last_output)
        out = self.metaconv(x, w)
        out = F.relu(self.seq(out) + self.shortcut(x), inplace=True)
        return out

class SuffixModel(nn.Module):
    def __init__(self, in_planes, planes, num_loops, num_classes):
        super(SuffixModel, self).__init__()
        self.num_loops = num_loops
        self.block = BasicBlockFB(in_planes, planes)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(),
            nn.Linear(planes, num_classes)
        )
        self.first_last_output = nn.Parameter(torch.zeros(1, 64, 6, 6))
        
    def forward(self, x):
        last_output = self.first_last_output.expand(x.shape[0], -1, -1, -1)
        for i in range(self.num_loops):
            last_output = self.block(x, last_output)

        return self.classifier(last_output)

In [7]:
model = SuffixModel(64, 64, 1, 10).to(device)
optimizer = Adam(model.parameters(), 2e-4)

In [8]:
trial_name = f'resnet110_pretrained_feedback_1_with_parameter_as_first_inputs'

In [9]:
tm = TrainingManager(trial_name, load=load, is_trial=is_trial)

In [10]:
tm.train(model, optimizer, 
         trainloader, testloader, 
         CrossEntropyLoss(), CrossEntropyLoss(), 
         accuracy, accuracy, device=device, no_iterations=100000)

  0%|          | 0/100000 [00:00<?, ?it/s]

Start training trial: [34mresnet110_pretrained_feedback_100_with_parameter_as_first_inputs[0m [31mis_trial[0m


{tr_loss: 0.80762, tr_acc: 0.88625, te_loss: 0.58180, te_acc: 0.91063}:   0%|          | 152/100000 [00:51<9:21:48,  2.96it/s]


KeyboardInterrupt: 