In [1]:
!pip install torch torchvision opencv-python tqdm scikit-image matplotlib gradio --quiet

In [2]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from skimage.metrics import structural_similarity as ssim_metric
import matplotlib.pyplot as plt
import gradio as gr
import shutil
from zipfile import ZipFile

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [4]:
!gdown 1-XVnUaJBYYcH8nGQ-FA4Cg8NXo0yxTUe

Downloading...
From (original): https://drive.google.com/uc?id=1-XVnUaJBYYcH8nGQ-FA4Cg8NXo0yxTUe
From (redirected): https://drive.google.com/uc?id=1-XVnUaJBYYcH8nGQ-FA4Cg8NXo0yxTUe&confirm=t&uuid=dd17c102-b703-4eda-9d85-b0b82a0b8471
To: /content/trained_models.zip
100% 114M/114M [00:00<00:00, 119MB/s]


In [5]:
!unzip /content/trained_models.zip -d /content/trained_models

# Model Path
model_save_dir_local = '/content/trained_models'

Archive:  /content/trained_models.zip
replace /content/trained_models/colorization_model.pth? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

Define U-Net Architecture

In [6]:
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=[64, 128, 256]):
        super(SimpleUNet, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder
        for feature in features:
            self.encoder.append(self.conv_block(in_channels, feature))
            in_channels = feature

        # Bottleneck
        self.bottleneck = self.conv_block(features[-1], features[-1]*2)

        # Decoder
        reversed_features = features[::-1]
        decoder_features = features[-1]*2  # Start with bottleneck channels

        for feature in reversed_features:
            self.decoder.append(nn.ConvTranspose2d(decoder_features, feature, kernel_size=2, stride=2))
            self.decoder.append(self.conv_block(decoder_features, feature))
            decoder_features = feature  # Update for next layer

        # Final Convolution
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Encoder
        for layer in self.encoder:
            x = layer(x)
            skip_connections.append(x)
            x = self.pool(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.decoder[idx+1](concat_skip)

        return self.final_conv(x)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

Load the Fine-Tuned Models

In [7]:
tasks = ['denoising', 'super_resolution', 'colorization', 'inpainting']
models = {}

def load_finetuned_model(task, model_class, model_path):
    """Load a fine-tuned model from the specified path."""
    model = model_class()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    return model

for task in tasks:
    model_path = os.path.join(model_save_dir_local, f'{task}_model.pth')
    if os.path.exists(model_path):
        models[task] = load_finetuned_model(task, SimpleUNet, model_path)
        print(f"{task.capitalize()} model loaded successfully.")
    else:
        print(f"{task.capitalize()} model not found at {model_path}. Please ensure {task}_model.pth is in models.zip.")
        models[task] = None

  model.load_state_dict(torch.load(model_path, map_location=device))


Denoising model loaded successfully.
Super_resolution model loaded successfully.
Colorization model loaded successfully.
Inpainting model loaded successfully.


Define Inference Functions

In [8]:
transform_inference = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((32, 32)),  # CIFAR-100 size; adjust if different
    transforms.ToTensor()
])

def run_inference(model, input_image):
    """
    Run inference on a single input image using the specified model.
    Returns the predicted output image as a numpy array [0,1].
    """
    input_tensor = transform_inference(input_image).unsqueeze(0).to(device) # (1,3,32,32)
    with torch.no_grad():
        output = model(input_tensor)  # (1,3,32,32)
    output = output.squeeze(0).cpu().numpy()  # (3,32,32)
    output = np.transpose(output, (1, 2, 0))  # HWC
    output = np.clip(output, 0, 1)
    return output

def inference_interface(task, input_image):
    """
    Given the task name and user-uploaded image, run inference using the corresponding model.
    Tasks: 'Denoising', 'Super-Resolution', 'Colorization', 'Inpainting'
    """
    task_key = task.lower().replace('-', '_')
    if task_key not in models or models[task_key] is None:
        raise ValueError(f"Model for {task} not available. Please ensure the model is loaded correctly.")
    model = models[task_key]
    output_image = run_inference(model, input_image)
    return output_image

Set up Gradio Interface

In [9]:
def create_gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# 🖼️ Image Enhancement App")
        gr.Markdown("Select an enhancement task, upload an image, and see the enhanced result.")

        with gr.Row():
            task = gr.Dropdown(
                choices=["Denoising", "Super-Resolution", "Colorization", "Inpainting"],
                label="Select Enhancement Task",
                value="Denoising"
            )
            image = gr.Image(type="numpy", label="Upload an Image")

        enhanced_image = gr.Image(type="numpy", label="Enhanced Image")

        run_button = gr.Button("Enhance Image")

        run_button.click(
            fn=inference_interface,
            inputs=[task, image],
            outputs=[enhanced_image]
        )

    return demo

Launch Gradio Interface

In [10]:
demo = create_gradio_interface()
demo.launch(debug=True)

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://2bde59bed12518621e.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://2bde59bed12518621e.gradio.live


