In [None]:
!unzip data.zip

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import cv2
import os
import numpy as np

class LOLDataset(Dataset):
    def __init__(self, low_dir, high_dir, transform=None):
        self.low_dir = low_dir
        self.high_dir = high_dir
        self.low_images = sorted(os.listdir(low_dir))
        self.high_images = sorted(os.listdir(high_dir))
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.low_images)

    def __getitem__(self, idx):
        low_path = os.path.join(self.low_dir, self.low_images[idx])
        high_path = os.path.join(self.high_dir, self.high_images[idx])

        low_img = cv2.imread(low_path, cv2.IMREAD_GRAYSCALE)
        high_img = cv2.imread(high_path, cv2.IMREAD_GRAYSCALE)

        if low_img is None or high_img is None:
            raise ValueError(f"Could not read image: {low_path} or {high_path}")


        low_img = transforms.ToPILImage()(low_img)
        high_img = transforms.ToPILImage()(high_img)

        if self.transform:
            low_img = self.transform(low_img)
            high_img = self.transform(high_img)

        return low_img, high_img

class PreCEModule(nn.Module):
    def forward(self, x):
        eps = 1e-6
        x_min = x.min()
        x_max = x.max()
        return (x - x_min) / (x_max - x_min + eps)

class GlobalEnhanceNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 16, 3, padding=1)
        self.conv4 = nn.Conv2d(16, 4, 3, padding=1)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):

        v = torch.cat([
            torch.ones_like(x),
            x,
            x**2,
            x**3
        ], dim=1)


        x = self.relu(self.conv1(x))
        x = self.dropout(x)
        x = self.relu(self.conv2(x))
        x = self.dropout(x)
        x = self.relu(self.conv3(x))
        theta_g = self.conv4(x)


        b, _, h, w = x.shape
        theta_g = theta_g.view(b, 4, 1, h, w)
        v = v.view(b, 4, 1, h, w)
        x_g = (theta_g * v).sum(dim=1, keepdim=True)
        return x_g

class LocalEnhanceNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(4, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 16, 3, padding=1)
        self.conv4 = nn.Conv2d(16, 4, 3, padding=1)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):

        v = torch.cat([
            torch.ones_like(x),
            x,
            x**2,
            x**3
        ], dim=1)


        x = self.relu(self.conv1(v))
        x = self.dropout(x)
        x = self.relu(self.conv2(x))
        x = self.dropout(x)
        x = self.relu(self.conv3(x))
        theta_l = self.conv4(x)


        b, _, h, w = x.shape
        theta_l = theta_l.view(b, 4, 1, h, w)
        v = v.view(b, 4, 1, h, w)
        x_l = (theta_l * v).sum(dim=1, keepdim=True)
        return x_l

class GLCENetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.pre_ce = PreCEModule()
        self.global_net = GlobalEnhanceNet()
        self.local_net = LocalEnhanceNet()

    def forward(self, x):
        x_std = self.pre_ce(x)
        x_g = self.global_net(x_std)
        x_l = self.local_net(x_std)

        enhanced = x_std + x_l + x_g
        return torch.clamp(enhanced, 0, 1)

class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def histogram_loss(self, pred, target, bins=256):
        def compute_histogram(tensor):
            eps = 1e-8
            tensor = torch.clamp(tensor, 0, 1)
            hist = torch.histc(tensor.float(), bins=bins, min=0, max=1)
            return hist / (hist.sum() + eps)

        pred_hist = compute_histogram(pred)
        target_hist = compute_histogram(target)
        return F.l1_loss(pred_hist, target_hist)

    def gradient_loss(self, pred, target):
        def get_gradient(x):

            gradient_x = F.pad(x[:, :, :, 1:] - x[:, :, :, :-1], (0, 1, 0, 0))
            gradient_y = F.pad(x[:, :, 1:, :] - x[:, :, :-1, :], (0, 0, 0, 1))
            return gradient_x, gradient_y

        pred_gradient_x, pred_gradient_y = get_gradient(pred)
        target_gradient_x, target_gradient_y = get_gradient(target)

        gradient_loss = F.l1_loss(pred_gradient_x, target_gradient_x) + \
                       F.l1_loss(pred_gradient_y, target_gradient_y)
        return gradient_loss

    def forward(self, pred, target):

        if pred.size() != target.size():
            target = target.expand_as(pred)


        α1, α2, α3 = 1.0, 0.2, 0.5

        mse_loss = self.mse(pred, target)
        hist_loss = self.histogram_loss(pred, target)
        grad_loss = self.gradient_loss(pred, target)

        total_loss = α1 * mse_loss + α2 * hist_loss + α3 * grad_loss
        return total_loss

def train_model(model, train_loader, num_epochs=240, learning_rate=1e-4):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    criterion = CustomLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

    best_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        batch_count = 0

        for low_img, high_img in train_loader:
            low_img = low_img.to(device)
            high_img = high_img.to(device)

            optimizer.zero_grad()

            enhanced_img = model(low_img)
            loss = criterion(enhanced_img, high_img)

            if not torch.isnan(loss) and not torch.isinf(loss):
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                epoch_loss += loss.item()
                batch_count += 1

        if batch_count > 0:
            avg_loss = epoch_loss / batch_count
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
            scheduler.step(avg_loss)

            if avg_loss < best_loss:
                best_loss = avg_loss
                torch.save(model.state_dict(), 'best_model.pth')

def test_model(model, test_image_path, output_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])


    img = cv2.imread(test_image_path, cv2.IMREAD_GRAYSCALE)
    img = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        enhanced = model(img)


    enhanced = enhanced.cpu().squeeze().numpy()
    enhanced = (enhanced * 255).clip(0, 255).astype(np.uint8)
    cv2.imwrite(output_path, enhanced)

if __name__ == "__main__":

    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)

    low_dir = 'data/low'
    high_dir = 'data/high'
    dataset = LOLDataset(low_dir, high_dir)
    train_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)


    model = GLCENetwork()
    train_model(model, train_loader)

In [None]:
!unzip test.zip

In [None]:
def test_model(model, test_image_path, output_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    img = cv2.imread(test_image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Image at {test_image_path} could not be loaded")

    img = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        enhanced = model(img)

    enhanced = enhanced.cpu().squeeze().numpy()
    enhanced = (enhanced * 255).clip(0, 255).astype(np.uint8)


    output_path_with_extension = output_path + '.png'
    cv2.imwrite(output_path_with_extension, enhanced)
    print(f"Enhanced image saved to {output_path_with_extension}")


if __name__ == "__main__":
  test_model(model, 'test/one.jpg', 'test/enhnone')


In [None]:

torch.save(model.state_dict(), 'model.pth')


In [None]:

model = GLCENetwork()
model.load_state_dict(torch.load('model.pth'))
model.eval()
test_model(model, '7.1.03.tiff', 'test/enhancedtank')
