# Import

In [1]:
import os
import functools
import math
import numpy as np
from tqdm import tqdm, trange


import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter as P
import torchvision

# Import stuff
import utils
import losses
import layers as layer
#import train_fns
import train_fns
from sync_batchnorm import patch_replication_callback

import os
import io

# Get config

In [2]:
weights_name   = 'weights'
logs_name      = 'logs'
samples_name   = ''

model_name  = 'BigGAN_MPCC'

model_path  = '%s/%s'%(weights_name, model_name)
logs_path   = '%s/%s'%(logs_name,    model_name)


config_path = '%s/metalog.txt'%logs_path

device = 'cuda'

file = open(config_path, 'r')
all_file = file.read()
fs1 = all_file.find('{')
fs2 = all_file.find('}')
config = all_file[fs1:fs2+1]
import ast
config = config.replace(", 'G_activation': ReLU()" , "")
config = config.replace(", 'D_activation': ReLU()" , "")
config = ast.literal_eval(config)

config['samples_root'] = 'samples_test'
config['weights_root'] = weights_name
config['concat']       = True
config['model']        = 'BigGAN_MPCC'

# Loading Models

In [3]:
model = __import__(config['model'])
utils.seed_rng(config['seed'])
# Prepare root folders if necessary
utils.prepare_root(config)

G = model.Generator(**config).to(device)
D = model.Discriminator(**config).to(device)
if config['is_encoder']:
    E = model.Encoder(**{**config, 'D': D}).to(device)
Prior  = layer.Prior(**config).to(device)   
GE = model.G_E(G,E,Prior)

utils.load_weights(G, None, '',
                config['weights_root'], model_name, 
                config['load_weights'] if config['load_weights'] else None,
                G if config['ema'] else None, 
                E = None if not config['is_encoder'] else E,
                Prior = Prior if not config['prior_type'] == 'default' else None)

Param count for Gs initialized parameters: 9360131
Param count for Ds initialized parameters: 9456769
16
2
Param count for Ds initialized parameters: 17615872
Loading weights from weights/BigGAN_MPCC...


# Sample

In [4]:
accumulate_stats = True
if accumulate_stats:
    utils.accumulate_standing_stats(G,
                           Prior, config['n_classes'],
                           config['num_standing_accumulations'])
sample = functools.partial(utils.sample, G=G, Prior = Prior, config=config)

### Sample sheet

In [5]:
utils.sample_sheet(G, Prior,
                    classes_per_sheet = 10,
                     #classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
                     num_classes=config['n_classes'],
                     samples_per_class=7, parallel=config['parallel'],
                     samples_root=config['samples_root'],
                     experiment_name=model_name,
                     folder_number= 4  , transpose = True, num_rep = 1)

# Reconstruction

In [6]:
D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations'])
config_aux = config.copy()
config_aux['augment'] = False
dataloader_noaug = utils.get_data_loaders(**{**config_aux, 'batch_size': D_batch_size})

Using dataset root location data/cifar
Data will not be augmented...
Files already downloaded and verified
(50000, 3072)
Files already downloaded and verified


### Obtain reconstruction sheet

In [7]:
utils.reconstruction_sheet(GE,
                         classes_per_sheet = 6,
                        #classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
                         num_classes = config['n_classes'], 
                         #samples_per_class = 10, 
                         samples_per_class = 4, 
                         parallel = config['parallel'],
                         samples_root= config['samples_root'],
                         experiment_name = model_name,
                         folder_number = 4, dataloader= dataloader_noaug, device = device,
                                   D_fp16 = config['D_fp16'], config = config)

# Testing accuracy

In [8]:
config['is_not_rec'] = False
test_acc, _,  error_rec = train_fns.test_accuracy(GE, dataloader_noaug, device, config['D_fp16'], config)
print("Clustering accuracy ", test_acc)

Clustering accuracy  0.68755


# Testing accuracy "as usual"

In [9]:
def obtain_cluster_transformation(Y_pred, Y):
    D = int(np.max((np.max(Y_pred), np.max(Y)))+1)
    w = np.zeros((D,D))
    for i in range(len(Y_pred)):
        w[int(Y_pred[i]), int(Y[i])] += 1
    #print(w)
    return w.argmax(1)

total_y, total_y_pred, mse_norm = train_fns.test_accuracy(GE, dataloader_noaug, device, config['D_fp16'], config, obtain_y = True)
transf = obtain_cluster_transformation(total_y_pred, total_y)
print("Clustering accuracy ", np.mean(transf[total_y_pred] == total_y))

Clustering accuracy  0.6875666666666667
