# 1: Import required libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import os
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm

# 2: Define the Generator Network

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        return out

class Generator(nn.Module):
    def __init__(self, num_residual_blocks=16):
        super(Generator, self).__init__()
        
        # First conv layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        
        # Residual blocks
        res_blocks = []
        for _ in range(num_residual_blocks):
            res_blocks.append(ResidualBlock(64))
        self.res_blocks = nn.Sequential(*res_blocks)
        
        # Second conv layer
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        
        # Upsampling layers
        self.upsampling = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        
        # Final output layer
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=9, padding=4),
            nn.Tanh()
        )

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out

    def save_checkpoint(self, epoch, optimizer, loss, filename="generator_checkpoint.pth"):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }
        torch.save(checkpoint, filename)
        print(f"Checkpoint saved: {filename}")

    def load_checkpoint(self, filename="generator_checkpoint.pth", optimizer=None):
        if os.path.exists(filename):
            checkpoint = torch.load(filename)
            self.load_state_dict(checkpoint['model_state_dict'])
            if optimizer is not None:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            epoch = checkpoint['epoch']
            loss = checkpoint['loss']
            print(f"Checkpoint loaded: {filename}")
            return epoch, loss
        else:
            print(f"No checkpoint found at {filename}")
            return 0, None

# 3: Define the Discriminator Network

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).view(-1, 1).squeeze(1)

# 4: Custom Dataset Class

In [9]:
class OldPhotoDataset(Dataset):
    def __init__(self, hr_dir, hr_size=256, transform=None):
        self.hr_dir = hr_dir
        self.hr_size = hr_size
        self.transform = transform
        self.image_files = os.listdir(hr_dir)
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        hr_image = Image.open(os.path.join(self.hr_dir, img_name))
        
        # Resize HR image to fixed size
        hr_image = hr_image.resize((self.hr_size, self.hr_size), Image.BICUBIC)
        
        # Create low-res version
        lr_image = hr_image.resize((self.hr_size//4, self.hr_size//4), Image.BICUBIC)
        
        if self.transform:
            hr_image = self.transform(hr_image)
            lr_image = self.transform(lr_image)
            
        return lr_image, hr_image

# 5: Training Function

In [10]:
def train_model(generator, discriminator, train_loader, num_epochs, device, checkpoint_interval=5):
    criterion_GAN = nn.BCELoss()
    criterion_content = nn.MSELoss()
    
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
    
    # Load checkpoint if exists
    start_epoch, _ = generator.load_checkpoint(optimizer=optimizer_G)
    
    for epoch in range(start_epoch, num_epochs):
        running_g_loss = 0.0
        for i, (lr_images, hr_images) in enumerate(tqdm(train_loader)):
            batch_size = lr_images.size(0)
            real_label = torch.ones(batch_size).to(device)
            fake_label = torch.zeros(batch_size).to(device)
            
            lr_images = lr_images.to(device)
            hr_images = hr_images.to(device)
            
            # Train Discriminator
            optimizer_D.zero_grad()
            sr_images = generator(lr_images)
            
            real_output = discriminator(hr_images)
            fake_output = discriminator(sr_images.detach())
            
            d_loss_real = criterion_GAN(real_output, real_label)
            d_loss_fake = criterion_GAN(fake_output, fake_label)
            d_loss = d_loss_real + d_loss_fake
            
            d_loss.backward()
            optimizer_D.step()
            
            # Train Generator
            optimizer_G.zero_grad()
            
            fake_output = discriminator(sr_images)
            content_loss = criterion_content(sr_images, hr_images)
            adversarial_loss = criterion_GAN(fake_output, real_label)
            
            g_loss = content_loss + 0.001 * adversarial_loss
            running_g_loss += g_loss.item()
            
            g_loss.backward()
            optimizer_G.step()
            
            if i % 100 == 0:
                print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(train_loader)}], '
                      f'D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}')
        
        # Save checkpoint at interval
        if (epoch + 1) % checkpoint_interval == 0:
            generator.save_checkpoint(
                epoch + 1,
                optimizer_G,
                running_g_loss / len(train_loader),
                f"generator_checkpoint_epoch_{epoch+1}.pth"
            )

# 6: Setup and Training

In [12]:
# Initialize device and models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Transform for the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Create the dataset and dataloader
dataset = OldPhotoDataset(
    hr_dir='/kaggle/input/photos/DIV2K_train_HR',
    hr_size=256,  # Set fixed size for high-resolution images
    transform=transform
)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)  # Removed collate_fn since we don't need it anymore

Using device: cuda


# 7: Evaluation Function

In [13]:
def evaluate_model(generator, val_loader, device):
    generator.eval()
    with torch.no_grad():
        for i, (lr_img, hr_img) in enumerate(val_loader):
            lr_img = lr_img.to(device)
            sr_img = generator(lr_img)
            
            # Save the super-resolved image
            save_image(sr_img, f'super_resolved_image_{i}.png')
            save_image(hr_img, f'high_res_image_{i}.png')
            
            if i >= 5:  # Save first 6 images
                break
    print("Evaluation complete. Check the saved images.")

# To evaluate after training
evaluate_model(generator, train_loader, device)

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a40c72c5bd0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionErrorException ignored in: : <function _MultiProcessingDataLoaderIter.__del__ at 0x7a40c72c5bd0>can only test a child process

Traceback (most recent call last):
Exception ignored in:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    <function _MultiProcessingDataLoaderIter.__del__ at 0x7a40c72c5bd0>
self._shutdown_workers()Traceback (most recent call last):

  File "/usr/local/lib/pyt

Evaluation complete. Check the saved images.
