In [None]:
import cv2
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

# Generator
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_residual_blocks=9):
        super(Generator, self).__init__()
        model = [
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]
        in_features = 64
        for _ in range(2):
            out_features = in_features * 2
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(in_features)]
        for _ in range(2):
            out_features = in_features // 2
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
        model += [nn.Conv2d(64, out_channels, kernel_size=7, stride=1, padding=3), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Load the generator model
generator_rain_to_clear = Generator()
generator_rain_to_clear.load_state_dict(torch.load("/Users/matthew/Jupyter/Thesis/DeRain/Epoch-50-Batch-4/generator_rain_to_clear.pth", map_location=torch.device('mps')), strict=False)
generator_rain_to_clear.eval()


# Preprocess the input image
transform = transforms.Compose([
    # transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load and transform the rainy image
# rainy_image = Image.open("/Users/matthew/Jupyter/Thesis/DeRain/JRDR/rain_data_test_Light/rain/X2/norain-38x2.png")
# # rainy_image = Image.open("/Users/matthew/Desktop/rainy_image.jpg")
# rainy_image_tensor = transform(rainy_image).unsqueeze(0)  # Add batch dimension

# # Generate the derained image
# with torch.no_grad():
#     derained_image_tensor = generator_rain_to_clear(rainy_image_tensor)

# # Post-process and display the image
# denormalize = transforms.Normalize((-1, -1, -1), (2, 2, 2))  # Reverse normalization
# derained_image_tensor = denormalize(derained_image_tensor.squeeze(0))
# derained_image = transforms.ToPILImage()(derained_image_tensor)

# plt.figure(figsize=(8, 4))
# plt.subplot(1, 2, 1)
# plt.title("Rainy Image")
# plt.imshow(rainy_image)
# plt.axis("off")

# plt.subplot(1, 2, 2)
# plt.title("Derained Image")
# plt.imshow(derained_image)
# plt.axis("off")

# plt.show()


# Function to process a single frame
def process_frame(frame):
    frame_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))  # Convert to PIL format
    frame_tensor = transform(frame_image).unsqueeze(0)  # Add batch dimension

    with torch.no_grad():
        derained_tensor = generator_rain_to_clear(frame_tensor)  # Run through the model

    denormalize = transforms.Normalize((-1, -1, -1), (2, 2, 2))  # Reverse normalization
    derained_tensor = denormalize(derained_tensor.squeeze(0))  # Remove batch dimension and denormalize
    derained_image = transforms.ToPILImage()(derained_tensor)

    return cv2.cvtColor(np.array(derained_image), cv2.COLOR_RGB2BGR)  # Convert back to OpenCV format

# Define the video processing and saving function
def process_video(input_video_path, output_video_path):
    # Load video
    cap = cv2.VideoCapture(input_video_path)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    # Set up output video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Process the frame with deraining model
        processed_frame = process_frame(frame)
        out.write(processed_frame)

    # Release resources
    cap.release()
    out.release()

# Test the function (update the paths to your specific input/output files)
input_video_path = '/Users/matthew/Jupyter/Thesis/Dataset/HR1.mov'
output_video_path = '/Users/matthew/Jupyter/Thesis/DeRain/HR1_DeRain.mov'
process_video(input_video_path, output_video_path)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# derained_image.save("derained_image.jpg", format="JPEG")