In [21]:
import os

import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from tqdm import tqdm

In [22]:
class Dehazer(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(Dehazer, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [23]:
# load the checkpoint
CKPT_PATH = "weights/lr_0.0002_dhaze_epochs50_test_run/checkpoint_epoch_50.pth"
checkpoint = torch.load(CKPT_PATH)

In [24]:
model = Dehazer(in_channels=3, out_channels=1)
model.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [25]:
def collect_files(src_dir, src_file_list):
    
    fnames = []
    if src_file_list is None:
        # no file containing the list, so list all files in the dir
        for f in os.listdir(src_dir):
            fnames.append(os.path.join(src_dir, f))
        
    else:
        # open the file
        with open(src_file_list, "r") as f:
            for line in f.readlines():
                fnames.append(line.strip())
    
    return fnames

In [26]:
def reshape_source(src_dir, dst_dir, factor=2):

    for filename in tqdm(os.listdir(src_dir)):
        
        src_path = os.path.join(src_dir, filename)

        if os.path.isfile(src_path):
            
            image = Image.open(src_path)
            
            width, height = image.size

            new_width = width // factor
            new_height = height // factor

            resized_image = image.resize((new_width, new_height))

            # Construct the destination file path
            dst_path = os.path.join(dst_dir, filename)

            # Save the resized image to the destination directory
            resized_image.save(dst_path)

In [27]:
src_dir = "dhazy/NYU_Hazy"
src_file_list = None
# src_file_list = "dhazy/NYU_split/test_Hazy.txt"

dst_dir = "dhazy/NYU_predictedDepthMap"
os.makedirs(dst_dir, exist_ok=True)

In [17]:
reshaped_src = "dhazy/Middlebury_Hazy_reshaped"
os.makedirs(reshaped_src, exist_ok=True)

In [22]:
reshape_source(src_dir, reshaped_src, factor=4)

100%|██████████| 23/23 [00:01<00:00, 12.39it/s]


In [28]:
transform = transforms.Compose([
    transforms.ToTensor(),
#     transforms.Resize(400),
#     transforms.CenterCrop(400),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])          
])

In [29]:
files = collect_files(src_dir, None)

In [30]:
# get predictions

device = torch.device("cuda")
model.to(device)

for f in tqdm(files):

    # read in the image
    hazy_img = transform(Image.open(f)).unsqueeze(0)
    
    # get the result
    output = model(hazy_img.to(device)).detach().cpu().numpy().squeeze()
    output = output/output.max()
    output = output*255
    output = output.astype("uint8")

    # store the result
    pil_output = Image.fromarray(output)
    dst_fname = os.path.join(dst_dir, os.path.split(f)[1].replace("Hazy", "predDMap"))
    pil_output.save(dst_fname)

100%|██████████| 1449/1449 [00:46<00:00, 31.43it/s]
