In [None]:
!git clone https://github.com/Hyukju/Raindrop-Removal

In [1]:
import sys
import os

# Thêm đường dẫn vào sys.path
sys.path.append('Raindrop-Removal')

https://github.com/Hyukju/Raindrop-Removal 

https://www.mdpi.com/2227-7390/11/15/3318

In [2]:
import sys
sys.argv = sys.argv[:1]  # Loại bỏ đối số '-f' của Jupyter


In [3]:
import os
import cv2 
import numpy as np 
import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms

In [4]:
class ToTensor():
    def __call__(self, data):
        for k, v in data.items():
            if v.ndim == 2:
                v = v[:,:,np.newaxis]
            data[k] = torch.from_numpy(v.transpose((2,0,1)).astype('float32'))

        return data

class Normalization():
    def __init__(self, mean=0.5, std=0.5):
        self.mean = mean
        self.std = std 

    def __call__(self, data):
        
        for k, v in data.items():
            data[k] = (v - self.mean) / self.std

        return data


class RandomFlip():
    def __call__(self, data):

        flag_lr = False 
        flag_ud = False 

        if np.random.rand() > 0.5: flag_lr = True
        if np.random.rand() > 0.5: flag_ud = True

        for k, v in data.items():
            if flag_lr: v = np.fliplr(v)
            if flag_ud: v = np.flipud(v)
            data[k] = v

        return data

class RandomCrop():
    def __init__(self, shape):
        self.shape = shape
    
    def __call__(self, data):
        input  = data['input']
        h, w = input.shape[:2]
        new_h, new_w = self.shape

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        
        id_y = np.arange(top, top + new_h, 1)[:, np.newaxis]
        id_x = np.arange(left, left + new_w, 1)

        for k, v in data.items():
            data[k] = v[id_y, id_x]

        return data
    

In [5]:
class ModelDataset(Dataset):
    def __init__(self, data_dir, dataA_dir = 'Drop', dataB_dir='Clear', use_transform=False):
        self.data_dir = data_dir
        self.use_transform = use_transform
        self.transform =  transforms.Compose([
                                            RandomCrop(shape=(256, 256)),
                                            RandomFlip(),
                                            Normalization(mean=0.5, std=0.5),
                                            ToTensor(),
                                            ])
        self.dataA_dir = dataA_dir
        self.dataB_dir = dataB_dir

        lst_dataA = []
        for folder in os.listdir(os.path.join(self.data_dir, self.dataA_dir)):
            folder_path = os.path.join(self.data_dir, self.dataA_dir, folder)
            if os.path.isdir(folder_path):  # Check if it's a folder
                images = sorted([img for img in os.listdir(folder_path) if img.lower().endswith(('jpg', 'png'))])
                for image in images:  # Ensure there is at least one image
                    lst_dataA.append(os.path.join(folder_path, image))  # Take the first image
        lst_dataA.sort()

        
        lst_dataB = []
        for folder in os.listdir(os.path.join(self.data_dir, self.dataB_dir)):
            folder_path = os.path.join(self.data_dir, self.dataB_dir, folder)
            if os.path.isdir(folder_path):  # Check if it's a folder
                images = sorted([img for img in os.listdir(folder_path) if img.lower().endswith(('jpg', 'png'))])
                for image in images:  # Ensure there is at least one image
                    lst_dataB.append(os.path.join(folder_path, image))  # Take the first image
        lst_dataB.sort()

        self.lst_dataA = lst_dataA
        self.lst_dataB = lst_dataB

    def __len__(self):
        return len(self.lst_dataA)

    def __getitem__(self, index):

        imgA = cv2.imread(os.path.join(self.data_dir, self.dataA_dir, self.lst_dataA[index]))
        imgB = cv2.imread(os.path.join(self.data_dir, self.dataB_dir, self.lst_dataB[index]))
        imgA = cv2.cvtColor(imgA, cv2.COLOR_BGR2RGB)
        imgB = cv2.cvtColor(imgB, cv2.COLOR_BGR2RGB)

        if imgA.dtype == np.uint8:
            imgA = imgA / 255.0  
        if imgB.dtype == np.uint8:
            imgB = imgB / 255.0      

        grayA = cv2.cvtColor(imgA.astype('float32'), cv2.COLOR_RGB2GRAY)
        grayB = cv2.cvtColor(imgB.astype('float32'), cv2.COLOR_RGB2GRAY)

        mask = (grayA - grayB) * 0.5 + 0.5
     
        data = {'input':imgA, 'label':imgB, 'mask':mask}

        if self.use_transform:
            data = self.transform(data)

        return data

In [6]:
import importlib

def find_model2(model_name, phase='train', learning_rate=0.002):
    """
    Searches for the specified model inside the 'model' directory.
    Imports the corresponding <model_name>_model.py and instantiates the 'Model' class.
    
    Parameters:
        model_name (str): Name of the model.
        phase (str): Either 'train' or 'test'.
        learning_rate (float): Learning rate for the model.

    Returns:
        model: Instantiated model class or None if not found.
    """
    
    model = None
    model_module_name = f'model.{model_name.lower()}_model'

    print(f'Searching for model: {model_module_name}')
    
    try:
        # Import the model module dynamically
        model_module = importlib.import_module(model_module_name)
    except ModuleNotFoundError:
        print(f'Error: {model_module_name} module not found.')
        return None
    
    target_model_name = 'Model'

    # Search for the class inside the module
    for name, cls in model_module.__dict__.items():
        if name.lower() == target_model_name.lower() and isinstance(cls, type):
            model = cls  # Store the class

    if model is None:
        print(f'Error: {target_model_name} class not found in {model_module_name}.')
        return None

    print(f'{model_module_name}.{target_model_name} successfully created.')
    
    # Instantiate the model with provided parameters
    return model(phase=phase, in_channels=3, out_channels=3, nker=64, lr=learning_rate)


## inference

In [7]:
import os 
import cv2
import argparse
import utils 
import numpy as np 
from model.find_model import find_model
import time
def test(args):
    start_time=time.time()
    #testsets = ['test_c', 'test_b']

    #ckpt_dir = os.path.join(args.ckpt_dir, args.model)
    ckpt_dir=args.ckpt_dir
    model, _ = find_model(args.model, 'test')
    
    epoch = model.load(ckpt_dir, epoch=args.ckpt_epoch)
    print(f'Loading {args.model} at EPOCH {epoch}!!')

    img_dir = os.path.join(args.dataset_dir)    
    #img_file_list = utils.get_image_file_list(img_dir)
    img_file_list=os.listdir(img_dir)
    img_file_list.sort()
    img_file_list=img_file_list

    for i, filename in enumerate(img_file_list, 1):
        img = cv2.imread(os.path.join(img_dir, filename))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if args.resize == 'square':
            square_img = cv2.resize(img, (512,512))
            input = utils.numpy2tensor(square_img)
            output = model.test_one_image(input)
        
        elif args.resize == 'expand':
            rows, cols = img.shape[:2]
            expand_img = utils.expand_size(img, 256)
            input = utils.numpy2tensor(expand_img)
            output = model.test_one_image(input)

            for title, output_img in output.items():
                output[title] = utils.restore_size(output_img, rows, cols)
        
        elif args.resize == 'original':
            input = utils.numpy2tensor(img)
            output = model.test_one_image(input)
            
        # save images 
        save_dir = os.path.join(args.save_dir, args.model, str(epoch) + '_' + args.resize)
        #utils.save_outputs(
         #   save_dir = save_dir,
          #  filename = f'{filename[:-4]}.png',
           # outputs = output,
            #max_display = 3
            #)

        save_output_dir = os.path.join(save_dir, 'output')
        os.makedirs(save_dir, exist_ok=True)
        cv2.imwrite(os.path.join(save_dir, f'{filename[:-4]}.png'), cv2.cvtColor(output['output'].squeeze() * 255, cv2.COLOR_RGB2BGR))
        
        print(f'{i}/{len(img_file_list)}:{filename}')
    
    print('Test Finished!!')
    print(time.time()-start_time)

    
if __name__=='__main__':
    parser = argparse.ArgumentParser(prog = 'DeRainDrop')             
    
    parser.add_argument('--model', default='proposed', type=str, dest='model')
    parser.add_argument('--ckpt_dir', default='model/GAN_model', type=str, dest='ckpt_dir') 
    parser.add_argument('--ckpt_epoch', default=800, type=int, dest='ckpt_epoch')    
    parser.add_argument('--dataset_dir', default='./private_dataset', type=str, dest='dataset_dir')    
    parser.add_argument('--save_dir', default='GAN_inference', type=str, dest='save_dir')
    parser.add_argument('--resize', default='original', type=str, dest='resize')
    
    args = parser.parse_args()
    test(args)






searching...  model.proposed_model
model.proposed_model is created
initialize network with normal
initialize network with normal
initialize network with normal


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 193MB/s]
  model_weights = torch.load(os.path.join(ckpt_dir, f'model_epoch{epoch}.pth'))


Loading proposed at EPOCH 700!!
1/731:00001.png
2/731:00002.png
3/731:00003.png
4/731:00004.png
5/731:00005.png
6/731:00006.png
7/731:00007.png
8/731:00008.png
9/731:00009.png
10/731:00010.png
11/731:00011.png
12/731:00012.png
13/731:00013.png
14/731:00014.png
15/731:00015.png
16/731:00016.png
17/731:00017.png
18/731:00018.png
19/731:00019.png
20/731:00020.png
21/731:00021.png
22/731:00022.png
23/731:00023.png
24/731:00024.png
25/731:00025.png
26/731:00026.png
27/731:00027.png
28/731:00028.png
29/731:00029.png
30/731:00030.png
31/731:00031.png
32/731:00032.png
33/731:00033.png
34/731:00034.png
35/731:00035.png
36/731:00036.png
37/731:00037.png
38/731:00038.png
39/731:00039.png
40/731:00040.png
41/731:00041.png
42/731:00042.png
43/731:00043.png
44/731:00044.png
45/731:00045.png
46/731:00046.png
47/731:00047.png
48/731:00048.png
49/731:00049.png
50/731:00050.png
51/731:00051.png
52/731:00052.png
53/731:00053.png
54/731:00054.png
55/731:00055.png
56/731:00056.png
57/731:00057.png
58/731:0