In [11]:
import torch
import numpy as np
from skimage.io import imshow
import matplotlib.pyplot as plt
import matplotlib

from interpolation import two_point_interpolation, analogies,gaussian_interpolation,\
    n_point_interpolation,vicinity_sampling
%matplotlib inline

In [12]:
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
)

In [13]:
ngpu = 1
nz = 100
ngf = 64
ndf = 64
nc = 3

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [15]:
netG = Generator(ngpu=ngpu).to(device)

In [16]:
netG.load_state_dict(torch.load('../bmml-ot/DCGAN_icon_meta/2/netG_epoch_24.pth'))

In [17]:
netG.eval()

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

In [18]:
from inception_score import inception_score

In [27]:
from IPython.core.debugger import set_trace

In [36]:
def IS_2point_interpolation_ot(N = 1000):
    t = 0.5
    src_set  = [torch.randn(1, 100, 1, 1, device=device) for _ in range(N)]
    dst_set  = [torch.randn(1, 100, 1, 1, device=device) for _ in range(N)]
    pair_set = [two_point_interpolation(src, dst, t=t) for src,dst in zip(src_set, dst_set)]
    output_pair = [netG.main(item) for item in pair_set]
    return inception_score(torch.cat(output_pair), cuda=True, batch_size=2, resize=True, splits=1)
    
def IS_2point_interpolation_linear(N = 1000):
    t = 0.5
    src_set  = [torch.randn(1, 100, 1, 1, device=device) for _ in range(N)]
    dst_set  = [torch.randn(1, 100, 1, 1, device=device) for _ in range(N)]
    pair_set = [two_point_interpolation(src, dst, t=t, do_scale=False) for src,dst in zip(src_set, dst_set)]
    output_pair = [netG.main(item) for item in pair_set]
    return inception_score(torch.cat(output_pair), cuda=True, batch_size=2, resize=True, splits=1)
    
def IS_4point_interpolation_ot(N = 1000):
    t = torch.tensor([1/4,1/4,1/4,1/4], device=device)
    src_set_1  = [torch.randn(100, device=device) for _ in range(N)]
    src_set_2  = [torch.randn(100, device=device) for _ in range(N)]
    src_set_3  = [torch.randn(100, device=device) for _ in range(N)]
    src_set_4  = [torch.randn(100, device=device) for _ in range(N)]
    interpolation_set = [n_point_interpolation(torch.stack([src_1, src_2, src_3, src_4]), t, do_scale=True) 
                         for src_1,src_2,src_3,src_4 in zip(src_set_1, src_set_2, src_set_3, src_set_4)]
    
    output_pair = [netG.main(item) for item in interpolation_set]
    return inception_score(torch.cat(output_pair), cuda=True, batch_size=1, resize=True, splits=1)
    
def IS_4point_interpolation_linear(N = 1000):
    t = torch.tensor([1/4,1/4,1/4,1/4],device=device)
    src_set_1  = [torch.randn(100, device=device) for _ in range(N)]
    src_set_2  = [torch.randn(100, device=device) for _ in range(N)]
    src_set_3  = [torch.randn(100, device=device) for _ in range(N)]
    src_set_4  = [torch.randn(100, device=device) for _ in range(N)]
    
    interpolation_set = [n_point_interpolation(torch.stack([src_1, src_2, src_3, src_4]), t, do_scale=False) 
                         for src_1,src_2,src_3,src_4 in zip(src_set_1, src_set_2, src_set_3, src_set_4)]
    output_pair = [netG.main(item) for item in interpolation_set]
    return inception_score(torch.cat(output_pair), cuda=True, batch_size=1, resize=True, splits=1)

In [22]:
print ("2 point interpolation (matched) : ", IS_2point_interpolation_ot(5))
print ("2 point interpolation (linear ) : ", IS_2point_interpolation_linear(5))

  "See the documentation of nn.Upsample for details.".format(mode))
  return F.softmax(x).data.cpu().numpy()


2 point interpolation (matched) :  (1.9473327521067487, 0.0)
2 point interpolation (linear ) :  (1.8282965813802927, 0.0)


In [37]:
print ("4 point interpolation (matched) : ", IS_4point_interpolation_ot(5))
print ("4 point interpolation (linear ) : ", IS_4point_interpolation_linear(5))

  "See the documentation of nn.Upsample for details.".format(mode))
  return F.softmax(x).data.cpu().numpy()


4 point interpolation (matched) :  (1.6925921782796927, 0.0)
4 point interpolation (linear ) :  (1.758471531175132, 0.0)
