In [None]:
import argparse
import importlib
import torch
import torch.nn as nn
from defocus.model import GAN

In [None]:
parser = argparse.ArgumentParser(description='It is time for more... experiments.')
parser.add_argument('--model_name', type=str, default='MSResNet', help='model name')
parser.add_argument('--batch_size', type=int, default=8, help='input batch size for training')
parser.add_argument('--num_gpu', type=int, default=2, help='number of gpus to use')
parser.add_argument('--num_workers', type=int, default=8, help='the number of dataloader workers')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--betas', type=float, nargs=2, default=(0.9, 0.999), help='ADAM betas')
parser.add_argument('--adv_loss', nargs='+', default=('BCEWithLogitsLoss', '1.0'), action='store', 
                    help='Adversarial loss function(s) and weighting, e.g. BCEWithLogitsLoss 0.5 MSELoss 0.5')
parser.add_argument('--rec_loss', nargs='+', default=('MSELoss', '1.0'), action='store',
                    help='Reconstruction loss function(s) and weighting, e.g. L1Loss 0.5 MSELoss 0.5')
parser.add_argument('--per_loss', nargs='+', action='store', 
                    help='Perceptual loss function(s) and weighting, e.g. L1Loss 0.5 MSELoss 0.5')
parser.add_argument('--fp16', action='store_true',
                    help='Mixed precision')
parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer to use. currently Adam or AdamP')
parser.add_argument('--milestones', type=int, nargs='+', default=[500, 750, 900], help='learning rate decay per N epochs')
parser.add_argument('--root_folder', type=str, default='/storage/ekonuk/projects/all_datasets/GOPRO/train/', help='root folder')
parser.add_argument('--image_pair_list', type=str, default='/storage/ekonuk/projects/all_datasets/GOPRO/train/train_image_pair_list.txt', help='image list')
parser.add_argument('--val_image_pair_list', type=str, default='/storage/ekonuk/projects/all_datasets/GOPRO/train/val_image_pair_list.txt', help='val image list')
parser.add_argument('--stop_loss', type=int, default=None, help='the epoch to start cooperative adversarial training')
parser.add_argument('--flood_loss', type=int, default=0, help='flood loss b threshold')
parser.add_argument('--val_metric', type=str, nargs='+', default=['SSIM', 'PSNR'], help='validation evaluation metrics')
parser.add_argument('--upload', action='store_false', help='if selected, will not upload')

In [None]:
argv = ['--batch_size', '2', 
        '--num_gpu', '1', 
        '--model_name', 'DeblurGANv2', 
        '--per_loss', 'MSELoss', '0.006', 
        '--rec_loss', 'MSELoss', '0.5',
       ]
args = parser.parse_args(argv)
gan_model = GAN(args)

In [None]:
args.per_loss

['MSELoss', '0.006']

In [None]:
argv = ['--batch_size', '4', '--num_workers', '4']

parser = argparse.ArgumentParser(description='It is time for more... experiments.')
parser.add_argument('--batch_size', type=int, default=8, help='input batch size for training')
parser.add_argument('--num_workers', type=int, default=6, help='the number of dataloader workers')
parser.add_argument('--distributed', action='store_true', help='blurb')
parser.add_argument('--world_size', default=1, type=int,
                    help='number of nodes for distributed training')

args = parser.parse_args(argv)

In [None]:
model_name = 'MSResNet'
architecture = importlib.import_module('defocus.architecture.' + model_name)
training = importlib.import_module('defocus.trainers.' + model_name)

In [None]:
G = architecture.Generator()
D = architecture.Discriminator()

In [None]:
model = Model()
model.G = G
model.D = D
# model.use_perceptual()
model.set_G_optimizer('AdamP')
model.set_D_optimizer('AdamP')

In [None]:
model.set_resconstruction_loss(loss_functions=[nn.L1Loss()], 
                               weights=[1.0])
model.set_adversarial_loss(loss_functions=[nn.BCEWithLogitsLoss()],
                           weights=[1.0])

In [None]:
data = importlib.import_module('defocus.data.' + model_name)
train_dataset = data.Dataset(root_folder='/storage/projects/all_datasets/GOPRO/train/', 
                             image_pair_list='/storage/projects/all_datasets/GOPRO/train/train_image_pair_list.txt',
                            )
validation_dataset = data.Dataset(root_folder='/storage/projects/all_datasets/GOPRO/train/', 
                                  image_pair_list='/storage/projects/all_datasets/GOPRO/train/val_image_pair_list.txt',
                                 )

In [None]:
trainer = training.Trainer(model, 
                           train_dataset, validation_dataset,
                           batch_size=args.batch_size,
                           num_workers=args.num_workers
                          )

In [None]:
trainer.train(epoch=0)

In [None]:
import torch
import cv2
import albumentations
import numpy as np

transform = albumentations.Compose([albumentations.HorizontalFlip()],
                                   additional_targets={'target_image':'image'})
class Dataset(torch.utils.data.Dataset):
    def __init__(self, transform=None):
        super(Dataset, self).__init__()
        self.transform = transform
        self.length = 100
        self.input_image = np.random.rand(224,224,3)
        
        
    def __len__(self):
        return self.length 
        
    def __getitem__(self, idx):        
        augmented = self.transform(image=self.input_image, 
                                   target_image=self.input_image)
        input_image = augmented['image']
        target_image = augmented['target_image']
        return input_image, target_image

In [None]:
dataset = Dataset(transform)

In [None]:
for idx in range(100):
    inp = dataset[idx]
    assert np.alltrue(inp[1] == inp[0])