In [1]:
import sys
sys.path.append("../")

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import pandas as pd
from data_sampler import data_sampler
from data_sampler import SequentialSampler, BatchSampler
import networks as net
import numpy as np

In [3]:
from inception_model import InceptionSham

In [4]:
input_nc = 1
output_nc = 1
discr_filters = 8
max_power = 8
n_layers = 2
norm_lay = nn.BatchNorm2d
start_size = 28
gen_filters = 8
dropout = 0.5
n_blocks = 2
batch_size = 5
alpha = 10

In [5]:
gener_a = net.ResnetGenerator(
    input_nc = input_nc,
    output_nc = output_nc,
    gen_filters = gen_filters,
    norm_lay = norm_lay,
    dropout = dropout,
    n_blocks = n_blocks
)

gener_b = net.ResnetGenerator(
    input_nc = input_nc,
    output_nc = output_nc,
    gen_filters = gen_filters,
    norm_lay = norm_lay,
    dropout = dropout,
    n_blocks = n_blocks
)

gener_a_state = torch.load('gener_a_tmp.pth', map_location='cpu')
gener_a.load_state_dict(gener_a_state)

gener_b_state = torch.load('gener_b_tmp.pth', map_location='cpu')
gener_b.load_state_dict(gener_b_state)

In [6]:
inception_model = InceptionSham(num_classes=10, input_nc=1, dropout=0.5)
inception_state = torch.load("../../../inception_sham_state.pth", map_location="cpu")
inception_model.load_state_dict(inception_state)

In [7]:
test_pull = pd.read_csv("../../../data/fashion_mnisit/test_pull.csv", header=None).values
test_top = pd.read_csv("../../../data/fashion_mnisit/test_top.csv", header=None).values

In [64]:
kek = inception_model(gener_a(data_sampler(1, test_pull, test_top)[0].view(-1, 1, 28, 28)))

In [85]:
kek = F.softmax(kek, dim=1)

In [86]:
lel = F.softmax(lel, dim=1)

In [87]:
F.kl_div(Variable(kek.data), Variable(lel.data)).data[0]

-0.13242989778518677

In [97]:
def inception_score(gener_a, gener_b, inception, data_a, data_b, batch_size=10, splits=10):
    n_images = min(data_a.shape[0], data_b.shape[0])
    
    sampler = SequentialSampler(n_images)
    batch_sampler = BatchSampler(sampler, batch_size)
    
    pred_a = None
    pred_b = None
    
    for i in batch_sampler:
        batch_a, batch_b = data_sampler(5, data_a[i], data_b[i])
    
        batch_a = batch_a.view(-1, 1, 28, 28).float()
        batch_b = batch_b.view(-1, 1, 28, 28).float()
        
        if pred_a is None:
            pred_a = F.softmax(inception(gener_a(batch_a)), dim=1)
            pred_b = F.softmax(inception(gener_b(batch_b)), dim=1)
        else:
            pred_a = torch.cat((pred_a,
                               F.softmax(inception(gener_a(batch_a)), dim=1)),
                               dim=0)
            pred_b = torch.cat((pred_b,
                   F.softmax(inception(gener_b(batch_b)), dim=1)),
                   dim=0)

    obs_kl_a = []
    obs_kl_b = []
    
    for k in range(splits):
        part_a = pred_a[k * (n_images // splits): (k+1) * (n_images // splits), :]
        part_b = pred_b[k * (n_images // splits): (k+1) * (n_images // splits), :]
        
        py_a = torch.mean(pred_a, dim=0)
        py_b = torch.mean(pred_b, dim=0)
        
        scores_a = []
        scores_b = []
        for row in range(part_a.shape[0]):
            pyx_a = part_a[row]
            pyx_b = part_b[row]
            
            scores_a.append(F.kl_div(Variable(pyx_a.data), Variable(py_a.data)).data[0])
            scores_b.append(F.kl_div(Variable(pyx_b.data), Variable(py_b.data)).data[0])
            
        obs_kl_a.append(np.exp(np.mean(scores_a)))
        obs_kl_b.append(np.exp(np.mean(scores_b)))

    return np.mean(obs_kl_a), np.mean(obs_kl_b)

In [98]:
inception_score(gener_a, gener_b, inception_model, test_pull[:10], test_top[:10], batch_size=5, splits=5)

(0.79197831842393329, 0.80116869130256951)