In [7]:
import os
from PIL import Image
from torchvision.transforms import ToTensor, Compose
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

# Function to generate low-resolution images
def generate_lr_images(hr_dir, lr_dir, scale_factor):
    if not os.path.exists(lr_dir):
        os.makedirs(lr_dir)

    for filename in os.listdir(hr_dir):
        hr_image_path = os.path.join(hr_dir, filename)
        lr_image_path = os.path.join(lr_dir, filename)
        
        hr_image = Image.open(hr_image_path)
        lr_image = hr_image.resize(
            (hr_image.width // scale_factor, hr_image.height // scale_factor),
            Image.BICUBIC
        )
        lr_image.save(lr_image_path)

In [29]:
class DIV2KDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform=None, target_size=(256, 256)):
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.transform = transform
        self.target_size = target_size
        self.hr_images = sorted(os.listdir(hr_dir))
        self.lr_images = sorted(os.listdir(lr_dir))
    
    def __len__(self):
        return len(self.hr_images)
    
    def __getitem__(self, idx):
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_images[idx])).convert('YCbCr')
        y_hr, _, _ = hr_image.split()
        
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_images[idx])).convert('YCbCr')
        y_lr, _, _ = lr_image.split()

        y_hr = y_hr.resize(self.target_size, Image.BICUBIC)
        y_lr = y_lr.resize(self.target_size, Image.BICUBIC)
        
        if self.transform:
            y_hr = self.transform(y_hr)
            y_lr = self.transform(y_lr)
        
        return y_lr, y_hr

transform = Compose([
    ToTensor()
])

In [30]:
hr_dir = './DIV2K_train_HR/'

# Directories to store low-resolution images
lr_train_dir = './low_resolution_train/'
lr_val_dir = './low_resolution_valid/'
lr_test_dir = './low_resolution_train/test'

# Split high-resolution images into training, validation, and testing sets
hr_images = sorted(os.listdir(hr_dir))
train_split = int(0.7 * len(hr_images))
val_split = int(0.1 * len(hr_images))

hr_train_dir = os.path.join(hr_dir, 'train')
hr_val_dir = os.path.join(hr_dir, 'val')
hr_test_dir = os.path.join(hr_dir, 'test')

os.makedirs(hr_train_dir, exist_ok=True)
os.makedirs(hr_val_dir, exist_ok=True)
os.makedirs(hr_test_dir, exist_ok=True)

In [4]:
for i, img_name in enumerate(hr_images):
    if i < train_split:
        os.rename(os.path.join(hr_dir, img_name), os.path.join(hr_train_dir, img_name))
    elif i < train_split + val_split:
        os.rename(os.path.join(hr_dir, img_name), os.path.join(hr_val_dir, img_name))
    else:
        os.rename(os.path.join(hr_dir, img_name), os.path.join(hr_test_dir, img_name))

# Generate low-resolution images for training, validation, and testing sets
scale_factor = 2
generate_lr_images(hr_train_dir, lr_train_dir, scale_factor)
generate_lr_images(hr_val_dir, lr_val_dir, scale_factor)
generate_lr_images(hr_test_dir, lr_test_dir, scale_factor)

In [31]:
train_dataset = DIV2KDataset(hr_train_dir, lr_train_dir, transform=transform)
val_dataset = DIV2KDataset(hr_val_dir, lr_val_dir, transform=transform)
test_dataset = DIV2KDataset(hr_test_dir, lr_test_dir, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

(tensor([[[0.4275, 0.4275, 0.4275,  ..., 0.2039, 0.2039, 0.2039],
         [0.4314, 0.4275, 0.4275,  ..., 0.2078, 0.2078, 0.2078],
         [0.4275, 0.4275, 0.4314,  ..., 0.2078, 0.2078, 0.2118],
         ...,
         [0.0784, 0.0784, 0.0824,  ..., 0.0667, 0.0706, 0.0745],
         [0.0824, 0.0784, 0.0784,  ..., 0.0627, 0.0745, 0.0667],
         [0.0784, 0.0824, 0.0745,  ..., 0.0588, 0.0706, 0.0706]]]), tensor([[[0.4275, 0.4275, 0.4275,  ..., 0.2078, 0.2039, 0.2039],
         [0.4314, 0.4275, 0.4275,  ..., 0.2078, 0.2078, 0.2078],
         [0.4275, 0.4275, 0.4314,  ..., 0.2078, 0.2078, 0.2078],
         ...,
         [0.0784, 0.0745, 0.0824,  ..., 0.0667, 0.0745, 0.0745],
         [0.0824, 0.0824, 0.0784,  ..., 0.0627, 0.0706, 0.0667],
         [0.0784, 0.0824, 0.0745,  ..., 0.0549, 0.0706, 0.0706]]]))
tensor([[[[0.5922, 0.6196, 0.5882,  ..., 0.5216, 0.4980, 0.5608],
          [0.6431, 0.7098, 0.6824,  ..., 0.5451, 0.4784, 0.4745],
          [0.6588, 0.6353, 0.6000,  ..., 0.5529, 0.54

KeyboardInterrupt: 

In [39]:
class ESPCN(nn.Module):
    def __init__(self, upscale_factor, target_size=(256, 256)):
        super(ESPCN, self).__init__()
        self.upscale_factor = upscale_factor
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, self.upscale_factor ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(self.upscale_factor)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((target_size[0], target_size[1]))
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.pixel_shuffle(self.conv3(x))
        x = self.adaptive_pool(x)
        return x


In [44]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

from tqdm import tqdm

upscale_factor = 2
model = ESPCN(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for lr, hr in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch'):
        lr, hr = lr.to(device), hr.to(device)
        optimizer.zero_grad()
        output = model(lr)
        loss = criterion(output, hr)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}')


Epoch 1/100: 100%|███████████████████████████| 35/35 [01:46<00:00,  3.04s/batch]


Epoch 1/100, Loss: 0.0362


Epoch 2/100: 100%|███████████████████████████| 35/35 [01:46<00:00,  3.04s/batch]


Epoch 2/100, Loss: 0.0040


Epoch 3/100: 100%|███████████████████████████| 35/35 [01:46<00:00,  3.04s/batch]


Epoch 3/100, Loss: 0.0015


Epoch 4/100: 100%|███████████████████████████| 35/35 [01:46<00:00,  3.03s/batch]


Epoch 4/100, Loss: 0.0006


Epoch 5/100: 100%|███████████████████████████| 35/35 [01:46<00:00,  3.04s/batch]


Epoch 5/100, Loss: 0.0003


Epoch 6/100: 100%|███████████████████████████| 35/35 [01:46<00:00,  3.04s/batch]


Epoch 6/100, Loss: 0.0002


Epoch 7/100: 100%|███████████████████████████| 35/35 [01:46<00:00,  3.04s/batch]


Epoch 7/100, Loss: 0.0001


Epoch 8/100: 100%|███████████████████████████| 35/35 [01:46<00:00,  3.04s/batch]


Epoch 8/100, Loss: 0.0002


Epoch 9/100: 100%|███████████████████████████| 35/35 [01:46<00:00,  3.04s/batch]


Epoch 9/100, Loss: 0.0001


Epoch 10/100: 100%|██████████████████████████| 35/35 [01:46<00:00,  3.04s/batch]


Epoch 10/100, Loss: 0.0001


Epoch 11/100: 100%|██████████████████████████| 35/35 [01:46<00:00,  3.03s/batch]


Epoch 11/100, Loss: 0.0001


Epoch 12/100: 100%|██████████████████████████| 35/35 [01:47<00:00,  3.07s/batch]


Epoch 12/100, Loss: 0.0001


Epoch 13/100:  97%|█████████████████████████▎| 34/35 [01:44<00:03,  3.09s/batch]


KeyboardInterrupt: 

In [45]:
model.eval()
with torch.no_grad():
    val_loss = 0
    for lr, hr in tqdm(val_loader, desc='Validation', unit='batch'):
        lr, hr = lr.to(device), hr.to(device)
        output = model(lr)
        loss = criterion(output, hr)
        val_loss += loss.item()
    print(f'Validation Loss: {val_loss/len(val_loader):.4f}')


Validation: 100%|██████████████████████████████| 5/5 [00:15<00:00,  3.00s/batch]

Validation Loss: 0.0001





In [46]:
model.eval()
with torch.no_grad():
    test_loss = 0
    for lr, hr in tqdm(test_loader, desc='Testing', unit='batch'):
        lr, hr = lr.to(device), hr.to(device)
        output = model(lr)
        loss = criterion(output, hr)
        test_loss += loss.item()
    print(f'Test Loss: {test_loss/len(test_loader):.4f}')

Testing: 100%|███████████████████████████████| 10/10 [00:30<00:00,  3.03s/batch]

Test Loss: 0.0001



