In [1]:
import sys
import os

In [2]:
os.chdir('pytorch-CycleGAN-and-pix2pix/')
os.getcwd()

'/common/home/jl2362/cs536/step3/pytorch-CycleGAN-and-pix2pix'

In [3]:
!which python3

/koko/system/anaconda/bin/python3


In [4]:
import torch
from torch import nn, optim
import torch.utils.data as tdata

import torchvision
import torchvision.transforms as tforms
import torchvision.utils as vutils
from torchvision.datasets import ImageFolder

from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

from PIL import Image
import numpy as np

In [5]:
# required cycleGAN packages to load data and models
from options.test_options import TestOptions
from models import create_model
from data import create_dataset

In [6]:
batch_size = 256
img_size = 256
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [7]:
def load_data(path: str, batch_size=64, num_workers=2):
    data_transform = tforms.Compose([
        tforms.Resize(img_size),
        tforms.CenterCrop(img_size),
        tforms.ToTensor(),
        tforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = ImageFolder(root=path, transform=data_transform)
    return tdata.DataLoader(
        dataset,
        shuffle=False,
        drop_last=False,
        batch_size=batch_size,
        num_workers=num_workers
    )

In [8]:
prerec_data = load_data('./datasets/eval_prerec', batch_size=128)
live_data = load_data('./datasets/eval_live', batch_size=128)
real_data = load_data('./datasets/eval_real', batch_size=128)

## Metric Evaluation Functions

In [19]:
def compute_fid(model: nn.Module, source_data: tdata.DataLoader, target_data: tdata.DataLoader):
    """Computes the FID for a model.
    
    Params:
        model -- a network translating the source domain to target domain
        target_data -- loads data from the target domain ('real' data)
        source_data -- loads data from the source domain (is translated into 'fake' data)
    Returns:
        FID value
    """
    fid = FrechetInceptionDistance()
    
    # feed real data in
    for target_batch in target_data:
        with torch.no_grad():
            # need to convert image to correct format
            img = target_batch[0].mul(0.5).add_(0.5).mul(255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8)
            fid.update(img, real=True)
    
    # feed fake/translated data in
    for source_batch in source_data:
        source_batch = source_batch[0].to(device)
        with torch.no_grad():
            fake_batch = model(source_batch)
            # need to convert image to correct format
            fid.update(fake_batch.mul(0.5).add_(0.5).mul(255).add_(0.5).clamp_(0, 255).to('cpu', dtype=torch.uint8), real=False)
    
    return fid.compute().item()

In [20]:
def compute_is(model: nn.Module, source_data: tdata.DataLoader):
    """Computes the IS for a model.
    
    Params:
        model -- a network translating the source domain to target domain
        source_data -- loads data from the source domain (is translated into 'fake' data)
    Returns:
        IS mean and stddev
    """
    inception = InceptionScore()
    
    # feed generated images
    for source_batch in source_data:
        source_batch = source_batch[0].to(device)
        with torch.no_grad():
            fake_batch = model(source_batch)
            # need to convert image to correct format
            inception.update(fake_batch.mul(0.5).add_(0.5).mul(255).add_(0.5).clamp_(0, 255).to('cpu', dtype=torch.uint8))
    
    is_mean, is_stddev = inception.compute()
    return is_mean.item(), is_stddev.item()

## Evaluate models between prerecorded pizza and real pizza

### prerec2real

In [11]:
sys.argv = "test.py --dataroot ./datasets/pizza --name prerec_pizza --model cycle_gan".split()
opt = TestOptions().parse()
# hard-code some parameters for test
opt.num_threads = 0   # test code only supports num_threads = 0
opt.batch_size = 1    # test code only supports batch_size = 1
opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1 
model = create_model(opt)
model.setup(opt)

----------------- Options ---------------
             aspect_ratio: 1.0                           
               batch_size: 1                             
          checkpoints_dir: ./checkpoints                 
                crop_size: 256                           
                 dataroot: ./datasets/pizza              	[default: None]
             dataset_mode: unaligned                     
                direction: AtoB                          
          display_winsize: 256                           
                    epoch: latest                        
                     eval: False                         
                  gpu_ids: 0                             
                init_gain: 0.02                          
                init_type: normal                        
                 input_nc: 3                             
                  isTrain: False                         	[default: None]
                load_iter: 0                            

In [12]:
prerec2real = model.netG_A
real2prerec = model.netG_B

In [13]:
# compute scores for prerecorded -> real
fid = compute_fid(prerec2real, prerec_data, real_data)
i_score = compute_is(prerec2real, prerec_data)

print(f'fid: {fid}')
print(f'is: {i_score[0]} \u00B1 {i_score[1]}')



fid: 200.0346221923828
is: 3.4202301502227783 ± 0.15140429139137268


In [21]:
# compute scores for real -> prerecorded
fid = compute_fid(real2prerec, real_data, prerec_data)
i_score = compute_is(real2prerec, real_data)

print(f'fid: {fid}')
print(f'is: {i_score[0]} \u00B1 {i_score[1]}')

fid: 154.47573852539062
is: 3.28467059135437 ± 0.15126541256904602


### live2real

In [22]:
sys.argv = "test.py --dataroot ./datasets/livepizza --name live_pizza --model cycle_gan".split()
opt = TestOptions().parse()
# hard-code some parameters for test
opt.num_threads = 0   # test code only supports num_threads = 0
opt.batch_size = 1    # test code only supports batch_size = 1
opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1 
model = create_model(opt)
model.setup(opt)

----------------- Options ---------------
             aspect_ratio: 1.0                           
               batch_size: 1                             
          checkpoints_dir: ./checkpoints                 
                crop_size: 256                           
                 dataroot: ./datasets/livepizza          	[default: None]
             dataset_mode: unaligned                     
                direction: AtoB                          
          display_winsize: 256                           
                    epoch: latest                        
                     eval: False                         
                  gpu_ids: 0                             
                init_gain: 0.02                          
                init_type: normal                        
                 input_nc: 3                             
                  isTrain: False                         	[default: None]
                load_iter: 0                            

In [23]:
live2real = model.netG_A
real2live= model.netG_B

In [24]:
# compute scores for live -> real
fid = compute_fid(live2real, live_data, real_data)
i_score = compute_is(live2real, live_data)

print(f'fid: {fid}')
print(f'is: {i_score[0]} \u00B1 {i_score[1]}')

fid: 36.53064727783203
is: 1.2937095165252686 ± 0.028916632756590843


In [25]:
# compute scores for real -> prerecorded
fid = compute_fid(real2live, real_data, live_data)
i_score = compute_is(real2live, real_data)

print(f'fid: {fid}')
print(f'is: {i_score[0]} \u00B1 {i_score[1]}')

fid: 65.1513900756836
is: 2.5775375366210938 ± 0.16500791907310486
