<a href="https://colab.research.google.com/github/khrishwanth/ML-based-project/blob/main/low_light_enhancement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import io

class ResidualDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_channels):
        super(ResidualDenseBlock, self).__init__()
        self.layer1 = nn.Conv2d(in_channels, growth_channels, kernel_size=3, padding=1)
        self.layer2 = nn.Conv2d(in_channels + growth_channels, in_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out1 = self.relu(self.layer1(x))
        out2 = self.relu(self.layer2(torch.cat([x, out1], 1)))
        return out2 + x  # Skip connection

class SimplifiedESRGAN(nn.Module):
    def __init__(self):
        super(SimplifiedESRGAN, self).__init__()
        self.initial = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.rdb1 = ResidualDenseBlock(64, 32)
        self.rdb2 = ResidualDenseBlock(64, 32)
        self.upsample = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
        self.final = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, x):
        out = self.initial(x)
        out = self.rdb1(out)
        out = self.rdb2(out)
        out = self.upsample(out)
        out = self.final(out)
        return out

# Instantiate the model
model = SimplifiedESRGAN()

In [None]:
from google.colab import files
import matplotlib.pyplot as plt

# Function to upload an image
def upload_image():
    uploaded = files.upload()
    for fn in uploaded.keys():
        img = Image.open(io.BytesIO(uploaded[fn]))
        return img

# Upload an image
img = upload_image()

# Transform the image
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256))
])

img_tensor = transform(img).unsqueeze(0)  # Add batch dimension

# Pass through the model
with torch.no_grad():
    enhanced_img_tensor = model(img_tensor)

# Convert back to PIL Image
enhanced_img = transforms.ToPILImage()(enhanced_img_tensor.squeeze(0))

# Display the original and enhanced images
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(img)
plt.axis("off")
plt.subplot(1, 2, 2)
plt.title("Enhanced Image")
plt.imshow(enhanced_img)
plt.axis("off")
plt.show()