In [1]:
from os import path
from typing import Dict

import torch
from torch import nn, optim
import torch.utils.data as tdata

import torchvision
import torchvision.transforms as tforms
import torchvision.utils as vutils
from torchvision.datasets import ImageFolder

from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

from PIL import Image
import numpy as np

In [2]:
batch_size = 256
img_size = 256
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [3]:
def load_data(path: str, batch_size=64, num_workers=2):
    data_transform = tforms.Compose([
        tforms.Resize(img_size),
        tforms.CenterCrop(img_size),
        tforms.ToTensor(),
        tforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = ImageFolder(root=path, transform=data_transform)
    return tdata.DataLoader(
        dataset,
        shuffle=False,
        drop_last=False,
        batch_size=batch_size,
        num_workers=num_workers
    )

In [4]:
prerec_data = load_data('./prerec_test', batch_size=128)
live_data = load_data('./live_test', batch_size=128)
real_data = load_data('./real_test', batch_size=128)

# Architecture Definition and helper functions

The following definitions are copied over from the training code to allow us to load the models

In [5]:
def set_device(model: nn.Module, device, ngpu):
    """Transfers the models onto the specified device.
        
    Params:
        model -- which model to transfer
        device -- which device to use
        ngpus -- how many gpus to use, if using cuda device
    Returns:
        The transferred model
    """
    model.to(device)
    if device.type == 'cuda' and ngpu > 1:
        model = nn.DataParallel(model, list(range(ngpu)))
    return model

In [6]:
def load_network(model: nn.Module, dir_name, name):
    """Loads a state dictionary from the files in given path.
    
    Params:
        model -- the model to load
        dir_name -- directory from which to load the state
        name -- name that is prefixed on the files
    """
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(torch.load(path.join(dir_name, f'{name}.pth')))
    else:
        model.load_state_dict(torch.load(path.join(dir_name, f'{name}.pth')))

In [7]:
def conv_norm_relu(n_in, n_out, kernel_size, stride, padding=0, padding_mode='reflect', transpose=False, **kwargs):
    """Standard convolution -> instance norm -> relu block.

    Params:
        n_in -- number of input channels
        n_out -- number of filters/output channels
        kernel_size -- passed to Conv2d
        stride -- passed to Conv2d
        padding -- passed to Conv2d
        padding_mode -- passed to Conv2d
        transpose -- whether to use a regular or transposed convolution layer
        kwargs -- other args passed to Conv2d
    Returns:
        A list containing a convolution, instance norm, and ReLU activation
    """
    if transpose:
        conv = nn.ConvTranspose2d(n_in, n_out, kernel_size, stride, padding, padding_mode=padding_mode, bias=True, **kwargs)
    else:
        conv = nn.Conv2d(n_in, n_out, kernel_size, stride, padding, padding_mode=padding_mode, bias=True, **kwargs)
    return [conv, nn.InstanceNorm2d(n_out), nn.ReLU(True)]

In [8]:
def conv_norm_leakyrelu(n_in, n_out, slope=0.2, **kwargs):
    """Standard convolution -> instance norm -> leaky relu block.
    
    Params:
        n_in -- number of input channels
        n_out -- number of filters/output channels
        slope -- slope of the leaky ReLU layer
        kwargs -- other args passed to the convolution layer
    Returns:
        A list containing a convolution, instance norm, and LeakyReLU activation
    """
    conv = nn.Conv2d(n_in, n_out, **kwargs)
    return [conv, nn.InstanceNorm2d(n_out), nn.LeakyReLU(slope, True)]

In [9]:
class ResidualBlock(nn.Module):
    """Defines a residual block with 2 3x3 conv-norm-relu layers."""
    
    def __init__(self, k, p=None):
        """Initialize a residual block.
        
        Params:
            k -- number of input and output channels
            p -- dropout rate (optional)
        """
        super().__init__()
        model = conv_norm_relu(k, k, 3, 1, 1)
        model.append(nn.Conv2d(k, k, 3, 1, 1, padding_mode='reflect', bias=True))
        model.append(nn.InstanceNorm2d(k))
        if p is not None:
            model.append(nn.Dropout(p, inplace=True))
        self.block = nn.Sequential(*model)
    
    def forward(self, input):
        residual = self.block(input)
        residual += input  # apply skip-connection
        return residual

In [10]:
class Encoder(nn.Module):
    """Convolutional-Resnet style encoder."""

    def __init__(self, n_head, n_res, in_channel, n_filter):
        """Initialize an encoder.
        
        Params:
            n_head -- number of downsampling convolution blocks at the head
            n_res -- number of residual blocks in the middle
            in_channel -- number of channels in the input
            n_filter -- number of filters to start with; doubles for each block in the head
        """
        super().__init__()
        # initial convolution
        front = conv_norm_relu(in_channel, n_filter, 7, 1, 3)
        # downsampling convolution blocks
        for _ in range(n_head):
            front += conv_norm_relu(n_filter, 2 * n_filter, 4, 2, 1)
            n_filter *= 2
        # middle residual blocks
        front += [ResidualBlock(n_filter) for _ in range(n_res)]
        self.model = nn.Sequential(*front)
        self.out_channel = n_filter  # record the number of filters before the adjustment
    
    def forward(self, input):
        return self.model(input)

In [11]:
class LatentAE(nn.Module):
    """Shared latent space VAE. Contains both encoder and decoder."""

    def __init__(self, n_res, in_channel, p=None):
        """Initialize a VAE.

        Params:
            n_res -- number of residual blocks for both the encoder and decoder
            n_channels -- number of channels in the input
            p -- dropout probability used (optional)
        """
        super().__init__()
        self.enc = nn.Sequential(*[ResidualBlock(in_channel) for _ in range(n_res)])
        self.dec = nn.Sequential(*[ResidualBlock(in_channel, p) for _ in range(n_res)])
        self.is_dist = False
    
    def forward(self, input):
        latent_mean = self.enc(input)
        sample = latent_mean + torch.randn(latent_mean.size(), device=latent_mean.device)
        return latent_mean, self.dec(sample)
    

In [12]:
class Decoder(nn.Module):
    """Convolutional-Resnet style decoder."""

    def __init__(self, n_tail, n_res, in_channel, out_channel, p=None):
        """Initialize a decoder.

        Params:
            n_tail -- number of upsampling convolution blocks at the tail
            n_res -- number of residual blocks in the middle
            in_channel -- number of channels in the input
            out_channel -- desired number of channels in the output
            p -- dropout probability used (optional)
        """
        super().__init__()
        # residual blocks in the middle
        model = [ResidualBlock(in_channel, p) for _ in range(n_res)]
        # upsampling transposed convolution blocks
        for _ in range(n_tail):
            model += conv_norm_relu(in_channel, in_channel // 2, 4, 2, 1, padding_mode='zeros', transpose=True)
            in_channel //= 2
        # final convolution (use tanh)
        model += [nn.Conv2d(in_channel, out_channel, 7, 1, 3, padding_mode='reflect', bias=True), nn.Tanh()]
        self.model = nn.Sequential(*model)
    
    def forward(self, input):
        return self.model(input)

In [13]:
class Translator(nn.Module):
    """Wraps the models necessary to perform translation between two domains."""

    def __init__(self, d1: str, d2: str, n_channel, n_conv, n_res, n_shared, n_filter, p=None):
        """Initializes two VAEs with shared inner weights.

        Params:
            d1 -- name of first domain
            d2 -- name of second domain
            n_channel -- number of input channels of an image
            n_conv -- number of outermost conv/conv-tranpose blocks in the VAE
            n_res -- number of residual blocks in the middle layers of the VAE
            n_shared -- number of residual blocks that are shared
            n_filter -- number of filters to start with in the encoder
            p -- dropout probability in the decoders (optional)
        """
        super().__init__()
        d1_encoder = Encoder(n_conv, n_res, n_channel, n_filter)
        d2_encoder = Encoder(n_conv, n_res, n_channel, n_filter)
        d1_decoder = Decoder(n_conv, n_res, d1_encoder.out_channel, n_channel, p)
        d2_decoder = Decoder(n_conv, n_res, d2_encoder.out_channel, n_channel, p)
        self.encoders = nn.ModuleDict({d1: d1_encoder, d2: d2_encoder})
        self.decoders = nn.ModuleDict({d1: d1_decoder, d2: d2_decoder})
        self.shared = LatentAE(n_shared, d1_encoder.out_channel, p)

    def translate(self, input, source: str, target: str, keep_mean=True, requires_grad=True):
        """Translates a batch of images from the source domain to the target domain.

        Params:
            input -- input image (batch)
            source -- source domain
            target -- target domain (of translation)
            keep_mean -- whether to also return the latent space mean
            requires_grad -- whether to track the computation graph
        Returns:
            The translated image
        """
        if requires_grad:
            l_mean, encoding = self.shared(self.encoders[source](input))
            output = self.decoders[target](encoding)
        else:
            with torch.no_grad():
                l_mean, encoding = self.shared(self.encoders[source](input))
                output = self.decoders[target](encoding)
        if keep_mean:
            return l_mean, output
        else:
            return output

# Metric Evaluation Functions

In [14]:
def compute_fid(model: nn.Module, source_data: tdata.DataLoader, target_data: tdata.DataLoader):
    """Computes the FID for a model.
    
    Params:
        model -- a network translating the source domain to target domain
        target_data -- loads data from the target domain ('real' data)
        source_data -- loads data from the source domain (is translated into 'fake' data)
    Returns:
        FID value
    """
    fid = FrechetInceptionDistance()
    
    # feed fake/translated data in
    for source_batch in source_data:
        source_batch = source_batch[0].to(device)
        with torch.no_grad():
            fake_batch = model(source_batch)
            # need to convert image to correct format
            fid.update(fake_batch.mul(0.5).add_(0.5).mul(255).add_(0.5).clamp_(0, 255).to('cpu', dtype=torch.uint8), real=False)
    
    # feed real data in
    for target_batch in target_data:
        with torch.no_grad():
            # need to convert image to correct format
            img = target_batch[0].mul(0.5).add_(0.5).mul(255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8)
            fid.update(img, real=True)
    
    return fid.compute().item()

In [15]:
def compute_is(model: nn.Module, source_data: tdata.DataLoader):
    """Computes the IS for a model.
    
    Params:
        model -- a network translating the source domain to target domain
        source_data -- loads data from the source domain (is translated into 'fake' data)
    Returns:
        IS mean and stddev
    """
    inception = InceptionScore()
    
    # feed generated images
    for source_batch in source_data:
        source_batch = source_batch[0].to(device)
        with torch.no_grad():
            fake_batch = model(source_batch)
            # need to convert image to correct format
            inception.update(fake_batch.mul(0.5).add_(0.5).mul(255).add_(0.5).clamp_(0, 255).to('cpu', dtype=torch.uint8))
    
    is_mean, is_stddev = inception.compute()
    return is_mean.item(), is_stddev.item()

# Load Models

In [16]:
# architecture parameters (see DomainModels constructor for details)
n_channel = 3
n_conv = 3
n_res = 3
n_shared = 1  # number of shared layers, in both encoder and decoder
n_gen_filter = 64
n_scale = 1  # number of scales to apply MS-disc at
n_layer = 5  # number of layers for the discriminator
n_disc_filter = 64
p = 0.25

In [17]:
# setup translators between domains
AB_gen = Translator('A', 'B', n_channel, n_conv, n_res, n_shared, n_gen_filter, p)
BC_gen = Translator('B', 'C', n_channel, n_conv, n_res, n_shared, n_gen_filter, p)

load_network(AB_gen, 'checkpoints', 'AB_20')
load_network(BC_gen, 'checkpoints', 'BC_20')
AB_gen = set_device(AB_gen, device, 1)
BC_gen = set_device(BC_gen, device, 1)

# Evaluate models between prerecorded pizza and real pizza

### prerec2real

In [18]:
class Source2Target(nn.Module):
    """Wrapper for translating inputs from source to target, via the bridge."""
    
    def __init__(self, s2b: Translator, b2t: Translator, source: str, bridge: str, target: str):
        """Initialize the Module.
        
        Params:
            s2b -- translator to convert between source and bridge domains
            b2t -- translator to convert between bridge and target domains
            source -- source domain name
            bridge -- bridge domain name
            target -- target domain name
        """
        super().__init__()
        self.s2b = s2b
        self.b2t = b2t
        self.s = source
        self.b = bridge
        self.t = target
    
    def forward(self, input):
        bridge_img = self.s2b.translate(input, self.s, self.b, keep_mean=False, requires_grad=False)
        return self.b2t.translate(bridge_img, self.b, self.t, keep_mean=False, requires_grad=False)

In [19]:
prerec2real = Source2Target(AB_gen, BC_gen, 'A', 'B', 'C')
real2prerec = Source2Target(BC_gen, AB_gen, 'C', 'B', 'A')

In [20]:
# compute scores for prerecorded -> real
fid = compute_fid(prerec2real, prerec_data, real_data)
i_score = compute_is(prerec2real, prerec_data)

print(f'fid: {fid}')
print(f'is: {i_score[0]} \u00B1 {i_score[1]}')



fid: 158.96975708007812
is: 2.973461389541626 ± 0.2130606323480606


In [21]:
# compute scores for real -> prerecorded
fid = compute_fid(real2prerec, real_data, prerec_data)
i_score = compute_is(real2prerec, real_data)

print(f'fid: {fid}')
print(f'is: {i_score[0]} \u00B1 {i_score[1]}')

fid: 194.5352325439453
is: 2.964608669281006 ± 0.1492490917444229


### live2real

In [22]:
class Source2Target(nn.Module):
    """Wrapper for translating inputs directly from source to target (no bridge)."""
    
    def __init__(self, model: Translator, source: str, target: str):
        """Initialize the Module.
        
        Params:
            model -- translator to convert between source and target domains
            source -- source domain name
            bridge -- bridge domain name
            target -- target domain name
        """
        super().__init__()
        self.model = model
        self.s = source
        self.t = target
    
    def forward(self, input):
        return self.model.translate(input, self.s, self.t, keep_mean=False, requires_grad=False)

In [23]:
live2real = Source2Target(BC_gen, 'B', 'C')
real2live = Source2Target(BC_gen, 'C', 'B')

In [24]:
# compute scores for live -> real
fid = compute_fid(live2real, live_data, real_data)
i_score = compute_is(live2real, live_data)

print(f'fid: {fid}')
print(f'is: {i_score[0]} \u00B1 {i_score[1]}')

fid: 132.95069885253906
is: 2.061117172241211 ± 0.07366422563791275


In [25]:
# compute scores for real -> prerecorded
fid = compute_fid(real2live, real_data, live_data)
i_score = compute_is(real2live, real_data)

print(f'fid: {fid}')
print(f'is: {i_score[0]} \u00B1 {i_score[1]}')

fid: 89.31855773925781
is: 2.04569149017334 ± 0.07454477250576019
