In [None]:
import argparse
import numpy as np
import pandas as pd
import os
import torch
# from tensorboardX import SummaryWriter
from torch.utils import data
from torchvision import transforms
from tqdm import tqdm

import opt
from evaluation import evaluate
from loss import InpaintingLoss
from net import PConvUNet
from net import VGG16FeatureExtractor
# from places2 import Places2
from celeb import CelebA
from util.io import load_ckpt
from util.io import save_ckpt


class InfiniteSampler(data.sampler.Sampler):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __iter__(self):
        return iter(self.loop())

    def __len__(self):
        return 2 ** 31

    def loop(self):
        i = 0
        order = np.random.permutation(self.num_samples)
        while True:
            yield order[i]
            i += 1
            if i >= self.num_samples:
                np.random.seed()
                order = np.random.permutation(self.num_samples)
                i = 0

class Argument():
    def __init__(self, ckpt_name=None):
        self.root = './data512x512'
        self.mask_root = './irregular_mask/disocclusion_img_mask/'
        self.save_dir = './snapshots/default'
        self.log_dir = './logs/default'
        self.lr = 0.0002
        self.lr_finetune = 0.00005
        self.max_iter = 9999999
        self.batch_size = 6
        self.n_threads = 4
        self.save_model_interval = 1000
        self.vis_interval = 100
        self.log_interval = 50
        self.image_size = 512
        if ckpt_name:
            self.finetune = True
            self.resume = ckpt_name
        else:
            self.finetune = False
            self.resume = None
            
args = Argument('./snapshots/default/ckpt/74000.pth')


torch.backends.cudnn.benchmark = True
device = torch.device('cuda')

if not os.path.exists(args.save_dir):
    os.makedirs('{:s}/images'.format(args.save_dir))
    os.makedirs('{:s}/ckpt'.format(args.save_dir))

if not os.path.exists(args.log_dir):
    os.makedirs(args.log_dir)
# writer = SummaryWriter(log_dir=args.log_dir)
loss_df = pd.DataFrame()
loss_df_val = pd.DataFrame()

size = (args.image_size, args.image_size)
img_tf = transforms.Compose(
    [transforms.Resize(size=size), transforms.ToTensor(),
     transforms.Normalize(mean=opt.MEAN, std=opt.STD)])
mask_tf = transforms.Compose(
    [
        transforms.RandomCrop(size=size),
#         transforms.Resize(size=size),
        transforms.ToTensor()
    ])
###############
dataset_train = CelebA(args.root, args.mask_root, img_tf, mask_tf, 'train')
dataset_val = CelebA(args.root, args.mask_root, img_tf, mask_tf, 'valid')

iterator_train = iter(data.DataLoader(
    dataset_train, batch_size=args.batch_size,
    sampler=InfiniteSampler(len(dataset_train)),
    num_workers=args.n_threads))

iterator_val = iter(data.DataLoader(
    dataset_val, batch_size=args.batch_size,
    sampler=InfiniteSampler(len(dataset_val)),
    num_workers=args.n_threads))
print(len(dataset_train))
model = PConvUNet().to(device)

if args.finetune:
    lr = args.lr_finetune
    model.freeze_enc_bn = True
else:
    lr = args.lr

start_iter = 0
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
criterion = InpaintingLoss(VGG16FeatureExtractor()).to(device)

if args.resume:
    start_iter = load_ckpt(
        args.resume, [('model', model)], [('optimizer', optimizer)])
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    print('Starting from iter ', start_iter)

for i in tqdm(range(start_iter, args.max_iter)):
    model.train()

    image, mask, gt = [x.to(device) for x in next(iterator_train)]
    output, _ = model(image, mask)
    loss_dict = criterion(image, mask, output, gt)

    loss = 0.0
    
    for key, coef in opt.LAMBDA_DICT.items():
        value = coef * loss_dict[key]
        loss += value

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (i + 1) % args.log_interval == 0:
        weihted_loss_dict = {}
        for key, coef in opt.LAMBDA_DICT.items():
            value = coef * loss_dict[key]
            loss += value
            weihted_loss_dict[key] = value.item()
        loss_df_tmp = pd.DataFrame(weihted_loss_dict.values(),
                               index=weihted_loss_dict.keys(), columns=[i + 1]).T
        loss_df = pd.concat([loss_df, loss_df_tmp])
        loss_df.to_csv(os.path.join(args.log_dir, 'loss.csv'))
        
        # validation
        image, mask, gt = [x.to(device) for x in next(iterator_val)]
        with torch.no_grad():
            output, _ = model(image, mask)
        loss_dict = criterion(image, mask, output, gt)
        weihted_loss_dict = {}
        for key, coef in opt.LAMBDA_DICT.items():
            value = coef * loss_dict[key]
            loss += value
            weihted_loss_dict[key] = value.item()
            print('loss_val{:s}'.format(key), value.item(), i + 1)
        loss_df_tmp = pd.DataFrame(weihted_loss_dict.values(),
                               index=weihted_loss_dict.keys(), columns=[i + 1]).T
        loss_df_val = pd.concat([loss_df_val, loss_df_tmp])
        loss_df_val.to_csv(os.path.join(args.log_dir, 'loss_val.csv'))
        #             writer.add_scalar('loss_{:s}'.format(key), value.item(), i + 1)

    if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
        save_ckpt('{:s}/ckpt/{:d}.pth'.format(args.save_dir, i + 1),
                  [('model', model)], [('optimizer', optimizer)], i + 1)

    if (i + 1) % args.vis_interval == 0:
        model.eval()
        evaluate(model, dataset_val, device,
                 '{:s}/images/test_{:d}.jpg'.format(args.save_dir, i + 1))

# writer.close()