### Predict on validation set

In [32]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:
import os
import sys
import torch
from torch.utils.data import Dataset
from tqdm.notebook import tqdm

In [34]:
ROOT_DIR = os.path.dirname(os.getcwd())
sys.path.append(ROOT_DIR)

In [35]:
from src.path import DATA_VALID, MODELS_DIR
from src.metrics import calculate_psnr
from models.rednet import RedNet

In [36]:
class ImgDataset(Dataset):
    def __init__(self, data_dir: str) -> None:
        self.data_dir = data_dir
        self.noisy_imgs ,self.clean_imgs = torch.load(data_dir)

    def __getitem__(self, index):
        x = self.noisy_imgs[index, ...]/255.0
        y = self.clean_imgs[index, ...]/255.0
        return x, y

    def __len__(self):
        return self.noisy_imgs.shape[0]

data = ImgDataset(DATA_VALID)
validDataLoader=torch.utils.data.DataLoader(data,batch_size=1,
                                           shuffle=True)

In [37]:
model = RedNet()
criterion = torch.nn.MSELoss()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)
model.eval()
model.to(device)

name = "model-20220420.pt"
if torch.cuda.is_available():
  model.load_state_dict(torch.load(os.path.join(MODELS_DIR, name)))
else:
  model.load_state_dict(torch.load(os.path.join(MODELS_DIR, name), map_location=torch.device('cpu')))

cpu


In [38]:
valid_loss = []
valid_pnsr = []
with tqdm(validDataLoader, desc=f'denoised validation set', unit='img', 
                  leave=False) as t1:
    for x_noised, x_clean in t1:
        x_noised = x_noised.to(device)
        x_clean = x_clean.to(device)

        x_denoised = model(x_noised)
        loss = criterion(x_denoised, x_clean)
        x_denoised = torch.squeeze(x_denoised, dim=0)
        x_clean = torch.squeeze(x_clean, dim=0)
        psnr = calculate_psnr(x_denoised, x_clean)
        
        valid_loss.append(loss.item())
        valid_pnsr.append(psnr.item())

denoised validation set:   0%|          | 0/1000 [00:00<?, ?img/s]

In [39]:
avg_loss = sum(valid_loss)/len(valid_loss)
avg_pnsr = sum(valid_pnsr)/len(valid_pnsr)

print(f'for model {name}')
print(f'Hit psnr = {avg_pnsr} dB')
print(f'Hit loss = {avg_loss}')

for model model-20220420.pt
Hit psnr = 24.91799765396118 dB
Hit loss = 0.00394204310758505
