In [16]:
import torch
from torch import nn 
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import Image 
#%matplotlib inline
import numpy as np
from utils import ops
import dataset_utils

import torch.nn.functional as F

# Import and split dataset

In [17]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size =  16

In [18]:
dataloaders,dataset_sizes,described_species_labels,n_classes = dataset_utils.get_dataset(image_path='image_dataset/',csv_path='final_dataset.csv',batch_size=batch_size,shuffle_loaders=True)
imsize =64
final_evaluation = False 
if final_evaluation:
    dataloaders['train'] = dataloaders['train_val']
    dataloaders['val'] = dataloaders['test']
    dataset_sizes['train'] = dataset_sizes['train_val']
    dataset_sizes['val'] = dataset_sizes['test']

In [19]:
print(dataset_sizes)

{'train': 13039, 'val': 6955, 'test': 12430, 'train_val': 19994}


# Initialize and train the models

In [20]:
import types
def define_modules(gan_config):
    layers = types.SimpleNamespace()
    if gan_config.apply_g_sn:
        layers.g_conv2d = ops.snconv2d
        layers.g_deconv2d = ops.sndeconv2d
        layers.g_linear = ops.snlinear
        layers.g_embedding = ops.sn_embedding
    else:
        layers.g_conv2d = ops.conv2d
        layers.g_deconv2d = ops.deconv2d
        layers.g_linear = ops.linear
        layers.g_embedding = ops.embedding

    if gan_config.apply_d_sn:
        layers.d_conv2d = ops.snconv2d
        layers.d_deconv2d = ops.sndeconv2d
        layers.d_linear = ops.snlinear
        layers.d_embedding = ops.sn_embedding
    else:
        layers.d_conv2d = ops.conv2d
        layers.d_deconv2d = ops.deconv2d
        layers.d_linear = ops.linear
        layers.d_embedding = ops.embedding

    if gan_config.g_cond_mtd == "cBN":
        layers.g_bn = ops.ConditionalBatchNorm2d
    elif gan_config.g_cond_mtd == "W/O":
        layers.g_bn = ops.batchnorm_2d
    else:
        raise NotImplementedError

    if not gan_config.apply_d_sn:
        layers.d_bn = ops.batchnorm_2d

    if gan_config.g_act_fn == "ReLU":
        layers.g_act_fn = nn.ReLU(inplace=True)
    elif gan_config.g_act_fn == "Leaky_ReLU":
        layers.g_act_fn = nn.LeakyReLU(negative_slope=0.1, inplace=True)
    elif gan_config.g_act_fn == "ELU":
        layers.g_act_fn = nn.ELU(alpha=1.0, inplace=True)
    elif gan_config.g_act_fn == "GELU":
        layers.g_act_fn = nn.GELU()
    elif gan_config.g_act_fn == "Auto":
        pass
    else:
        raise NotImplementedError

    if gan_config.d_act_fn == "ReLU":
        layers.d_act_fn = nn.ReLU(inplace=True)
    elif gan_config.d_act_fn == "Leaky_ReLU":
        layers.d_act_fn = nn.LeakyReLU(negative_slope=0.1, inplace=True)
    elif gan_config.d_act_fn == "ELU":
        layers.d_act_fn = nn.ELU(alpha=1.0, inplace=True)
    elif gan_config.d_act_fn == "GELU":
        layers.d_act_fn = nn.GELU()
    elif gan_config.g_act_fn == "Auto":
        pass
    else:
        raise NotImplementedError
    return layers
    
config = types.SimpleNamespace()
config.d_act_fn = "ReLU"
config.g_act_fn = "ReLU"
config.apply_d_sn= True
config.apply_g_sn= True
config.g_cond_mtd= "cBN"

l=define_modules(config)

In [21]:
import importlib
import modelReACGAN as m

importlib.reload(m)


discriminator = m.Discriminator(imsize,128,True,True,[1],"D2DCE",2048,True,num_classes=1050,d_init="ortho",d_depth=2,mixed_precision=True,MODULES=l)
discriminator.cuda()

generator = m.Generator(100,128,imsize,128,True,[4],"cBN",num_classes=1050,g_init="ortho",g_depth=2,mixed_precision=True,MODULES=l)
generator.cuda()
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(),lr=2e-4,betas=(0.0, 0.999))
#discriminator_optimizer = torch.optim.RMSprop(discriminator.parameters(),lr=1e-5)
#generator_optimizer = torch.optim.RMSprop(generator.parameters(),lr=1e-5)
generator_optimizer = torch.optim.Adam(generator.parameters(),lr=2e-5,betas=(0.0, 0.999))


n_params = dataset_utils.count_trainable_parameters(discriminator);
print(n_params)
n_params = dataset_utils.count_trainable_parameters(generator);
print(n_params)


writer = SummaryWriter()

84520322
55822084


In [22]:
class Data2DataCrossEntropyLoss(torch.nn.Module):
    def __init__(self, num_classes, temperature, m_p, device):
        super(Data2DataCrossEntropyLoss, self).__init__()
        self.num_classes = num_classes
        self.temperature = temperature
        self.m_p = m_p
        self.device = device
        self.calculate_similarity_matrix = self._calculate_similarity_matrix()
        self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1)

    def _calculate_similarity_matrix(self):
        return self._cosine_simililarity_matrix

    def _cosine_simililarity_matrix(self, x, y):
        v = self.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def make_index_matrix(self, labels):
        labels = labels.detach().cpu().numpy()
        num_samples = labels.shape[0]
        mask_multi, target = np.ones([self.num_classes, num_samples]), 0.0

        for c in range(self.num_classes):
            c_indices = np.where(labels==c)
            mask_multi[c, c_indices] = target
        return torch.tensor(mask_multi).type(torch.long).to(self.device)

    def remove_diag(self, M):
        h, w = M.shape
        assert h==w, "h and w should be same"
        mask = np.ones((h, w)) - np.eye(h)
        mask = torch.from_numpy(mask)
        mask = (mask).type(torch.bool).to(self.device)
        return M[mask].view(h, -1)

    def forward(self, embed, proxy, label, **_):
        # calculate similarities between sample embeddings
        sim_matrix = self.calculate_similarity_matrix(embed, embed) + self.m_p - 1
        # remove diagonal terms
        sim_matrix = self.remove_diag(sim_matrix/self.temperature)
        # for numerical stability
        sim_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)
        sim_matrix = F.relu(sim_matrix) - sim_max.detach()

        # calculate similarities between sample embeddings and the corresponding proxies
        smp2proxy = self.cosine_similarity(embed, proxy)
        # make false negative removal
        removal_fn = self.remove_diag(self.make_index_matrix(label)[label])
        # apply the negative removal to the similarity matrix
        improved_sim_matrix = removal_fn*torch.exp(sim_matrix)

        # compute positive attraction term
        pos_attr = F.relu((self.m_p - smp2proxy)/self.temperature)
        # compute negative repulsion term
        neg_repul = torch.log(torch.exp(-pos_attr) + improved_sim_matrix.sum(dim=1))
        # compute data to data cross-entropy criterion
        criterion = pos_attr + neg_repul
        return criterion.mean()
def d_hinge(d_logit_real, d_logit_fake):
    return torch.mean(F.relu(1. - d_logit_real)) + torch.mean(F.relu(1. + d_logit_fake))
def g_hinge(d_logit_fake):
    return -torch.mean(d_logit_fake)

In [23]:
cond_loss = Data2DataCrossEntropyLoss(n_classes,0.5,0.98,device)
cond_lambda = 1 

In [24]:
#suffix = "PretrainedReACGAN18"
suffix = "FinetunedReACGAN12"
g_dict = torch.load('checkpoints/generator'+suffix)
generator.load_state_dict(g_dict['model_state_dict'])
d_dict = torch.load('checkpoints/discriminator'+suffix)
discriminator.load_state_dict(d_dict['model_state_dict'])
generator_optimizer.load_state_dict(g_dict['optimizer_state_dict'])
discriminator_optimizer.load_state_dict(d_dict['optimizer_state_dict'])

In [25]:
d_dict

{'epoch': 12,
 'model_state_dict': OrderedDict([('blocks.0.0.conv2d0.bias',
               tensor([-0.0594, -0.0433,  0.0100,  0.0202,  0.0045, -0.0182, -0.0379,  0.0060,
                        0.0098, -0.0048,  0.0568,  0.0145, -0.0211,  0.0132, -0.0091, -0.0296,
                        0.0225,  0.0321, -0.0035,  0.0198, -0.0166,  0.0004, -0.0726,  0.0312,
                       -0.0149, -0.0008,  0.0499,  0.0202, -0.1389, -0.0454, -0.0124, -0.0187,
                       -0.0163,  0.0092,  0.0194, -0.1027, -0.0028, -0.0067, -0.0007, -0.0024,
                       -0.0043,  0.0248,  0.0383, -0.0490, -0.0116, -0.0070, -0.0336, -0.0025,
                       -0.0116,  0.0225, -0.0254, -0.0042, -0.0628,  0.0224, -0.0938,  0.0085,
                       -0.0889,  0.0580,  0.0030, -0.0073,  0.0028,  0.0006,  0.0052, -0.0891,
                        0.0059, -0.0175, -0.0391, -0.0059,  0.0124, -0.0402,  0.0081, -0.0214,
                       -0.0044, -0.0042,  0.0071,  0.0052, -0.0129,  

In [26]:
torch.backends.cudnn.benchmark = True

In [28]:
fixed_latent = torch.randn(100,100).to(device)
import torchvision
from tqdm.notebook import tqdm
discriminator.train()
generator.train()
for epoch in range(1,13):
    for real_images, real_classes in tqdm(dataloaders['train']):
        real_images = real_images.to(device)
        real_classes = real_classes.to(device)
        #TRAIN DISCRIMINATOR
        for k in range(2):
            discriminator_optimizer.zero_grad()
            #use discriminator on real images
            real_dict = discriminator(real_images,real_classes)
            #use discriminator on fake images
            with torch.no_grad():
                random_classes = torch.tensor(described_species_labels[np.random.randint(0, len(described_species_labels), batch_size)],device=device)
                t = generator(torch.randn(batch_size,100).to(device),random_classes,eval = True)
            fake_dict = discriminator(t,random_classes)
            #Compute the two losses
            dis_acml_loss = d_hinge(real_dict["adv_output"], fake_dict["adv_output"])
            real_cond_loss = cond_loss(**real_dict)
            dis_acml_loss += cond_lambda * real_cond_loss
            dis_acml_loss.backward()
            discriminator_optimizer.step()


        
        #TRAIN GENERATOR
        generator_optimizer.zero_grad()
        random_classes = torch.tensor(described_species_labels[np.random.randint(0, len(described_species_labels), batch_size)],device=device)
        t = generator(torch.randn(batch_size,100).to(device),random_classes,eval = True)
        fake_dict = discriminator(t,random_classes)
        gen_acml_loss = g_hinge(fake_dict["adv_output"])
        fake_cond_loss = cond_loss(**fake_dict)
        gen_acml_loss += cond_lambda * fake_cond_loss
        gen_acml_loss.backward()
        generator_optimizer.step()
    
    print(f"disc loss={dis_acml_loss.item()}",end=',')
    print(f"gen loss={gen_acml_loss.item()}")
    with torch.no_grad():
        t = generator(fixed_latent,torch.tensor(np.arange(100)).to(device),eval = True)
    t = dataset_utils.denorm(t)
    p = torchvision.transforms.functional.to_pil_image(torchvision.utils.make_grid(t))
    p.save(f"generated/finetune18ex{epoch}.jpg")
        #torch.cuda.empty_cache()
        #loss_d, real_score, fake_score, class_accuracy_real, class_accuracy_fake

  0%|          | 0/1250 [00:00<?, ?it/s]

disc loss=2.15625,gen loss=5.328125


  0%|          | 0/1250 [00:00<?, ?it/s]

disc loss=2.419921875,gen loss=5.37109375


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.56 GiB. GPU 0 has a total capacity of 11.71 GiB of which 1.11 GiB is free. Including non-PyTorch memory, this process has 9.72 GiB memory in use. Of the allocated memory 7.47 GiB is allocated by PyTorch, and 2.01 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [28]:
from extract_features import extract_image_features
(tf,tl),(vf,vl),(_,_) = extract_image_features(discriminator,device,save_to_disk=True,)#save_name_prefix="final_eval")

In [None]:
raise Exception
torch.save({
            'epoch':21,
            'model_state_dict': generator.state_dict(),
            'optimizer_state_dict': generator_optimizer.state_dict(),
            }, "checkpoints/generatorFinetunedReACGAN21")
torch.save({
            'epoch': 21,
            'model_state_dict': discriminator.state_dict(),
            'optimizer_state_dict': discriminator_optimizer.state_dict(),
            }, "checkpoints/discriminatorFinetunedReACGAN21")