In [2]:
import argparse
import numpy as np
import os
import torch
from tensorboardX import SummaryWriter
from torch.utils import data
from torchvision import transforms
from tqdm import tqdm

import opt
from evaluation import evaluate
from loss import InpaintingLoss
from net import PConvUNet
from net import VGG16FeatureExtractor
from my_dataset import MyDataset
from util.io import load_ckpt
from util.io import save_ckpt


In [2]:
torch.cuda.is_available()

True

In [4]:
class InfiniteSampler(data.sampler.Sampler):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __iter__(self):
        return iter(self.loop())

    def __len__(self):
        return 2 ** 31

    def loop(self):
        i = 0
        order = np.random.permutation(self.num_samples)
        while True:
            yield order[i]
            i += 1
            if i >= self.num_samples:
                np.random.seed()
                order = np.random.permutation(self.num_samples)
                i = 0



# training options
root='./my_train'
mask_root='./masks'
save_dir='./snapshots/default/'
log_dir='./logs/default'
lr=2e-4
lr_finetune=5e-5
max_iter=2500
batch_size=2
n_threads=4
save_model_interval=100
vis_interval=500
log_interval=100
image_size=256
finetune='store_true'
#finetune=False

resume = './snapshots/default/ckpt/1007800.pth'

torch.backends.cudnn.benchmark = True
device = torch.device('cuda')

if not os.path.exists(save_dir):
    os.makedirs('{:s}/images'.format(save_dir))
    os.makedirs('{:s}/ckpt'.format(save_dir))

if not os.path.exists( log_dir):
    os.makedirs( log_dir)
    
writer = SummaryWriter(log_dir=log_dir)

size = ( image_size,  image_size)
img_tf = transforms.Compose(
    [transforms.Resize(size=size), transforms.ToTensor(),
     transforms.Normalize(mean=opt.MEAN, std=opt.STD)])
mask_tf = transforms.Compose(
    [transforms.Resize(size=size), transforms.ToTensor()])

dataset_train = MyDataset(root,  mask_root, img_tf, mask_tf, 'train')
dataset_val = MyDataset(root,  mask_root, img_tf, mask_tf, 'val')

iterator_train = iter(data.DataLoader(
    dataset_train, batch_size= batch_size,
    sampler=InfiniteSampler(len(dataset_train)),
    num_workers=n_threads, 
    shuffle=True))

print(len(dataset_train))

3889


In [6]:
model = PConvUNet().to(device)

if  finetune:
    lr =  lr_finetune
    model.freeze_enc_bn = True
    print('fine_tuning: ',lr)
else:
    lr =  lr

start_iter = 0
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
criterion = InpaintingLoss(VGG16FeatureExtractor()).to(device)

if  resume:
    start_iter = load_ckpt(
         resume, [('model', model)], [('optimizer', optimizer)])
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    max_iter = max_iter+start_iter
    print('Starting from iter ', start_iter)

for i in tqdm(range(start_iter,  max_iter)):
    model.train()

    image, mask, gt = [x.to(device) for x in next(iterator_train)]
    output, _ = model(image, mask)
    loss_dict = criterion(image, mask, output, gt)

    loss = 0.0
    for key, coef in opt.LAMBDA_DICT.items():
        value = coef * loss_dict[key]
        loss += value
        if (i + 1) %  log_interval == 0:
            writer.add_scalar('loss_{:s}'.format(key), value.item(), i + 1)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i + 1) %  save_model_interval == 0 or (i + 1) ==  max_iter:
        save_ckpt('{:s}/ckpt/{:d}.pth'.format( save_dir, i + 1),
                  [('model', model)], [('optimizer', optimizer)], i + 1)

    if (i + 1) %  vis_interval == 0:
        model.eval()
        evaluate(model, dataset_val, device,
                 '{:s}/images/test_{:d}.jpg'.format( save_dir, i + 1))

writer.close()