In [35]:
from helpers.data import get_dataloaders
from helpers.train import TrainingManager
from helpers.loss_accuracy import accuracy
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.nn import CrossEntropyLoss
import models.resnet
import random

In [36]:
# parser = argparse.ArgumentParser()
# parser.add_argument('-in_c', type=int, required=True)
# parser.add_argument('-r', type=int, required=True, nargs=3)
# parser.add_argument('-b', type=int, required=True, nargs=3)
# parser.add_argument('-c', type=int, required=True, nargs=3)
# parser.add_argument('-bs', type=int, required=True)
# parser.add_argument('-lr', type=float, required=True)
# parser.add_argument('--trial', dest='is_trial', action='store_true')
# parser.add_argument('--load', dest='load', action='store_true')
# parser.set_defaults(is_trial=False, load=False)
# args = parser.parse_args()
# in_planes_parameter = args.in_c
# repeats_parameter = args.r
# num_blocks_parameters = args.b
# num_channels_parameters = args.c
# is_trial = args.is_trial
# load = args.load
# lr = args.lr
# bs = args.bs

In [37]:
trainloader, testloader = get_dataloaders('cifar10', 8)
device = torch.device('cuda')

Files already downloaded and verified


In [57]:
my_models = [
    models.resnet.resnet110() for i in range(3)
]
for model in my_models[:-1]:
    model.load_state_dict(torch.load('pretrained_models/resnet110-1d1ed7c2_new.th'))

In [58]:
class MyModel(nn.Module):
    def __init__(self, my_models, repeats):
        super(MyModel, self).__init__()
        self.repeats = repeats
        self.start_weights = nn.Parameter(torch.zeros(1, len(my_models)))
        self.models = nn.ModuleList(my_models)
        self.feature_calculator = nn.Sequential(
            nn.AdaptiveAvgPool2d((3,3)), nn.Flatten(), 
            nn.Linear(3*3*64, 16), nn.ReLU(inplace=True))
        self.weights_calculator = nn.Sequential(
            nn.Linear(16*len(my_models), len(my_models)))
        self.softmax = nn.Softmax(dim=1)
    
    def combine(self, current_weights, outputs):
        return sum([
            outputs[i] * current_weights[:, i].reshape(-1, *([1] * (len(outputs[i].shape) - 1)))
            for i in range(len(outputs))
        ])
    
    def forward(self, x):
        repeats = self.repeats if random.uniform(0,1) < .5 else random.randint(1, 2 * self.repeats)
        current_weights = self.softmax(self.start_weights.expand(x.shape[0], -1))
        for i in range(repeats):
            x_ = x
            x_ = self.combine(current_weights, [model.pre_layer(x_) for model in self.models])
            x_ = self.combine(current_weights, [model.layer1(x_) for model in self.models])
            x_ = self.combine(current_weights, [model.layer2(x_) for model in self.models])
            last_x_ = [model.layer3(x_) for model in self.models]
            x_ = self.combine(current_weights, last_x_)
            res = self.combine(current_weights, [model.classifier(x_) for model in self.models])
            features = torch.cat([self.feature_calculator(x3) for x3 in last_x_], dim=1)
            current_weights = self.softmax(self.weights_calculator(features))
        
        return res

In [60]:
model = MyModel(my_models, 2)
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=1e-5, 
                      momentum=0.9, nesterov=True)
lr_scheduler =optim.lr_scheduler.CyclicLR(optimizer, 1e-9, 1e-5)

In [61]:
trial_name = f"resnet110_multiple_networks_feedback"

In [62]:
tm = TrainingManager(trial_name, load=False, is_trial=False)

In [63]:
tm.train(model, optimizer,
         trainloader, testloader,
         CrossEntropyLoss(), CrossEntropyLoss(),
         accuracy, accuracy, lr_scheduler=lr_scheduler, device=device, no_iterations=10000)

  0%|          | 0/10000 [00:00<?, ?it/s]

Start training trial: [34mresnet110_multiple_networks_feedback[0m [31mis_trial[0m


{tr_loss: 3.40320, tr_acc: 0.18939, te_loss: 792.56018, te_acc: 0.19697}:   2%|▏         | 170/10000 [01:19<1:16:57,  2.13it/s] 


KeyboardInterrupt: 