In [23]:
#1
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        return residual + out

class Generator(nn.Module):
    def __init__(self, num_res_blocks=8, num_channels=64):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(3, num_channels, kernel_size=9, stride=1, padding=4),
            nn.PReLU()
        )
        self.res_blocks = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_res_blocks)])
        self.middle = nn.Sequential(
            nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_channels)
        )
        self.upsample_blocks = nn.Sequential(
            nn.Conv2d(num_channels, num_channels * 4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(num_channels, num_channels * 4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        self.final = nn.Conv2d(num_channels, 3, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x = self.initial(x)
        residual = x
        x = self.res_blocks(x)
        x = self.middle(x) + residual
        x = self.upsample_blocks(x)
        x = self.final(x)
        return x

class DiscriminatorGlobal(nn.Module):
    def __init__(self):
        super(DiscriminatorGlobal, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        features = [64, 128, 256, 512]
        layers = []
        for i in range(1, len(features)):
            layers.append(nn.Conv2d(features[i-1], features[i], kernel_size=3, stride=1, padding=1))
            layers.append(nn.BatchNorm2d(features[i]))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(features[i], features[i], kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(features[i]))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.middle = nn.Sequential(*layers)
        self.final = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.middle(x)
        x = self.final(x)
        return x

class DiscriminatorLocal(nn.Module):
    def __init__(self):
        super(DiscriminatorLocal, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.final = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.final(x)
        return x


In [24]:
#2
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cpu


In [25]:

!nvidia-smi

Tue May  7 18:37:51 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 512.74       Driver Version: 512.74       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0 Off |                  N/A |
| N/A   32C    P0    14W /  N/A |      0MiB /  4096MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [26]:
#3
from datasets import load_dataset

div2k_dataset = load_dataset("eugenesiow/Div2k")
print(div2k_dataset)

train_data = div2k_dataset["train"]
validation_data = div2k_dataset["validation"]

for sample in train_data[:5]:
    print(sample)


DatasetDict({
    train: Dataset({
        features: ['lr', 'hr'],
        num_rows: 800
    })
    validation: Dataset({
        features: ['lr', 'hr'],
        num_rows: 100
    })
})
lr
hr


In [27]:
# 4
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
from datasets import load_dataset

class Div2kDataset(Dataset):
    def __init__(self, div2k_data, transform=None, target_size=(256, 256), patch_size=24):
        self.div2k_data = div2k_data
        self.transform = transform
        self.target_size = target_size
        self.patch_size = patch_size

    def __len__(self):
        return len(self.div2k_data)

    def __getitem__(self, idx):
        lr_image_path = self.div2k_data[idx]['lr']
        hr_image_path = self.div2k_data[idx]['hr']
        
        # Load images
        lr_image = Image.open(lr_image_path).convert('RGB')
        hr_image = Image.open(hr_image_path).convert('RGB')
        
        # Resize images to target size
        lr_image = lr_image.resize(self.target_size, Image.BICUBIC)
        hr_image = hr_image.resize(self.target_size, Image.BICUBIC)
        
        # Apply transforms if any
        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)

        # Extract patches for local discriminator
        patches = self.extract_patches(hr_image)

        return lr_image, hr_image, patches

    def extract_patches(self, image):
        patches = []
        _, height, width = image.size()
        step = self.patch_size
        for i in range(0, height - step + 1, step):
            for j in range(0, width - step + 1, step):
                patch = image[:, i:i+step, j:j+step]
                patches.append(patch)
        patches = torch.stack(patches)  # Stack patches into a single tensor
        return patches

# Load dataset
div2k_dataset = load_dataset("eugenesiow/Div2k", split='train[:200]')

# Transformations
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to tensor
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize images
])

# Create dataset
train_dataset = Div2kDataset(div2k_dataset, transform=transform, target_size=(96, 96), patch_size=24)
BATCH_SIZE = 1 # Adjusted batch size for better training dynamics
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [28]:
len(train_loader)

200

In [29]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
from datasets import load_dataset

class Div2kDataset(Dataset):
    def __init__(self, div2k_data, lr_transform=None, hr_transform=None, patch_size=24):
        self.div2k_data = div2k_data
        self.lr_transform = lr_transform
        self.hr_transform = hr_transform
        self.patch_size = patch_size

    def __len__(self):
        return len(self.div2k_data)

    def __getitem__(self, idx):
        lr_image_path = self.div2k_data[idx]['lr']
        hr_image_path = self.div2k_data[idx]['hr']

        lr_image = Image.open(lr_image_path).convert('RGB')
        hr_image = Image.open(hr_image_path).convert('RGB')

        if self.lr_transform:
            lr_image = self.lr_transform(lr_image)
        if self.hr_transform:
            hr_image = self.hr_transform(hr_image)

        patches = self.extract_patches(hr_image)

        return lr_image, hr_image, patches

    def extract_patches(self, image):
        patches = []
        c, height, width = image.shape
        step = self.patch_size
        for i in range(0, height - step + 1, step):
            for j in range(0, width - step + 1, step):
                patch = image[:, i:i+step, j:j+step]
                patches.append(patch)
        return torch.stack(patches)

# Load dataset
div2k_dataset = load_dataset("eugenesiow/Div2k", split='train[:200]')

# Define transformations
lr_transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

hr_transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create dataset instance with transformations
train_dataset = Div2kDataset(div2k_dataset, lr_transform=lr_transform, hr_transform=hr_transform, patch_size=24)
BATCH_SIZE = 1
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Checking dimensions
lr_example, hr_example, patches = next(iter(train_loader))
print("LR image dimensions:", lr_example.shape)
print("HR image dimensions:", hr_example.shape)
print("Patch dimensions:", patches.shape[1:], "Number of patches:", patches.shape[0])


LR image dimensions: torch.Size([1, 3, 96, 96])
HR image dimensions: torch.Size([1, 3, 384, 384])
Patch dimensions: torch.Size([256, 3, 24, 24]) Number of patches: 1


In [30]:
len(train_loader)

200

In [31]:
# 6
import torch
import torch.nn as nn
import torchvision.models as models

class VGGFeatureExtractor(nn.Module):
    def __init__(self, feature_layer=34, use_bn=False):
        super(VGGFeatureExtractor, self).__init__()
        # Load pretrained VGG19 model
        vgg = models.vgg19(pretrained=True)
        # Use features up to the selected layer to act as our feature extractor
        self.feature_extractor = nn.Sequential(*list(vgg.features.children())[:feature_layer + 1])
        self.use_bn = use_bn
        # If batch normalization is used, initialize mean and std
        if use_bn:
            self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
            self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

    def forward(self, x):
        # Normalize the image batch using the same normalization used in training VGG if batch normalization is enabled
        if self.use_bn:
            x = (x - self.mean.to(x.device)) / self.std.to(x.device)
        # Extract features
        return self.feature_extractor(x)

# Ensure that all components are sent to the same device as the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = VGGFeatureExtractor(feature_layer=34, use_bn=True).to(device)
feature_extractor.eval()  # Set to eval mode to freeze batch norm layers and dropout layers




VGGFeatureExtractor(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF

# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
generator = Generator().to(device)
discriminator_global = DiscriminatorGlobal().to(device)
discriminator_local = DiscriminatorLocal().to(device)
feature_extractor = VGGFeatureExtractor(feature_layer=34, use_bn=False).to(device)
feature_extractor.eval()  # Ensure the feature extractor is not in training mode

# Define loss functions
criterion_bce = nn.BCEWithLogitsLoss()
criterion_pixel = nn.L1Loss()
criterion_feature = nn.MSELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_global = optim.Adam(discriminator_global.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_local = optim.Adam(discriminator_local.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training parameters
num_epochs = 100
adversarial_loss_weight = 1e-3
pixel_loss_weight = 1
feature_loss_weight = 1e-3
log_interval = 100
save_interval = 2
checkpoint_path = r'C:\Users\juver\Untitled Folder 1\model_checkpoint3.pth'

def extract_patches(images, patch_size=24):
    patches = []
    batch_size, channels, height, width = images.size()
    for i in range(0, height - patch_size + 1, patch_size):
        for j in range(0, width - patch_size + 1, patch_size):
            patches.append(images[:, :, i:i+patch_size, j:j+patch_size])
    return torch.stack(patches, dim=0)

for epoch in range(num_epochs):
    generator.train()
    discriminator_global.train()
    discriminator_local.train()

    for batch_idx, (lr_images, hr_images, _) in enumerate(train_loader):
        lr_images, hr_images = lr_images.to(device), hr_images.to(device)
        
        # Generate high-resolution images from low-resolution images
        fake_hr_images = generator(lr_images)

        # Global Discriminator Update
        optimizer_D_global.zero_grad()
        real_output_global = discriminator_global(hr_images)
        fake_output_global = discriminator_global(fake_hr_images.detach())
        loss_D_global = (criterion_bce(real_output_global, torch.ones_like(real_output_global)) +
                         criterion_bce(fake_output_global, torch.zeros_like(fake_output_global))) / 2
        loss_D_global.backward()
        optimizer_D_global.step()

        # Local Discriminator Update
        hr_patches = extract_patches(hr_images)
        fake_hr_patches = extract_patches(fake_hr_images.detach())
        optimizer_D_local.zero_grad()
        loss_D_local = 0
        num_patches = hr_patches.shape[0]
        for i in range(num_patches):
            real_output_local = discriminator_local(hr_patches[i])
            fake_output_local = discriminator_local(fake_hr_patches[i])
            loss_D_local += (criterion_bce(real_output_local, torch.ones_like(real_output_local)) +
                             criterion_bce(fake_output_local, torch.zeros_like(fake_output_local))) / num_patches
        loss_D_local.backward()
        optimizer_D_local.step()

        # Generator Update
        optimizer_G.zero_grad()
        fake_output_global = discriminator_global(fake_hr_images)
        adversarial_loss = criterion_bce(fake_output_global, torch.ones_like(fake_output_global))
        pixel_loss = criterion_pixel(fake_hr_images, hr_images)

        # Compute feature loss using VGG
        real_features = feature_extractor(hr_images)
        fake_features = feature_extractor(fake_hr_images)
        feature_loss = criterion_feature(fake_features, real_features)

        total_loss = adversarial_loss_weight * adversarial_loss + pixel_loss_weight * pixel_loss + feature_loss_weight * feature_loss
        total_loss.backward()
        optimizer_G.step()

        if batch_idx % log_interval == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], "
                  f"Generator Loss: {total_loss.item():.4f}, Discriminator Global Loss: {loss_D_global.item():.4f}, "
                  f"Discriminator Local Loss: {loss_D_local.item():.4f}")

    # Save model checkpoint
    if (epoch + 1) % save_interval == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_global_state_dict': discriminator_global.state_dict(),
            'discriminator_local_state_dict': discriminator_local.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_global_state_dict': optimizer_D_global.state_dict(),
            'optimizer_D_local_state_dict': optimizer_D_local.state_dict(),
            'total_loss': total_loss
        }, checkpoint_path)


Epoch [1/100], Batch [1/200], Generator Loss: 0.5371, Discriminator Global Loss: 0.7200, Discriminator Local Loss: 1.4484
Epoch [1/100], Batch [101/200], Generator Loss: 0.1205, Discriminator Global Loss: 0.5893, Discriminator Local Loss: 1.3082
Epoch [2/100], Batch [1/200], Generator Loss: 0.1756, Discriminator Global Loss: 0.5034, Discriminator Local Loss: 1.2778
Epoch [2/100], Batch [101/200], Generator Loss: 0.1471, Discriminator Global Loss: 0.5034, Discriminator Local Loss: 1.2119
Epoch [3/100], Batch [1/200], Generator Loss: 0.0620, Discriminator Global Loss: 0.5032, Discriminator Local Loss: 1.1855
Epoch [3/100], Batch [101/200], Generator Loss: 0.0824, Discriminator Global Loss: 0.5032, Discriminator Local Loss: 1.1514
Epoch [4/100], Batch [1/200], Generator Loss: 0.0926, Discriminator Global Loss: 0.6410, Discriminator Local Loss: 1.0683
Epoch [4/100], Batch [101/200], Generator Loss: 0.0638, Discriminator Global Loss: 0.5032, Discriminator Local Loss: 1.0825
Epoch [5/100], B