In [1]:
!pip install torch torchvision matplotlib
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np



In [2]:
import cv2
import numpy as np

def normalize_image(image):
    image = (image - np.min(image)) / (np.max(image) - np.min(image)) * 255
    return image.astype(np.uint8)

image = cv2.imread('/content/drive/MyDrive/AstroSR-main/dataset/HSC_train_HR/0.46565214_1.5316281.png')
normalized_image = normalize_image(image)


In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os

class EDSR(nn.Module):
    def __init__(self):
        super(EDSR, self).__init__()
        self.residual_layer = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        )
    def forward(self, x):
        return self.residual_layer(x)

model = EDSR()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

In [12]:
import os
from glob import glob

lr_image_path = '/content/drive/MyDrive/AstroSR-main/dataset/SDSS_train_LR/**/*'
hr_image_path = '/content/drive/MyDrive/AstroSR-main/dataset/HSC_train_HR/**/*'

lr_images = sorted(glob(lr_image_path, recursive=True))
hr_images = sorted(glob(hr_image_path, recursive=True))

print(len(lr_images))
print(len(hr_images))

print(f"Found {len(lr_images)} low-res images and {len(hr_images)} high-res images")

2001
2000
Found 2001 low-res images and 2000 high-res images


In [36]:
from glob import glob
import os
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

class GalaxyDataset(Dataset):
    def __init__(self, lr_image_paths, hr_image_paths, transform=None):
        self.lr_image_paths = lr_image_paths
        self.hr_image_paths = hr_image_paths
        self.transform = transform

    def __len__(self):
        return len(self.lr_image_paths)

    def __getitem__(self, idx):
        lr_image = Image.open(self.lr_image_paths[idx]).convert('RGB')
        hr_image = Image.open(self.hr_image_paths[idx]).convert('RGB')

        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)

        return lr_image, hr_image

lr_image_path = '/content/drive/MyDrive/AstroSR-main/dataset/SDSS_train_LR/**/*'
hr_image_path = '/content/drive/MyDrive/AstroSR-main/dataset/HSC_train_HR/**/*'

lr_images = sorted(glob(lr_image_path, recursive=True))
hr_images = sorted(glob(hr_image_path, recursive=True))

if len(lr_images) == 0 or len(hr_images) == 0:
    raise ValueError("No images found. Check the image paths.")

transform = transforms.Compose([transforms.ToTensor()])
dataset = GalaxyDataset(lr_images, hr_images, transform)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)



test_lr_image_path = '/content/drive/MyDrive/AstroSR-main/dataset/SDSS_test_LR/**/*'
test_hr_image_path = '/content/drive/MyDrive/AstroSR-main/dataset/HSC_test_HR/**/*'

test_lr_images = sorted(glob(test_lr_image_path, recursive=True))
test_hr_images = sorted(glob(test_hr_image_path, recursive=True))
test_dataset = GalaxyDataset(test_lr_images, test_hr_images, transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [17]:
import math
from skimage.metrics import structural_similarity as ssim

def calculate_psnr(hr_image, sr_image):
    mse = np.mean((hr_image - sr_image) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

def calculate_ssim(hr_image, sr_image):
    return ssim(hr_image, sr_image, multichannel=True)

In [26]:
import os
from glob import glob
from PIL import Image
from torch.utils.data import Dataset

class GalaxyDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_images = sorted(glob(os.path.join(lr_dir, '*')))
        self.hr_images = sorted(glob(os.path.join(hr_dir, '*')))

        if len(self.lr_images) != len(self.hr_images):
            raise ValueError("Mismatch: The number of low-res and high-res images must be the same.")

        self.transform = transform

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_image = Image.open(self.lr_images[idx])
        hr_image = Image.open(self.hr_images[idx])

        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)

        return lr_image, hr_image


lr_test_dir = '/content/drive/MyDrive/AstroSR-main/dataset/SDSS_test_LR/X2'
hr_test_dir = '/content/drive/MyDrive/AstroSR-main/dataset/HSC_test_HR'
transform = transforms.Compose([transforms.ToTensor()])
test_dataset = GalaxyDataset(lr_test_dir, hr_test_dir, transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [37]:
def calculate_ssim(hr_image, sr_image):
    sr_image_upsampled = F.interpolate(torch.from_numpy(sr_image), size=hr_image.shape[2:], mode='bilinear', align_corners=False)

    sr_image_upsampled_np = sr_image_upsampled.numpy()

    batch_size = hr_image.shape[0]
    ssim_values = []

    for i in range(batch_size):
        hr_img = hr_image[i].transpose(1, 2, 0)
        sr_img = sr_image_upsampled_np[i].transpose(1, 2, 0)

        ssim_value = ssim(hr_img, sr_img, win_size=3, channel_axis=-1)
        ssim_values.append(ssim_value)

    return np.mean(ssim_values)


In [None]:
model.eval()
test_psnr = []
test_ssim = []

with torch.no_grad():
    for batch in test_loader:
        inputs, targets = batch
        outputs = model(inputs)

        targets_np = targets.numpy()
        outputs_np = outputs.numpy()

        psnr_value = calculate_psnr(targets_np, outputs_np)
        ssim_value = calculate_ssim(targets_np, outputs_np)

        test_psnr.append(psnr_value)
        test_ssim.append(ssim_value)

print(f"Average PSNR: {np.mean(test_psnr)}")
print(f"Average SSIM: {np.mean(test_ssim)}")
