# VAE: MNIST

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
import torchvision.datasets as dst
from torchvision.utils import save_image
import torchvision.datasets as sets
class VAE(nn.Module):
    def __init__(self,img_dim=28*28, z_dim = 100):
        super(VAE, self).__init__()
        self.img_dim = img_dim
        self.z_dim = z_dim


        self.encoder1 = nn.Linear(self.img_dim, 280)
        self.encoder2 = nn.Linear(280,450)
        self.encoder3 = nn.Linear(450,650)
        self.encoder4 = nn.Linear(650,self.z_dim*2)
        
        self.decoder1 = nn.Linear(self.z_dim, 650)
        self.decoder2 = nn.Linear(650,450)
        self.decoder3 = nn.Linear(450,280)
        self.decoder4 = nn.Linear(280,self.img_dim)

    def reparameterize(self, mu, logvar):
        eps = Variable(torch.randn(mu.size(0), mu.size(1))).cuda()
        z = mu + eps * torch.exp(logvar/2)       
        return z

    def forward(self, x):

        x = F.relu(self.encoder1(x),0.2)
        x = F.relu(self.encoder2(x),0.2)
        x = F.relu(self.encoder3(x),0.2)
        x = self.encoder4(x).view(-1, 2, self.z_dim)
        mu = x[:, 0, :]
        logvar = x[:, 1, :]
        z = self.reparameterize(mu, logvar)        
        x_hat = F.leaky_relu(self.decoder1(z),0.2) 
        x_hat = F.leaky_relu(self.decoder2(x_hat),0.2)
        x_hat = F.leaky_relu(self.decoder3(x_hat),0.2)
        x_hat = torch.sigmoid(self.decoder4(x_hat))
        
        return x_hat, mu, logvar


def loss_func(x_hat, x, mu, logvar):
    BCE = F.binary_cross_entropy(x_hat, x,  reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE+KLD

def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.cuda()
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        loss = loss_func(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.cuda()
            data = data.view(data.size(0), -1)
            recon_batch, mu, logvar = vae(data)
            test_loss += loss_func(recon_batch, data, mu, logvar).item()
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

def run(Epoch):
    for epoch in range(Epoch):
        train(epoch)
        test()

In [None]:
EPOCH = 15
BATCH_SIZE = 100
n = 2   # num_workers
LATENT_CODE_NUM = 32   
log_interval = 10
transform = transforms.Compose([transforms.ToTensor()])

In [None]:
# Load and transform data
train_set = sets.MNIST('MNIST_data/', train=True, download=True, transform=transform)
test_set = sets.MNIST('MNIST_data/', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_set, num_workers=n,batch_size=BATCH_SIZE, drop_last=True, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, num_workers=n,batch_size=BATCH_SIZE, drop_last=True, shuffle=True)
vae = VAE().cuda()
optimizer =  optim.Adam(vae.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
run(10)

====> Epoch: 0 Average loss: 183.5557
====> Test set loss: 153.1385
====> Epoch: 1 Average loss: 140.3206
====> Test set loss: 132.6157
====> Epoch: 2 Average loss: 127.9984
====> Test set loss: 123.4396
====> Epoch: 3 Average loss: 122.4233
====> Test set loss: 120.2433
====> Epoch: 4 Average loss: 119.6780
====> Test set loss: 117.4957
====> Epoch: 5 Average loss: 117.0864
====> Test set loss: 116.0838
====> Epoch: 6 Average loss: 115.3432
====> Test set loss: 114.6357
====> Epoch: 7 Average loss: 113.9937
====> Test set loss: 112.8071
====> Epoch: 8 Average loss: 112.8924
====> Test set loss: 112.1502
====> Epoch: 9 Average loss: 111.9243
====> Test set loss: 111.4149


In [None]:
with torch.no_grad():
    z = torch.randn(100, 100).cuda()
    x_hat = F.leaky_relu(vae.decoder1(z),0.2) 
    x_hat = F.leaky_relu(vae.decoder2(x_hat),0.2)
    x_hat = F.leaky_relu(vae.decoder3(x_hat),0.2)
    sample = torch.sigmoid(vae.decoder4(x_hat))
    save_image(sample.view(100, 1, 28, 28), './sample_1' + '.png')

In [None]:
!pip install pytorch_model_summary
from pytorch_model_summary import summary
print(summary(VAE().cuda(), torch.zeros((1, 1, 784)).cuda(), show_input=False))

Collecting pytorch_model_summary
  Downloading https://files.pythonhosted.org/packages/fe/45/01d67be55fe3683a9221ac956ba46d1ca32da7bf96029b8d1c7667b8a55c/pytorch_model_summary-0.1.2-py3-none-any.whl
Installing collected packages: pytorch-model-summary
Successfully installed pytorch-model-summary-0.1.2
-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
          Linear-1         [1, 1, 280]         219,800         219,800
          Linear-2         [1, 1, 450]         126,450         126,450
          Linear-3         [1, 1, 650]         293,150         293,150
          Linear-4         [1, 1, 200]         130,200         130,200
          Linear-5            [1, 650]          65,650          65,650
          Linear-6            [1, 450]         292,950         292,950
          Linear-7            [1, 280]         126,280         126,280
          Linear-8            [1, 784]         220,304   

### Linearly interpolate

In [None]:
def image_gen_vae(x, position = "right"):
  x_hat = F.leaky_relu(vae.decoder1(x),0.2) 
  x_hat = F.leaky_relu(vae.decoder2(x_hat),0.2)
  x_hat = F.leaky_relu(vae.decoder3(x_hat),0.2)
  sample = F.sigmoid(vae.decoder4(x_hat))
  save_image(sample.view(1, 1, 28, 28), position + '.png')

In [None]:
# random pick point p1,p2
p1 = torch.randn(1, 100).cuda() 
p2 = torch.randn(1, 100).cuda()
# make 10 points between
step = (p2-p1)/11
p = p1
for i in range(10):
  p = p+step
  image_gen_vae(p,position= 'vae_MNIST{} '.format(i)) #decode then save picture



# GAN! MNIST


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Generator(nn.Module):
    def __init__(self, g_input_dim=100, g_output_dim=784):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim=784):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))


def D_train(x,batch_size=100, z_dim=100, image_dim=784):
    #=======================Train the discriminator=======================#
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, image_dim), torch.ones(batch_size, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on facke
    z = Variable(torch.randn(batch_size, z_dim).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(batch_size, 1).to(device))

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()

    return  D_loss.data.item()

def G_train(x, batch_size=100, z_dim=100):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = Variable(torch.randn(batch_size, z_dim).to(device))
    y = Variable(torch.ones(batch_size, 1).to(device))

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()

    return G_loss.data.item()

def run(Epoch=200):
    for epoch in range(Epoch):           
        D_losses, G_losses = [], []
        for batch_idx, (x, _) in enumerate(train_loader):

            D_losses.append(D_train(x))
            G_losses.append(G_train(x))
            

        print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
                (epoch+1), Epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
        
        
        # with torch.no_grad():
            # x_real = x.view(-1, image_dim).to(device)
            # test_z = Variable(torch.randn(100, 100).to(device))
            # generated = G(test_z)


In [None]:
G = Generator().to(device)
D = Discriminator().to(device)
criterion = nn.BCELoss() 
# optimizer
lr = 0.0002
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

# MNIST Dataset

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

run()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw
Processing...
Done!
[1/200]: loss_d: 1.069, loss_g: 2.795
[2/200]: loss_d: 1.007, loss_g: 1.765
[3/200]: loss_d: 0.850, loss_g: 2.061
[4/200]: loss_d: 0.737, loss_g: 2.309
[5/200]: loss_d: 0.594, loss_g: 2.617
[6/200]: loss_d: 0.581, loss_g: 2.654
[7/200]: loss_d: 0.659, loss_g: 2.339
[8/200]: loss_d: 0.756, loss_g: 2.045
[9/200]: loss_d: 0.656, loss_g: 2.156
[10/200]: loss_d: 0.676, loss_g: 2.220
[11/200]: loss_d: 0.738, loss_g: 1.950
[12/200]: loss_d: 0.748, loss_g: 2.053
[13/200]: loss_d: 0.674, loss_g: 2.329
[14/200]: loss_d: 0.701, loss_g: 2.221
[15/200]: loss_d: 0.706, loss_g: 2.194
[16/200]: loss_d: 0.750, loss_g: 2.037
[17/200]: loss_d: 0.835, loss_g: 1.839
[18/200]: loss_d: 0.879, loss_g: 1.684
[19/200]: loss_d: 0.908, loss_g: 1.622
[20/200]: loss_d: 0.910, loss_g: 1.615
[21/200]: loss_d: 0.943, loss_g: 1.537
[22/200]: loss_d: 0.932, loss_g: 1.557
[23/200]: loss_d: 0.911, loss_g: 1.597
[24/20

In [None]:
import matplotlib.pyplot as plt



In [None]:
from pytorch_model_summary import summary
print(summary(Generator().cuda(), torch.zeros((1, 1, 100)).cuda(), show_input=False))
print(summary(Discriminator().cuda(), torch.zeros((1, 1, 784)).cuda(), show_input=False))

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
          Linear-1         [1, 1, 256]          25,856          25,856
          Linear-2         [1, 1, 512]         131,584         131,584
          Linear-3        [1, 1, 1024]         525,312         525,312
          Linear-4         [1, 1, 784]         803,600         803,600
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
-----------------------------------------------------------------------
-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
          Linear-1        [1, 1, 1024]         803,840         803,840
          Linear-2         [1, 1, 512]         524,800         524,800
          Linear-3         [1, 1, 256]         131,328         131,328
          Linear-4           [1, 1, 1]             257             25

In [None]:
with torch.no_grad():
    test_z = Variable(torch.randn(100, 100).to(device))
    generated = G(test_z)
    save_image(generated.view(generated.size(0), 1, 28, 28), './sample_2' + '.png')

### Linearly interpolate

In [None]:
def image_gen_gans(x, position = "right"):
  test_z = Variable(x.to(device))
  generated = G(test_z)
  save_image(generated.view(1, 1, 28, 28), position + '.png')

In [None]:
# random pick point p1,p2
p1 = torch.randn(1, 100).cuda() 
p2 = torch.randn(1, 100).cuda()
# make 10 points between
step = (p2-p1)/11
p = p1
for i in range(10):
  p = p+step
  image_gen_gans(p,position= 'Gans_MNIST{} '.format(i)) #decode then save picture

In [None]:
for batch_idx, (data, _) in enumerate(train_loader):
#for batch_idx, (data, _) in enumerate(train_imagenette):
  print(data.shape)
  recon, mu, log_var = vae(data.cuda())
  data = transforms.Resize(96)(data)
  data = to_3(data.cuda())
  img = transforms.Resize(96)(recon)
  img = to_3(img)


  activation_f = inception_network(img).cpu().data.numpy()
  activation_r = inception_network(data).cpu().data.numpy()
        
  ##get mean and sigma
  mu_f = np.mean(activation_f, axis=0, keepdims = True)
  print(mu_f.shape)
  sigma_f = np.cov(activation_f, rowvar=False)

  mu_r = np.mean(activation_r, axis=0, keepdims = True)
  sigma_r = np.cov(activation_r, rowvar=False)

  ssdiff = np.sum((mu_f - mu_r)**2.0)
  covmean = sqrtm(sigma_f.dot(sigma_r))
# check and correct imaginary numbers from sqrt
  if iscomplexobj(covmean):
    covmean = covmean.real
	  # calculate score
  fid = ssdiff + trace(sigma_f + sigma_r - 2.0 * covmean)
  print(fid)

torch.Size([100, 1, 28, 28])


RuntimeError: ignored

# Adjusted inception v3

In [None]:
from collections import namedtuple
import warnings
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
from typing import Callable, Any, Optional, Tuple, List


__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']


model_urls = {
    # Inception v3 ported from TensorFlow
    'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}

InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs


def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True
        kwargs['init_weights'] = False  # we are loading weights from a pretrained model
        model = Inception3(**kwargs)
        state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
                                              progress=progress)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            model.AuxLogits = None
        return model

    return Inception3(**kwargs)


class Inception3(nn.Module):

    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
        init_weights: Optional[bool] = None
    ) -> None:
        super(Inception3, self).__init__()
        if inception_blocks is None:
            inception_blocks = [
                BasicConv2d, InceptionA, InceptionB, InceptionC,
                InceptionD, InceptionE, InceptionAux
            ]
        if init_weights is None:
            warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
                          'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
                          ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
            init_weights = True
        assert len(inception_blocks) == 7
        conv_block = inception_blocks[0]
        inception_a = inception_blocks[1]
        inception_b = inception_blocks[2]
        inception_c = inception_blocks[3]
        inception_d = inception_blocks[4]
        inception_e = inception_blocks[5]
        inception_aux = inception_blocks[6]

        self.aux_logits = aux_logits
        self.transform_input = transform_input
        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)
        self.Mixed_6a = inception_b(288)
        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)
        self.AuxLogits: Optional[nn.Module] = None
        if aux_logits:
            self.AuxLogits = inception_aux(768, num_classes)
        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(2048, num_classes)
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                    import scipy.stats as stats
                    stddev = m.stddev if hasattr(m, 'stddev') else 0.1
                    X = stats.truncnorm(-2, 2, scale=stddev)
                    values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
                    values = values.view(m.weight.size())
                    with torch.no_grad():
                        m.weight.copy_(values)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def _transform_input(self, x: Tensor) -> Tensor:
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        return x

    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
        # N x 3 x 299 x 299
        x = self.Conv2d_1a_3x3(x)
        # N x 32 x 149 x 149
        x = self.Conv2d_2a_3x3(x)
        # N x 32 x 147 x 147
        x = self.Conv2d_2b_3x3(x)
        # N x 64 x 147 x 147
        x = self.maxpool1(x)
        # N x 64 x 73 x 73
        x = self.Conv2d_3b_1x1(x)
        # N x 80 x 73 x 73
        x = self.Conv2d_4a_3x3(x)
        # N x 192 x 71 x 71
        x = self.maxpool2(x)
        # N x 192 x 35 x 35
        x = self.Mixed_5b(x)
        # N x 256 x 35 x 35
        x = self.Mixed_5c(x)
        # N x 288 x 35 x 35
        x = self.Mixed_5d(x)
        # N x 288 x 35 x 35
        x = self.Mixed_6a(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6b(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6c(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6d(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6e(x)
        # N x 768 x 17 x 17
        aux: Optional[Tensor] = None
        if self.AuxLogits is not None:
            if self.training:
                aux = self.AuxLogits(x)
        # N x 768 x 17 x 17
        x = self.Mixed_7a(x)
        # N x 1280 x 8 x 8
        x = self.Mixed_7b(x)
        # N x 2048 x 8 x 8
        x = self.Mixed_7c(x)
        # N x 2048 x 8 x 8
        # Adaptive average pooling
        x = self.avgpool(x)
        # N x 2048 x 1 x 1
        x = self.dropout(x)
        # N x 2048 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 2048
        return x, aux

    @torch.jit.unused
    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
        if self.training and self.aux_logits:
            return InceptionOutputs(x, aux)
        else:
            return x  # type: ignore[return-value]

    def forward(self, x: Tensor) -> InceptionOutputs:
        x = self._transform_input(x)
        x, aux = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)


class InceptionA(nn.Module):

    def __init__(
        self,
        in_channels: int,
        pool_features: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(InceptionA, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)

        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)

        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)

        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):

    def __init__(
        self,
        in_channels: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(InceptionB, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)

        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):

    def __init__(
        self,
        in_channels: int,
        channels_7x7: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(InceptionC, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)

        c7 = channels_7x7
        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))

        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))

        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):

    def __init__(
        self,
        in_channels: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(InceptionD, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)

        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):

    def __init__(
        self,
        in_channels: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(InceptionE, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)

        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(InceptionAux, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
        self.conv1 = conv_block(128, 768, kernel_size=5)
        self.conv1.stddev = 0.01  # type: ignore[assignment]
        self.fc = nn.Linear(768, num_classes)
        self.fc.stddev = 0.001  # type: ignore[assignment]

    def forward(self, x: Tensor) -> Tensor:
        # N x 768 x 17 x 17
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
        # N x 768 x 5 x 5
        x = self.conv0(x)
        # N x 128 x 5 x 5
        x = self.conv1(x)
        # N x 768 x 1 x 1
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
        # N x 768 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 768
        x = self.fc(x)
        return x


class BasicConv2d(nn.Module):

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        **kwargs: Any
    ) -> None:
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

# FID test::MNIST

In [None]:
BATCH_SIZE = 100
n = 2   # num_workers
data_train = sets.MNIST('MNIST_data/', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=data_train, num_workers=n,batch_size=BATCH_SIZE, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz





HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw
Processing...
Done!


In [None]:
# example of calculating the frechet inception distance
import numpy as np
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy.random import random
from scipy.linalg import sqrtm




In [None]:
def FID_test (real,fake):
  activation_f = inception_network(fake).cpu().data.numpy()
  activation_r = inception_network(real).cpu().data.numpy()
        
  ##get mean and sigma
  mu_f = np.mean(activation_f, axis=1, keepdims = True)
  sigma_f = np.cov(activation_f, rowvar=True)

  mu_r = np.mean(activation_r, axis=1, keepdims = True)
  sigma_r = np.cov(activation_r, rowvar=True)

  ssdiff = np.sum((mu_f - mu_r)**2.0)
  covmean = sqrtm(sigma_f.dot(sigma_r))
# check and correct imaginary numbers from sqrt
  if iscomplexobj(covmean):
    covmean = covmean.real
	  # calculate score
  fid = ssdiff + trace(sigma_f + sigma_r - 2.0 * covmean)
  return(fid)

### VAE::MNIST


In [None]:
to_3 = nn.Conv2d(1,3,kernel_size =1 ).cuda()
FID = []
inception_network = inception_v3()
for batch_idx, (data, _) in enumerate(test_loader):
  # real
  data = transforms.Resize(96)(data).cuda()

  # generate
  with torch.no_grad():
    z = torch.randn(100, 100).cuda()
    x_hat = F.leaky_relu(vae.decoder1(z),0.2) 
    x_hat = F.leaky_relu(vae.decoder2(x_hat),0.2)
    x_hat = F.leaky_relu(vae.decoder3(x_hat),0.2)
    sample = torch.sigmoid(vae.decoder4(x_hat))  
    img = transforms.Resize(96)(sample.view(100,1,28,28))
  fid = FID_test(data,img)
  FID.append(fid)
print('====> Avr FID'.format(np.mean(FID))
print('====> Std FID'.format(np.std(FID))

23.406281410678833
5.457478430393551


In [None]:
inception_network = inception_v3().cuda()



### GANs::MNIST


In [None]:
to_3 = nn.Conv2d(1,3,kernel_size =1).cuda()
FID = []
for batch_idx, (data, _) in enumerate(test_loader):
  #real
  print(data.shape)
  data = transforms.Resize(96)(data)
  data = to_3(data.cuda())

  #generate
  with torch.no_grad():
    test_z = Variable(torch.randn(100, 100).to(device))
    generated = G(test_z)
    generated = generated.view(generated.size(0), 1, 28, 28)
  img = transforms.Resize(96)(generated)
  img = to_3(img).cuda()


  fid = FID_test(data,img).cuda()
  FID.append(fid)
print('====> Avr FID'.format(np.mean(FID)))
print('====> Std FID'.format(np.std(FID)))

torch.Size([100, 1, 28, 28])


RuntimeError: ignored