In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [1]:
import os
import torch
from torch.utils.data import DataLoader
# from torchsummaryX import summary
from tqdm.notebook import tqdm

# root = '/content/gdrive/MyDrive/Changing Hair Color in Images/'
# os.sys.path.append(root)

from data import HairDataset
from model import HairColorGAN
from utils import parse_args

In [4]:
# setup all the relevant parameters - can look into utils/arg_parse for available options
params = parse_args([
    '--dataroot', './dataset',
    '--dataset_type', 'train',
    '--K', '20',
    '--L', '500',
    '--lr', '0.0002',
    '--lr_policy', 'linear',
    '--lambda_cyc', '5.0',
    '--lambda_idt', '0.333',
    '--continue_train',
    '--save_dir', './model/checkpoints',
    '--img_pool_size', '30',
    ])

In [5]:
train_dataset = HairDataset(params)
train_data = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=False)

i_max = int(len(train_dataset)/params.batch_size)
print('num of training images: ', len(train_dataset))
print('num of training batches: ', i_max)

num of training images:  12000
num of training batches:  1500


In [8]:
model = HairColorGAN(params, i_max)
# summary(model.gen, torch.rand(params.batch_size,6,256,256)) #(batchsize, c, h, w)
# summary(model.desc, torch.rand(params.batch_size,6,256,256)) #(batchsize, c, h, w)

successfully loaded:  model_gen_iter_1250.pth
successfully loaded:  model_disc_iter_1250.pth
starting from epoch 5, iter 1251


In [None]:
for epoch in range(model.checkpoint, params.n_epochs + params.n_epochs_decay + 1):
    model.refresh(epoch)

    for i in tqdm(range(model.iter, i_max), desc=(('Epoch %s ') % (epoch))):
        batch = next(iter(train_data))
        model.set_inputs(batch)
        model.optimize_parameters()

        # keep adding stats for tracking purposes
        model.update_trackers(i)
        
        # get losses and print them by display frequency
        if i % params.print_iter_interval == 0 or i == i_max - 1:
            stats_dict = model.get_stats()
            print(('iter = %d \t lr = %.4f \t loss_G= %.4f \t loss_D= %.4f') %
                (i, stats_dict['lr'], stats_dict['loss_G'], stats_dict['loss_D'])
            )
        
        # save images (params.img_pool_size) and model using random indices
        if i in model.save_indices:
            model.save_model(i)
            model.save_images(i)
    
    model.save_logs(epoch)
    model.update_learning_rate()
