<a href="https://colab.research.google.com/github/deepeshhada/SABR/blob/master/train_jayant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import os

import numpy as np
import pandas as pd
import scipy.io as io

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 64

In [0]:
# set dataset from: CUB, SUN, AWA, AWA2, APY
_dataset = "CUB"
data_root = "./drive/My Drive/ZSL Datasets/" + _dataset + "/"

In [0]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, features, labels, class_embeddings):
        self.features = features
        self.labels = labels
        self.class_embeddings = class_embeddings
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        label = self.labels[index]
        return (self.features[index], label, class_embeddings[label])

In [0]:
# load mat files
res101 = io.loadmat(data_root + "res101.mat")
att_splits = io.loadmat(data_root + "att_splits.mat")

resnet_features = res101['features'].T
class_labels = res101['labels']
class_embeddings = att_splits['att'].T

# print((att_splits['trainval_loc']).reshape(-1))

In [7]:
def generate_splits(loc, shuffle=False):
    indices = att_splits[loc].reshape(-1) - 1
    features = resnet_features[indices]
    labels = class_labels[indices].reshape(-1) - 1

    split = Dataset(
        features=features,
        labels=labels,
        class_embeddings=class_embeddings
    )

    dataloader = torch.utils.data.DataLoader(
        dataset=split,
        batch_size=batch_size,
        shuffle=shuffle
    )

    return split, dataloader


train_set, trainloader = generate_splits(loc='trainval_loc', shuffle=True)
seen_test_set, seen_testloader = generate_splits(loc='test_seen_loc', shuffle=False)
unseen_test_set, unseen_testloader = generate_splits(loc='test_unseen_loc', shuffle=False)
print(train_set.labels[0])

196


In [0]:
unseen_labels = np.unique(unseen_test_set.labels)
seen_labels = np.unique(train_set.labels)

unseen = class_embeddings[unseen_labels]
seen = class_embeddings[seen_labels]

unseen_cy = torch.tensor(unseen, device=device).float()
seen_cy = torch.tensor(seen, device=device).float()
unseen_y = torch.tensor(unseen_labels, device=device).long()
seen_y = torch.tensor(seen_labels, device=device).long()


In [0]:
def normal_initialize(module):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(0.0, 0.02)
        module.bias.data.normal_(0.0, 0.02)
        pass

In [0]:
# use this in sync with the Generator
# Generator class looks similar to the "LatentTransform" class
# the out_features of both the classifier and regressor are hardcoded for now.
# TODO: make the out_features generic.

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=2048, out_features=150, bias=True), # Earlier out_features set to 200 but changed to 150
            nn.Softmax(dim=1)
        )

    def weights_init(self):
        for m in self._modules:
            normal_initialize(self._modules[m])

    def loss(self, true, pred):
        return F.cross_entropy(pred, true)

    def forward(self, input):
        return self.model(input)


class Regressor(nn.Module):
    def __init__(self):
        super(Regressor, self).__init__()
        self.model = nn.Linear(in_features=2048, out_features=312, bias=True)
        self.softmax = nn.Softmax(dim=1)

    def weights_init(self):
        for m in self._modules:
            normal_initialize(self._modules[m])

    # def loss(self, true, pred, b_size):   # to be checked for possible errors..
    #     # true, pred --> batch_size * 312
    #     a = F.normalize(true, p=2, dim=1, eps=1e-8).to(device)
    #     b = F.normalize(pred, p=2, dim=1, eps=1e-8).to(device)
    #     similarity_scores = torch.mm(a, b.T)  # batch * batch
    #     arr = np.arange(0,b_size)
    #     target = torch.tensor(arr).to(device).long()
    #     loss = F.cross_entropy(similarity_scores, target)
    #     return loss

    def loss(self, true, pred):
        similarity_scores = torch.mm(pred, seen_cy.T) # batch * 150
        loss = F.cross_entropy(similarity_scores, true)
        return loss

    def forward(self, x):  # removed c_y from signature
        x = self.model(x)
        return x

        # norm = torch.norm(input=x, p=2, dim=1).detach()
        # x = x.div(norm.expand_as(x))
        # return torch.bmm(x, c_y)

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=624, out_features=2048, bias=True),
            nn.LeakyReLU(negative_slope=0.01, inplace=False),
            nn.Linear(in_features=2048, out_features=2048),
            nn.ReLU(inplace=False)
        )

    def weights_init(self):
        for m in self._modules:
            normal_initialize(self._modules[m])

    def forward(self, x):
        output = self.model(x)
        return output

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=2360, out_features=4096, bias=True),
            nn.LeakyReLU(negative_slope=0.01, inplace=False),
            nn.Linear(in_features=4096, out_features=1),
        )

    def weights_init(self):
        for m in self._modules:
            normal_initialize(self._modules[m])

    def forward(self, x):
        output = self.model(x)
        return output

In [0]:
generator_model = Generator()
discriminator_model = Discriminator()
classifier_model = Classifier()
regressor_model = Regressor()

generator_model = generator_model.to(device)
discriminator_model = discriminator_model.to(device)
classifier_model = classifier_model.to(device)
regressor_model = regressor_model.to(device)

generator_model.weights_init()
discriminator_model.weights_init()
classifier_model.weights_init()
regressor_model.weights_init()

learning_rate = 0.001
gamma = 0.01
beeta = 0.1
lamda = 10
num_epochs = 20

train_params = list(generator_model.parameters()) + list(classifier_model.parameters()) + list(regressor_model.parameters())
G_optimizer = optim.Adam(train_params, lr=learning_rate, betas=(0.5,0.999))
D_optimizer = optim.Adam(discriminator_model.parameters(), lr = learning_rate, betas=(0.5,0.999))


In [14]:
for ep in range(10):
    for i, data in enumerate(trainloader, 0):
        f, l, e = data
        features, labels, embeddings = f.to(device).float(), l.to(device).long(), e.to(device).float()
        discriminator_model.zero_grad()
        b_size = embeddings.shape[0]
        noise = torch.randn(b_size, 312, device=device)
        gen_input = torch.cat((noise, embeddings), dim=1)
        fake_features = generator_model(gen_input)
        align_cls = classifier_model(fake_features)
        cls_loss = classifier_model.loss(labels, align_cls)  # Computing classifier loss
        # print(cls_loss)
        align_reg = regressor_model(fake_features)
        reg_loss = regressor_model.loss(labels, align_reg) # Computing Regressor loss
        # print(reg_loss)

        # Discriminator Loss

        for k in range(2):
            # alpha = torch.rand((batch_size, 1)).to(device) # Random [0,1) from uniform dist.
            # interpolate = alpha*features + (1 - alpha)*fake_features   # b_size * 2048

            # train with real features
            disc_input_real = torch.cat((features, embeddings), dim=1)
            disc_score_real = discriminator_model(disc_input_real)
            disc_loss_real = torch.mean(disc_score_real)

            # train with fake features
            noise = torch.randn(b_size, 312, device=device)
            gen_input = torch.cat((noise, embeddings), dim=1)
            fake_features = generator_model(gen_input)
            disc_input_fake = torch.cat((fake_features, embeddings), dim=1)
            disc_score_fake = discriminator_model(disc_input_fake)
            disc_loss_fake = torch.mean(disc_score_fake)
            
            disc_loss_total = torch.mean(disc_score_fake - disc_score_real)  # Still have to implement grad penalty
            # print(disc_loss_total)
            disc_loss_total.backward(retain_graph=True)
            D_optimizer.step()
        

        # Wasserstein Generator loss
        with torch.autograd.set_detect_anomaly(True):
            generator_model.zero_grad()
            noise = torch.randn(b_size, 312, device=device)
            gen_input = torch.cat((noise, embeddings), dim=1)
            fake_features = generator_model(gen_input)
            disc_input_fake = torch.cat((fake_features, embeddings), dim=1)
            disc_score_fake = discriminator_model(disc_input_fake)
            disc_loss_fake = torch.mean(disc_score_fake)
            gen_loss1 = -disc_loss_fake
            gen_loss = gen_loss1 + beeta *(cls_loss + (gamma * reg_loss))
            # print(gen_loss)
            gen_loss.backward()
            G_optimizer.step()

    print("discriminator loss after", ep+1, " iterations", disc_loss_total.item())
    print("generator loss after", ep+1, " iterations", gen_loss.item())
    print()


discriminator loss after 1  iterations -4236.0498046875
generator loss after 1  iterations -711.643310546875

discriminator loss after 2  iterations -3390.333251953125
generator loss after 2  iterations 4916.939453125

discriminator loss after 3  iterations -1322.494873046875
generator loss after 3  iterations 1124.317138671875

discriminator loss after 4  iterations 725.8442993164062
generator loss after 4  iterations -62.28472900390625

discriminator loss after 5  iterations -1125.94921875
generator loss after 5  iterations 632.861328125

discriminator loss after 6  iterations -76.3913803100586
generator loss after 6  iterations 1744.808837890625

discriminator loss after 7  iterations -334.569580078125
generator loss after 7  iterations 516.9773559570312

discriminator loss after 8  iterations -996.8807373046875
generator loss after 8  iterations 3550.105224609375

discriminator loss after 9  iterations -1357.279541015625
generator loss after 9  iterations 1484.5313720703125

discri

In [15]:
# generating features from the unseen classes using trained generator

x_train = torch.tensor(train_set.features, device=device).float()
y_train = torch.tensor(train_set.labels, device=device).long()

k = -1 # current index
for c_y in unseen_cy:
    k += 1
    embed = c_y.repeat(1, 100).view(100, -1)  # 100 *312
    lab = (unseen_y[k]).repeat(1, 100).view(100) # 100 labels
    rand_noise = torch.randn(100, 312, device=device)  # generate 100 features
    gen_inp = torch.cat((rand_noise, embed), dim=1)
    generated = generator_model(gen_inp)  # 100 * 2048 : 100 features generated 
    x_train = torch.cat((x_train, generated), dim=0)
    y_train = torch.cat((y_train, lab), dim=0)

print(x_train.shape)
print(y_train.shape)
# print(torch.unique(y_train))

torch.Size([12057, 2048])
torch.Size([12057])


In [0]:
class Final_Classifier(nn.Module):
    def __init__(self):
        super(Final_Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(2048, 200),
            nn.Softmax(dim=1)
        )
    
    def weights_init(self):
        for m in self._modules:
            normal_initialize(self._modules[m])

    # to be written
    def compute_per_class_accuracy_gzsl(self, true_labels, pred_labels):
        per_class_acc = 0.0
        return per_class_acc


    def forward(self, x):
        output = self.model(x)
        return output


In [18]:
softmax_cls = Final_Classifier()
softmax_cls = softmax_cls.to(device)
softmax_cls.weights_init()

num_iters = 50
lr = 0.001
cls_optimizer = optim.Adam(softmax_cls.parameters(), lr=lr, betas=(0.5,0.999))

for ep in range(num_iters):
    final_preds = softmax_cls(x_train)
    loss = F.cross_entropy(final_preds, y_train)
    cls_optimizer.zero_grad()
    loss.backward(retain_graph=True)
    cls_optimizer.step()
    
    print("loss after ", ep + 1, "iters ", loss.item())


loss after  1 iters  5.298337459564209
loss after  2 iters  5.295351028442383
loss after  3 iters  5.285797595977783
loss after  4 iters  5.2744832038879395
loss after  5 iters  5.263491153717041
loss after  6 iters  5.260190963745117
loss after  7 iters  5.246638774871826
loss after  8 iters  5.228035926818848
loss after  9 iters  5.226659297943115
loss after  10 iters  5.2174072265625
loss after  11 iters  5.2051920890808105
loss after  12 iters  5.203485012054443
loss after  13 iters  5.199697494506836
loss after  14 iters  5.189092636108398
loss after  15 iters  5.185328483581543
loss after  16 iters  5.172701358795166
loss after  17 iters  5.1678972244262695
loss after  18 iters  5.161052227020264
loss after  19 iters  5.160949230194092
loss after  20 iters  5.159963130950928
loss after  21 iters  5.155332565307617
loss after  22 iters  5.147940635681152
loss after  23 iters  5.141739845275879
loss after  24 iters  5.148510932922363
loss after  25 iters  5.1422624588012695
loss af