In [None]:
# %pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
# %pip install pandas==1.3.3 pytorch-fid==0.2.1

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torchvision
import pytorch_fid

print(torch.__version__, torchvision.__version__, pytorch_fid.__version__, pd.__version__)

1.10.0+cu111 0.11.1+cu111 0.2.1 1.1.5


In [None]:
import random
def set_seeds(val=23):
    torch.manual_seed(val)
    np.random.seed(val)
    random.seed(val)

set_seeds()

In [None]:
!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz

--2021-12-15 19:07:39--  http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz
Resolving efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)... 128.32.244.190
Connecting to efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)|128.32.244.190|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30168306 (29M) [application/x-gzip]
Saving to: ‘facades.tar.gz’


2021-12-15 19:07:47 (3.55 MB/s) - ‘facades.tar.gz’ saved [30168306/30168306]



In [None]:
!tar -xf facades.tar.gz

In [None]:
!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/cityscapes.tar.gz

--2021-12-04 12:04:25--  http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/cityscapes.tar.gz
Resolving efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)... 128.32.244.190
Connecting to efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)|128.32.244.190|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 103441232 (99M) [application/x-gzip]
Saving to: ‘cityscapes.tar.gz’


2021-12-04 12:08:57 (371 KB/s) - ‘cityscapes.tar.gz’ saved [103441232/103441232]



In [None]:
!tar -xf cityscapes.tar.gz

In [None]:
from PIL import Image
import os
import glob

def split(X):
    # (n, h, 2*w, c)
    X = np.transpose(X, (0, 3, 1, 2))
    # (n, c, h, 2*w)
    X_1 = X[:, :, :, 256:]
    X_2 = X[:, :, :, :256]
    return np.stack((X_1, X_2), axis=1)
    # (n, 2, c, h, 2*w)

def load_dataset(name):
    X_train = []
    X_val = []
    X_test = []

    for file in sorted(glob.glob(name+'/train/*')):
        X_train.append(np.array(Image.open(file)))

    for file in sorted(glob.glob(name+'/val/*')):
        X_val.append(np.array(Image.open(file)))

    X_train = np.array(X_train)
    X_val = np.array(X_val)

    has_test = os.path.exists(name + '/test')
    if has_test:
        for file in sorted(glob.glob(name+'/test/*')):
            X_test.append(np.array(Image.open(file)))
        X_test = np.array(X_test)
    else:
        sz = int(len(X_val) * 0.4)
        X_test = X_val[:sz]
        X_val = X_val[sz:]
        
    return split(X_train), split(X_val), split(X_test)

In [None]:
X_train, X_val, X_test = load_dataset('facades')

In [None]:
Image.fromarray(X_train[0][0].transpose((1, 2, 0))).show()
Image.fromarray(X_train[0][1].transpose((1, 2, 0))).show()

In [None]:
print(X_train.shape, X_val.shape, X_test.shape, X_train.dtype)

(400, 2, 3, 256, 256) (100, 2, 3, 256, 256) (106, 2, 3, 256, 256) uint8


In [None]:
# adapted from shw2

from torchvision import models, transforms
import torchvision.transforms.functional_tensor as TF
import torchvision.transforms.functional as TR_F

@torch.no_grad()
def train_transform(X):
    X = torch.from_numpy(X.copy())
    UP = 286
    LOW = 256
    Y = torch.zeros(X.shape[0], 2, 3, UP, UP)
    for i in range(len(X)):
        tr = transforms.Resize(UP)
        Y[i][0] = TF.resize(X[i][0], UP)
        Y[i][1] = TF.resize(X[i][1], UP)
        #done resize(286)
        top, left, h, w = transforms.RandomCrop(LOW).get_params(Y[i][0], (LOW, LOW))
        X[i][0] = TF.crop(Y[i][0], top, left, h, w)
        X[i][1] = TF.crop(Y[i][1], top, left, h, w)
        #done randomCrop(256)
        if np.random.rand() < 0.5:
            X[i][0] = TF.hflip(X[i][0])
            X[i][1] = TF.hflip(X[i][1])
        
    X = TF.convert_image_dtype(X, torch.float32)
    X = X * 2 - 1
    #[-1;1] range
    return X

@torch.no_grad()
def test_transform(X):
    X = torch.from_numpy(X.copy())
    X = TF.convert_image_dtype(X, torch.float32)
    X = X * 2 - 1
    return X

#torch (c, h, w) to pil
@torch.no_grad()
def to_PIL(x):
    #[-1;1]
    x = (x + 1) / 2
    x = TF.convert_image_dtype(x, torch.uint8)
    x = TR_F.to_pil_image(x)
    return x

# accepts numpy, returns torch
data_transforms = {
    'train': train_transform,
    'val': test_transform,
    'test': test_transform
}


In [None]:
#!g1.1
# pix2pix https://arxiv.org/pdf/1611.07004.pdf , architecture and optimizer and etc parameters taken from paper
# and their github (details to match architecture) https://github.com/phillipi/pix2pix/blob/master/models.lua
# Unet adapted (changed a lot) from shw5

class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch, batch_norm=True, stride=2):
        super().__init__()
        
        mods = [nn.Conv2d(in_ch, out_ch, 4, stride, 1, bias=not batch_norm)]
        if batch_norm:
            mods.append(nn.InstanceNorm2d(out_ch))
        mods.append(nn.LeakyReLU(0.2))
        self.lay = nn.Sequential(*mods)

    def forward(self, x):
        x = self.lay(x)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=False):
        super().__init__()

        mods = [nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False)]
        mods.append(nn.InstanceNorm2d(out_ch))
        if dropout:
            mods.append(nn.Dropout(0.5))
        mods.append(nn.ReLU())
        self.lay = nn.Sequential(*mods)


    def forward(self, x, old_x):
        # (batch, c, h, w)
        if old_x is not None: # not None
            x = torch.cat((x, old_x), dim=1)
        x = self.lay(x)
        return x

    
class UNet(nn.Module):
    def __init__(self):
        
        super(UNet, self).__init__()
        self.down_blocks = []
        self.up_blocks = []
        
        self.down_blocks = [ #3x256x256
            DownBlock(3, 64, batch_norm=False), #64x128x128
            DownBlock(64, 128), #128x64x64
            DownBlock(128, 256), #256x32x32
            DownBlock(256, 512), #512x16x16
            DownBlock(512, 512), #512x8x8
            DownBlock(512, 512), #512x4x4
            DownBlock(512, 512), #512x2x2
            DownBlock(512, 512, batch_norm=False), #512x1x1
        ]
        
        self.up_blocks = [ #3x256x256
            UpBlock(512, 512, dropout=True),
            UpBlock(1024, 512, dropout=True),
            UpBlock(1024, 512, dropout=True),
            UpBlock(1024, 512),
            UpBlock(1024, 256),
            UpBlock(512, 128),
            UpBlock(256, 64),
            
        ]
        self.last_conv = nn.ConvTranspose2d(128, 3, 4, 2, 1)
        self.tanh = nn.Tanh()
        
        self.down_blocks = nn.ModuleList(self.down_blocks)
        self.up_blocks = nn.ModuleList(self.up_blocks)

    def forward(self, x):
        
        res = []
        for i in range(len(self.down_blocks)):
            x = self.down_blocks[i](x)
            res.append(x)
        
        res = list(reversed(res[:-1]))
        
        x = self.up_blocks[0](x, None)
        for i in range(1, len(self.up_blocks)):
            x = self.up_blocks[i](x, res[i - 1])
        
        x = self.last_conv(torch.cat((x, res[-1]), dim=1))
        x = self.tanh(x)
        
        return x  

In [None]:
class Discriminator(nn.Module):
    def __init__(self, use_input):
        super().__init__()
        self.use_input = use_input
        if use_input:
            in_ch = 6
        else:
            in_ch = 3
        #in_ch is 3 or 6, depending on whether we concat input to output
        self.blocks = nn.Sequential( #in_chx256x256
            DownBlock(in_ch, 64, batch_norm=False), #64x128x128
            DownBlock(64, 128), #128x64x64
            DownBlock(128, 256), #256x32x32
            DownBlock(256, 512, stride=1), #512x34x34
            nn.Conv2d(512, 1, 4, 1, 1), #1x30x30
        )
    
    def forward(self, x_in, x):
        if self.use_input:
            x = torch.cat((x, x_in), dim=1)
        
        return self.blocks(x)


In [None]:

@torch.no_grad()
def iterate_minibatches(X, batchsize, mode, todel, shuffle):
    
    if todel != 1:
        X = X[:len(X) // todel]

    X = data_transforms[mode](X).numpy()
    
    if shuffle:
        indices = np.random.permutation(len(X))
    
    for start_idx in range(0, len(X), batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = np.array(range(start_idx, min(len(X), start_idx + batchsize)))

        inp = torch.from_numpy(X[excerpt])
        yield inp[:, 0], inp[:, 1]

In [None]:
def train_epoch(model_g, model_d, opt_g, opt_d, DEVICE, df_frac, batch_size, use_disc, use_l1):
    model_g.train()
    
    if use_disc:
        model_d.train()
    
    total_loss_g = 0.0
    total_loss_d = 0.0

    cnter = 0
    
    for src, tgt in iterate_minibatches(X_train, batch_size, 'train', df_frac, True):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        output = model_g(src)
        
        #discriminator loss
        if use_disc:
            out_copy = output.clone().detach()
            disc_tgt = model_d(src, tgt)
            disc_output = model_d(src, out_copy)
            loss_d = F.binary_cross_entropy_with_logits(disc_output, torch.zeros_like(disc_output)) + \
                     F.binary_cross_entropy_with_logits(disc_tgt, torch.ones_like(disc_tgt))
            loss_d = loss_d / 2

            total_loss_d += loss_d.item()
            
            opt_d.zero_grad()
            loss_d.backward()
            opt_d.step()

        #generator loss

        l1_loss = 100 * F.l1_loss(output, tgt)

        loss_g = torch.tensor(0.0).to(DEVICE)

        if use_l1:
            loss_g = loss_g + l1_loss

        if use_disc:
            disc_output = model_d(src, output)
            loss_g = loss_g + F.binary_cross_entropy_with_logits(disc_output, torch.ones_like(disc_output))
        
        total_loss_g += loss_g.item()
        
        opt_g.zero_grad()
        loss_g.backward()
        opt_g.step()

        cnter += 1
    
    return total_loss_g / cnter, total_loss_d / cnter

@torch.no_grad()
def evaluate(model_g, model_d, DEVICE, df_frac, batch_size, use_disc, use_l1):
    #intentional
    model_g.train()
    
    if use_disc:
        model_d.train()
    
    total_loss_g = 0.0
    total_loss_d = 0.0

    cnter = 0
    
    for src, tgt in iterate_minibatches(X_val, batch_size, 'train', df_frac, False):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        output = model_g(src)

        l1_loss = 100 * F.l1_loss(output, tgt)
        
        loss_g = torch.tensor(0.0).to(DEVICE)
        loss_d = torch.tensor(0.0).to(DEVICE)
        
        if use_l1:
            loss_g = loss_g + l1_loss

        if use_disc:            
            disc_tgt = model_d(src, tgt)
            disc_output = model_d(src, output)
            loss_g = loss_g + F.binary_cross_entropy_with_logits(disc_output, torch.ones_like(disc_output))
            loss_d = F.binary_cross_entropy_with_logits(disc_output, torch.zeros_like(disc_output)) + \
                     F.binary_cross_entropy_with_logits(disc_tgt, torch.ones_like(disc_tgt))
            loss_d = loss_d / 2
        
        total_loss_g += loss_g.item()
        total_loss_d += loss_d.item()
        
        cnter += 1
    
    return total_loss_g / cnter, total_loss_d / cnter

In [None]:
@torch.no_grad()
def clear_folder(folder):
    os.makedirs(folder, exist_ok=True)
    files = glob.glob(folder+'/*')
    for f in files:
        os.remove(f)


In [None]:
@torch.no_grad()
def save_to_folder(dset_name, X_name):
    assert dset_name in {'facades', 'cityscapes'}
    if X_name == 'val':
        X = X_val
    elif X_name == 'train':
        X = X_train
    else:
        assert X_name == 'test'
        X = X_test
    folder = dset_name + '/fid/' + X_name + '_tgt/'
    os.makedirs(folder, exist_ok=True)
    files = glob.glob(folder+'*')
    for f in files:
        os.remove(f)
    cnt = 0
    for src, tgt in iterate_minibatches(X, 1, 'test', 1, False):
        x = to_PIL(tgt[0])
        cnt += 1
        x.save(folder + str(cnt) + '.jpg')
    return


In [None]:
@torch.no_grad()
def save_src_to_folder(dset_name, X_name):
    assert dset_name in {'facades', 'cityscapes'}
    if X_name == 'val':
        X = X_val
    elif X_name == 'train':
        X = X_train
    else:
        assert X_name == 'test'
        X = X_test
    folder = dset_name + '/fid/' + X_name + '_src/'
    os.makedirs(folder, exist_ok=True)
    files = glob.glob(folder+'*')
    for f in files:
        os.remove(f)
    cnt = 0
    for src, tgt in iterate_minibatches(X, 1, 'test', 1, False):
        x = to_PIL(src[0])
        cnt += 1
        x.save(folder + str(cnt) + '.jpg')
    return


In [None]:
@torch.no_grad()
def calculate_fid(dset_name, X_name, model, DEVICE='cuda', filename=None):
    import pytorch_fid.fid_score
    assert dset_name in {'facades', 'cityscapes'}
    if X_name == 'val':
        X = X_val
    elif X_name == 'train':
        X = X_train
    else:
        assert X_name == 'test'
        X = X_test
    tgt_folder = dset_name + '/fid/' + X_name + '_tgt/'
    out_folder = dset_name + '/fid/' + X_name + '_out/'
    os.makedirs(out_folder, exist_ok=True)
    cnt = 0
    for src, tgt in iterate_minibatches(X, 1, 'test', 1, False):
        src = src.to(DEVICE)
        output = model(src)
        x = to_PIL(output[0].cpu())
        if cnt == 0:
            x.show()
            if filename is not None:
                fold = dset_name + '/fid/' + X_name + '_sample/'
                os.makedirs(fold, exist_ok=True)
                x.save(fold + filename + '.jpg')
        cnt += 1
        x.save(out_folder + str(cnt) + '.jpg')

    fid = pytorch_fid.fid_score.calculate_fid_given_paths([tgt_folder, out_folder], 50, 'cuda', 2048, 8)
    return fid


In [None]:
save_to_folder('facades', 'train')
save_to_folder('facades', 'val')
save_to_folder('facades', 'test')
save_src_to_folder('facades', 'test')

In [None]:
#!g1.1
best_model_wts_g = None
best_model_wts_d = None

In [None]:
#!g1.1
import time
import torch.nn as nn
import copy
from IPython.display import clear_output

def train(num_epochs, model_g, model_d, opt_g, opt_d, scheduler, DEVICE, df_frac, batch_size, onlytest, use_disc, use_l1, dset_name):
    start_time = time.time()

    global best_model_wts_g
    global best_model_wts_d
    best_model_wts_g = copy.deepcopy(model_g.state_dict())
    best_model_wts_d = copy.deepcopy(model_d.state_dict())
    best_model_epoch = -1
    best_fid = 1e9

    lsttt = ['train', 'val']
    if onlytest:
        lsttt = lsttt[1:]
    
    train_loss_g_log = []
    val_loss_g_log = []
    train_loss_d_log = []
    val_loss_d_log = []
    train_fid_log = []
    val_fid_log = []
    xs_fid_log = []
    
    for epoch in range(num_epochs):
        
        BLOCK = 10
        
        if epoch % BLOCK == 0:
            clear_output()
        print('epoch', epoch, 'out of', num_epochs)
        
        for phase in lsttt:
            cur_time = time.time()
            
            print('doing phase', phase)
            
            if phase == 'train':
                loss_g, loss_d = train_epoch(model_g, model_d, opt_g, opt_d, DEVICE, df_frac, batch_size, use_disc, use_l1)
                train_loss_g_log.append(loss_g)
                train_loss_d_log.append(loss_d)
            else:
                loss_g, loss_d = evaluate(model_g, model_d, DEVICE, df_frac, batch_size, use_disc, use_l1)
                val_loss_g_log.append(loss_g)
                val_loss_d_log.append(loss_d)

            
            print('Generator loss is', loss_g)
            print('Discriminator loss is', loss_d)
            tt = time.time() - cur_time

            if epoch % BLOCK == 0 or epoch + 1 == num_epochs:
                sample_filename = None
                if epoch % 50 == 0 or epoch + 1 == num_epochs:
                    sample_filename = 'epoch_' + str(epoch)
                if phase == 'train':
                    train_fid = calculate_fid(dset_name, 'train', model_g, DEVICE, sample_filename)
                    train_fid_log.append(train_fid)
                    print('FID is', train_fid)
                else:
                    val_fid = calculate_fid(dset_name, 'val', model_g, DEVICE, sample_filename)
                    val_fid_log.append(val_fid)
                    xs_fid_log.append(epoch)
                    print('FID is', val_fid)
                    
                    if best_fid > val_fid:
                        best_fid = val_fid
                        best_model_wts_g = copy.deepcopy(model_g.state_dict())
                        best_model_wts_d = copy.deepcopy(model_d.state_dict())
                        best_model_epoch = epoch
                        print('made improvement', best_fid)
            
                if phase == 'val':
                    tt = time.time() - start_time
                    print('total elapsed time {:.0f}m {:.0f}s'.format(tt // 60, tt % 60))
    
    print('')
    print('best fid is', best_fid)
    print('got such fid at epoch', best_model_epoch)
    return

In [None]:
#!g1.1

model_g = UNet()
model_d = Discriminator(True)

for p in model_g.parameters():
    nn.init.normal_(p, 0, 0.02)
for p in model_d.parameters():
    nn.init.normal_(p, 0, 0.02)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_g = model_g.to(DEVICE)
model_d = model_d.to(DEVICE)


opt_g = torch.optim.Adam(model_g.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_d = torch.optim.Adam(model_d.parameters(), lr=0.0002, betas=(0.5, 0.999))

train(200, model_g, model_d, opt_g, opt_d, None, DEVICE, 1, 1, False, use_disc=True, use_l1=True, dset_name='facades')


In [None]:
#!g1.1
calculate_fid('facades', 'test', model_g, DEVICE, None)

In [None]:
#!g1.1
examples = data_transforms['test'](X_test[-3:])
results = model_g(examples[:, 0].cuda()).cpu()
for i in range(3):
    to_PIL(results[i]).show()

In [None]:
#!g1.1
model_facades_g = model_g.state_dict()
model_facades_d = model_d.state_dict()

In [None]:
X_train, X_val, X_test = load_dataset('cityscapes')

In [None]:
Image.fromarray(X_val[0][0].transpose((1, 2, 0))).show()
Image.fromarray(X_val[0][1].transpose((1, 2, 0))).show()

In [None]:
save_to_folder('cityscapes', 'train')
save_to_folder('cityscapes', 'val')
save_to_folder('cityscapes', 'test')
save_src_to_folder('cityscapes', 'test')

In [None]:
#!g1.1

model_g = UNet()
model_d = Discriminator(True)

for p in model_g.parameters():
    nn.init.normal_(p, 0, 0.02)
for p in model_d.parameters():
    nn.init.normal_(p, 0, 0.02)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_g = model_g.to(DEVICE)
model_d = model_d.to(DEVICE)

opt_g = torch.optim.Adam(model_g.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_d = torch.optim.Adam(model_d.parameters(), lr=0.0002, betas=(0.5, 0.999))

train(200, model_g, model_d, opt_g, opt_d, None, DEVICE, 1, 1, False, use_disc=True, use_l1=True, dset_name='cityscapes')


In [None]:
#!g1.1
calculate_fid('cityscapes', 'test', model_g, DEVICE, None)

In [None]:
#!g1.1
model_cityscapes_g = model_g.state_dict()
model_cityscapes_d = model_d.state_dict()