In [None]:
%load_ext autoreload 
%autoreload 2 

import torch
from torch.utils.data import DataLoader

import numpy as np
from numpy import linalg as LA
# Sciki-learn
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
# Visualisation
import matplotlib.pyplot as plt
import os 

from train import Trainer
from generator import *
from discriminator import GAN
from dataset import CocoStuffDataSet
from utils import *

NUM_CLASSES = 11
SAVE_DIR = "../checkpoints" # Assuming this is launched from code/ subfolder.
# experiment_name = 'gan_animal'
# experiment_name = 'animal-batchnorm-50-nobnend'
experiment_name = 'gan_low_reg'
use_bn = True
experiment_dir = os.path.join(SAVE_DIR, experiment_name)
batch_size = 64

%matplotlib inline

In [None]:
HEIGHT, WIDTH = 128, 128
val_loader = DataLoader(CocoStuffDataSet(mode='val', supercategories=['animal'], height=HEIGHT, width=WIDTH),
                            batch_size, shuffle=False)
train_loader = DataLoader(CocoStuffDataSet(mode='train', supercategories=['animal'], height=HEIGHT, width=WIDTH),
                            batch_size, shuffle=False)

In [None]:
generator = SegNet16(NUM_CLASSES, use_bn=use_bn)
image_shape = (3, HEIGHT, WIDTH)
segmentation_shape = (NUM_CLASSES, HEIGHT, WIDTH)
discriminator = GAN(NUM_CLASSES, segmentation_shape, image_shape)
# discriminator = None
trainer = Trainer(generator, discriminator, train_loader, val_loader, \
                experiment_dir=experiment_dir, resume=True, load_iter=None, train_gan=True)


In [None]:
def retrieve_features(trainer, loader, number, mode='gen'):
    '''
    Retrieves features for at least number images from the loader generator/discriminator
    Returns 
    features ND-array B x feature_size
    dominant_classes ND-array size B containing index of dominant class in image
    '''
    total = 0
    to_return = None
    dominant_classes = []
    for data, mask_gt, gt_visual in loader:
        if number is None or total < number:      
            data = data.cuda()
            batch_size = data.size()[0]
            total += batch_size
            if mode == 'gen':
                features = trainer._gen.get_feature_embedding(data).detach().cpu().numpy() # B x 512 x 4 x 4
            else:
                features = trainer._disc.get_feature_embedding(data).detach().cpu().numpy() # B x 512 x H x W
            features = np.reshape(features, (batch_size, -1))
            classes = dominant_class(gt_visual, loader.dataset.numClasses)
            if to_return is None:
                to_return = features
                dominant_classes = classes
            else:
                to_return = np.concatenate([to_return, features], axis=0)
                dominant_classes = np.concatenate([dominant_classes, classes], axis=0)
    return to_return, dominant_classes

In [None]:
def get_embedding_features(trainer, loader, PCA_value):
    features, classes = retrieve_features(trainer, loader, None, mode='disc')
    print ("Retrieved features")
    # Standardize features
    scaler = StandardScaler()
    scaler.fit(features)
    std_features = scaler.transform(features)

    # Apply PCA to each of the features
    pca = PCA(n_components=PCA_value)
    pca.fit(std_features)
    transformed_features = pca.transform(std_features)

    print ("Applied PCA")
    # Apply t-SNE to the transformed features for visualisation
    embedded_features = TSNE(n_components=2).fit_transform(transformed_features)
    return embedded_features, classes


In [None]:
def visualize_data(embedded_features, classes):
    # Visualise data
    fig = plt.figure(1, figsize=(10, 10))
    plt.clf()
    L = len(set(classes))

    # Generate L random colors
    colors = [(
            np.random.randint(0,255) / 255, 
            np.random.randint(0,255) / 255, 
            np.random.randint(0,255) / 255) for i in range(L)]

    animal_cat_names = get_category_name_array(val_loader)
    for i in range(val_loader.dataset.numClasses - 1):
        batch = embedded_features[classes==i]
        plt.scatter(batch[:,0], batch[:,1], label=animal_cat_names[i])
    plt.legend(loc='lower left', numpoints=1, ncol=1, fontsize=12)
    plt.show()

In [None]:
embedded_features, classes = get_embedding_features(trainer, val_loader, 150)


In [None]:
visualize_data(embedded_features, classes)