<a href="https://colab.research.google.com/github/deepeshhada/SABR/blob/master/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import os

import numpy as np
import pandas as pd
import scipy.io as io

import torch
import torch.nn as nn

In [0]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 64

In [0]:
# set dataset from: CUB, SUN, AWA, AWA2, APY
_dataset = "CUB"
data_root = "./drive/My Drive/Deep Learning/datasets/ZSL Datasets/" + _dataset + "/"

In [0]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, features, labels, class_embeddings):
        self.features = features
        self.labels = labels
        self.class_embeddings = class_embeddings
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        label = self.labels[index]
        return (self.features[index], label, class_embeddings[label])

In [0]:
# load mat files
res101 = io.loadmat(data_root + "res101.mat")
att_splits = io.loadmat(data_root + "att_splits.mat")

resnet_features = res101['features'].T
class_labels = res101['labels']
class_embeddings = att_splits['att'].T

In [0]:
def generate_splits(loc, shuffle=False):
    indices = att_splits[loc].reshape(-1) - 1
    features = resnet_features[indices]
    labels = class_labels[indices].reshape(-1) - 1

    split = Dataset(
        features=features,
        labels=labels,
        class_embeddings=class_embeddings
    )

    dataloader = torch.utils.data.DataLoader(
        dataset=split,
        batch_size=batch_size,
        shuffle=shuffle
    )

    return split, dataloader


train_set, trainloader = generate_splits(loc='trainval_loc', shuffle=True)
seen_test_set, seen_testloader = generate_splits(loc='test_seen_loc', shuffle=False)
unseen_test_set, unseen_testloader = generate_splits(loc='test_unseen_loc', shuffle=False)

In [0]:
# This is the transformation, Ψ, which operates on 2048-dimensional resnet features.
# May have to remove this.

class LatentTransform(nn.Module):
    def __init__(self):
        super(LatentTransform, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=2048, out_features=2048, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=2048, out_features=1024, bias=True),
            nn.ReLU()
        )

    def forward(self, input):
        return self.model(input)

In [0]:
# use this in sync with the Generator
# Generator class looks similar to the "LatentTransform" class
# the out_features of both the classifier and regressor are hardcoded for now.
# TODO: make the out_features generic.

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=1024, out_features=200, bias=True),
            nn.Softmax(dim=1)
        )

    def forward(self, input):
        return self.model(input)


class Regressor(nn.Module):
    def __init__(self):
        super(Regressor, self).__init__()
        self.model = nn.Linear(in_features=1024, out_features=312, bias=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, c_y):
        x = self.model(x)
        norm = torch.norm(input=x, p=2, dim=1).detach()
        x = x.div(norm.expand_as(x))
        return torch.bmm(x, c_y)

In [0]:
for i, data in enumerate(trainloader, 0):
    features, labels, embeddings = data
    features, labels, embeddings = features.to(device), labels.to(device), embeddings.to(device)
    # TODO: train the model
    break