In [None]:
"""Main script for Maximal Correlation
"""

mode = "tiny_imagenet" #options are "cifar," "dogs," "tiny_imagenet"
num_source_samps = 250 #recommend 500 for Cifar, 50 for Dogs, and 500 for tiny_imagenet
num_target_samps = 5

if mode == "cifar":
    num_classes = 2
elif mode == "dogs":
    num_classes = 5
elif mode == "tiny_imagenet":
    num_classes = 5
else:
     raise Exception('Invalid dataset type')


import torch
import os
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression
import tqdm
import nets
import datasets

print("Extracting datasets...")
trainloader_source, trainloader_target, testloader = datasets.generate_dataset(mode, num_source_samps, num_target_samps)

all_nets = []
for i in range(len(trainloader_source)):
    net = nets.generate_net(mode)
    net.load_state_dict(torch.load("{}-shot/net_{}.pt".format(num_target_samps, i)))
    for param in net.parameters():
        param.requires_grad = False
    all_nets.append(net) #net[i][0] is just the first part up to the penultimate layer

In [None]:
class Moe_Gating(nn.Module):
    def __init__(self, all_nets):
        super(Moe_Gating, self).__init__()
        self.all_nets = all_nets
        self.gate = nn.Linear(840, 10)
#         self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(84, 5)
        
    def forward(
        self,
        inputs
    ):
        outputs = []
        for i in range(len(self.all_nets)):
            outputs.append(self.all_nets[i][0](inputs))
        
        inputs_gate  = torch.cat(outputs, dim=1)
        inputs_gate = inputs_gate.detach()
        
        
        outputs_gate = self.gate(inputs_gate.float())
        outputs_gate_softmax = torch.nn.functional.sigmoid(outputs_gate)
        
        sequence_outputs = torch.stack(outputs, dim=-1)
        sequence_outputs = torch.sum(outputs_gate_softmax.unsqueeze(1) * sequence_outputs, dim=-1)

#         sequence_outputs = self.dropout(sequence_outputs)
        logits = self.classifier(sequence_outputs)
        
        return logits

In [None]:
lifa = Moe_Gating(all_nets)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, lifa.parameters()), lr=5e-3)

In [None]:
best_acc = 0.0

for epoch in tqdm.tqdm(range(300)):
    for inputs, labels in trainloader_target:
        optimizer.zero_grad()
        
        outputs = lifa(inputs)
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()

    for inputs, labels in testloader:
        outputs = lifa(inputs)
        preds = torch.argmax(outputs, dim=1)
    
    acc = (sum(labels == preds) / len(labels)).item()
    if acc > best_acc:
        best_acc = acc
        print(round(best_acc, 3))
        
print('Finished Training')