In [26]:
import os
import os.path as osp

import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from tqdm import tqdm

from model import Dehazer
from data import reshape_source, get_transform, collect_paths

In [27]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
def extract_dmaps(src_dir, paths):
    
    dst_dir = osp.join(osp.split(src_dir)[0], "dmap_pred")
    os.makedirs(dst_dir, exist_ok=True)
    
    print(f"Destination directory of dmaps: {dst_dir}")
        
    # get predictions
    device = torch.device("cuda")
    model.to(device)

    for p in tqdm(paths):

        # get filename from the path
        fname = osp.split(p)[1]
        fname_out = fname.replace("hazy", "dmap")

        # read in the image
        hazy_img = transform(Image.open(p)).unsqueeze(0)

        # get the result
        output = model(hazy_img.to(device)).detach().cpu().numpy().squeeze()
        output = output/output.max()
        output = output*255
        output = output.astype("uint8")

        # store the result
        pil_output = Image.fromarray(output)
        dst_fname = osp.join(dst_dir, fname_out)
        pil_output.save(dst_fname)

In [30]:
# load the checkpoint into the model

ckpt_path = "logs/dh/NYU/lr0.0002_epochs30/weights/030.pth"
model = Dehazer(in_channels=3, out_channels=1)
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [38]:
# set the dataset, list the files

dataset = "dh/NYU"
src_dir = f"datasets/{dataset}/hazy"
src_list_file = None
# src_file_list = "dhazy/NYU_split/test_Hazy.txt"

paths = collect_paths(src_dir, src_list_file)
transform = get_transform(dataset)

In [40]:
print("Input transformation being used:")
print(transform)

Input transformation being used:
Compose(
    ToTensor()
)


In [None]:
extract_dmaps(src_dir, paths)