In [1]:
import torch
import torchvision
import torch.utils.data as data
import os
from os.path import join
from tqdm import tqdm
from torchvision.utils import save_image

**utils**

In [2]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

def check_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

class Img_to_zero_center(object):
    def __int__(self):
        pass
    def __call__(self, t_img):
        '''
        :param img:tensor be 0-1
        :return:
        '''
        t_img=(t_img-0.5)*2
        return t_img

class Reverse_zero_center(object):
    def __init__(self):
        pass
    def __call__(self,t_img):
        t_img=t_img/2+0.5
        return t_img

**DataLoader**

In [3]:
from torchvision.datasets.folder import pil_loader
from random import shuffle

class CACD(data.Dataset):
    def __init__(self,split="train",transforms=None, label_transforms=None):

        self.split=split

        #define label 128*128 for condition generate image
        list_root = "/kaggle/input/facial-expression-recognition/FER/lists"
        data_root = "/kaggle/input/facial-expression-recognition/FER/FER"
        list_emotion = ["anger", "contempt", "disgust", "fear", "happy", "neutral", "sad", "surprise"]
        
        self.condition128=[]
        full_one=np.ones((128,128),dtype=np.float32)
        for i in range(8):
            full_zero=np.zeros((128,128,8),dtype=np.float32)
            full_zero[:,:,i]=full_one
            self.condition128.append(full_zero)

        # define label 64*64 for condition discriminate image
        self.condition64 = []
        full_one = np.ones((64, 64),dtype=np.float32)
        for i in range(8):
            full_zero = np.zeros((64, 64, 8),dtype=np.float32)
            full_zero[:, :, i] = full_one
            self.condition64.append(full_zero)

        #define label_pairs
        label_pair_root="/kaggle/input/facial-expression-recognition/train_label_pair.txt"
        with open(label_pair_root,'r') as f:
            lines=f.readlines()
        lines=[line.strip() for line in lines]
        shuffle(lines)
        self.label_pairs=[]
        for line in lines:
            label_pair=[]
            items=line.split()
            label_pair.append(int(items[0]))
            label_pair.append(int(items[1]))
            self.label_pairs.append(label_pair)

        #define group_images
        group_lists = [
            os.path.join(list_root, 'train_emo_group_0.txt'),
            os.path.join(list_root, 'train_emo_group_1.txt'),
            os.path.join(list_root, 'train_emo_group_2.txt'),
            os.path.join(list_root, 'train_emo_group_3.txt'),
            os.path.join(list_root, 'train_emo_group_4.txt'),
            os.path.join(list_root, 'train_emo_group_5.txt'),
            os.path.join(list_root, 'train_emo_group_6.txt'),
            os.path.join(list_root, 'train_emo_group_7.txt')
        ]

        self.label_group_images = []
        for i in range(len(group_lists)):
            with open(group_lists[i], 'r') as f:
                lines = f.readlines()
                lines = [line.strip() for line in lines]
            group_images = []
            for l in lines:
                items = l.split()
                group_images.append(os.path.join(data_root,list_emotion[int(items[1])],items[0]))
            self.label_group_images.append(group_images)

        #define train.txt
        if self.split == "train":
            self.source_images = []#which use to aging transfer
            with open(os.path.join(list_root, 'train.txt'), 'r') as f:
                lines = f.readlines()
                lines = [line.strip() for line in lines]
            shuffle(lines)
            for l in lines:
                items = l.split()
                self.source_images.append(os.path.join(data_root,list_emotion[int(items[1])],items[0]))
        else:
            self.source_images = []  # which use to aging transfer
            with open(os.path.join(list_root, 'test.txt'), 'r') as f:
                lines = f.readlines()
                lines = [line.strip() for line in lines]
            shuffle(lines)
            for l in lines:
                items = l.split()
                self.source_images.append(os.path.join(data_root,list_emotion[int(items[1])],items[0]))

        #define pointer
        self.train_group_pointer=[0,0,0,0,0,0,0,0]
        self.source_pointer=0
        self.batch_size=32
        self.transforms=transforms
        self.label_transforms=label_transforms

    def __getitem__(self, idx):
        if self.split == "train":
            pair_idx=idx//self.batch_size #a batch train the same pair
            true_label=int(self.label_pairs[pair_idx][0])
            fake_label=int(self.label_pairs[pair_idx][1])

            true_label_128=self.condition128[true_label]
            true_label_64=self.condition64[true_label]
            fake_label_64=self.condition64[fake_label]

            true_label_img=pil_loader(self.label_group_images[true_label][self.train_group_pointer[true_label]]).resize((128,128))
            source_img=pil_loader(self.source_images[self.source_pointer])

            source_img_227=source_img.resize((227,227))
            source_img_128=source_img.resize((128,128))

            if self.train_group_pointer[true_label]<len(self.label_group_images[true_label])-1:
                self.train_group_pointer[true_label]+=1
            else:
                self.train_group_pointer[true_label]=0

            if self.source_pointer<len(self.source_images)-1:
                self.source_pointer+=1
            else:
                self.source_pointer=0

            if self.transforms != None:
                true_label_img=self.transforms(true_label_img)
                source_img_227=self.transforms(source_img_227)
                source_img_128=self.transforms(source_img_128)

            if self.label_transforms != None:
                true_label_128=self.label_transforms(true_label_128)
                true_label_64=self.label_transforms(true_label_64)
                fake_label_64=self.label_transforms(fake_label_64)
            #source img 227 : use it to extract face feature
            #source img 128 : use it to generate different age face -> then resize to (227,227) to extract feature, compile with source img 227
            #ture_label_img : img in target age group -> use to train discriminator
            #true_label_128 : use this condition to generate
            #true_label_64 and fake_label_64 : use this condition to discrimination
            #true_label : label

            return source_img_227,source_img_128,true_label_img,true_label_128,true_label_64,fake_label_64, true_label
        else:
            source_img_128=pil_loader(self.source_images[idx]).resize((128,128))
            if self.transforms != None:
                source_img_128=self.transforms(source_img_128)
            condition_128_tensor_li=[]
            if self.label_transforms != None:
                for condition in self.condition128:
                    condition_128_tensor_li.append(self.label_transforms(condition).cuda())
            return source_img_128.cuda(),condition_128_tensor_li

    def __len__(self):
        if self.split == "train":
            return len(self.label_pairs)
        else:
            return len(self.source_images)


transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
label_transforms=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])
CACD_dataset=CACD("train", transforms , label_transforms)
train_loader = torch.utils.data.DataLoader(
    dataset=CACD_dataset,
    batch_size=32,
    shuffle=True
)
for idx,(source_img_227,source_img_128,true_label_img,true_label_128,true_label_64,fake_label_64, true_label) in enumerate(train_loader):
    print(true_label)
    break

tensor([5, 1, 3, 6, 4, 6, 2, 6, 2, 2, 0, 4, 0, 5, 1, 5, 0, 7, 0, 5, 7, 4, 0, 5,
        7, 5, 2, 0, 5, 6, 0, 4])


**other architecture**

In [4]:
import torch.nn as nn

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):

    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        return out

In [5]:
from torch.nn import functional as F
import math
from torch.nn.parameter import Parameter
from torch.nn.functional import pad
from torch.nn.modules import Module
from torch.nn.modules.utils import _single, _pair, _triple

class _ConvNd(Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding, groups, bias):
        super(_ConvNd, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        if transposed:
            self.weight = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)

class Conv2d(_ConvNd):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias)

    def forward(self, input):
        return _conv2d_same_padding(input, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

# custom conv2d, because pytorch don't have "padding='same'" option.
def _conv2d_same_padding(input, weight, bias=None, stride=(1,1), padding=1, dilation=(1,1), groups=1):

    input_rows = input.size(2)
    filter_rows = weight.size(2)
    effective_filter_size_rows = (filter_rows - 1) * dilation[0] + 1
    out_rows = (input_rows + stride[0] - 1) // stride[0]
    padding_needed = max(0, (out_rows - 1) * stride[0] + effective_filter_size_rows -
                  input_rows)
    padding_rows = max(0, (out_rows - 1) * stride[0] +
                        (filter_rows - 1) * dilation[0] + 1 - input_rows)
    rows_odd = (padding_rows % 2 != 0)
    padding_cols = max(0, (out_rows - 1) * stride[0] +
                        (filter_rows - 1) * dilation[0] + 1 - input_rows)
    cols_odd = (padding_rows % 2 != 0)

    if rows_odd or cols_odd:
        input = pad(input, [0, int(cols_odd), 0, int(rows_odd)])

    return F.conv2d(input, weight, bias, stride,
                  padding=(padding_rows // 2, padding_cols // 2),
                  dilation=dilation, groups=groups)


In [6]:
import torch.nn as nn
import torch
import os
import torch.nn.functional as F

class AgeAlexNet(nn.Module):
    def __init__(self,pretrainded=False,modelpath=None):
        super(AgeAlexNet, self).__init__()
        assert pretrainded is False or modelpath is not None,"pretrain model need to be specified"
        self.features = nn.Sequential(
            Conv2d(3, 96, kernel_size=11, stride=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(2,2e-5,0.75),

            Conv2d(96, 256, kernel_size=5, stride=1,groups=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(2, 2e-5, 0.75),

            Conv2d(256, 384, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),

            Conv2d(384, 384, kernel_size=3,stride=1,groups=2),
            nn.ReLU(inplace=True),

            Conv2d(384, 256, kernel_size=3,stride=1,groups=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.age_classifier=nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 8),
        )
        if pretrainded is True:
            self.load_pretrained_params(modelpath)

        self.Conv3_feature_module=nn.Sequential()
        self.Conv4_feature_module=nn.Sequential()
        self.Conv5_feature_module=nn.Sequential()
        self.Pool5_feature_module=nn.Sequential()
        for x in range(10):
            self.Conv3_feature_module.add_module(str(x), self.features[x])
        for x in range(10,12):
            self.Conv4_feature_module.add_module(str(x),self.features[x])
        for x in range(12,14):
            self.Conv5_feature_module.add_module(str(x),self.features[x])
        for x in range(14,15):
            self.Pool5_feature_module.add_module(str(x),self.features[x])


    def forward(self, x):
        self.conv3_feature=self.Conv3_feature_module(x)
        self.conv4_feature=self.Conv4_feature_module(self.conv3_feature)
        self.conv5_feature=self.Conv5_feature_module(self.conv4_feature)
        pool5_feature=self.Pool5_feature_module(self.conv5_feature)
        self.pool5_feature=pool5_feature
        flattened = pool5_feature.view(pool5_feature.size(0), -1)
        age_logit = self.age_classifier(flattened)
        return age_logit

    def load_pretrained_params(self,path):
        # step1: load pretrained model
        pretrained_dict = torch.load(path)
        # step2: get model state_dict
        model_dict = self.state_dict()
        # step3: remove pretrained_dict params which is not in model_dict
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # step4: update model_dict using pretrained_dict
        model_dict.update(pretrained_dict)
        # step5: update model using model_dict
        self.load_state_dict(model_dict)


class AgeClassify:
    def __init__(self):
        #step 1:define model
        self.model=AgeAlexNet(pretrainded=False).cuda()
        #step 2:define optimizer
        self.optim=torch.optim.Adam(self.model.parameters(),lr=1e-4,betas=(0.5, 0.999))
        #step 3:define loss
        self.criterion=nn.CrossEntropyLoss().cuda()

    def train(self,input,label):
        self.model.train()
        output=self.model(input)
        self.loss=self.criterion(output,label)

    def val(self,input):
        self.model.eval()
        output=F.softmax(self.model(input),dim=1).max(1)[1]
        return output

    def save_model(self,dir,filename):
        torch.save(self.model.state_dict(),os.path.join(dir,filename))

**IPCGANs**

In [7]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import os

class PatchDiscriminator(nn.Module):
    def __init__(self):
        super(PatchDiscriminator, self).__init__()
        self.lrelu = nn.LeakyReLU(0.2)
        self.conv1 = Conv2d(3, 64, kernel_size=4, stride=2)
        self.conv2 = Conv2d(72, 128, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(128, eps=0.001, track_running_stats=True)
        self.conv3 = Conv2d(128, 256, kernel_size=4, stride=2)
        self.bn3 = nn.BatchNorm2d(256, eps=0.001, track_running_stats=True)
        self.conv4 = Conv2d(256, 512, kernel_size=4, stride=2)
        self.bn4 = nn.BatchNorm2d(512, eps=0.001, track_running_stats=True)
        self.conv5 = Conv2d(512, 512, kernel_size=4, stride=2)

    def forward(self, x,condition):
        x = self.lrelu(self.conv1(x))
        x=torch.cat((x,condition),1)
        x = self.lrelu(self.bn2(self.conv2(x)))
        x = self.lrelu(self.bn3(self.conv3(x)))
        x = self.lrelu(self.bn4(self.conv4(x)))
        x = self.conv5(x)
        return x

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = Conv2d(11, 32, kernel_size=7, stride=1)
        self.conv2 = Conv2d(32, 64, kernel_size=3, stride=2)
        self.conv3 = Conv2d(64, 128, kernel_size=3, stride=2)
        self.conv4 = Conv2d(32, 3, kernel_size=7, stride=1)
        self.bn1 = nn.BatchNorm2d(32, eps=0.001, track_running_stats=True)
        self.bn2 = nn.BatchNorm2d(64, eps=0.001, track_running_stats=True)
        self.bn3 = nn.BatchNorm2d(128, eps=0.001, track_running_stats=True)
        self.bn4 = nn.BatchNorm2d(64, eps=0.001, track_running_stats=True)
        self.bn5 = nn.BatchNorm2d(32, eps=0.001, track_running_stats=True)
        self.repeat_blocks=self._make_repeat_blocks(BasicBlock(128,128),6)
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0,output_padding=1)
        self.relu=nn.ReLU()
        self.tanh=nn.Tanh()

    def _make_repeat_blocks(self,block,repeat_times):
        layers=[]
        for i in range(repeat_times):
            layers.append(block)
        return nn.Sequential(*layers)

    def forward(self, x,condition=None):
        if condition is not None:
            x=torch.cat((x,condition),1)

        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.repeat_blocks(x)
        x = self.deconv1(x)
        x = self.relu(self.bn4(x))
        x = self.deconv2(x)
        x = self.relu(self.bn5(x))
        x = self.tanh(self.conv4(x))
        return x

class IPCGANs:
    def __init__(self,lr=0.01,age_classifier_path=None,gan_loss_weight=75,feature_loss_weight=0.5e-4,age_loss_weight=30):

        self.d_lr=lr
        self.g_lr=lr

        self.generator=Generator().cuda()
        self.discriminator=PatchDiscriminator().cuda()
        if age_classifier_path != None:
            self.age_classifier=AgeAlexNet(pretrainded=True,modelpath=age_classifier_path).cuda()
        else:
            self.age_classifier = AgeAlexNet(pretrainded=False).cuda()
        self.MSEloss=nn.MSELoss().cuda()
        self.CrossEntropyLoss=nn.CrossEntropyLoss().cuda()

        self.gan_loss_weight=gan_loss_weight
        self.feature_loss_weight = feature_loss_weight
        self.age_loss_weight=age_loss_weight

        self.d_optim = torch.optim.Adam(self.discriminator.parameters(),self.d_lr,betas=(0.5,0.99))
        self.g_optim = torch.optim.Adam(self.generator.parameters(), self.g_lr, betas=(0.5, 0.99))

    def save_model(self,dir,filename):
        torch.save(self.generator.state_dict(),os.path.join(dir,"g"+filename))
        torch.save(self.discriminator.state_dict(),os.path.join(dir,"d"+filename))

    def load_generator_state_dict(self,state_dict):
        pretrained_dict = state_dict
        # step2: get model state_dict
        model_dict = self.generator.state_dict()
        # step3: remove pretrained_dict params which is not in model_dict
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # step4: update model_dict using pretrained_dict
        model_dict.update(pretrained_dict)
        # step5: update model using model_dict
        self.generator.load_state_dict(model_dict)

    def test_generate(self,source_img_128,condition):
        self.generator.eval()
        with torch.no_grad():
            generate_image=self.generator(source_img_128,condition)
        return generate_image

    def cuda(self):
        self.generator=self.generator.cuda()

    def train(self,source_img_227,source_img_128,true_label_img,true_label_128,true_label_64,fake_label_64, age_label):
        '''

        :param source_img_227: use this img to extract conv5 feature
        :param source_img_128: use this img to generate face in target age
        :param true_label_img:
        :param true_label_128:
        :param true_label_64:
        :param fake_label_64:
        :param age_label:
        :return:
        '''

        ###################################gan_loss###############################
        self.g_source=self.generator(source_img_128,condition=true_label_128)

        #real img, right age label
        #logit means prob which hasn't been normalized

        #d1 logit ,discriminator 1 means true,0 means false.
        d1_logit=self.discriminator(true_label_img,condition=true_label_64)

        d1_real_loss=self.MSEloss(d1_logit,torch.ones((d1_logit.size())).cuda())

        #real img, false label
        d2_logit=self.discriminator(true_label_img,condition=fake_label_64)
        d2_fake_loss=self.MSEloss(d2_logit,torch.zeros((d1_logit.size())).cuda())

        #fake img,real label
        d3_logit=self.discriminator(self.g_source,condition=true_label_64)
        d3_fake_loss=self.MSEloss(d3_logit,torch.zeros((d1_logit.size())).cuda())#use this for discriminator
        d3_real_loss=self.MSEloss(d3_logit,torch.ones((d1_logit.size())).cuda())#use this for genrator

        self.d_loss=(1./2 * (d1_real_loss + 1. / 2 * (d2_fake_loss + d3_fake_loss))) * self.gan_loss_weight
        g_loss=(1./2*d3_real_loss)*self.gan_loss_weight


        ################################feature_loss#############################

        self.age_classifier(source_img_227)
        source_feature=self.age_classifier.conv5_feature

        generate_img_227 = F.interpolate(self.g_source, (227, 227), mode="bilinear", align_corners=True)
        generate_img_227 = Img_to_zero_center()(generate_img_227)

        self.age_classifier(generate_img_227)
        generate_feature =self.age_classifier.conv5_feature
        self.feature_loss=self.MSEloss(source_feature,generate_feature)

        ################################age_cls_loss##############################



        age_logit=self.age_classifier(generate_img_227)
        self.age_loss=self.CrossEntropyLoss(age_logit,age_label)

        self.g_loss=self.age_loss+g_loss+self.feature_loss

**trainning**

In [8]:
# Optimizer
learning_rate = 1e-4
batch_size = 32
max_epoches = 30
val_interval = 1400 #Number of steps to validate
save_interval = 1400 #Number of batches to save model

d_iters = 1 
g_iters = 1 

#model
gan_loss_weight = 75
feature_loss_weight = 0.5e-4
age_loss_weight = 30
age_groups = 5
age_classifier_path = "/kaggle/input/facealexnet_pretrain4emotion/keras/default/1/kaggle/working/model/epoch_9_iter_399.pth"

#data, io
checkpoint = "/kaggle/working/checkpoint/"
saved_model_folder = "/kaggle/working/checkpoint/saved_parameters/"
saved_validation_folder = "/kaggle/working/checkpoint/validation/"

#check_dir
check_dir(checkpoint)
check_dir(saved_model_folder)
check_dir(saved_validation_folder)

def main():
    print("Start to train:\n")

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        Img_to_zero_center()
    ])
    label_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])

    train_dataset = CACD("train",transforms, label_transforms)
    test_dataset = CACD("test", transforms, label_transforms)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    model=IPCGANs(lr=learning_rate,age_classifier_path=age_classifier_path,gan_loss_weight=gan_loss_weight,feature_loss_weight=feature_loss_weight,age_loss_weight=age_loss_weight)
    d_optim=model.d_optim
    g_optim=model.g_optim

    for epoch in range(max_epoches):
        epochp = "epoch: " + str(epoch)
        print(epochp)
        for idx, (source_img_227,source_img_128,true_label_img,true_label_128,true_label_64,fake_label_64, true_label) in enumerate(train_loader,1):

            running_d_loss=None
            running_g_loss=None
            n_iter = epoch * len(train_loader) + idx


            #mv to gpu
            source_img_227=source_img_227.cuda()
            source_img_128=source_img_128.cuda()
            true_label_img=true_label_img.cuda()
            true_label_128=true_label_128.cuda()
            true_label_64=true_label_64.cuda()
            fake_label_64=fake_label_64.cuda()
            true_label=true_label.cuda()

            #train discriminator
            for d_iter in range(d_iters):
                #d_lr_scheduler.step()
                d_optim.zero_grad()
                model.train(
                    source_img_227=source_img_227,
                    source_img_128=source_img_128,
                    true_label_img=true_label_img,
                    true_label_128=true_label_128,
                    true_label_64=true_label_64,
                    fake_label_64=fake_label_64,
                    age_label=true_label
                )
                d_loss=model.d_loss
                running_d_loss=d_loss
                d_loss.backward()
                d_optim.step()

            #train generator
            for g_iter in range(g_iters):
                #g_lr_scheduler.step()
                g_optim.zero_grad()
                model.train(
                    source_img_227=source_img_227,
                    source_img_128=source_img_128,
                    true_label_img=true_label_img,
                    true_label_128=true_label_128,
                    true_label_64=true_label_64,
                    fake_label_64=fake_label_64,
                    age_label=true_label
                )
                g_loss = model.g_loss
                running_g_loss=g_loss
                g_loss.backward()
                g_optim.step()
            if idx % 350 == 0:
                print('step %d/%d, g_loss = %.3f, d_loss = %.3f' %(idx, len(train_loader),running_g_loss,running_d_loss))

            # save the parameters at the end of each save interval
            if idx % save_interval == 0 and epoch > 5:
                model.save_model(dir=saved_model_folder,
                                 filename='epoch_%d_iter_%d.pth'%(epoch, idx))
                print('checkpoint has been created!')

            #val step
            if idx % val_interval == 0 and epoch > 5:
                save_dir = os.path.join(saved_validation_folder, "epoch_%d" % epoch, "idx_%d" % idx)
                check_dir(save_dir)
                for val_idx, (source_img_128, true_label_128) in enumerate(tqdm(test_loader)):
                    save_image(Reverse_zero_center()(source_img_128),os.path.join(save_dir,"batch_%d_source.jpg"%(val_idx)))

                    pic_list = []
                    pic_list.append(source_img_128)
                    for age in range(age_groups):
                        img = model.test_generate(source_img_128, true_label_128[age])
                        save_image(Reverse_zero_center()(img),os.path.join(save_dir,"batch_%d_age_group_%d.jpg"%(val_idx,age)))
                print('validation image has been created!')


main()

Start to train:



  pretrained_dict = torch.load(path)


epoch: 0
step 350/1407, g_loss = 11.927, d_loss = 21.102
step 700/1407, g_loss = 15.136, d_loss = 20.022
step 1050/1407, g_loss = 13.517, d_loss = 15.830
step 1400/1407, g_loss = 9.753, d_loss = 14.713
epoch: 1
step 350/1407, g_loss = 18.512, d_loss = 15.340
step 700/1407, g_loss = 10.170, d_loss = 15.167
step 1050/1407, g_loss = 15.369, d_loss = 17.617
step 1400/1407, g_loss = 12.964, d_loss = 14.253
epoch: 2
step 350/1407, g_loss = 15.737, d_loss = 11.771
step 700/1407, g_loss = 12.794, d_loss = 10.630
step 1050/1407, g_loss = 12.577, d_loss = 11.604
step 1400/1407, g_loss = 15.261, d_loss = 9.269
epoch: 3
step 350/1407, g_loss = 14.145, d_loss = 14.470
step 700/1407, g_loss = 11.182, d_loss = 12.252
step 1050/1407, g_loss = 11.640, d_loss = 11.066
step 1400/1407, g_loss = 18.102, d_loss = 13.842
epoch: 4
step 350/1407, g_loss = 21.631, d_loss = 10.184
step 700/1407, g_loss = 19.375, d_loss = 16.977
step 1050/1407, g_loss = 12.398, d_loss = 12.048
step 1400/1407, g_loss = 15.415, d_l

100%|██████████| 87/87 [00:54<00:00,  1.59it/s]


validation image has been created!
epoch: 7
step 350/1407, g_loss = 30.858, d_loss = 9.376
step 700/1407, g_loss = 16.465, d_loss = 9.584
step 1050/1407, g_loss = 17.304, d_loss = 8.596
step 1400/1407, g_loss = 15.926, d_loss = 9.467
checkpoint has been created!


100%|██████████| 87/87 [00:40<00:00,  2.16it/s]


validation image has been created!
epoch: 8
step 350/1407, g_loss = 15.651, d_loss = 13.310
step 700/1407, g_loss = 19.073, d_loss = 17.230
step 1050/1407, g_loss = 21.279, d_loss = 8.423
step 1400/1407, g_loss = 20.034, d_loss = 7.843
checkpoint has been created!


100%|██████████| 87/87 [00:40<00:00,  2.15it/s]


validation image has been created!
epoch: 9
step 350/1407, g_loss = 31.032, d_loss = 9.744
step 700/1407, g_loss = 21.404, d_loss = 6.829
step 1050/1407, g_loss = 17.472, d_loss = 7.795
step 1400/1407, g_loss = 16.462, d_loss = 8.851
checkpoint has been created!


100%|██████████| 87/87 [00:39<00:00,  2.20it/s]


validation image has been created!
epoch: 10
step 350/1407, g_loss = 20.683, d_loss = 8.764
step 700/1407, g_loss = 20.633, d_loss = 7.593
step 1050/1407, g_loss = 24.099, d_loss = 9.901
step 1400/1407, g_loss = 33.028, d_loss = 9.536
checkpoint has been created!


100%|██████████| 87/87 [00:41<00:00,  2.12it/s]


validation image has been created!
epoch: 11
step 350/1407, g_loss = 25.890, d_loss = 9.720
step 700/1407, g_loss = 14.686, d_loss = 8.403
step 1050/1407, g_loss = 21.705, d_loss = 9.130
step 1400/1407, g_loss = 20.768, d_loss = 8.203
checkpoint has been created!


100%|██████████| 87/87 [00:40<00:00,  2.15it/s]


validation image has been created!
epoch: 12
step 350/1407, g_loss = 22.250, d_loss = 8.826
step 700/1407, g_loss = 12.469, d_loss = 15.089
step 1050/1407, g_loss = 22.218, d_loss = 6.979
step 1400/1407, g_loss = 28.212, d_loss = 6.735
checkpoint has been created!


100%|██████████| 87/87 [00:39<00:00,  2.19it/s]


validation image has been created!
epoch: 13
step 350/1407, g_loss = 21.539, d_loss = 6.586
step 700/1407, g_loss = 16.102, d_loss = 7.263
step 1050/1407, g_loss = 16.392, d_loss = 10.250
step 1400/1407, g_loss = 32.822, d_loss = 5.230
checkpoint has been created!


100%|██████████| 87/87 [00:39<00:00,  2.19it/s]


validation image has been created!
epoch: 14
step 350/1407, g_loss = 31.939, d_loss = 16.523
step 700/1407, g_loss = 22.804, d_loss = 8.114
step 1050/1407, g_loss = 22.154, d_loss = 9.183
step 1400/1407, g_loss = 17.727, d_loss = 7.500
checkpoint has been created!


100%|██████████| 87/87 [00:39<00:00,  2.18it/s]


validation image has been created!
epoch: 15
step 350/1407, g_loss = 23.350, d_loss = 7.073
step 700/1407, g_loss = 21.878, d_loss = 6.460
step 1050/1407, g_loss = 28.662, d_loss = 5.096
step 1400/1407, g_loss = 19.550, d_loss = 7.541
checkpoint has been created!


100%|██████████| 87/87 [00:40<00:00,  2.17it/s]


validation image has been created!
epoch: 16
step 350/1407, g_loss = 50.890, d_loss = 11.416
step 700/1407, g_loss = 39.072, d_loss = 10.063
step 1050/1407, g_loss = 21.419, d_loss = 11.600
step 1400/1407, g_loss = 20.263, d_loss = 7.143
checkpoint has been created!


100%|██████████| 87/87 [00:39<00:00,  2.18it/s]


validation image has been created!
epoch: 17
step 350/1407, g_loss = 20.124, d_loss = 6.830
step 700/1407, g_loss = 21.149, d_loss = 7.344
step 1050/1407, g_loss = 28.719, d_loss = 5.049
step 1400/1407, g_loss = 20.414, d_loss = 8.315
checkpoint has been created!


100%|██████████| 87/87 [00:40<00:00,  2.16it/s]


validation image has been created!
epoch: 18
step 350/1407, g_loss = 28.317, d_loss = 8.610
step 700/1407, g_loss = 32.049, d_loss = 4.434
step 1050/1407, g_loss = 30.435, d_loss = 9.917
step 1400/1407, g_loss = 23.238, d_loss = 6.532
checkpoint has been created!


100%|██████████| 87/87 [00:39<00:00,  2.18it/s]


validation image has been created!
epoch: 19
step 350/1407, g_loss = 29.137, d_loss = 5.061
step 700/1407, g_loss = 24.830, d_loss = 5.546
step 1050/1407, g_loss = 27.038, d_loss = 4.281
step 1400/1407, g_loss = 23.960, d_loss = 6.780
checkpoint has been created!


100%|██████████| 87/87 [00:37<00:00,  2.30it/s]


validation image has been created!
epoch: 20
step 350/1407, g_loss = 12.442, d_loss = 9.546
step 700/1407, g_loss = 24.891, d_loss = 8.440
step 1050/1407, g_loss = 18.206, d_loss = 9.037
step 1400/1407, g_loss = 23.893, d_loss = 7.164
checkpoint has been created!


100%|██████████| 87/87 [00:39<00:00,  2.18it/s]


validation image has been created!
epoch: 21
step 350/1407, g_loss = 32.354, d_loss = 4.370
step 700/1407, g_loss = 34.829, d_loss = 6.161
step 1050/1407, g_loss = 28.194, d_loss = 5.846
step 1400/1407, g_loss = 35.226, d_loss = 7.520
checkpoint has been created!


100%|██████████| 87/87 [00:39<00:00,  2.23it/s]


validation image has been created!
epoch: 22
step 350/1407, g_loss = 29.655, d_loss = 4.624
step 700/1407, g_loss = 33.771, d_loss = 8.674
step 1050/1407, g_loss = 27.979, d_loss = 4.897
step 1400/1407, g_loss = 32.377, d_loss = 12.838
checkpoint has been created!


100%|██████████| 87/87 [00:39<00:00,  2.22it/s]


validation image has been created!
epoch: 23
step 350/1407, g_loss = 20.877, d_loss = 6.266
step 700/1407, g_loss = 30.473, d_loss = 22.114
step 1050/1407, g_loss = 23.724, d_loss = 6.810
step 1400/1407, g_loss = 36.563, d_loss = 3.128
checkpoint has been created!


100%|██████████| 87/87 [00:37<00:00,  2.31it/s]


validation image has been created!
epoch: 24
step 350/1407, g_loss = 21.429, d_loss = 18.655
step 700/1407, g_loss = 25.539, d_loss = 8.960
step 1050/1407, g_loss = 29.870, d_loss = 5.894
step 1400/1407, g_loss = 32.518, d_loss = 4.460
checkpoint has been created!


100%|██████████| 87/87 [00:37<00:00,  2.30it/s]


validation image has been created!
epoch: 25
step 350/1407, g_loss = 23.391, d_loss = 6.122
step 700/1407, g_loss = 18.150, d_loss = 7.880
step 1050/1407, g_loss = 22.412, d_loss = 9.252
step 1400/1407, g_loss = 36.385, d_loss = 3.575
checkpoint has been created!


100%|██████████| 87/87 [00:37<00:00,  2.33it/s]


validation image has been created!
epoch: 26
step 350/1407, g_loss = 25.795, d_loss = 6.416
step 700/1407, g_loss = 32.743, d_loss = 2.920
step 1050/1407, g_loss = 35.279, d_loss = 5.039
step 1400/1407, g_loss = 23.462, d_loss = 7.246
checkpoint has been created!


100%|██████████| 87/87 [00:37<00:00,  2.31it/s]


validation image has been created!
epoch: 27
step 350/1407, g_loss = 36.630, d_loss = 11.488
step 700/1407, g_loss = 34.978, d_loss = 5.205
step 1050/1407, g_loss = 29.551, d_loss = 19.648
step 1400/1407, g_loss = 25.473, d_loss = 10.301
checkpoint has been created!


100%|██████████| 87/87 [00:37<00:00,  2.34it/s]


validation image has been created!
epoch: 28
step 350/1407, g_loss = 35.475, d_loss = 2.765
step 700/1407, g_loss = 22.663, d_loss = 6.516
step 1050/1407, g_loss = 23.561, d_loss = 6.155
step 1400/1407, g_loss = 24.737, d_loss = 8.114
checkpoint has been created!


100%|██████████| 87/87 [00:37<00:00,  2.31it/s]


validation image has been created!
epoch: 29
step 350/1407, g_loss = 38.263, d_loss = 10.591
step 700/1407, g_loss = 39.916, d_loss = 5.043
step 1050/1407, g_loss = 35.397, d_loss = 4.774
step 1400/1407, g_loss = 32.916, d_loss = 4.749
checkpoint has been created!


100%|██████████| 87/87 [00:38<00:00,  2.28it/s]


validation image has been created!


In [9]:
"""
import os
import subprocess
from IPython.display import FileLink, display

def download_file(path, download_file_name):
    os.chdir('/kaggle/working/')
    zip_name = f"/kaggle/working/{download_file_name}.zip"
    command = f"zip {zip_name} {path} -r"
    result = subprocess.run(command, shell=True, capture_output=True, text=True)
    if result.returncode != 0:
        print("Unable to run zip command!")
        print(result.stderr)
        return
    display(FileLink(f'{download_file_name}.zip'))

download_file('/kaggle/working/checkpoint', 'out')
"""

'\nimport os\nimport subprocess\nfrom IPython.display import FileLink, display\n\ndef download_file(path, download_file_name):\n    os.chdir(\'/kaggle/working/\')\n    zip_name = f"/kaggle/working/{download_file_name}.zip"\n    command = f"zip {zip_name} {path} -r"\n    result = subprocess.run(command, shell=True, capture_output=True, text=True)\n    if result.returncode != 0:\n        print("Unable to run zip command!")\n        print(result.stderr)\n        return\n    display(FileLink(f\'{download_file_name}.zip\'))\n\ndownload_file(\'/kaggle/working/checkpoint\', \'out\')\n'