In [1]:
import torch
import torchvision
import torch.utils.data as data
import os
from os.path import join
from tqdm import tqdm
from torchvision.utils import save_image

**utils**

In [2]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

def check_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

class Img_to_zero_center(object):
    def __int__(self):
        pass
    def __call__(self, t_img):
        '''
        :param img:tensor be 0-1
        :return:
        '''
        t_img=(t_img-0.5)*2
        return t_img

class Reverse_zero_center(object):
    def __init__(self):
        pass
    def __call__(self,t_img):
        t_img=t_img/2+0.5
        return t_img

**DataLoader**

In [3]:
from torchvision.datasets.folder import pil_loader
from random import shuffle

class CACD(data.Dataset):
    def __init__(self,split="train",transforms=None, label_transforms=None):

        self.split=split

        #define label 128*128 for condition generate image
        list_root = "/kaggle/input/utkpreprocess/lists"
        data_root = "/kaggle/input/utkpreprocess/UTKFaceCrop"

        self.condition128=[]
        full_one=np.ones((128,128),dtype=np.float32)
        for i in range(5):
            full_zero=np.zeros((128,128,5),dtype=np.float32)
            full_zero[:,:,i]=full_one
            self.condition128.append(full_zero)

        # define label 64*64 for condition discriminate image
        self.condition64 = []
        full_one = np.ones((64, 64),dtype=np.float32)
        for i in range(5):
            full_zero = np.zeros((64, 64, 5),dtype=np.float32)
            full_zero[:, :, i] = full_one
            self.condition64.append(full_zero)

        #define label_pairs
        label_pair_root="/kaggle/input/utkpreprocess/train_label_pair.txt"
        with open(label_pair_root,'r') as f:
            lines=f.readlines()
        lines=[line.strip() for line in lines]
        shuffle(lines)
        self.label_pairs=[]
        for line in lines:
            label_pair=[]
            items=line.split()
            label_pair.append(int(items[0]))
            label_pair.append(int(items[1]))
            self.label_pairs.append(label_pair)

        #define group_images
        group_lists = [
            os.path.join(list_root, 'train_age_group_0.txt'),
            os.path.join(list_root, 'train_age_group_1.txt'),
            os.path.join(list_root, 'train_age_group_2.txt'),
            os.path.join(list_root, 'train_age_group_3.txt'),
            os.path.join(list_root, 'train_age_group_4.txt')
        ]

        self.label_group_images = []
        for i in range(len(group_lists)):
            with open(group_lists[i], 'r') as f:
                lines = f.readlines()
                lines = [line.strip() for line in lines]
            group_images = []
            for l in lines:
                items = l.split()
                group_images.append(os.path.join(data_root, items[0]))
            self.label_group_images.append(group_images)

        #define train.txt
        if self.split == "train":
            self.source_images = []#which use to aging transfer
            with open(os.path.join(list_root, 'train.txt'), 'r') as f:
                lines = f.readlines()
                lines = [line.strip() for line in lines]
            shuffle(lines)
            for l in lines:
                items = l.split()
                self.source_images.append(os.path.join(data_root, items[0]))
        else:
            self.source_images = []  # which use to aging transfer
            with open(os.path.join(list_root, 'test.txt'), 'r') as f:
                lines = f.readlines()
                lines = [line.strip() for line in lines]
            shuffle(lines)
            for l in lines:
                items = l.split()
                self.source_images.append(os.path.join(data_root, items[0]))

        #define pointer
        self.train_group_pointer=[0,0,0,0,0]
        self.source_pointer=0
        self.batch_size=32
        self.transforms=transforms
        self.label_transforms=label_transforms

    def __getitem__(self, idx):
        if self.split == "train":
            pair_idx=idx//self.batch_size #a batch train the same pair
            true_label=int(self.label_pairs[pair_idx][0])
            fake_label=int(self.label_pairs[pair_idx][1])

            true_label_128=self.condition128[true_label]
            true_label_64=self.condition64[true_label]
            fake_label_64=self.condition64[fake_label]

            true_label_img=pil_loader(self.label_group_images[true_label][self.train_group_pointer[true_label]]).resize((128,128))
            source_img=pil_loader(self.source_images[self.source_pointer])

            source_img_227=source_img.resize((227,227))
            source_img_128=source_img.resize((128,128))

            if self.train_group_pointer[true_label]<len(self.label_group_images[true_label])-1:
                self.train_group_pointer[true_label]+=1
            else:
                self.train_group_pointer[true_label]=0

            if self.source_pointer<len(self.source_images)-1:
                self.source_pointer+=1
            else:
                self.source_pointer=0

            if self.transforms != None:
                true_label_img=self.transforms(true_label_img)
                source_img_227=self.transforms(source_img_227)
                source_img_128=self.transforms(source_img_128)

            if self.label_transforms != None:
                true_label_128=self.label_transforms(true_label_128)
                true_label_64=self.label_transforms(true_label_64)
                fake_label_64=self.label_transforms(fake_label_64)
            #source img 227 : use it to extract face feature
            #source img 128 : use it to generate different age face -> then resize to (227,227) to extract feature, compile with source img 227
            #ture_label_img : img in target age group -> use to train discriminator
            #true_label_128 : use this condition to generate
            #true_label_64 and fake_label_64 : use this condition to discrimination
            #true_label : label

            return source_img_227,source_img_128,true_label_img,true_label_128,true_label_64,fake_label_64, true_label
        else:
            source_img_128=pil_loader(self.source_images[idx]).resize((128,128))
            if self.transforms != None:
                source_img_128=self.transforms(source_img_128)
            condition_128_tensor_li=[]
            if self.label_transforms != None:
                for condition in self.condition128:
                    condition_128_tensor_li.append(self.label_transforms(condition).cuda())
            return source_img_128.cuda(),condition_128_tensor_li

    def __len__(self):
        if self.split == "train":
            return len(self.label_pairs)
        else:
            return len(self.source_images)


transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
label_transforms=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])
CACD_dataset=CACD("train", transforms , label_transforms)
train_loader = torch.utils.data.DataLoader(
    dataset=CACD_dataset,
    batch_size=32,
    shuffle=True
)
for idx,(source_img_227,source_img_128,true_label_img,true_label_128,true_label_64,fake_label_64, true_label) in enumerate(train_loader):
    print(true_label)
    break

tensor([2, 3, 3, 4, 4, 4, 2, 3, 2, 3, 0, 2, 3, 1, 1, 0, 2, 4, 0, 2, 1, 1, 0, 4,
        2, 1, 0, 0, 2, 1, 4, 1])


**other architecture**

In [4]:
import torch.nn as nn

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):

    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        return out

In [5]:
from torch.nn import functional as F
import math
from torch.nn.parameter import Parameter
from torch.nn.functional import pad
from torch.nn.modules import Module
from torch.nn.modules.utils import _single, _pair, _triple

class _ConvNd(Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding, groups, bias):
        super(_ConvNd, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        if transposed:
            self.weight = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)

class Conv2d(_ConvNd):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias)

    def forward(self, input):
        return _conv2d_same_padding(input, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

# custom conv2d, because pytorch don't have "padding='same'" option.
def _conv2d_same_padding(input, weight, bias=None, stride=(1,1), padding=1, dilation=(1,1), groups=1):

    input_rows = input.size(2)
    filter_rows = weight.size(2)
    effective_filter_size_rows = (filter_rows - 1) * dilation[0] + 1
    out_rows = (input_rows + stride[0] - 1) // stride[0]
    padding_needed = max(0, (out_rows - 1) * stride[0] + effective_filter_size_rows -
                  input_rows)
    padding_rows = max(0, (out_rows - 1) * stride[0] +
                        (filter_rows - 1) * dilation[0] + 1 - input_rows)
    rows_odd = (padding_rows % 2 != 0)
    padding_cols = max(0, (out_rows - 1) * stride[0] +
                        (filter_rows - 1) * dilation[0] + 1 - input_rows)
    cols_odd = (padding_rows % 2 != 0)

    if rows_odd or cols_odd:
        input = pad(input, [0, int(cols_odd), 0, int(rows_odd)])

    return F.conv2d(input, weight, bias, stride,
                  padding=(padding_rows // 2, padding_cols // 2),
                  dilation=dilation, groups=groups)


In [6]:
import torch.nn as nn
import torch
import os
import torch.nn.functional as F

class AgeAlexNet(nn.Module):
    def __init__(self,pretrainded=False,modelpath=None):
        super(AgeAlexNet, self).__init__()
        assert pretrainded is False or modelpath is not None,"pretrain model need to be specified"
        self.features = nn.Sequential(
            Conv2d(3, 96, kernel_size=11, stride=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(2,2e-5,0.75),

            Conv2d(96, 256, kernel_size=5, stride=1,groups=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(2, 2e-5, 0.75),

            Conv2d(256, 384, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),

            Conv2d(384, 384, kernel_size=3,stride=1,groups=2),
            nn.ReLU(inplace=True),

            Conv2d(384, 256, kernel_size=3,stride=1,groups=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.age_classifier=nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 5),
        )
        if pretrainded is True:
            self.load_pretrained_params(modelpath)

        self.Conv3_feature_module=nn.Sequential()
        self.Conv4_feature_module=nn.Sequential()
        self.Conv5_feature_module=nn.Sequential()
        self.Pool5_feature_module=nn.Sequential()
        for x in range(10):
            self.Conv3_feature_module.add_module(str(x), self.features[x])
        for x in range(10,12):
            self.Conv4_feature_module.add_module(str(x),self.features[x])
        for x in range(12,14):
            self.Conv5_feature_module.add_module(str(x),self.features[x])
        for x in range(14,15):
            self.Pool5_feature_module.add_module(str(x),self.features[x])


    def forward(self, x):
        self.conv3_feature=self.Conv3_feature_module(x)
        self.conv4_feature=self.Conv4_feature_module(self.conv3_feature)
        self.conv5_feature=self.Conv5_feature_module(self.conv4_feature)
        pool5_feature=self.Pool5_feature_module(self.conv5_feature)
        self.pool5_feature=pool5_feature
        flattened = pool5_feature.view(pool5_feature.size(0), -1)
        age_logit = self.age_classifier(flattened)
        return age_logit

    def load_pretrained_params(self,path):
        # step1: load pretrained model
        pretrained_dict = torch.load(path)
        # step2: get model state_dict
        model_dict = self.state_dict()
        # step3: remove pretrained_dict params which is not in model_dict
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # step4: update model_dict using pretrained_dict
        model_dict.update(pretrained_dict)
        # step5: update model using model_dict
        self.load_state_dict(model_dict)


class AgeClassify:
    def __init__(self):
        #step 1:define model
        self.model=AgeAlexNet(pretrainded=False).cuda()
        #step 2:define optimizer
        self.optim=torch.optim.Adam(self.model.parameters(),lr=1e-4,betas=(0.5, 0.999))
        #step 3:define loss
        self.criterion=nn.CrossEntropyLoss().cuda()

    def train(self,input,label):
        self.model.train()
        output=self.model(input)
        self.loss=self.criterion(output,label)

    def val(self,input):
        self.model.eval()
        output=F.softmax(self.model(input),dim=1).max(1)[1]
        return output

    def save_model(self,dir,filename):
        torch.save(self.model.state_dict(),os.path.join(dir,filename))

**IPCGANs**

In [7]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import os

class PatchDiscriminator(nn.Module):
    def __init__(self):
        super(PatchDiscriminator, self).__init__()
        self.lrelu = nn.LeakyReLU(0.2)
        self.conv1 = Conv2d(3, 64, kernel_size=4, stride=2)
        self.conv2 = Conv2d(69, 128, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(128, eps=0.001, track_running_stats=True)
        self.conv3 = Conv2d(128, 256, kernel_size=4, stride=2)
        self.bn3 = nn.BatchNorm2d(256, eps=0.001, track_running_stats=True)
        self.conv4 = Conv2d(256, 512, kernel_size=4, stride=2)
        self.bn4 = nn.BatchNorm2d(512, eps=0.001, track_running_stats=True)
        self.conv5 = Conv2d(512, 512, kernel_size=4, stride=2)

    def forward(self, x,condition):
        x = self.lrelu(self.conv1(x))
        x=torch.cat((x,condition),1)
        x = self.lrelu(self.bn2(self.conv2(x)))
        x = self.lrelu(self.bn3(self.conv3(x)))
        x = self.lrelu(self.bn4(self.conv4(x)))
        x = self.conv5(x)
        return x

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = Conv2d(8, 32, kernel_size=7, stride=1)
        self.conv2 = Conv2d(32, 64, kernel_size=3, stride=2)
        self.conv3 = Conv2d(64, 128, kernel_size=3, stride=2)
        self.conv4 = Conv2d(32, 3, kernel_size=7, stride=1)
        self.bn1 = nn.BatchNorm2d(32, eps=0.001, track_running_stats=True)
        self.bn2 = nn.BatchNorm2d(64, eps=0.001, track_running_stats=True)
        self.bn3 = nn.BatchNorm2d(128, eps=0.001, track_running_stats=True)
        self.bn4 = nn.BatchNorm2d(64, eps=0.001, track_running_stats=True)
        self.bn5 = nn.BatchNorm2d(32, eps=0.001, track_running_stats=True)
        self.repeat_blocks=self._make_repeat_blocks(BasicBlock(128,128),6)
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0,output_padding=1)
        self.relu=nn.ReLU()
        self.tanh=nn.Tanh()

    def _make_repeat_blocks(self,block,repeat_times):
        layers=[]
        for i in range(repeat_times):
            layers.append(block)
        return nn.Sequential(*layers)

    def forward(self, x,condition=None):
        if condition is not None:
            x=torch.cat((x,condition),1)

        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.repeat_blocks(x)
        x = self.deconv1(x)
        x = self.relu(self.bn4(x))
        x = self.deconv2(x)
        x = self.relu(self.bn5(x))
        x = self.tanh(self.conv4(x))
        return x

class IPCGANs:
    def __init__(self,lr=0.01,age_classifier_path=None,gan_loss_weight=75,feature_loss_weight=0.5e-4,age_loss_weight=30):

        self.d_lr=lr
        self.g_lr=lr

        self.generator=Generator().cuda()
        self.discriminator=PatchDiscriminator().cuda()
        if age_classifier_path != None:
            self.age_classifier=AgeAlexNet(pretrainded=True,modelpath=age_classifier_path).cuda()
        else:
            self.age_classifier = AgeAlexNet(pretrainded=False).cuda()
        self.MSEloss=nn.MSELoss().cuda()
        self.CrossEntropyLoss=nn.CrossEntropyLoss().cuda()

        self.gan_loss_weight=gan_loss_weight
        self.feature_loss_weight = feature_loss_weight
        self.age_loss_weight=age_loss_weight

        self.d_optim = torch.optim.Adam(self.discriminator.parameters(),self.d_lr,betas=(0.5,0.99))
        self.g_optim = torch.optim.Adam(self.generator.parameters(), self.g_lr, betas=(0.5, 0.99))

    def save_model(self,dir,filename):
        torch.save(self.generator.state_dict(),os.path.join(dir,"g"+filename))
        torch.save(self.discriminator.state_dict(),os.path.join(dir,"d"+filename))

    def load_generator_state_dict(self,state_dict):
        pretrained_dict = state_dict
        # step2: get model state_dict
        model_dict = self.generator.state_dict()
        # step3: remove pretrained_dict params which is not in model_dict
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # step4: update model_dict using pretrained_dict
        model_dict.update(pretrained_dict)
        # step5: update model using model_dict
        self.generator.load_state_dict(model_dict)

    def test_generate(self,source_img_128,condition):
        self.generator.eval()
        with torch.no_grad():
            generate_image=self.generator(source_img_128,condition)
        return generate_image

    def cuda(self):
        self.generator=self.generator.cuda()

    def train(self,source_img_227,source_img_128,true_label_img,true_label_128,true_label_64,fake_label_64, age_label):
        '''

        :param source_img_227: use this img to extract conv5 feature
        :param source_img_128: use this img to generate face in target age
        :param true_label_img:
        :param true_label_128:
        :param true_label_64:
        :param fake_label_64:
        :param age_label:
        :return:
        '''

        ###################################gan_loss###############################
        self.g_source=self.generator(source_img_128,condition=true_label_128)

        #real img, right age label
        #logit means prob which hasn't been normalized

        #d1 logit ,discriminator 1 means true,0 means false.
        d1_logit=self.discriminator(true_label_img,condition=true_label_64)

        d1_real_loss=self.MSEloss(d1_logit,torch.ones((d1_logit.size())).cuda())

        #real img, false label
        d2_logit=self.discriminator(true_label_img,condition=fake_label_64)
        d2_fake_loss=self.MSEloss(d2_logit,torch.zeros((d1_logit.size())).cuda())

        #fake img,real label
        d3_logit=self.discriminator(self.g_source,condition=true_label_64)
        d3_fake_loss=self.MSEloss(d3_logit,torch.zeros((d1_logit.size())).cuda())#use this for discriminator
        d3_real_loss=self.MSEloss(d3_logit,torch.ones((d1_logit.size())).cuda())#use this for genrator

        self.d_loss=(1./2 * (d1_real_loss + 1. / 2 * (d2_fake_loss + d3_fake_loss))) * self.gan_loss_weight
        g_loss=(1./2*d3_real_loss)*self.gan_loss_weight


        ################################feature_loss#############################

        self.age_classifier(source_img_227)
        source_feature=self.age_classifier.conv5_feature

        generate_img_227 = F.interpolate(self.g_source, (227, 227), mode="bilinear", align_corners=True)
        generate_img_227 = Img_to_zero_center()(generate_img_227)

        self.age_classifier(generate_img_227)
        generate_feature =self.age_classifier.conv5_feature
        self.feature_loss=self.MSEloss(source_feature,generate_feature)

        ################################age_cls_loss##############################



        age_logit=self.age_classifier(generate_img_227)
        self.age_loss=self.CrossEntropyLoss(age_logit,age_label)

        self.g_loss=self.age_loss+g_loss+self.feature_loss

**trainning**

In [8]:
# Optimizer
learning_rate = 1e-4
batch_size = 32
max_epoches = 20
val_interval = 1400 #Number of steps to validate
save_interval = 1400 #Number of batches to save model

d_iters = 1 
g_iters = 1 

#model
gan_loss_weight = 75
feature_loss_weight = 0.5e-4
age_loss_weight = 30
age_groups = 5
age_classifier_path = "/kaggle/input/facealexnet_pretrain/keras/default/1/kaggle/working/model/epoch_0_iter_0.pth"

#data, io
checkpoint = "/kaggle/working/checkpoint/"
saved_model_folder = "/kaggle/working/checkpoint/saved_parameters/"
saved_validation_folder = "/kaggle/working/checkpoint/validation/"

#check_dir
check_dir(checkpoint)
check_dir(saved_model_folder)
check_dir(saved_validation_folder)

def main():
    print("Start to train:\n")

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        Img_to_zero_center()
    ])
    label_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])

    train_dataset = CACD("train",transforms, label_transforms)
    test_dataset = CACD("test", transforms, label_transforms)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    model=IPCGANs(lr=learning_rate,age_classifier_path=age_classifier_path,gan_loss_weight=gan_loss_weight,feature_loss_weight=feature_loss_weight,age_loss_weight=age_loss_weight)
    d_optim=model.d_optim
    g_optim=model.g_optim

    for epoch in range(max_epoches):
        epochp = "epoch: " + str(epoch)
        print(epochp)
        for idx, (source_img_227,source_img_128,true_label_img,true_label_128,true_label_64,fake_label_64, true_label) in enumerate(train_loader,1):

            running_d_loss=None
            running_g_loss=None
            n_iter = epoch * len(train_loader) + idx


            #mv to gpu
            source_img_227=source_img_227.cuda()
            source_img_128=source_img_128.cuda()
            true_label_img=true_label_img.cuda()
            true_label_128=true_label_128.cuda()
            true_label_64=true_label_64.cuda()
            fake_label_64=fake_label_64.cuda()
            true_label=true_label.cuda()

            #train discriminator
            for d_iter in range(d_iters):
                #d_lr_scheduler.step()
                d_optim.zero_grad()
                model.train(
                    source_img_227=source_img_227,
                    source_img_128=source_img_128,
                    true_label_img=true_label_img,
                    true_label_128=true_label_128,
                    true_label_64=true_label_64,
                    fake_label_64=fake_label_64,
                    age_label=true_label
                )
                d_loss=model.d_loss
                running_d_loss=d_loss
                d_loss.backward()
                d_optim.step()

            #train generator
            for g_iter in range(g_iters):
                #g_lr_scheduler.step()
                g_optim.zero_grad()
                model.train(
                    source_img_227=source_img_227,
                    source_img_128=source_img_128,
                    true_label_img=true_label_img,
                    true_label_128=true_label_128,
                    true_label_64=true_label_64,
                    fake_label_64=fake_label_64,
                    age_label=true_label
                )
                g_loss = model.g_loss
                running_g_loss=g_loss
                g_loss.backward()
                g_optim.step()
            if idx % 100 == 0:
                print('step %d/%d, g_loss = %.3f, d_loss = %.3f' %(idx, len(train_loader),running_g_loss,running_d_loss))

            # save the parameters at the end of each save interval
            if idx % save_interval == 0:
                model.save_model(dir=saved_model_folder,
                                 filename='epoch_%d_iter_%d.pth'%(epoch, idx))
                print('checkpoint has been created!')

            #val step
            if idx % val_interval == 0:
                save_dir = os.path.join(saved_validation_folder, "epoch_%d" % epoch, "idx_%d" % idx)
                check_dir(save_dir)
                for val_idx, (source_img_128, true_label_128) in enumerate(tqdm(test_loader)):
                    save_image(Reverse_zero_center()(source_img_128),os.path.join(save_dir,"batch_%d_source.jpg"%(val_idx)))

                    pic_list = []
                    pic_list.append(source_img_128)
                    for age in range(age_groups):
                        img = model.test_generate(source_img_128, true_label_128[age])
                        save_image(Reverse_zero_center()(img),os.path.join(save_dir,"batch_%d_age_group_%d.jpg"%(val_idx,age)))
                print('validation image has been created!')


main()

Start to train:



  pretrained_dict = torch.load(path)


epoch: 0
step 100/1407, g_loss = 15.421, d_loss = 17.011
step 200/1407, g_loss = 14.279, d_loss = 18.689
step 300/1407, g_loss = 11.595, d_loss = 17.283
step 400/1407, g_loss = 10.350, d_loss = 17.699
step 500/1407, g_loss = 12.125, d_loss = 17.330
step 600/1407, g_loss = 13.319, d_loss = 18.475
step 700/1407, g_loss = 13.852, d_loss = 19.154
step 800/1407, g_loss = 11.492, d_loss = 17.120
step 900/1407, g_loss = 9.970, d_loss = 19.055
step 1000/1407, g_loss = 11.952, d_loss = 18.694
step 1100/1407, g_loss = 10.165, d_loss = 15.307
step 1200/1407, g_loss = 9.335, d_loss = 19.308
step 1300/1407, g_loss = 8.432, d_loss = 15.095
step 1400/1407, g_loss = 10.907, d_loss = 19.557
checkpoint has been created!


100%|██████████| 72/72 [00:35<00:00,  2.01it/s]


validation image has been created!
epoch: 1
step 100/1407, g_loss = 10.390, d_loss = 19.706
step 200/1407, g_loss = 9.776, d_loss = 16.697
step 300/1407, g_loss = 11.405, d_loss = 17.330
step 400/1407, g_loss = 15.206, d_loss = 17.407
step 500/1407, g_loss = 13.828, d_loss = 19.402
step 600/1407, g_loss = 11.974, d_loss = 16.681
step 700/1407, g_loss = 12.047, d_loss = 16.756
step 800/1407, g_loss = 12.411, d_loss = 16.735
step 900/1407, g_loss = 10.715, d_loss = 16.128
step 1000/1407, g_loss = 12.108, d_loss = 16.160
step 1100/1407, g_loss = 8.998, d_loss = 22.302
step 1200/1407, g_loss = 11.538, d_loss = 15.147
step 1300/1407, g_loss = 11.761, d_loss = 20.950
step 1400/1407, g_loss = 10.127, d_loss = 18.323
checkpoint has been created!


100%|██████████| 72/72 [00:26<00:00,  2.74it/s]


validation image has been created!
epoch: 2
step 100/1407, g_loss = 10.071, d_loss = 17.254
step 200/1407, g_loss = 14.330, d_loss = 16.662
step 300/1407, g_loss = 9.416, d_loss = 14.982
step 400/1407, g_loss = 13.302, d_loss = 15.000
step 500/1407, g_loss = 10.519, d_loss = 17.946
step 600/1407, g_loss = 11.946, d_loss = 16.065
step 700/1407, g_loss = 9.310, d_loss = 15.696
step 800/1407, g_loss = 12.223, d_loss = 13.844
step 900/1407, g_loss = 10.138, d_loss = 14.261
step 1000/1407, g_loss = 11.499, d_loss = 18.078
step 1100/1407, g_loss = 10.862, d_loss = 21.132
step 1200/1407, g_loss = 10.375, d_loss = 15.579
step 1300/1407, g_loss = 8.229, d_loss = 17.082
step 1400/1407, g_loss = 12.129, d_loss = 15.323
checkpoint has been created!


100%|██████████| 72/72 [00:26<00:00,  2.76it/s]


validation image has been created!
epoch: 3
step 100/1407, g_loss = 11.724, d_loss = 14.685
step 200/1407, g_loss = 8.462, d_loss = 12.257
step 300/1407, g_loss = 10.803, d_loss = 12.432
step 400/1407, g_loss = 7.801, d_loss = 14.155
step 500/1407, g_loss = 15.013, d_loss = 17.820
step 600/1407, g_loss = 10.624, d_loss = 16.729
step 700/1407, g_loss = 8.575, d_loss = 14.014
step 800/1407, g_loss = 10.545, d_loss = 12.875
step 900/1407, g_loss = 8.926, d_loss = 22.399
step 1000/1407, g_loss = 10.346, d_loss = 13.980
step 1100/1407, g_loss = 12.085, d_loss = 20.405
step 1200/1407, g_loss = 14.759, d_loss = 29.185
step 1300/1407, g_loss = 12.151, d_loss = 14.574
step 1400/1407, g_loss = 15.142, d_loss = 18.678
checkpoint has been created!


100%|██████████| 72/72 [00:26<00:00,  2.70it/s]


validation image has been created!
epoch: 4
step 100/1407, g_loss = 7.883, d_loss = 16.441
step 200/1407, g_loss = 11.538, d_loss = 16.298
step 300/1407, g_loss = 9.985, d_loss = 15.583
step 400/1407, g_loss = 9.935, d_loss = 13.981
step 500/1407, g_loss = 11.104, d_loss = 16.165
step 600/1407, g_loss = 7.841, d_loss = 10.486
step 700/1407, g_loss = 13.071, d_loss = 16.000
step 800/1407, g_loss = 10.826, d_loss = 14.768
step 900/1407, g_loss = 11.364, d_loss = 13.721
step 1000/1407, g_loss = 12.091, d_loss = 16.931
step 1100/1407, g_loss = 10.811, d_loss = 14.692
step 1200/1407, g_loss = 13.772, d_loss = 14.231
step 1300/1407, g_loss = 14.663, d_loss = 18.596
step 1400/1407, g_loss = 9.475, d_loss = 15.838
checkpoint has been created!


100%|██████████| 72/72 [00:26<00:00,  2.71it/s]


validation image has been created!
epoch: 5
step 100/1407, g_loss = 13.721, d_loss = 13.638
step 200/1407, g_loss = 12.799, d_loss = 12.385
step 300/1407, g_loss = 13.750, d_loss = 15.095
step 400/1407, g_loss = 15.216, d_loss = 18.729
step 500/1407, g_loss = 12.681, d_loss = 11.275
step 600/1407, g_loss = 14.965, d_loss = 13.888
step 700/1407, g_loss = 9.097, d_loss = 13.773
step 800/1407, g_loss = 11.816, d_loss = 14.379
step 900/1407, g_loss = 9.919, d_loss = 14.130
step 1000/1407, g_loss = 15.447, d_loss = 19.721
step 1100/1407, g_loss = 10.080, d_loss = 11.358
step 1200/1407, g_loss = 8.876, d_loss = 17.289
step 1300/1407, g_loss = 11.732, d_loss = 18.914
step 1400/1407, g_loss = 11.693, d_loss = 17.396
checkpoint has been created!


100%|██████████| 72/72 [00:26<00:00,  2.70it/s]


validation image has been created!
epoch: 6
step 100/1407, g_loss = 14.093, d_loss = 15.363
step 200/1407, g_loss = 10.390, d_loss = 15.661
step 300/1407, g_loss = 15.765, d_loss = 17.471
step 400/1407, g_loss = 12.433, d_loss = 12.070
step 500/1407, g_loss = 7.549, d_loss = 13.009
step 600/1407, g_loss = 8.039, d_loss = 13.151
step 700/1407, g_loss = 7.859, d_loss = 16.933
step 800/1407, g_loss = 12.710, d_loss = 16.320
step 900/1407, g_loss = 13.543, d_loss = 19.145
step 1000/1407, g_loss = 11.374, d_loss = 11.409
step 1100/1407, g_loss = 7.069, d_loss = 23.243
step 1200/1407, g_loss = 15.961, d_loss = 13.574
step 1300/1407, g_loss = 11.811, d_loss = 12.102
step 1400/1407, g_loss = 8.375, d_loss = 16.803
checkpoint has been created!


100%|██████████| 72/72 [00:27<00:00,  2.60it/s]


validation image has been created!
epoch: 7
step 100/1407, g_loss = 13.301, d_loss = 15.706
step 200/1407, g_loss = 14.482, d_loss = 15.168
step 300/1407, g_loss = 14.265, d_loss = 13.076
step 400/1407, g_loss = 11.725, d_loss = 12.026
step 500/1407, g_loss = 8.790, d_loss = 13.610
step 600/1407, g_loss = 16.603, d_loss = 15.332
step 700/1407, g_loss = 13.792, d_loss = 14.157
step 800/1407, g_loss = 5.883, d_loss = 25.882
step 900/1407, g_loss = 11.744, d_loss = 11.282
step 1000/1407, g_loss = 11.329, d_loss = 10.739
step 1100/1407, g_loss = 14.727, d_loss = 11.576
step 1200/1407, g_loss = 14.892, d_loss = 11.425
step 1300/1407, g_loss = 11.011, d_loss = 12.703
step 1400/1407, g_loss = 5.728, d_loss = 21.659
checkpoint has been created!


100%|██████████| 72/72 [00:27<00:00,  2.59it/s]


validation image has been created!
epoch: 8
step 100/1407, g_loss = 14.785, d_loss = 13.615
step 200/1407, g_loss = 11.008, d_loss = 11.293
step 300/1407, g_loss = 10.082, d_loss = 15.208
step 400/1407, g_loss = 11.978, d_loss = 12.547
step 500/1407, g_loss = 8.976, d_loss = 12.705
step 600/1407, g_loss = 14.613, d_loss = 13.245
step 700/1407, g_loss = 15.796, d_loss = 15.051
step 800/1407, g_loss = 8.914, d_loss = 11.290
step 900/1407, g_loss = 10.376, d_loss = 11.181
step 1000/1407, g_loss = 10.615, d_loss = 13.474
step 1100/1407, g_loss = 16.767, d_loss = 13.077
step 1200/1407, g_loss = 5.890, d_loss = 17.375
step 1300/1407, g_loss = 7.745, d_loss = 14.630
step 1400/1407, g_loss = 10.687, d_loss = 11.617
checkpoint has been created!


100%|██████████| 72/72 [00:27<00:00,  2.63it/s]


validation image has been created!
epoch: 9
step 100/1407, g_loss = 10.800, d_loss = 13.838
step 200/1407, g_loss = 14.051, d_loss = 15.924
step 300/1407, g_loss = 13.830, d_loss = 14.640
step 400/1407, g_loss = 16.468, d_loss = 14.588
step 500/1407, g_loss = 16.385, d_loss = 13.727
step 600/1407, g_loss = 13.294, d_loss = 18.374
step 700/1407, g_loss = 14.214, d_loss = 11.101
step 800/1407, g_loss = 9.393, d_loss = 11.533
step 900/1407, g_loss = 7.410, d_loss = 12.245
step 1000/1407, g_loss = 13.783, d_loss = 13.943
step 1100/1407, g_loss = 9.677, d_loss = 16.307
step 1200/1407, g_loss = 9.976, d_loss = 14.838
step 1300/1407, g_loss = 10.422, d_loss = 12.421
step 1400/1407, g_loss = 17.135, d_loss = 11.203
checkpoint has been created!


100%|██████████| 72/72 [00:26<00:00,  2.67it/s]


validation image has been created!
epoch: 10
step 100/1407, g_loss = 18.651, d_loss = 15.566
step 200/1407, g_loss = 12.014, d_loss = 10.228
step 300/1407, g_loss = 15.189, d_loss = 12.446
step 400/1407, g_loss = 12.313, d_loss = 14.761
step 500/1407, g_loss = 11.845, d_loss = 14.894
step 600/1407, g_loss = 15.921, d_loss = 11.638
step 700/1407, g_loss = 12.955, d_loss = 8.982
step 800/1407, g_loss = 19.922, d_loss = 10.194
step 900/1407, g_loss = 11.927, d_loss = 12.476
step 1000/1407, g_loss = 12.748, d_loss = 13.232
step 1100/1407, g_loss = 13.862, d_loss = 13.556
step 1200/1407, g_loss = 14.069, d_loss = 14.311
step 1300/1407, g_loss = 18.673, d_loss = 12.472
step 1400/1407, g_loss = 20.301, d_loss = 14.745
checkpoint has been created!


100%|██████████| 72/72 [00:28<00:00,  2.55it/s]


validation image has been created!
epoch: 11
step 100/1407, g_loss = 9.854, d_loss = 11.603
step 200/1407, g_loss = 12.200, d_loss = 12.400
step 300/1407, g_loss = 13.317, d_loss = 9.719
step 400/1407, g_loss = 16.241, d_loss = 14.188
step 500/1407, g_loss = 10.091, d_loss = 10.560
step 600/1407, g_loss = 12.979, d_loss = 9.646
step 700/1407, g_loss = 17.510, d_loss = 11.342
step 800/1407, g_loss = 10.583, d_loss = 11.394
step 900/1407, g_loss = 14.952, d_loss = 13.950
step 1000/1407, g_loss = 10.664, d_loss = 14.396
step 1100/1407, g_loss = 18.616, d_loss = 14.248
step 1200/1407, g_loss = 13.320, d_loss = 10.298
step 1300/1407, g_loss = 17.246, d_loss = 21.109
step 1400/1407, g_loss = 11.837, d_loss = 10.989
checkpoint has been created!


100%|██████████| 72/72 [00:25<00:00,  2.81it/s]


validation image has been created!
epoch: 12
step 100/1407, g_loss = 11.356, d_loss = 12.802
step 200/1407, g_loss = 12.898, d_loss = 13.805
step 300/1407, g_loss = 25.796, d_loss = 13.806
step 400/1407, g_loss = 16.726, d_loss = 11.077
step 500/1407, g_loss = 10.557, d_loss = 9.770
step 600/1407, g_loss = 7.455, d_loss = 12.161
step 700/1407, g_loss = 11.011, d_loss = 12.759
step 800/1407, g_loss = 6.389, d_loss = 16.152
step 900/1407, g_loss = 13.932, d_loss = 11.944
step 1000/1407, g_loss = 15.042, d_loss = 14.169
step 1100/1407, g_loss = 18.317, d_loss = 11.427
step 1200/1407, g_loss = 14.333, d_loss = 18.604
step 1300/1407, g_loss = 12.221, d_loss = 12.671
step 1400/1407, g_loss = 11.865, d_loss = 10.735
checkpoint has been created!


100%|██████████| 72/72 [00:25<00:00,  2.78it/s]


validation image has been created!
epoch: 13
step 100/1407, g_loss = 9.626, d_loss = 14.296
step 200/1407, g_loss = 17.268, d_loss = 14.252
step 300/1407, g_loss = 14.735, d_loss = 12.151
step 400/1407, g_loss = 17.709, d_loss = 12.634
step 500/1407, g_loss = 9.694, d_loss = 17.175
step 600/1407, g_loss = 14.797, d_loss = 12.527
step 700/1407, g_loss = 22.227, d_loss = 15.377
step 800/1407, g_loss = 19.208, d_loss = 8.168
step 900/1407, g_loss = 18.063, d_loss = 14.652
step 1000/1407, g_loss = 18.601, d_loss = 9.514
step 1100/1407, g_loss = 11.116, d_loss = 16.445
step 1200/1407, g_loss = 15.116, d_loss = 10.577
step 1300/1407, g_loss = 17.142, d_loss = 11.535
step 1400/1407, g_loss = 9.662, d_loss = 13.423
checkpoint has been created!


100%|██████████| 72/72 [00:26<00:00,  2.69it/s]


validation image has been created!
epoch: 14
step 100/1407, g_loss = 19.155, d_loss = 12.916
step 200/1407, g_loss = 12.626, d_loss = 13.052
step 300/1407, g_loss = 15.633, d_loss = 10.091
step 400/1407, g_loss = 11.730, d_loss = 12.658
step 500/1407, g_loss = 13.404, d_loss = 10.015
step 600/1407, g_loss = 14.560, d_loss = 10.300
step 700/1407, g_loss = 22.460, d_loss = 12.571
step 800/1407, g_loss = 14.118, d_loss = 13.522
step 900/1407, g_loss = 15.662, d_loss = 10.904
step 1000/1407, g_loss = 14.244, d_loss = 8.760
step 1100/1407, g_loss = 14.673, d_loss = 11.276
step 1200/1407, g_loss = 12.331, d_loss = 11.053
step 1300/1407, g_loss = 11.457, d_loss = 11.443
step 1400/1407, g_loss = 10.696, d_loss = 15.243
checkpoint has been created!


100%|██████████| 72/72 [00:25<00:00,  2.80it/s]


validation image has been created!
epoch: 15
step 100/1407, g_loss = 13.699, d_loss = 11.471
step 200/1407, g_loss = 13.632, d_loss = 10.149
step 300/1407, g_loss = 17.049, d_loss = 10.358
step 400/1407, g_loss = 12.892, d_loss = 11.377
step 500/1407, g_loss = 15.933, d_loss = 11.183
step 600/1407, g_loss = 12.797, d_loss = 10.678
step 700/1407, g_loss = 22.279, d_loss = 12.988
step 800/1407, g_loss = 11.852, d_loss = 9.356
step 900/1407, g_loss = 17.937, d_loss = 9.067
step 1000/1407, g_loss = 20.069, d_loss = 8.432
step 1100/1407, g_loss = 13.065, d_loss = 10.570
step 1200/1407, g_loss = 16.932, d_loss = 12.287
step 1300/1407, g_loss = 18.501, d_loss = 15.588
step 1400/1407, g_loss = 10.067, d_loss = 13.109
checkpoint has been created!


100%|██████████| 72/72 [00:24<00:00,  2.99it/s]


validation image has been created!
epoch: 16
step 100/1407, g_loss = 18.828, d_loss = 10.495
step 200/1407, g_loss = 19.532, d_loss = 7.987
step 300/1407, g_loss = 25.752, d_loss = 11.667
step 400/1407, g_loss = 17.815, d_loss = 9.741
step 500/1407, g_loss = 15.573, d_loss = 11.338
step 600/1407, g_loss = 21.497, d_loss = 12.655
step 700/1407, g_loss = 14.340, d_loss = 9.241
step 800/1407, g_loss = 15.521, d_loss = 10.298
step 900/1407, g_loss = 16.479, d_loss = 9.631
step 1000/1407, g_loss = 13.977, d_loss = 10.202
step 1100/1407, g_loss = 15.330, d_loss = 11.972
step 1200/1407, g_loss = 14.168, d_loss = 10.292
step 1300/1407, g_loss = 13.233, d_loss = 9.563
step 1400/1407, g_loss = 14.872, d_loss = 8.020
checkpoint has been created!


100%|██████████| 72/72 [00:24<00:00,  2.90it/s]


validation image has been created!
epoch: 17
step 100/1407, g_loss = 13.606, d_loss = 14.405
step 200/1407, g_loss = 18.735, d_loss = 9.554
step 300/1407, g_loss = 19.719, d_loss = 10.366
step 400/1407, g_loss = 18.955, d_loss = 12.066
step 500/1407, g_loss = 21.387, d_loss = 11.524
step 600/1407, g_loss = 16.503, d_loss = 7.802
step 700/1407, g_loss = 14.420, d_loss = 8.248
step 800/1407, g_loss = 18.225, d_loss = 11.100
step 900/1407, g_loss = 21.672, d_loss = 8.482
step 1000/1407, g_loss = 18.864, d_loss = 9.539
step 1100/1407, g_loss = 17.626, d_loss = 9.248
step 1200/1407, g_loss = 19.339, d_loss = 11.457
step 1300/1407, g_loss = 17.032, d_loss = 9.080
step 1400/1407, g_loss = 12.139, d_loss = 12.199
checkpoint has been created!


100%|██████████| 72/72 [00:24<00:00,  2.89it/s]


validation image has been created!
epoch: 18
step 100/1407, g_loss = 13.376, d_loss = 9.252
step 200/1407, g_loss = 19.683, d_loss = 6.986
step 300/1407, g_loss = 28.176, d_loss = 10.861
step 400/1407, g_loss = 20.609, d_loss = 10.694
step 500/1407, g_loss = 15.453, d_loss = 9.269
step 600/1407, g_loss = 18.445, d_loss = 9.373
step 700/1407, g_loss = 16.729, d_loss = 8.823
step 800/1407, g_loss = 18.143, d_loss = 10.748
step 900/1407, g_loss = 17.009, d_loss = 9.062
step 1000/1407, g_loss = 25.223, d_loss = 10.530
step 1100/1407, g_loss = 13.965, d_loss = 12.040
step 1200/1407, g_loss = 23.851, d_loss = 9.269
step 1300/1407, g_loss = 13.036, d_loss = 14.161
step 1400/1407, g_loss = 17.594, d_loss = 8.579
checkpoint has been created!


100%|██████████| 72/72 [00:24<00:00,  2.98it/s]


validation image has been created!
epoch: 19
step 100/1407, g_loss = 19.235, d_loss = 8.897
step 200/1407, g_loss = 21.278, d_loss = 8.269
step 300/1407, g_loss = 14.012, d_loss = 9.887
step 400/1407, g_loss = 10.904, d_loss = 15.430
step 500/1407, g_loss = 21.048, d_loss = 7.989
step 600/1407, g_loss = 14.162, d_loss = 9.386
step 700/1407, g_loss = 14.878, d_loss = 11.111
step 800/1407, g_loss = 22.417, d_loss = 8.229
step 900/1407, g_loss = 17.444, d_loss = 9.688
step 1000/1407, g_loss = 21.553, d_loss = 9.354
step 1100/1407, g_loss = 23.137, d_loss = 6.763
step 1200/1407, g_loss = 22.637, d_loss = 12.971
step 1300/1407, g_loss = 25.696, d_loss = 12.855
step 1400/1407, g_loss = 17.709, d_loss = 9.208
checkpoint has been created!


100%|██████████| 72/72 [00:24<00:00,  2.95it/s]


validation image has been created!


In [9]:
"""
import os
import subprocess
from IPython.display import FileLink, display

def download_file(path, download_file_name):
    os.chdir('/kaggle/working/')
    zip_name = f"/kaggle/working/{download_file_name}.zip"
    command = f"zip {zip_name} {path} -r"
    result = subprocess.run(command, shell=True, capture_output=True, text=True)
    if result.returncode != 0:
        print("Unable to run zip command!")
        print(result.stderr)
        return
    display(FileLink(f'{download_file_name}.zip'))

download_file('/kaggle/working/checkpoint', 'out')
"""

'\nimport os\nimport subprocess\nfrom IPython.display import FileLink, display\n\ndef download_file(path, download_file_name):\n    os.chdir(\'/kaggle/working/\')\n    zip_name = f"/kaggle/working/{download_file_name}.zip"\n    command = f"zip {zip_name} {path} -r"\n    result = subprocess.run(command, shell=True, capture_output=True, text=True)\n    if result.returncode != 0:\n        print("Unable to run zip command!")\n        print(result.stderr)\n        return\n    display(FileLink(f\'{download_file_name}.zip\'))\n\ndownload_file(\'/kaggle/working/checkpoint\', \'out\')\n'