In [None]:
#-*- coding:utf-8 -*-
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
class CycleGAN(object):
    def __init__(self, genA2B, genB2A, discA, discB, classifier, device):
        self.genA2B = genA2B
        self.genB2A = genB2A
        self.discA = discA
        self.discB = discB
        self.classifier = classifier
        self.device = device
        self.cycle_trained = False
        self.classifier_trained = False

    def train(self,param):

        self.genA2B.apply(self.init_weights).to(self.device)
        self.genB2A.apply(self.init_weights).to(self.device)
        self.discA.apply(self.init_weights).to(self.device)
        self.discB.apply(self.init_weights).to(self.device)
        
        name = param.name
        input_nc_T = param.channels    # input channels
        output_nc_T = param.channels   # output channels
        epoch_T = 0       # starting epoch
        n_epochs_T = param.epochs   # number of epochs of training
        decay_epoch_T = np.ceil(n_epochs_T / param.lr_sched) # epoch to start linearly decaying the learning rate 
        lr_T = param.lr     # initial learning rate
        size_T = param.size        # image size (width or height), squared assumed
        batchSize_T = param.bs           # batchsize
        lambda_iden = param.lambdas[1]   # eig 5
        lambda_cyc  = param.lambdas[0]   # eig 10
        size_replay_buffer = param.size_replay_buffer
        resnet_blocks = param.resnet_blocks
        loss_adv = param.loss_adv
        loss_cyc_ide = param.loss_cyc_ide
        down_upsampling_layers = param.down_upsampling_layers

    
        cuda = torch.cuda.is_available()
        device = torch.device("cuda:0" if cuda else "cpu")
        print("using cuda: ", cuda)

        # inputs and targets memory allocation
        Tensor = torch.Tensor
        input_A_T = Tensor(batchSize_T, input_nc_T, size_T, size_T)
        input_B_T = Tensor(batchSize_T, output_nc_T, size_T, size_T)
        target_real_T = torch.ones((1),requires_grad=False).to(device) 
        target_fake_T = torch.zeros((1),requires_grad=False).to(device)

        # init replayBuffer
        fake_A_buffer_T = ReplayBuffer(param.size_replay_buffer)
        fake_B_buffer_T = ReplayBuffer(param.size_replay_buffer)

        # init networks
        netG_A2B_T = Small_Generator(input_nc_T, output_nc_T,n_residual_blocks=param.resnet_blocks,down_upsampling_layers=param.down_upsampling_layers)
        netG_B2A_T = Small_Generator(output_nc_T, input_nc_T,n_residual_blocks=param.resnet_blocks,down_upsampling_layers=param.down_upsampling_layers)
        netD_A_T = Small_Discriminator(input_nc_T)
        netD_B_T = Small_Discriminator(output_nc_T)

        #init losses
        losses = {"epoch": [], "adv_G_A2B": [],"adv_G_B2A": [],"adv_D_A": [], "adv_D_B": [], "cycle_loss": [], "identity_loss": []}

        # init weights and putting them to device
        netG_A2B_T.apply(weights_init_normal).to(device)
        netG_B2A_T.apply(weights_init_normal).to(device)
        netD_A_T.apply(weights_init_normal).to(device)
        netD_B_T.apply(weights_init_normal).to(device)

        # define lossfunctions
        criterion_GAN_T = param.loss_adv
        criterion_cycle_T = param.loss_cyc_ide
        criterion_identity_T = param.loss_cyc_ide


        # define optimizers
        optimizer_G_T = torch.optim.Adam(itertools.chain(netG_A2B_T.parameters(), netG_B2A_T.parameters()), lr=lr_T, betas=(0.5, 0.999))
        optimizer_D_A_T = torch.optim.Adam(netD_A_T.parameters(), lr=lr_T, betas=(0.5, 0.999))
        optimizer_D_B_T = torch.optim.Adam(netD_B_T.parameters(), lr=lr_T, betas=(0.5, 0.999))

        # define learning rate schedulers
        lr_sched_G   = torch.optim.lr_scheduler.LambdaLR(optimizer_G_T, lr_lambda=LambdaLR(n_epochs_T, epoch_T, decay_epoch_T).step)
        lr_sched_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A_T, lr_lambda=LambdaLR(n_epochs_T, epoch_T, decay_epoch_T).step)
        lr_sched_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B_T, lr_lambda=LambdaLR(n_epochs_T, epoch_T, decay_epoch_T).step)

        # define transformations for data
        if input_ch == 1:
            transforms_T = [
                transforms.Resize(int(size_T*1.12), Image.BICUBIC), 
                transforms.RandomCrop(size_T), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5), (0.5))
            ]
        else:
            transforms_T = [
                transforms.Resize(int(size_T*1.12), Image.BICUBIC), 
                transforms.RandomCrop(size_T), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]

        
        # create dataloader
        train_dataset_T = ImageDataset(pathA_Train, pathB_Train, transforms_ = transforms_T,  unaligned=True, rgb = False)
        train_loader_T = DataLoader(train_dataset_T, batch_size=batchSize_T, shuffle = True) 

        # putting all nets to training mode
        netG_A2B_T.train()
        netG_B2A_T.train()
        netD_A_T.train()
        netD_B_T.train()    

        for epoch in range(epoch_T,n_epochs_T):
            tic = time.time()
            for i, batch in enumerate(train_loader_T):
                inner_tic = time.time()
                # model input
                real_A = Variable(input_A_T.copy_(batch['A'])).to(device)
                real_B = Variable(input_B_T.copy_(batch['B'])).to(device)

                ###### Train Generators A2B and B2A #####
                optimizer_G_T.zero_grad()

                # GAN (adversial) loss
                # aka loss is small, if disc thiks generated sample looks real
                fake_B = netG_A2B_T(real_A)
                loss_GAN_A2B = criterion_GAN_T(netD_B_T(fake_B), target_real_T.expand_as(netD_B_T(fake_B)))
                fake_A = netG_B2A_T(real_B)
                loss_GAN_B2A = criterion_GAN_T(netD_A_T(fake_A), target_real_T.expand_as(netD_A_T(fake_A)))
                
                gan_loss = loss_GAN_A2B + loss_GAN_B2A

                # Cycle loss
                # aka loss is small if recovered image is similar to original 
                recovered_A = netG_B2A_T(fake_B)
                loss_cycle_ABA = criterion_cycle_T(recovered_A, real_A) 
                recovered_B = netG_A2B_T(fake_A)
                loss_cycle_BAB = criterion_cycle_T(recovered_B, real_B) 
                
                cycle_loss = loss_cycle_ABA + loss_cycle_BAB
                
                # Identity loss
                # G_A2B(B) should equal B if real B is fed
                same_B = netG_A2B_T(real_B)
                loss_identity_B = criterion_identity_T(same_B, real_B) 
                same_A = netG_B2A_T(real_A)
                loss_identity_A = criterion_identity_T(same_A, real_A) 
                
                identity_loss = loss_identity_A +loss_identity_B   
            
                # Total loss
                loss_G = gan_loss + (identity_loss * lambda_iden) + (cycle_loss * lambda_cyc)
                loss_G.backward()

                optimizer_G_T.step()

                ###### Train Discriminator A ######
                optimizer_D_A_T.zero_grad()

                # Real loss
                pred_real = netD_A_T(real_A)
                loss_D_real = criterion_GAN_T(pred_real, target_real_T.expand_as(pred_real))

                # Fake loss using a image buffer
                fake_A = fake_A_buffer_T.push_and_pop(fake_A)
                pred_fake = netD_A_T(fake_A.detach())
                loss_D_fake = criterion_GAN_T(pred_fake, target_fake_T.expand_as(pred_fake))

                # Total loss
                loss_D_A = (loss_D_real + loss_D_fake) / 2
                loss_D_A.backward()

                optimizer_D_A_T.step()

                ###### Train Discriminator B #####
                optimizer_D_B_T.zero_grad()

                # Real loss
                pred_real = netD_B_T(real_B)
                loss_D_real = criterion_GAN_T(pred_real, target_real_T.expand_as(pred_real))
                
                # Fake loss
                fake_B = fake_B_buffer_T.push_and_pop(fake_B)
                pred_fake = netD_B_T(fake_B.detach())
                loss_D_fake = criterion_GAN_T(pred_fake, target_fake_T.expand_as(pred_fake))

                # Total loss
                loss_D_B = (loss_D_real + loss_D_fake) / 2
                loss_D_B.backward()

                optimizer_D_B_T.step()

                inner_tac = time.time()
                
                losses["epoch"].append(epoch)
                losses["adv_G_A2B"].append(loss_GAN_A2B.detach().cpu().numpy())
                losses["adv_G_B2A"].append(loss_GAN_B2A.detach().cpu().numpy())
                losses["adv_D_A"].append(loss_D_A.detach().cpu().numpy())
                losses["adv_D_B"].append(loss_D_B.detach().cpu().numpy())
                losses["cycle_loss"].append(cycle_loss.detach().cpu().numpy())
                losses["identity_loss"].append(identity_loss.detach().cpu().numpy())

                print("batch {} done in {} seconds, cycle_loss:{}".format(i+1,np.round(inner_tac-inner_tic, decimals = 4),cycle_loss))

            # save the last model
            if (epoch==n_epochs_T-1) :
                save_model(netG_A2B_T, checkpointsFolder, "netG_A2B_MNIST", epoch+1, param)
                save_model(netG_B2A_T, checkpointsFolder, "netG_B2A_MNIST", epoch+1, param)
                save_model(netD_A_T, checkpointsFolder, "netD_A_MNIST", epoch+1, param)
                save_model(netD_B_T, checkpointsFolder, "netD_B_MNIST", epoch+1, param)
            
            # save losses per epoch
            losses["epoch"].append(epoch)
            losses["adv_G_A2B"].append(loss_GAN_A2B.detach().cpu().numpy())
            losses["adv_G_B2A"].append(loss_GAN_B2A.detach().cpu().numpy())
            losses["adv_D_A"].append(loss_D_A.detach().cpu().numpy())
            losses["adv_D_B"].append(loss_D_B.detach().cpu().numpy())
            losses["cycle_loss"].append(cycle_loss.detach().cpu().numpy())
            losses["identity_loss"].append(identity_loss.detach().cpu().numpy())
            # backup losses
            np.save(checkpointsFolder+"lossesMNIST_{}.npy".format(param.name), losses)     
            tac = time.time()
            print("epoch {} of {} finished in {} seconds, cycle_loss: {}".format(epoch+1,n_epochs_T, 
                                                                                np.round(tac-tic, decimals = 3), (cycle_loss))) 

            # update learning rates
            lr_sched_G.step()
            lr_sched_D_A.step()
            lr_sched_D_B.step()

        


    def train_classifier(self, param):
        if self.classifier_trained == True:
            print("classifier already trained")
        

    def eval(self, img, target_domain): 
    # eval function to map an image to target domain
    # and give a certantinty "quality"-measure of the result
        softy = nn.Softmax(dim = 1)
        if target_domain == "A":
            gen_img   = self.genB2A(img)
            certantiy = np.round(softy(self.classifier(img).detach()).cpu().numpy(), decimals = 3)[1]
        elif target_domain == "B":
            result = self.genA2B(img)
            certantiy = np.round(softy(self.classifier(img).detach()).cpu().numpy(), decimals = 3)[0]
        else: 
            print("ALARM") #t.b.a. vllt ne echte exception
        return result, certantiy        

    #def continue_train():
        # hier richtige parameter wie aktuelle lr und so beachten
        # vllt doch nur eine train funktion fuer fÃ¼r das cycleGan die etwas flexibler ist, statt train und countiue train?
    
   
    def init_weights(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.costant_(m.bias.data, 0.0)

    #def save_model():


    #def load_model():
        
        

In [None]:
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block       
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


In [None]:
class Small_Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=3, down_upsampling_layers = 2):
        super(Small_Generator, self).__init__()

        # Initial convolution block
        n = 32
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, n, 7),
                    nn.InstanceNorm2d(n),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        depth = down_upsampling_layers #not bigger than 4!
        in_features = n
        out_features = in_features*2
        for _ in range(depth):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(depth):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(n, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forwassrd(self, x):
        return self.model(x)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        return  x 

In [None]:
class Small_Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Small_Discriminator, self).__init__()

        # A bunch of convolutions one after another
        n = 32
        model = [   nn.Conv2d(input_nc, n, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(n, 2*n, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(2*n), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [nn.Conv2d(2*n, 1, 2, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
       
        return  x 

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)


In [None]:
class Param():
    def __init__(self, channels, epochs, size, name="default", down_upsampling_layers=2, lr = 0.0002, lr_sched = 2.0, size_replay_buffer = 50, resnet_blocks =9, loss_adv = torch.nn.MSELoss(), loss_cyc_ide = torch.nn.L1Loss() , lambdas=(10,0.5), bs=1):
        self.channels = channels
        self.epochs = epochs
        self.size = size
        self.name = name #which param and which value
        self.lr = lr
        self.lr_sched = lr_sched
        self.size_replay_buffer = size_replay_buffer
        self.resnet_blocks = resnet_blocks
        self.loss_adv = loss_adv
        self.loss_cyc_ide = loss_cyc_ide
        self.lambdas = lambdas
        self.bs = bs
        self.down_upsampling_layers = down_upsampling_layers
    

In [None]:
class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)


In [None]:
class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

In [None]:
class ImageDataset(Dataset):
    def __init__(self, pathA, pathB,transforms_ = None, unaligned=False,rgb = True ):
        self.pathA = pathA
        self.pathB = pathB
        self.unaligned = unaligned
        self.rgb = rgb
        #dont do transformation if there are no transforms..
        if(transforms_==None):
            self.dontTransform = True
        else:
            self.transform = transforms.Compose(transforms_)
            self.dontTransform = False

    def __len__(self):
        return max(len(listdir(self.pathA)), len(listdir(self.pathB)))
    
        
    def __getitem__(self, index):
        sampleA = Image.open(self.pathA + listdir(self.pathA)[index % len(listdir(self.pathA))])
        if self.unaligned:
            sampleB = Image.open(self.pathB + listdir(self.pathB)[random.randint(0, len(listdir(self.pathB))-1)])
        else:
            sampleB = Image.open(self.pathB + listdir(self.pathB)[index % len(listdir(self.pathB))])
        
        #transform image AND convert to RGB to fix grayscale image dimension problem
        if self.rgb:
            sampleA = sampleA.convert('RGB')
            sampleB = sampleB.convert('RGB')
        #dont do transformation if there are no transforms..
        if not self.dontTransform:
            sampleA = self.transform(sampleA)
            sampleB = self.transform(sampleB)     
     
        return {'A': sampleA, 'B': sampleB}

In [None]:
class Classifier():
    

SyntaxError: unexpected EOF while parsing (<ipython-input-12-867d0c768c0f>, line 2)

In [None]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
        return tensor