### Transfer learning models:

This notebook is about transfer learning models and the reinforcement learning agent to determine whether to use the transfer learner or an active learning policy.

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import src.active_learning as al
import src.viz as viz
import src.reinforcement as rl
import src.data as d 
from src.models import logreg, CNN, AgentRL
# import active_learning as al
# import viz
# import reinforcement as rl
# import data as d 
# from models import logreg, CNN, AgentRL

import importlib as imp
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm_notebook as tqdm
import torchvision.models as tmodels
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.distributions import Categorical

%matplotlib inline

In [2]:
# Get raw datasets - MNIST
d = imp.reload(d)
train_set = dset.MNIST(root='./data', train=True, transform=transforms.ToTensor(),download=False)
test_set = dset.MNIST(root='./data', train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=len(test_set),shuffle=False)

# Get raw dataset - USPS
percent_test = 0.3
usps_batch = 64
usps_set = d.get_usps('usps/usps_all.mat', size=(28,28))
usps_x, usps_y, usps_test_x, usps_test_y = al.get_dataset_split(usps_set,int(len(usps_set)*percent_test))
usps_test_loader = torch.utils.data.DataLoader(dataset=torch.utils.data.TensorDataset(usps_test_x, usps_test_y), \
                                               batch_size=len(usps_test_y),shuffle=False)
usps_train_loader = torch.utils.data.DataLoader(dataset=torch.utils.data.TensorDataset(usps_x, usps_y), \
                                               batch_size=usps_batch,shuffle=True)


In [3]:
# Get x/y split for the test set 
train_x, train_y, val_x, val_y = al.get_dataset_split(train_set)
test_x,test_y = al.get_xy_split(test_loader)

In [4]:
# Get resnet from pytorch (trained on imagenet)
model_in = tmodels.resnet18(pretrained=True)
# model_in

In [5]:
# Get pretrained usps handwritten classifier 
# Warnings causes by pytorch versioning issues with 0.3.1 on conda GPU vs 0.3.0 on conda CPU
model_usps = torch.load('paul_models/usps_model.pt')
print('USPS training accuracy:',al.accuracy(model_usps, usps_x, usps_y))
print('USPS test accuracy:',al.accuracy(model_usps, usps_test_x, usps_test_y))



USPS training accuracy: 0.966493506494
USPS test accuracy: 0.953939393939


### Now that the transfer learners are loaded we are ready to reinforcement learn:

for a given state of the model.. meaning all potential training points, labeled and unlabeled and our predictions on those, is it better to active learn with policy $y$ and continue training our original model or is it better to transfer a different model and retrain the final layer.

## Q's for Pavlos:

How should we try to integrate the different transfer agent? Do we just check its performance on points we miss with the other? How should we evaluate its contribution (provided it has any)?

In [6]:
# Make the RL agent to interact with the environment 
class AgentRL(nn.Module):
    def __init__(self, inpt_dim, hidden_dim, num_policies):
        super(AgentRL, self).__init__()
        self.num_policies = num_policies
        self.inner_layer = nn.Linear(inpt_dim, hidden_dim)
        self.outer_layer = nn.Linear(hidden_dim, num_policies)
        self.rewards = []
        self.saved_log_probs = []
        
    def forward(self, x):
        x = x.view(1,-1)
        x = F.relu(self.inner_layer(x))
        x = self.outer_layer(x)
        return F.softmax(x, dim=1)

In [7]:
agent = AgentRL(int(len(train_x)*10),128, 6) # 6 for the 5 AL policies and one for the TL policy
optimizer_rl = optim.Adam(agent.parameters(), lr=1e-2)
