In [None]:
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os
import os.path
import random
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import ntpath
import sys
from skimage.transform import resize
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
import time
import importlib
from collections import OrderedDict
import shutil
from util import util
import imageio
from util import html
from util.visualizer import save_images
import torch.multiprocessing as mp
from util.visualizer import Visualizer

if mp.get_start_method() != 'spawn':
    mp.set_start_method('spawn')



# Data

In [None]:
class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return "BaseDataset"

    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    def initialize(self, opt):
        pass

    def __len__(self):
        return 0

In [None]:
def get_transform(opt):
    transform_list = []
    if opt.resize_or_crop == "resize_and_crop":
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Resize(osize, Image.BICUBIC))
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == "crop":
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == "scale_width":
        transform_list.append(
            transforms.Lambda(lambda img: __scale_width(img, opt.fineSize))
        )
    elif opt.resize_or_crop == "scale_width_and_crop":
        transform_list.append(
            transforms.Lambda(lambda img: __scale_width(img, opt.loadSize))
        )
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == "none":
        transform_list.append(transforms.Lambda(lambda img: __adjust(img)))
    else:
        raise ValueError(
            "--resize_or_crop %s is not a valid option." % opt.resize_or_crop
        )

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.RandomHorizontalFlip())

    transform_list += [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
    return transforms.Compose(transform_list)


# just modify the width and height to be multiple of 4
def __adjust(img):
    ow, oh = img.size

    # the size needs to be a multiple of this number,
    # because going through generator network may change img size
    # and eventually cause size mismatch error
    mult = 4
    if ow % mult == 0 and oh % mult == 0:
        return img
    w = (ow - 1) // mult
    w = (w + 1) * mult
    h = (oh - 1) // mult
    h = (h + 1) * mult

    if ow != w or oh != h:
        __print_size_warning(ow, oh, w, h)

    return img.resize((w, h), Image.BICUBIC)


def __scale_width(img, target_width):
    ow, oh = img.size

    # the size needs to be a multiple of this number,
    # because going through generator network may change img size
    # and eventually cause size mismatch error
    mult = 4
    assert target_width % mult == 0, (
        "the target width needs to be multiple of %d." % mult
    )
    if ow == target_width and oh % mult == 0:
        return img
    w = target_width
    target_height = int(target_width * oh / ow)
    m = (target_height - 1) // mult
    h = (m + 1) * mult

    if target_height != h:
        __print_size_warning(target_width, target_height, w, h)

    return img.resize((w, h), Image.BICUBIC)


def __print_size_warning(ow, oh, w, h):
    if not hasattr(__print_size_warning, "has_printed"):
        print(
            "The image size needs to be a multiple of 4. "
            "The loaded image size was (%d, %d), so it was adjusted to "
            "(%d, %d). This adjustment will be done to all images "
            "whose sizes are not multiples of 4" % (ow, oh, w, h)
        )
        __print_size_warning.has_printed = True

In [None]:
class BaseDataLoader():
    def __init__(self):
        pass

    def initialize(self, opt):
        self.opt = opt
        pass

    def load_data(self):
        return None

In [None]:
class AffineGANDataset(BaseDataset):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    def pre_process_img(self, path, convertRGB, w_offset, h_offset, flip):
        if not os.path.exists(path):
            if convertRGB:
                return np.zeros(
                    (self.opt.fineSize, self.opt.fineSize, 3), dtype=np.float32
                )
            else:
                return np.zeros((self.opt.fineSize, self.opt.fineSize), dtype=np.float32)

        image = Image.open(path)
        if convertRGB:
            image = image.convert("RGB")
        else:
            image = image.convert("L")
        image = image.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
        image = np.array(image)
        if not convertRGB:
            image = image[..., np.newaxis]
            image = np.tile(image, [1, 1, 3])
        image = transforms.ToTensor()(image)
        image = image[
            :,
            h_offset : h_offset + self.opt.fineSize,
            w_offset : w_offset + self.opt.fineSize,
        ]
        image = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(image)
        if not convertRGB:
            image[image < 0.5] = 0.0
            image[image >= 0.5] = 1.0
        if flip:
            idx = torch.LongTensor([i for i in range(image.size(2) - 1, -1, -1)])
            image = image.index_select(2, idx)

        return image

    def initialize(self, opt):
        self.opt = opt
        self.dir_AB = os.path.join(opt.dataroot, opt.phase, "img")
        self.AB_paths = []

        if not self.opt.no_patch:
            self.dir_AB_patch = os.path.join(opt.dataroot, opt.phase, "patch")
            self.AB_patch = []

        video_names = sorted([f for f in os.listdir(self.dir_AB) if "." not in f])
        self.sample_num = len(video_names)

        for sample_idx in range(self.sample_num):
            sample_name = video_names[sample_idx]
            self.AB_paths.append(os.path.join(self.dir_AB, sample_name))
            if not self.opt.no_patch:
                self.AB_patch.append(os.path.join(self.dir_AB_patch, sample_name))

        assert opt.resize_or_crop == "resize_and_crop"

    def __getitem__(self, index):
        w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
        h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))

        AB_path = self.AB_paths[index]
        img_names = sorted([f for f in os.listdir(AB_path)])

        flip = (not self.opt.no_flip) and random.random() < 0.5

        A_name = os.path.join(AB_path, img_names[0])
        A = self.pre_process_img(A_name, True, w_offset, h_offset, flip)
        ret_dict = {"A": A, "A_paths": AB_path}

        if self.opt.isTrain:
            if not self.opt.no_patch:
                # When Testing, the model doesn't need patches
                A_name = os.path.join(self.AB_patch[index], img_names[0])
                A_patch = self.pre_process_img(A_name, False, w_offset, h_offset, flip)
                ret_dict["A_patch"] = A_patch
                B_patch_list = []

            B_list = []
            np.random.seed()
            img_sample = range(1, len(img_names))
            img_sample = np.random.choice(
                img_sample, self.opt.train_imagenum, replace=True
            )
            for img_idx in range(self.opt.train_imagenum):
                sample_image_idx = img_sample[img_idx]

                B_name = os.path.join(AB_path, img_names[sample_image_idx])
                B = self.pre_process_img(B_name, True, w_offset, h_offset, flip)

                B_list.append(B)
                if not self.opt.no_patch:
                    B_name = os.path.join(
                        self.AB_patch[index], img_names[sample_image_idx]
                    )
                    B_patch = self.pre_process_img(
                        B_name, False, w_offset, h_offset, flip
                    )
                    B_patch_list.append(B_patch)

            ret_dict["B_list"] = B_list
            if not self.opt.no_patch:
                ret_dict["B_patch_list"] = B_patch_list

        return ret_dict

    def __len__(self):
        return self.sample_num

    def name(self):
        return "AffineGANDataset"

In [None]:
def find_dataset_using_name(dataset_name):
    dataset = None
    if dataset_name.lower() == "affinegan":
        dataset = AffineGANDataset

    if dataset is None:
        print(
            "There should be a subclass of BaseDataset with class name that matches %s in lowercase."
            % dataset_name
        )
        exit(0)

    return dataset

def get_option_setter(dataset_name):
    dataset_class = find_dataset_using_name(dataset_name)
    return dataset_class.modify_commandline_options


def create_dataset(opt):
    dataset = find_dataset_using_name(opt.dataset_mode)
    instance = dataset()
    instance.initialize(opt)
    print("dataset [%s] was created" % (instance.name()))
    return instance


def CreateDataLoader(opt):
    data_loader = CustomDatasetDataLoader()
    print(data_loader.name())
    data_loader.initialize(opt)
    return data_loader


class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return "CustomDatasetDataLoader"

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.dataset = create_dataset(opt)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batch_size,
            shuffle=not opt.serial_batches,
            num_workers=0,
        )

    def load_data(self):
        return self

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)

    def __iter__(self):
        for i, data in enumerate(self.dataloader):
            if i * self.opt.batch_size >= self.opt.max_dataset_size:
                break
            yield data

# Models

In [None]:
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)

In [None]:
class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6,
                 padding_type='reflect'):
        assert (n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        n_downsampling = 2

        down_base = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
                           bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]
        down_res = [nn.ReflectionPad2d(3),
                      nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
                                bias=use_bias),
                      norm_layer(ngf),
                      nn.ReLU(True)]
        for i in range(n_downsampling):
            mult = 2 ** i
            down_base += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]
            down_res += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                     stride=2, padding=1, bias=use_bias),
                           norm_layer(ngf * mult * 2),
                           nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            down_base += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                                  use_bias=use_bias)]
            down_res += [
                ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                            use_bias=use_bias)]
        up_all = []
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            up_all += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        up_all += [nn.ReflectionPad2d(3)]
        up_all += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        up_all += [nn.Tanh()]

        self.down_base = nn.Sequential(*down_base)
        self.down_res = nn.Sequential(*down_res)
        self.up_all = nn.Sequential(*up_all)

    def forward(self, input, base_stage, temporal_stage, isTrain):
        down_base = self.down_base(input)
        down_res = self.down_res(input)
        feature = base_stage * down_base + temporal_stage * down_res
        return self.up_all(feature), feature


# Define a resnet block
class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim),
                       nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

In [None]:
class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
        super(UnetGenerator, self).__init__()

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
                                             innermost=True, gpu_ids=gpu_ids)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
                                                 norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer, gpu_ids=gpu_ids)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer, gpu_ids=gpu_ids)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,
                                             gpu_ids=gpu_ids)
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
                                             norm_layer=norm_layer, gpu_ids=gpu_ids)

        self.model = unet_block

    def forward(self, input, base_stage, temporal_stage, isTrain):
        return self.model(input, input, base_stage, temporal_stage, isTrain)

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,
                 gpu_ids=[]):
        super(UnetSkipConnectionBlock, self).__init__()
        self.device = torch.device('cuda:{}'.format(gpu_ids[0])) if gpu_ids else torch.device('cpu')
        self.outermost = outermost
        self.innermost = innermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv_base = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                                  stride=2, padding=1, bias=use_bias)
        downrelu_base = nn.LeakyReLU(0.2, True)
        downnorm_base = norm_layer(inner_nc)
        downconv_res = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                                 stride=2, padding=1, bias=use_bias)
        downrelu_res = nn.LeakyReLU(0.2, True)
        downnorm_res = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)

            down_base = [downconv_base]
            down_res = [downconv_res]
            up_all = [uprelu, upconv, nn.Tanh()]


        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)

            down_base = [downrelu_base, downconv_base]
            down_res = [downrelu_res, downconv_res]
            up_all = [uprelu, upconv, upnorm]

        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)

            down_base = [downrelu_base, downconv_base, downnorm_base]
            down_res = [downrelu_res, downconv_res, downnorm_res]
            up = [uprelu, upconv, upnorm]
            up_all = up + [nn.Dropout(0.5)] if use_dropout else up


        self.down_base = nn.Sequential(*down_base)
        self.down_res = nn.Sequential(*down_res)
        self.up_all = nn.Sequential(*up_all)
        self.sub = submodule


    def forward(self, x_base, x_res, base_stage, temporal_stage, isTrain):
        if self.outermost:
            down_base = self.down_base(x_base)
            down_res = self.down_res(x_res)
            sub, feature = self.sub(down_base, down_res, base_stage, temporal_stage, isTrain)
            sub_up = self.up_all(sub)
            return sub_up, feature
        if self.innermost:
            concat_1 = base_stage * x_base + temporal_stage * x_res
            if isTrain:
                shape = x_base.shape
                noise = torch.cuda.FloatTensor(shape) if torch.cuda.is_available() else torch.FloatTensor(shape)
                torch.randn(shape, out=noise)
                concat_1 += noise * 0.01
            down_base = self.down_base(x_base)
            down_res = self.down_res(x_res)

            down = base_stage * down_base + temporal_stage * down_res
            if isTrain:
                shape = down_base.shape
                noise = torch.cuda.FloatTensor(shape) if torch.cuda.is_available() else torch.FloatTensor(shape)
                torch.randn(shape, out=noise)
                down += noise * 0.01
            sub_up = self.up_all(down)
            return torch.cat([concat_1, sub_up], 1), down
        else:
            concat_1 = base_stage * x_base + temporal_stage * x_res
            if isTrain:
                shape = x_base.shape
                noise = torch.cuda.FloatTensor(shape) if torch.cuda.is_available() else torch.FloatTensor(shape)
                torch.randn(shape, out=noise)
                concat_1 += noise * 0.01
            down_base = self.down_base(x_base)
            down_res = self.down_res(x_res)
            sub, feature = self.sub(down_base, down_res, base_stage, temporal_stage, isTrain)
            sub_up = self.up_all(sub)
            return torch.cat([concat_1, sub_up], 1), feature

In [None]:
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)


In [None]:
class PixelDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(PixelDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        self.net = [
            nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
            norm_layer(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]

        if use_sigmoid:
            self.net.append(nn.Sigmoid())

        self.net = nn.Sequential(*self.net)

    def forward(self, input):
        return self.net(input)

In [None]:
class AlphaDiscriminator(nn.Module):
    def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(AlphaDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        # self.net = [
        #     nn.Linear(input_nc, 256, bias=True),
        #     nn.ReLU(),
        #     nn.Linear(256, 256, bias=True),
        #     nn.ReLU(),
        #     nn.Linear(256, 1, bias=True)]
        self.net = [
            nn.Linear(input_nc, input_nc, bias=True),
            nn.LeakyReLU(0.2, True),
            nn.Linear(input_nc, input_nc, bias=True),
            nn.LeakyReLU(0.2, True),
            nn.Linear(input_nc, 1, bias=True)
        ]

        if use_sigmoid:
            self.net.append(nn.Sigmoid())

        self.net = nn.Sequential(*self.net)

    def forward(self, input):
        return self.net(input)

In [None]:
def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    elif norm_type == 'group':
        norm_layer = functools.partial(nn.GroupNorm, num_groups=64)
    elif norm_type == 'none':
        norm_layer = None
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer


def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)
    
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    if len(gpu_ids) > 0:
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)
    init_weights(net, init_type, gain=init_gain)
    return net


def get_scheduler(optimizer, opt):
    if opt.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
            return lr_l

        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    elif opt.lr_policy == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
    return scheduler

In [None]:
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02,
             gpu_ids=[]):
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netG == 'resnet_9blocks':
        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
    elif netG == 'resnet_6blocks':
        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
    elif netG == 'unet_64':
        net = UnetGenerator(input_nc, output_nc, 6, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                            gpu_ids=gpu_ids)
    elif netG == 'unet_128':
        net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
    elif netG == 'unet_256':
        net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
    return init_net(net, init_type, init_gain, gpu_ids)


def define_D(input_nc, ndf, netD,
             n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netD == 'basic':
        net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
    elif netD == 'n_layers':
        net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
    elif netD == 'pixel':
        net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
    return init_net(net, init_type, init_gain, gpu_ids)


def define_D_alpha(input_nc, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
    norm_layer = get_norm_layer(norm_type=norm)
    net = AlphaDiscriminator(input_nc, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
    return init_net(net, init_type, init_gain, gpu_ids)

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)

In [None]:
class BaseModel():

    # modify parser to add command line options,
    # and also change the default values if needed
    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    def name(self):
        return 'BaseModel'

    def test_all_frame(self):
        pass

    def initialize(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
        if opt.resize_or_crop != 'scale_width':
            torch.backends.cudnn.benchmark = True
        self.loss_names = []
        self.model_names = []
        self.visual_names = []
        self.image_paths = []
        self.optimizers = []

    def set_input(self, input):
        self.input = input

    def forward(self):
        pass

    # load and print networks; create schedulers
    def setup(self, opt, parser=None):
        if self.isTrain:
            self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers]

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.epoch)
        self.print_networks(opt.verbose)

    # make models eval mode during test time
    def eval(self):
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                net.eval()

    def train(self):
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                net.train()

    # used in test time, wrapping `forward` in no_grad() so we don't save
    # intermediate steps for backprop
    def test(self):
        with torch.no_grad():
            self.forward()

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def optimize_parameters(self):
        pass

    # update learning rate (called once every epoch)
    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
        # print('learning rate = %.7f' % lr)

    # return visualization images. train.py will display these images, and save the images to a html
    def get_current_visuals(self):
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                attr = getattr(self, name)
                if isinstance(attr, list):
                    for i in range(len(attr)):
                        visual_ret[name+str(i)] = attr[i]
                else:
                    visual_ret[name] = attr

        return visual_ret

    # return traning losses/errors. train.py will print out these errors as debugging information
    def get_current_losses(self):
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                # float(...) works for both scalar tensor and float number
                a = getattr(self, 'loss_' + name)
                if isinstance(a, list):
                    errors_ret[name] = a
                else:
                    errors_ret[name] = float(a)
        return errors_ret

    # save models to the disk
    # def save_networks(self, epoch):
    #     for name in self.model_names:
    #         if isinstance(name, str):
    #             save_filename = '%s_net_%s.pth' % (epoch, name)
    #             save_path = os.path.join(self.save_dir, save_filename)
    #             net = getattr(self, 'net' + name)

    #             if len(self.gpu_ids) > 0 and torch.cuda.is_available():
    #                 torch.save(net.module.cpu().state_dict(), save_path)
    #                 net.cuda(self.gpu_ids[0])
    #             else:
    #                 torch.save(net.cpu().state_dict(), save_path)

    def save_networks(self, epoch):
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)

                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    if isinstance(net, torch.nn.DataParallel):
                        torch.save(net.module.cpu().state_dict(), save_path)
                    else:
                        torch.save(net.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                else:
                    torch.save(net.cpu().state_dict(), save_path)


    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
        key = keys[i]
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
            if module.__class__.__name__.startswith('InstanceNorm') and \
                    (key == 'running_mean' or key == 'running_var'):
                if getattr(module, key) is None:
                    state_dict.pop('.'.join(keys))
            if module.__class__.__name__.startswith('InstanceNorm') and \
               (key == 'num_batches_tracked'):
                state_dict.pop('.'.join(keys))
        else:
            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

    # load models from the disk
    def load_networks(self, epoch):
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_net_%s.pth' % (epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print('loading the model from %s' % load_path)
                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path, map_location=str(self.device))
                if hasattr(state_dict, '_metadata'):
                    del state_dict._metadata

                # patch InstanceNorm checkpoints prior to 0.4
                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
                net.load_state_dict(state_dict)

    # print network information
    def print_networks(self, verbose):
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
        print('-----------------------------------------------')

    # set requies_grad=Fasle to avoid computation
    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

In [None]:
class AffineGANModel(BaseModel):
    def name(self):
        return "AffineGANModel"

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        # AffineGAN use instance norm
        parser.set_defaults(pool_size=0, no_lsgan=False, norm="instance")
        parser.set_defaults(dataset_mode="affineGAN")
        parser.set_defaults(netG="unet_256")
        if is_train:
            parser.add_argument(
                "--lambda_L1", type=float, default=100.0, help="weight for L1 loss"
            )

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = [
            "G_GAN_D1",
            "G_L1",
            "D_real",
            "D_fake",
            "G_GAN_D_alpha",
            "D_alpha",
            "img_recons",
        ]
        if not opt.no_patch:
            self.loss_names += [
                "G_GAN_patch",
                "D_real_patch",
                "D_fake_patch",
                "D_patch",
            ]

        if self.isTrain:
            self.visual_names = ["input_A", "fake_B", "real_B"]
            self.model_names = ["G", "D", "D_alpha"]
            if not opt.no_patch:
                self.model_names.append("D_Patch")

        else:  # during test time, only load Gs
            self.visual_names = ["input_A"] + ["fake_B_list"]
            self.model_names = ["G"]
        # load/define networks
        self.netG = define_G(
            opt.input_nc,
            opt.input_nc,
            opt.ngf,
            opt.netG,
            opt.norm,
            not opt.no_dropout,
            opt.init_type,
            opt.init_gain,
            self.gpu_ids,
        )

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = define_D(
                opt.input_nc + opt.output_nc,
                opt.ndf,
                opt.netD,
                opt.n_layers_D,
                opt.norm,
                use_sigmoid,
                opt.init_type,
                opt.init_gain,
                self.gpu_ids,
            )
            self.netD_alpha = define_D_alpha(
                opt.train_imagenum,
                opt.norm,
                use_sigmoid,
                opt.init_type,
                opt.init_gain,
                self.gpu_ids,
            )

            # define loss functions
            self.criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(
                self.device
            )
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(
                self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)
            )
            self.optimizer_D = torch.optim.Adam(
                self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)
            )
            self.optimizer_D_Alpha = torch.optim.Adam(
                self.netD_alpha.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)
            )

            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            self.optimizers.append(self.optimizer_D_Alpha)

            if not opt.no_patch:
                self.netD_Patch = define_D(
                    opt.input_nc + opt.output_nc,
                    opt.ndf,
                    opt.netD,
                    opt.n_layers_D,
                    opt.norm,
                    use_sigmoid,
                    opt.init_type,
                    opt.init_gain,
                    self.gpu_ids,
                )
                self.optimizer_D_Patch = torch.optim.Adam(
                    self.netD_Patch.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)
                )
                self.optimizers.append(self.optimizer_D_Patch)

    def get_alpha(self, f_t0, f_t, f_t0_res):
        return torch.abs(torch.sum((f_t - f_t0) * f_t0_res)) / (
            f_t0_res.norm() + 1e-6
        )

    def set_input(self, input):
        self.input_A = input["A"].to(self.device)
        self.image_paths = input["A_paths"]
        if self.isTrain:
            self.input_B_list = []
            self.input_B_patch_list = []
            for img_idx in range(self.opt.train_imagenum):
                self.input_B_list.append(input["B_list"][img_idx].to(self.device))
                if not self.opt.no_patch:
                    self.input_A_patch = input["A_patch"].to(self.device)
                    self.input_B_patch_list.append(
                        input["B_patch_list"][img_idx].to(self.device)
                    )

    def forward(self):
        if not self.opt.no_patch and self.isTrain:
            self.input_A_img_patch = self.input_A * self.input_A_patch

        self.t0_reconstruct, f_t0 = self.netG(self.input_A, 1.0, 0.0, self.isTrain)

        self.real_B_list = []
        self.fake_B_list = []
        self.B_reconstruct_img_list = []
        if not self.opt.no_patch:
            self.fake_B_img_patch_list = []
            self.real_B_img_patch_list = []

        alpha_list_torch = []

        _, f_t0_res = self.netG(self.input_A, 0.0, 1.0, self.isTrain)
        f_t0_res = torch.squeeze(f_t0_res)
        f_t0 = torch.squeeze(f_t0)

        for img_idx in range(self.opt.train_imagenum):
            real_B = self.input_B_list[img_idx]

            t_reconstruct, f_t = self.netG(real_B, 1.0, 0.0, self.isTrain)
            self.B_reconstruct_img_list.append(t_reconstruct)
            f_t = torch.squeeze(f_t)
            alpha = self.get_alpha(f_t0, f_t, f_t0_res)
            alpha_list_torch.append(alpha.view(1))

            fake_B, _ = self.netG(self.input_A, 1.0, float(alpha), self.isTrain)

            self.real_B_list.append(real_B)
            self.fake_B_list.append(fake_B)

            if not self.opt.no_patch:
                real_B_patch = self.input_B_patch_list[img_idx]
                real_B_img_patch = real_B * real_B_patch
                fake_B_img_patch = fake_B * real_B_patch
                self.fake_B_img_patch_list.append(fake_B_img_patch)
                self.real_B_img_patch_list.append(real_B_img_patch)

        self.alpha_list_torch = torch.stack(alpha_list_torch, dim=1)

        self.alpha_list_sample = torch.rand(1, self.opt.train_imagenum).to(self.device)

        self.fake_B = self.fake_B_list[0]
        self.real_B = self.real_B_list[0]

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        loss_D = 0
        for img_idx in range(self.opt.train_imagenum):
            fake_AB = torch.cat((self.input_A, self.fake_B_list[img_idx]), 1)
            pred_fake = self.netD(fake_AB.detach())
            self.loss_D_fake = self.criterionGAN(pred_fake, False)

            # Real
            real_AB = torch.cat((self.input_A, self.real_B_list[img_idx]), 1)
            pred_real = self.netD(real_AB.detach())
            self.loss_D_real = self.criterionGAN(pred_real, True)

            # Combined loss
            loss_D += (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D = loss_D / (self.opt.train_imagenum + 0.0)
        self.loss_D.backward()

    def backward_D_patch(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        loss_D_patch = 0
        for img_idx in range(self.opt.train_imagenum):
            fake_AB_patch = torch.cat(
                (self.input_A_img_patch, self.fake_B_img_patch_list[img_idx]), 1
            )
            pred_fake_patch = self.netD_Patch(fake_AB_patch.detach())
            self.loss_D_fake_patch = self.criterionGAN(pred_fake_patch, False)

            # Real
            real_AB_patch = torch.cat(
                (self.input_A_img_patch, self.real_B_img_patch_list[img_idx]), 1
            )
            pred_real_patch = self.netD_Patch(real_AB_patch.detach())
            self.loss_D_real_patch = self.criterionGAN(pred_real_patch, True)

            # Combined loss
            loss_D_patch += (self.loss_D_fake_patch + self.loss_D_real_patch) * 0.5
        self.loss_D_patch = loss_D_patch / (self.opt.train_imagenum + 0.0)
        self.loss_D_patch.backward()

    def backward_D_alpha(self):
        # Fake
        # stop backprop to the generator by detaching fake_B

        pred_fake_alpha = self.netD_alpha(self.alpha_list_torch.detach())
        pred_true_alpha = self.netD_alpha(self.alpha_list_sample.detach())

        self.loss_D_fake_alpha = self.criterionGAN(pred_fake_alpha, False)
        self.loss_D_real_alpha = self.criterionGAN(pred_true_alpha, True)
        # Combined loss
        self.loss_D_alpha = (
            (self.loss_D_fake_alpha + self.loss_D_real_alpha) * 0.5 * self.opt.lambda_A
        )
        self.loss_D_alpha.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        loss_G = 0
        loss_G_GAN_D1 = 0
        loss_G_GAN_patch = 0
        loss_G_L1 = 0
        img_recons_loss = 0

        pred_fake_alpha = self.netD_alpha(self.alpha_list_torch)
        loss_G_GAN_D_alpha = (
            self.criterionGAN(pred_fake_alpha, True) * self.opt.lambda_A
        )

        for img_idx in range(self.opt.train_imagenum):
            fake_AB = torch.cat((self.input_A, self.fake_B_list[img_idx]), 1)
            pred_fake = self.netD(fake_AB)
            current_loss_G_GAN_D1 = self.criterionGAN(pred_fake, True)
            loss_G_GAN_D1 += current_loss_G_GAN_D1 / (self.opt.train_imagenum + 0.0)

            # First_2, G(A) should fake the discriminator_patch
            if not self.opt.no_patch:
                fake_AB_patch = torch.cat(
                    (self.input_A_img_patch, self.fake_B_img_patch_list[img_idx]), 1
                )
                pred_fake_patch = self.netD_Patch(fake_AB_patch)
                current_loss_G_GAN_patch = self.criterionGAN(pred_fake_patch, True)
                loss_G_GAN_patch += current_loss_G_GAN_patch / (
                    self.opt.train_imagenum + 0.0
                )
                loss_G += current_loss_G_GAN_patch

            # Second, G(A) = B
            current_loss_G_L1 = (
                self.criterionL1(self.fake_B_list[img_idx], self.real_B_list[img_idx])
                * self.opt.lambda_L1
            )
            loss_G_L1 += current_loss_G_L1 / (self.opt.train_imagenum + 0.0)
            current_img_recons_loss = (
                self.criterionL1(
                    self.B_reconstruct_img_list[img_idx], self.real_B_list[img_idx]
                )
                * 10.0
            )
            img_recons_loss = current_img_recons_loss / (self.opt.train_imagenum + 0.0)
            loss_G += (
                current_loss_G_GAN_D1 + current_loss_G_L1 + current_img_recons_loss
            )

        loss_G += loss_G_GAN_D_alpha
        loss_G /= self.opt.train_imagenum + 0.0
        self.loss_G = loss_G
        self.loss_G_GAN_D1 = loss_G_GAN_D1
        self.loss_G_GAN_patch = loss_G_GAN_patch
        self.loss_G_L1 = loss_G_L1
        self.loss_G_GAN_D_alpha = loss_G_GAN_D_alpha
        self.loss_img_recons = img_recons_loss
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.set_requires_grad(self.netD_alpha, True)
        self.optimizer_D_Alpha.zero_grad()
        self.backward_D_alpha()
        self.optimizer_D_Alpha.step()

        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        if not self.opt.no_patch:
            self.set_requires_grad(self.netD_Patch, True)
            self.optimizer_D_Patch.zero_grad()
            self.backward_D_patch()
            self.optimizer_D_Patch.step()
            self.set_requires_grad(self.netD_Patch, False)

        self.set_requires_grad(self.netD, False)
        self.set_requires_grad(self.netD_alpha, False)

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    # no backprop gradients
    def test(self):
        with torch.no_grad():
            self.fake_B_list = []
            for i in range(int(1.0 / self.opt.interval)):
                self.fake_B_list.append(
                    self.netG(self.input_A, 1.0, self.opt.interval * i, self.isTrain)[0]
                )

In [None]:
def find_model_using_name(model_name):
    model = None
    if model_name.lower() == "affinegan":
        model = AffineGANModel

    if model is None:
        print(
            "There should be a subclass of BaseModel with class name that matches %s in lowercase."
            % model_name
        )
        exit(0)

    return model

def get_option_setter(model_name):
    model_class = find_model_using_name(model_name)
    return model_class.modify_commandline_options


def create_model(opt):
    model = find_model_using_name(opt.model)
    instance = model()
    instance.initialize(opt)
    print("model [%s] was created" % (instance.name()))
    return instance

# Options

In [None]:
class Params:
    def __init__(self):
        self.initialized = False
        self.isTrain = True
        self.dataroot = 'D:/AffineGAN-master/dataset/happy'
        self.batch_size = 1
        self.loadSize = 286
        self.display_winsize = 256
        self.fineSize = 256
        self.input_nc = 3
        self.output_nc = 3
        self.ngf = 64
        self.ndf = 64
        self.netD = 'basic'
        self.n_layers_D = 3
        self.netG = 'unet_256'
        self.gpu_ids = '0'
        self.name = 'happy'
        self.dataset_mode = 'affineGAN'
        self.model = 'affineGAN'
        self.epoch = 'best'
        self.num_threads = 1
        self.checkpoints_dir = 'D:/AffineGAN-master/check_param'
        self.norm = 'instance'
        self.serial_batches = False
        self.no_dropout = False
        self.max_dataset_size = float("inf")
        self.resize_or_crop = 'resize_and_crop'
        self.no_flip = False
        self.init_type = 'normal'
        self.init_gain = 0.02
        self.verbose = False
        self.suffix = ''
        self.no_patch = False

    def print_options(self):
        message = ''
        message += '----------------- Options ---------------\n'
        for k, v in sorted(vars(self).items()):
            comment = ''
            message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
        message += '----------------- End -------------------'
        print(message)

        # save to the disk
        expr_dir = os.path.join(self.checkpoints_dir, self.name)
        util.mkdirs(expr_dir)
        file_name = os.path.join(expr_dir, 'opt.txt')
        with open(file_name, 'wt') as opt_file:
            opt_file.write(message)
            opt_file.write('\n')

    def parse(self):
        self.print_options()
        self.gpu_ids = [int(id) for id in self.gpu_ids.split(',') if int(id) >= 0]
        if len(self.gpu_ids) > 0:
            torch.cuda.set_device(self.gpu_ids[0])
        self.initialized = True
        return self

In [None]:
class TrainParams(Params):
    def __init__(self):
        super(TrainParams, self).__init__()
        self.display_freq = 100
        self.display_ncols = 0
        self.display_id = 1
        self.display_server = "http://localhost"
        self.display_env = "main"
        self.display_port = 8097
        self.update_html_freq = 1000
        self.print_freq = 50
        self.save_latest_freq = 1000
        self.save_epoch_freq = 200
        self.continue_train = False
        self.epoch_count = 1
        self.phase = "train"
        self.niter = 10000
        self.niter_decay = 100
        self.beta1 = 0.5
        self.lr = 0.0002
        self.no_lsgan = False
        self.lambda_A = 100.0
        self.lambda_L1 = 100.0
        self.no_html = False
        self.lr_policy = "lambda"
        self.lr_decay_iters = 50
        self.w_pa = 1.0
        self.w_la = 1.0
        self.w_co = 1.0
        self.train_imagenum = 5
        self.isTrain = True

    def initialize(self, parser):
        parser = Params.initialize(self, parser)
        # Add custom arguments using self.
        self.display_freq = 100
        self.display_ncols = 0
        self.display_id = -1
        self.display_server = "http://localhost"
        self.display_env = "main"
        self.display_port = 8097
        self.update_html_freq = 1000
        self.print_freq = 50
        self.save_latest_freq = 1000
        self.save_epoch_freq = 200
        self.continue_train = False
        self.epoch_count = 1
        self.phase = "train"
        self.niter = 10000
        self.niter_decay = 100
        self.beta1 = 0.5
        self.lr = 0.0002
        self.no_lsgan = False
        self.lambda_A = 100.0
        self.lambda_L1 = 100.0
        self.no_html = False
        self.lr_policy = "lambda"
        self.lr_decay_iters = 50
        self.w_pa = 1.0
        self.w_la = 1.0
        self.w_co = 1.0
        self.train_imagenum = 5

        return parser


In [None]:
class TestParams(Params):
    def __init__(self):
        super(TestParams, self).__init__()
        self.ntest = float("inf")
        self.results_dir = "D:/AffineGAN-master/results_best"
        self.aspect_ratio = 1.0
        self.phase = "test"
        self.num_test = 100
        self.interval = 0.05
        self.eval = True
        self.loadSize = 256
        self.w_pa = 1.0
        self.w_la = 1.0
        self.w_co = 1.0
        self.isTrain = False

    def initialize(self, parser):
        parser = Params.initialize(self, parser)
        # Add custom arguments using self.
        self.ntest = float("inf")
        self.results_dir = "D:/AffineGAN-master/results_best"
        self.aspect_ratio = 1.0
        self.phase = "test"
        self.num_test = 100
        self.interval = 0.05
        self.eval = True
        self.loadSize = 256
        self.w_pa = 1.0
        self.w_la = 1.0
        self.w_co = 1.0

        return parser

# Train

In [None]:
import itertools
import json

def train():
    best_loss = float('inf')  # Giá trị lỗi tốt nhất
    best_params = None

    params_netD = ['basic', 'n_layers', 'pixel']
    params_nLayerD = [1, 3, 5]
    params_initGain = [0.01, 0.02, 0.05, 0.07, 0.1]
    params_noPatch = [True, False]
    params_lrPolicy = ['lambda' ,'step', 'plateau', 'cosine']
    params_noLsgan = [True, False]
    params_beta1 = [0.1, 0.2, 0.4, 0.3, 0,5, 0.6, 0.7]
    params_lr = [0.0001, 0.0002, 0.0025]
    # params_lambdaA = [30.0, 70.0, 100.0, 130.0, 150.0]
    # params_lambdaL1 = [30.0, 70.0, 100.0, 130.0, 150.0]

    param_combinations = itertools.product(params_netD, params_nLayerD, params_initGain,
                                           params_noPatch, params_lrPolicy, params_noLsgan,
                                           params_beta1, params_lr)
    for params in param_combinations:
        params_netD, params_nLayerD, params_initGain, params_noPatch, params_lrPolicy, params_noLsgan, params_beta1, params_lr = params
        opt = TrainParams().parse()
        opt.netD = params_netD
        opt.n_layers_D = params_nLayerD
        opt.init_gain = params_initGain
        opt.no_patch = params_noPatch
        opt.lr_policy = params_lrPolicy
        opt.no_lsgan = params_noLsgan
        opt.beta1 = params_beta1
        opt.lr = params_lr
        # opt.lambda_A = params_lambdaA
        # opt.lambda_L1 = params_lambdaL1
        
        data_loader = CreateDataLoader(opt)
        dataset = data_loader.load_data()
        dataset_size = len(data_loader)
        print("#training images = %d" % dataset_size)

        model = create_model(opt)
        model.setup(opt)
        visualizer = Visualizer(opt)
        total_steps = 0

        for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
            epoch_start_time = time.time()
            iter_data_time = time.time()
            epoch_iter = 0

            for i, data in enumerate(dataset):
                iter_start_time = time.time()
                visualizer.reset()
                total_steps += opt.batch_size
                epoch_iter += opt.batch_size
                model.set_input(data)
                model.optimize_parameters()

                if total_steps % opt.display_freq == 0:
                    save_result = total_steps % opt.update_html_freq == 0
                    visualizer.display_current_results(
                        model.get_current_visuals(), epoch, save_result
                    )

                if total_steps % opt.print_freq == 0:
                    losses = model.get_current_losses()
                    t = (time.time() - iter_start_time) / opt.batch_size
                    visualizer.print_current_losses(
                        epoch, epoch_iter, losses, t, iter_start_time - iter_data_time
                    )
                    if opt.display_id > 0:
                        visualizer.plot_current_losses(
                            epoch, float(epoch_iter) / dataset_size, opt, losses
                        )

                if total_steps % opt.save_latest_freq == 0:
                    print(
                        "saving the latest model (epoch %d, total_steps %d)"
                        % (epoch, total_steps)
                    )
                    model.save_networks("latest")

                iter_data_time = time.time()
            if epoch % opt.save_epoch_freq == 0:
                print(
                    "saving the model at the end of epoch %d, iters %d"
                    % (epoch, total_steps)
                )
                model.save_networks("latest")
                model.save_networks(epoch)

            model.update_learning_rate()

            # Lưu mô hình tốt nhất
            if losses['G_L1'] < best_loss:
                print("Saving the best model (epoch %d, total_steps %d)" % (epoch, total_steps))
                best_loss = losses['G_L1']
                best_params = {'params_netD': params_netD, 
                               'params_nLayerD': params_nLayerD, 
                               'params_initGain': params_initGain,
                               'params_noPatch': params_noPatch, 
                               'params_lrPolicy': params_lrPolicy, 
                               'params_noLsgan': params_noLsgan,
                               'params_noLsgan': params_noLsgan, 
                               'params_beta1': params_beta1, 
                               'params_lr': params_lr}
                # Lưu thông tin về các tham số tối ưu thành file .txt
                with open('best_params.txt', 'w') as f:
                    json.dump(best_params, f, indent=4)
                model.save_networks("best")

In [None]:
train()

# Test

In [None]:
def generate():
    opt = TestParams().parse()
    # hard-code some parameters for test
    opt.num_threads = 1  # test code only supports num_threads = 1
    opt.batch_size = 1  # test code only supports batch_size = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip
    opt.display_id = -1  # no visdom display
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    model = create_model(opt)
    model.setup(opt)
    # create a website
    web_dir = os.path.join(opt.results_dir, opt.name, "%s_%s" % (opt.phase, opt.epoch))
    webpage = html.HTML(
        web_dir,
        "Experiment = %s, Phase = %s, Epoch = %s" % (opt.name, opt.phase, opt.epoch),
    )

    if opt.eval:
        model.eval()
    for i, data in enumerate(dataset):
        if i >= opt.num_test:
            break
        model.set_input(data)
        model.test()
        visuals = model.get_current_visuals()
        img_path = model.get_image_paths()
        if i % 5 == 0:
            print("processing (%04d)-th image... %s" % (i, img_path))
        save_images(
            webpage,
            visuals,
            img_path,
            aspect_ratio=opt.aspect_ratio,
            width=opt.display_winsize,
        )
    # save the website
    webpage.save()


In [None]:
generate()

In [None]:
class ImageToGifConverter:
    def __init__(self, exp_names, results_dir="./results/", epoch="best", phase="test", dataroot=None, interval=0.05):
        self.base_output_dir = "gifs"
        self.exp_list = exp_names.split(",")
        self.results_dir = results_dir
        self.epoch = epoch
        self.phase = phase
        self.dataroot = dataroot
        self.interval = interval

    def img2gif(self):
        for exp_name in self.exp_list:
            current_output_dir = os.path.join(self.results_dir, self.base_output_dir, exp_name)
            if not os.path.exists(current_output_dir):
                os.makedirs(current_output_dir)

            for sample_idx in os.listdir(os.path.join(self.dataroot, "test", "img")):
                filenames = []
                images = []
                num_str = sample_idx

                for i in range(int(1 / self.interval)):
                    c_name = os.path.join(
                        self.results_dir,
                        exp_name,
                        "%s_%s" % (self.phase, self.epoch),
                        "images",
                        "%s_fake_B_list%d.png" % (num_str, i),
                    )
                    filenames.append(c_name)

                for filename in filenames:
                    a = np.array(imageio.imread(filename))
                    images.append(a)

                output_dir = os.path.join(
                    current_output_dir, sample_idx + "_" + str(self.epoch) + ".gif"
                )
                imageio.mimsave(output_dir, images)


In [None]:
converter = ImageToGifConverter(exp_names="happy", results_dir="D:/AffineGAN-master/results_best", 
                                epoch="best", phase="test", dataroot="D:/AffineGAN-master/dataset/test_star", interval=0.05)
converter.img2gif()
