In [None]:
import torch

import utility
import data
import model
import loss
from trainer import Trainer
from torch.utils.data import DataLoader

In [None]:
class args:
    def __init__(self):
        self.debug = False
        self.template = '.'            

        # Hardware specifications
        self.n_threads = 3
        self.cpu = False
        self.n_GPUs = 1
        self.seed = 1

        # Data specifications
        self.dir_train = ["/kaggle/input/df2k-ost/train/DIV2K/DIV2K_train_HR"]
        self.dir_test = ["/kaggle/input/df2k-ost/test/DIV2K_valid"]
        self.benchmark_noise = False
        self.scale = [4]
        self.patch_size = 192
        self.rgb_range = 255
        self.n_colors = 3
        self.noise = '.'
        self.chop = False

        # Model specifications
        self.model = 'RCAN'
        self.act = 'relu'
        self.pre_train = '.'
        self.extend = '.'
        self.n_resblocks = 20
        self.n_feats = 64
        self.res_scale = 1
        self.shift_mean = True
        self.precision = 'single'

        # Training specifications
        self.reset = False
        self.test_every = 1000
        self.epochs = 1000
        self.batch_size = 16
        self.split_batch = 1
        self.self_ensemble = False
        self.test_only = False
        self.gan_k = 1


        # Optimization specifications
        self.lr = 1e-4
        self.lr_decay = 200
        self.decay_type = 'step'
        self.gamma = 0.5
        self.optimizer = 'ADAM'
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.epsilon = 1e-8
        self.weight_decay = 0

        # Loss specifications
        self.loss = '1*L1'
        self.skip_threshold = float(1e6)

        # Log specifications
        self.save = 'test'                           # file name to save
        self.load = '.'                              # file name to load
        self.resume = 0                              # resume from specific checkpoint
        self.print_model = False
        self.save_models = False                     # save all intermediate models
        self.print_every = 100                       # how many batches to wait before logging training status
        self.save_results = False                    # save output results

        # options for residual group and feature channel reduction
        self.n_resgroups = 10  
        self.reduction = 16                          # number of feature maps reduction

        # options for test
        self.testpath = '../test/DIV2K_val_LR_our'   # dataset directory for testing
        self.testset = 'Set5'                        # dataset name for testing


        if self.epochs == 0:
            self.epochs = 1e8
            
args = args()

In [None]:
import torch

import utility
import data
import model
import loss
from trainer import Trainer
from torch.utils.data import DataLoader


torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)

if checkpoint.ok:
    train_dataset = data.DS(args.dir_train, crop_size=args.patch_size, upscale_factor=args.scale[0])
    test_dataset = data.DS(args.dir_test, crop_size=args.patch_size, upscale_factor=args.scale[0])

    loader_train = DataLoader(dataset=train_dataset, num_workers=2, batch_size=args.batch_size, shuffle=True, pin_memory=True)
    loader_test = DataLoader(dataset=test_dataset, num_workers=2, batch_size=args.batch_size, shuffle=True, pin_memory=True)

    model = model.Model(args, checkpoint)
    loss = loss.Loss(args, checkpoint) if not args.test_only else None
    t = Trainer(args, loader_train, loader_test, model, loss, checkpoint)
    while not t.terminate():
        t.train()
        t.test()

    checkpoint.done()