<a href="https://colab.research.google.com/github/deepeshhada/AGSN/blob/master/train_jayant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import os

import math
import statistics 

import numpy as np
import pandas as pd
import scipy.io as io

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 64

In [0]:
# set dataset from: CUB, SUN, AWA, AWA2, APY
_dataset = "CUB"
data_root = "./drive/My Drive/ZSL Datasets/" + _dataset + "/"

In [0]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, features, labels, class_embeddings):
        self.features = features
        self.labels = labels
        self.class_embeddings = class_embeddings
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        label = self.labels[index]
        return (self.features[index], label, class_embeddings[label])

In [0]:
# load mat files
res101 = io.loadmat(data_root + "res101.mat")
att_splits = io.loadmat(data_root + "att_splits.mat")

resnet_features = res101['features'].T
class_labels = res101['labels']
class_embeddings = att_splits['att'].T

# print((att_splits['trainval_loc']).reshape(-1))

In [7]:
def generate_splits(loc, shuffle=False):
    indices = att_splits[loc].reshape(-1) - 1
    features = resnet_features[indices]
    labels = class_labels[indices].reshape(-1) - 1

    split = Dataset(
        features=features,
        labels=labels,
        class_embeddings=class_embeddings
    )

    dataloader = torch.utils.data.DataLoader(
        dataset=split,
        batch_size=batch_size,
        shuffle=shuffle
    )

    return split, dataloader


train_set, trainloader = generate_splits(loc='trainval_loc', shuffle=True)
seen_test_set, seen_testloader = generate_splits(loc='test_seen_loc', shuffle=False)
unseen_test_set, unseen_testloader = generate_splits(loc='test_unseen_loc', shuffle=False)
print(train_set.labels[0])

196


In [8]:
unseen_labels = np.unique(unseen_test_set.labels)
seen_labels = np.unique(train_set.labels)

unseen = class_embeddings[unseen_labels]
seen = class_embeddings[seen_labels]

unseen_cy = torch.tensor(unseen, device=device).float()
seen_cy = torch.tensor(seen, device=device).float()
unseen_y = torch.tensor(unseen_labels, device=device).long()
seen_y = torch.tensor(seen_labels, device=device).long()
all_cy = torch.tensor(class_embeddings, device=device).float()
print(all_cy.shape)


torch.Size([200, 312])


In [0]:
def normal_initialize(module):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(0.0, 0.02)
        module.bias.data.normal_(0.0, 0.02)
        pass

In [0]:
# use this in sync with the Generator
# Generator class looks similar to the "LatentTransform" class
# the out_features of both the classifier and regressor are hardcoded for now.
# TODO: make the out_features generic.

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=2048, out_features=200, bias=True), # Earlier out_features set to 200 but changed to 150
            nn.Softmax(dim=1)
        )

    def weights_init(self):
        for m in self._modules:
            normal_initialize(self._modules[m])

    def loss(self, true, pred):
        return F.cross_entropy(pred, true)

    def forward(self, input):
        return self.model(input)


class Regressor(nn.Module):
    def __init__(self):
        super(Regressor, self).__init__()
        self.model = nn.Linear(in_features=2048, out_features=312, bias=True)
        self.softmax = nn.Softmax(dim=1)

    def weights_init(self):
        for m in self._modules:
            normal_initialize(self._modules[m])


    def loss(self, true, pred):
        similarity_scores = torch.mm(pred, all_cy.T) # batch * 200
        loss = F.cross_entropy(similarity_scores, true)
        return loss

    def forward(self, x):  # removed c_y from signature
        x = self.model(x)
        return x

        # norm = torch.norm(input=x, p=2, dim=1).detach()
        # x = x.div(norm.expand_as(x))
        # return torch.bmm(x, c_y)

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=624, out_features=2048, bias=True),
            nn.BatchNorm1d(num_features=2048),
            nn.LeakyReLU(negative_slope=0.01, inplace=False),
            nn.Linear(in_features=2048, out_features=2048),
            nn.ReLU(inplace=False)
        )

    def weights_init(self):
        for m in self._modules:
            normal_initialize(self._modules[m])

    def forward(self, x):
        output = self.model(x)
        return output

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=2360, out_features=4096, bias=True),
            nn.BatchNorm1d(num_features=4096),
            nn.LeakyReLU(negative_slope=0.01, inplace=False),
            nn.Linear(in_features=4096, out_features=1),
        )

    def weights_init(self):
        for m in self._modules:
            normal_initialize(self._modules[m])

    def forward(self, x):
        output = self.model(x)
        return output

In [0]:
generator_model = Generator()
discriminator_model = Discriminator()
classifier_model = Classifier()
regressor_model = Regressor()

generator_model = generator_model.to(device)
discriminator_model = discriminator_model.to(device)
classifier_model = classifier_model.to(device)
regressor_model = regressor_model.to(device)

generator_model.weights_init()
discriminator_model.weights_init()
classifier_model.weights_init()
regressor_model.weights_init()

learning_rate = 0.001
gamma = 0.01
beeta = 0.1
lamda = 10
num_epochs = 20

train_params = list(generator_model.parameters()) + list(classifier_model.parameters()) + list(regressor_model.parameters())
G_optimizer = optim.Adam(train_params, lr=learning_rate, betas=(0.5,0.999))
D_optimizer = optim.Adam(discriminator_model.parameters(), lr = learning_rate, betas=(0.5,0.999))


In [14]:
for ep in range(50):
    for i, data in enumerate(trainloader, 0):
        f, l, e = data
        features, labels, embeddings = f.to(device).float(), l.to(device).long(), e.to(device).float()
        
        b_size = embeddings.shape[0]
        noise = torch.randn(b_size, 312, device=device)
        gen_input = torch.cat((noise, embeddings), dim=1)
        fake_features = generator_model(gen_input)
        align_cls = classifier_model(fake_features)
        cls_loss = classifier_model.loss(labels, align_cls)  # Computing classifier loss
        # print(cls_loss)
        align_reg = regressor_model(fake_features)
        reg_loss = regressor_model.loss(labels, align_reg) # Computing Regressor loss
        # print(reg_loss)

        # Discriminator Loss

        for k in range(2):
            
            for parameter in discriminator_model.parameters(): #Weight Clip
                parameter.data.clamp_(-0.01, 0.01)

            discriminator_model.zero_grad()
            # train with real features
            disc_input_real = torch.cat((features, embeddings), dim=1)
            disc_score_real = discriminator_model(disc_input_real)
            disc_loss_real = torch.mean(disc_score_real)

            # train with fake features
            noise = torch.randn(b_size, 312, device=device)
            gen_input = torch.cat((noise, embeddings), dim=1)
            fake_features = generator_model(gen_input)
            disc_input_fake = torch.cat((fake_features, embeddings), dim=1)
            disc_score_fake = discriminator_model(disc_input_fake)
            disc_loss_fake = torch.mean(disc_score_fake)
            
            disc_loss_total = torch.mean(disc_score_fake - disc_score_real)  # Still have to implement grad penalty
            # print(disc_loss_total)
            disc_loss_total.backward(retain_graph=True)
            D_optimizer.step()
        

        # Wasserstein Generator loss
        with torch.autograd.set_detect_anomaly(True):
            generator_model.zero_grad()
            noise = torch.randn(b_size, 312, device=device)
            gen_input = torch.cat((noise, embeddings), dim=1)
            fake_features = generator_model(gen_input)
            disc_input_fake = torch.cat((fake_features, embeddings), dim=1)
            disc_score_fake = discriminator_model(disc_input_fake)
            disc_loss_fake = torch.mean(disc_score_fake)
            gen_loss1 = -disc_loss_fake
            gen_loss = gen_loss1 + beeta *(cls_loss + (gamma * reg_loss))
            # print(gen_loss)
            gen_loss.backward()
            G_optimizer.step()

    print("discriminator loss after", ep+1, " iterations", disc_loss_total.item())
    print("generator loss after", ep+1, " iterations", gen_loss.item())
    print()


discriminator loss after 47  iterations -0.04899667948484421
generator loss after 47  iterations 0.5490357279777527

discriminator loss after 48  iterations -0.04561380669474602
generator loss after 48  iterations 0.5345721244812012

discriminator loss after 49  iterations -0.03150445967912674
generator loss after 49  iterations 0.526042640209198

discriminator loss after 50  iterations -0.0401001051068306
generator loss after 50  iterations 0.5421980023384094



In [15]:
# generating features from the unseen classes using trained generator

x_train = torch.tensor(train_set.features, device=device).float()
y_train = torch.tensor(train_set.labels, device=device).long()

k = -1 # current index
for c_y in unseen_cy:
    k += 1
    embed = c_y.repeat(1, 100).view(100, -1)  # 100 *312
    lab = (unseen_y[k]).repeat(1, 100).view(100) # 100 labels
    rand_noise = torch.randn(100, 312, device=device)  # generate 100 features
    gen_inp = torch.cat((rand_noise, embed), dim=1)
    generated = generator_model(gen_inp)  # 100 * 2048 : 100 features generated 
    x_train = torch.cat((x_train, generated), dim=0)
    y_train = torch.cat((y_train, lab), dim=0)

print(x_train.shape)
print(y_train.shape)
# print(torch.unique(y_train))

torch.Size([12057, 2048])
torch.Size([12057])


In [0]:
class Final_Classifier(nn.Module):
    def __init__(self):
        super(Final_Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(2048, 200),
            nn.Softmax(dim=1)
        )
    
    def weights_init(self):
        for m in self._modules:
            normal_initialize(self._modules[m])

    # computes 200*200 confusion matrix for relevant classes
    
    def compute_confusion_matrix(self, inputs, classes):
        per_class_acc = 0.0
        nb_classes = 200
        confusion_matrix = torch.zeros(nb_classes, nb_classes)
        with torch.no_grad():
            outputs = self.model(inputs)
            _, preds = torch.max(outputs, 1)
            for t, p in zip(classes.view(-1), preds.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1

        
        return(confusion_matrix)

    def forward(self, x):
        output = self.model(x)
        return output


In [17]:
softmax_cls = Final_Classifier()
softmax_cls = softmax_cls.to(device)
softmax_cls.weights_init()

num_iters = 100
lr = 0.0009
cls_optimizer = optim.Adam(softmax_cls.parameters(), lr=lr, betas=(0.5,0.999))

for ep in range(num_iters):
    final_preds = softmax_cls(x_train)
    loss = F.cross_entropy(final_preds, y_train)
    cls_optimizer.zero_grad()
    loss.backward(retain_graph=True)
    cls_optimizer.step()
    
    print("loss after ", ep + 1, "iters ", loss.item())


loss after  85 iters  5.137166976928711
loss after  86 iters  5.137012004852295
loss after  87 iters  5.136969089508057
loss after  88 iters  5.136739253997803
loss after  89 iters  5.1366963386535645
loss after  90 iters  5.1365647315979
loss after  91 iters  5.136486053466797
loss after  92 iters  5.138314723968506
loss after  93 iters  5.137510299682617
loss after  94 iters  5.137487888336182
loss after  95 iters  5.137416362762451
loss after  96 iters  5.137307167053223
loss after  97 iters  5.1372175216674805
loss after  98 iters  5.137146472930908
loss after  99 iters  5.137106418609619
loss after  100 iters  5.137054920196533


In [20]:
# Final Test and accuracy computation

# test_model = Final_Classifier()
# test_model = test_model.to(device)
# test_model.weights_init()

# compute confusion matrix for seen and unseen classes separately
x_test_s = torch.tensor(seen_test_set.features, device=device).float()
y_test_s = torch.tensor(seen_test_set.labels, device=device).long()
x_test_u = torch.tensor(unseen_test_set.features, device=device).float()
y_test_u = torch.tensor(unseen_test_set.labels, device=device).long()

cm_unseen = softmax_cls.compute_confusion_matrix(x_test_s, y_test_s)
cm_seen = softmax_cls.compute_confusion_matrix(x_test_u, y_test_u)

# compute per class accuracy matrix

acc_mat_us = ((cm_unseen.diag()/cm_unseen.sum(1)))
acc_mat_s = ((cm_seen.diag()/cm_seen.sum(1)))

unseen_acc = []
seen_acc = []

# Remove Nan's from irrelevant classes

for acc in acc_mat_us:
    if not math.isnan(acc):
        unseen_acc.append(acc)

for acc in acc_mat_s:
    if not math.isnan(acc):
        seen_acc.append(acc)

# Compute per class accuracy

per_class_seen = np.mean(seen_acc)
per_class_unseen = np.mean(unseen_acc)
l =[per_class_seen, per_class_unseen]

print("unseen classs accuray for zsl is :", per_class_unseen)
print("seen classs accuray is :", per_class_seen)
print("harmonic mean for generalized zsl is :", statistics.harmonic_mean(l))


unseen classs accuray for zsl is : 0.23546585
seen classs accuray is : 0.00033333336
harmonic mean for generalized zsl is : 0.0006657242960983318
