In [None]:
!pip install catalyst==20.10.1

In [None]:
!pip install dists-pytorch
# for colab need to change sys.prefix to '/usr/local'

In [None]:
import os
import sys
import random
from typing import Optional, Tuple, Callable

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import functools

import catalyst
from catalyst import dl, utils

from DISTS_pytorch import DISTS

from tqdm.notebook import tqdm

In [None]:
import numpy
numpy.random.seed(123)
import random
random.seed(123)
torch.backends.cudnn.deterministic = True
torch.manual_seed(123)
torch.cuda.manual_seed(123)

In [1]:
from core import SRDataset, InferDataset, ResidualDenseBlock_5C, RRDB, RRDBNet
from utilities import make_layer, init_weights, make_submission

In [None]:
# from google.colab import drive
# drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [None]:
lr_dir = '/gdrive/MyDrive/Deep_Learning/huawei_cup_final_2020/cleaned_train/LR' #path to cleaned train LR images
hr_dir = '/gdrive/MyDrive/Deep_Learning/huawei_cup_final_2020/cleaned_train/HR' #path to cleaned train HR images

In [None]:
samples = []
for name in os.listdir(lr_dir):
    if not name.endswith(".png"):
        continue
    if not os.path.exists(os.path.join(hr_dir, name)):
        raise RuntimeError(f"File {name} does not exist in {hr_dir}")
    samples.append(name)

In [None]:
len(samples)

343

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [None]:
TRAIN_VALID_SPLIT = 1
random.shuffle(samples)
train_valid_split = int(TRAIN_VALID_SPLIT*len(samples))
train_samples = samples[:train_valid_split]
valid_samples = samples[train_valid_split:]
print(len(train_samples), len(valid_samples))

343 0


In [None]:
train_dataset = SRDataset(hr_dir, lr_dir, train_samples, crop_size=64, length=8000)
valid_dataset = SRDataset(hr_dir, lr_dir, valid_samples)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=12, shuffle=True, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
loaders = {"train": train_dataloader, "valid": valid_dataloader}
# loaders = {"train": train_dataloader} # without validation validation to make final result

In [None]:
runner = dl.SupervisedRunner(
    input_key="features",
    output_key="logits",
    input_target_key="targets",
)      

In [None]:
model = RRDBNet(nb = 8).to(device)
init_weights(model)

In [None]:
criterion = functools.partial(DISTS().to(device), require_grad=True, batch_average=True)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0008)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5, last_epoch=-1, verbose=False)

In [None]:
runner.train(
  model=model, 
  optimizer=optimizer,
  criterion=criterion,
  scheduler=scheduler,
  loaders=loaders, 
  num_epochs=10,
  verbose=True,
  timeit=False,
  callbacks={
    "criterion_dists": dl.CriterionCallback(
      input_key="targets",
      output_key="logits",
      prefix="loss",
    ),
    "optimizer": dl.OptimizerCallback(
      metric_key="loss", 
      accumulation_steps=1,
      grad_clip_params=None,
    )
  }
)

1/10 * Epoch (train):   0% 0/727 [00:00<?, ?it/s]


The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 



1/10 * Epoch (train): 100% 727/727 [49:28<00:00,  4.08s/it, loss=0.099]
1/10 * Epoch (valid): 100% 727/727 [15:52<00:00,  1.31s/it, loss=0.096]
[2020-12-18 02:39:14,054] 
1/10 * Epoch 1 (_base): lr=0.0002 | momentum=0.9000
1/10 * Epoch 1 (train): loss=0.1125
1/10 * Epoch 1 (valid): loss=0.1007
2/10 * Epoch (train):   0% 0/727 [00:00<?, ?it/s]


To get the last learning rate computed by the scheduler, please use `get_last_lr()`.



2/10 * Epoch (train): 100% 727/727 [43:20<00:00,  3.58s/it, loss=0.104]
2/10 * Epoch (valid): 100% 727/727 [16:06<00:00,  1.33s/it, loss=0.103]
[2020-12-18 03:38:41,063] 
2/10 * Epoch 2 (_base): lr=0.0001 | momentum=0.9000
2/10 * Epoch 2 (train): loss=0.0973
2/10 * Epoch 2 (valid): loss=0.0948
3/10 * Epoch (train): 100% 727/727 [43:40<00:00,  3.60s/it, loss=0.088]
3/10 * Epoch (valid): 100% 727/727 [16:11<00:00,  1.34s/it, loss=0.086]
[2020-12-18 04:38:33,408] 
3/10 * Epoch 3 (_base): lr=5.000e-05 | momentum=0.9000
3/10 * Epoch 3 (train): loss=0.0923
3/10 * Epoch 3 (valid): loss=0.0908
4/10 * Epoch (train): 100% 727/727 [43:42<00:00,  3.61s/it, loss=0.098]
4/10 * Epoch (valid): 100% 727/727 [16:14<00:00,  1.34s/it, loss=0.097]
[2020-12-18 05:38:30,579] 
4/10 * Epoch 4 (_base): lr=2.500e-05 | momentum=0.9000
4/10 * Epoch 4 (train): loss=0.0893
4/10 * Epoch 4 (valid): loss=0.0884
5/10 * Epoch (train): 100% 727/727 [43:41<00:00,  3.61s/it, loss=0.101]
5/10 * Epoch (valid): 100% 727/727 [1

In [None]:
MODEL_CHECKPOINT_PATH = 'my_model.pth'

In [None]:
torch.save(model.state_dict(), MODEL_CHECKPOINT_PATH)