In [1]:
import torch
from torch.utils.data import DataLoader
from torchsummaryX import summary
from tqdm import tqdm_notebook as tqdm

from data import HairDataset
from model import HairColorGAN
from utils import parse_args

In [2]:
# setup all the relevant parameters - can look into utils/arg_parse for available options
params = parse_args([
    '--dataset_type', 'train',
    '--K', '20',
    '--L', '1000',
    '--lr', '0.0002',
    '--lr_policy', 'linear',
    '--lambda_cyc', '5.0',
    '--lambda_idt', '0.333',
    # '--continue_train',
    '--save_interval', '10',
    ])

In [3]:
train_dataset = HairDataset(params)
train_data = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=False)

i_max = int(len(train_dataset)/params.batch_size)
print('num of training images: ', len(train_dataset))
print('num of training batches: ', i_max)

num of training images:  12000
num of training batches:  1500


In [None]:
model = HairColorGAN(params)

# summary(model.gen, torch.rand(params.batch_size,6,256,256)) #(batchsize, c, h, w)
# summary(model.desc, torch.rand(params.batch_size,6,256,256)) #(batchsize, c, h, w)

In [None]:
for epoch in range(model.checkpoint, params.n_epochs + params.n_epochs_decay + 1):
    model.refresh(epoch, i_max)
    
    for i, data in enumerate(train_data):
        model.set_inputs(data)
        model.optimize_parameters()

        # keep adding stats for tracking purposes
        model.update_trackers(i)
        
        # get losses and print them here by display frequency
        if i % params.print_iter_interval == 0:
            stats_dict = model.get_stats()
            print(('iter = %d \t lr = %.4f \t loss_G= %.4f \t loss_D= %.4f') %
                (i, stats_dict['lr'], stats_dict['loss_G'], stats_dict['loss_D'])
            )
        
        if i in model.save_indices:
            model.save_images(i)
    
    model.save_logs(epoch)
    model.update_learning_rate()
