In [1]:
# 데이터셋 다운로드 : drive.google.com/file/d/1CPMBmRj-jBDO2ax4CxkBs9iczIFrs8VA/view

In [2]:
# 사전학습 모델 다운로드 :drive.google.com/file/d/1YQl7DEbUzSDOBHtB8QOXDc3-6ZGOaUN5/view?usp=sharing

In [1]:
import os
from tqdm import tqdm
from torchvision.transforms.functional import to_pil_image
import torch
import json
import model.model as module_arch
from base.trainer import Trainer
from utils.logger import Logger
from utils.util import get_lr_scheduler
from base import data_loader as module_data
from model import loss as module_loss
from model import metric as module_metric
from pathlib import Path
from utils.util import denormalize
from base.data_loader import CustomDataLoader

config_path = 'config.json'

if torch.cuda.is_available():
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    
with open(config_path) as handle:
    config = json.load(handle)



In [2]:
path = os.path.join(config['trainer']['save_dir'], config['name'])

train_logger = Logger()

data_loader_class = getattr(module_data, config['data_loader']['type'])
data_loader = data_loader_class(**config['data_loader']['args'])
valid_data_loader = data_loader.split_validation()

In [3]:
generator_class = getattr(module_arch, config['generator']['type'])
generator = generator_class(**config['generator']['args'])

discriminator_class = getattr(module_arch, 
                              config['discriminator']['type'])
discriminator = discriminator_class(**config['discriminator']['args'])

loss = {k: getattr(module_loss, v) 
        for k, v in config['loss'].items()}
metrics = [getattr(module_metric, met) 
           for met in config['metrics']]

gen_train_params = filter(lambda p: p.requires_grad, 
                          generator.parameters())
dis_train_params = filter(lambda p: p.requires_grad,
                          discriminator.parameters())
optimizer_class = getattr(torch.optim, config['optimizer']['type'])
optimizer = dict()
optimizer['generator'] = optimizer_class(gen_train_params,
                                         **config['optimizer']['args'])
optimizer['discriminator'] = optimizer_class(dis_train_params,
                                             **config['optimizer']['args'])

lr_scheduler = dict()
lr_scheduler['generator'] = get_lr_scheduler(config['lr_scheduler'],
                                             optimizer['generator'])
lr_scheduler['discriminator'] = get_lr_scheduler(config['lr_scheduler'],
                                                 optimizer['discriminator'])

In [4]:
trainer = Trainer(config, generator, discriminator, loss, metrics, 
                  optimizer, lr_scheduler, data_loader, train_logger)
trainer.train()



KeyboardInterrupt: 

In [5]:
# 디블러링 테스트 코드
blurred_path = 'test_img'
save_path = 'save'
model_path = 'checkpoint/G_latest.pth'

if torch.cuda.is_available():
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    
checkpoint = torch.load(model_path)
config = checkpoint['config']

data_loader = CustomDataLoader(data_dir=blurred_path)
    
generator_class = getattr(module_arch, config['generator']['type'])
generator = generator_class(**config['generator']['args'])

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
generator.to(device)

generator.load_state_dict(checkpoint['generator'])
generator.eval()

Path(save_path).mkdir(exist_ok=True, parents=True) 
with torch.no_grad():
    for batch_idx, sample in enumerate(tqdm(data_loader, ascii=True)):
        blurred = sample['blurred'].to(device)
        image_name = sample['image_name'][0]

        result = generator(blurred)
        result = to_pil_image(denormalize(result).squeeze().cpu())

        result.save(os.path.join(save_path, 'deblurred_' + image_name))


100%|####################################################################################| 2/2 [00:00<00:00,  2.82it/s]
