In [1]:
from config import args
from pytorch_lightning import Trainer, loggers
from models import SegmentationModel
from augmentations import get_transforms
from datasets import get_train_val_dataloaders
import warnings 
import cv2
import numpy as np
from pathlib import Path
warnings.filterwarnings(action= 'ignore')

In [2]:
transform_train=get_transforms('rcf')
transform_train_512=get_transforms('rcf512')
transform_val=get_transforms('center_c')
transform_val_512=get_transforms('center_c512')

In [5]:
import torch
import torch.nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl

from torch_enhance.datasets import BSDS300, Set14, Set5
# from torch_enhance.models import SRCNN, SRResNet, VDSR, EDSR
from torch_enhance.models import VDSR
from torch_enhance import metrics


class Module(pl.LightningModule):

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        lr, hr = batch
        sr = self(lr)
        loss = F.mse_loss(sr, hr, reduction="mean")
        
        # metrics
        mae = metrics.mae(sr, hr)
        psnr = metrics.psnr(sr, hr)

        # Logs
        self.log("train_loss", loss)
        self.log("train_mae", mae)
        self.log("train_psnr", psnr)

        return loss

    def validation_step(self, batch, batch_idx):
        lr, hr = batch
        sr = self(lr)
        loss = F.mse_loss(sr, hr, reduction="mean")
        
        # metrics
        mae = metrics.mae(sr, hr)
        psnr = metrics.psnr(sr, hr)

        # Logs
        self.log("val_loss", loss)
        self.log("val_mae", mae)
        self.log("val_psnr", psnr)

        return loss

    def test_step(self, batch, batch_idx):
        lr, hr = batch
        sr = self(lr)
        loss = F.mse_loss(sr, hr, reduction="mean")
        
        # metrics
        mae = metrics.mae(sr, hr)
        psnr = metrics.psnr(sr, hr)

        # Logs
        self.log("test_loss", loss)
        self.log("test_mae", mae)
        self.log("test_psnr", psnr)

        return loss

scale_factor = 2

# Setup dataloaders
train_dataset = BSDS300(scale_factor=scale_factor)
val_dataset = Set14(scale_factor=scale_factor)
test_dataset = Set5(scale_factor=scale_factor)
train_dataloader = DataLoader(train_dataset, batch_size=8)
val_dataloader = DataLoader(val_dataset, batch_size=1)
test_dataloader = DataLoader(test_dataset, batch_size=1)

# Define model
channels = 3 if train_dataset.color_space == "RGB" else 1
model = VDSR(scale_factor, channels)
# model = SRResNet(scale_factor, channels)
module = Module(model)

trainer = pl.Trainer(max_epochs=60, gpus=1)

trainer.fit(
    module,
    train_dataloader,
    val_dataloader,
) 

trainer.test(module, test_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | VDSR | 668 K 
-------------------------------
668 K     Trainable params
0         Non-trainable params
668 K     Total params
2.673     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss          0.0008328888798132539
        test_mae           0.016294006258249283
        test_psnr           32.053993225097656
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.0008328888798132539,
  'test_mae': 0.016294006258249283,
  'test_psnr': 32.053993225097656}]

In [6]:
module = module.cuda()

# Make directories for SR (large) images and save them

In [7]:
train_dataloader, val_dataloader = get_train_val_dataloaders(
    transform_train=transform_train, transform_val=transform_val,
    include_massachusetts=False,
    num_workers=1,
    batch_size=4,
)


In [8]:
test_images_dir = Path('./data/test/images/')
Path(str(test_images_dir).replace('images','images_800')).mkdir(exist_ok=True)
test_paths = list(test_images_dir.glob('*.png'))

In [9]:
# Path(str(test_images_dir).replace('images','images_800')).mkdir(exist_ok=True)
Path(str(train_dataloader.dataset.img_files[0].parent).replace('images','images_800')).mkdir(exist_ok=True)
Path(str(train_dataloader.dataset.mask_files[0].parent).replace('groundtruth','groundtruth_800')).mkdir(exist_ok=True)

In [10]:
for test_path in test_paths:
    img_np = cv2.imread(str(test_path))
    img_tensor=torch.Tensor(img_np[None,:].transpose(0,3,1,2)).to(module.device)
    img_tensor_big = module(img_tensor)
    fname_out = str(Path(str(test_path).replace('images','images_800')))
    cv2.imwrite(fname_out,img_tensor_big[0].detach().cpu().numpy().transpose(1,2,0))

In [None]:
for train_path in train_dataloader.dataset.img_files+val_dataloader.dataset.img_files:
    img_np = cv2.imread(str(train_path))
    img_tensor=torch.Tensor(img_np[None,:].transpose(0,3,1,2)).to(module.device)
    img_tensor_big = module(img_tensor)
    fname_out = str(Path(str(train_path).replace('images','images_800')))
    cv2.imwrite(fname_out,img_tensor_big[0].detach().cpu().numpy().transpose(1,2,0))

In [None]:
for train_mask_path in train_dataloader.dataset.mask_files+val_dataloader.dataset.mask_files:
    img_np = cv2.imread(str(train_mask_path))
    img_np_big=cv2.resize(img_np, (800,800), interpolation=cv2.INTER_AREA)
    fname_out = str(Path(str(train_mask_path).replace('groundtruth','groundtruth_800')))
    cv2.imwrite(fname_out,img_np_big)