### Model Definition

In [1]:
import torch
import torch.nn as nn

# GENERATOR
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_filters=64):
        super(UNetGenerator, self).__init__()

        # Encoder
        self.encoder1 = self._block(in_channels, num_filters)
        self.encoder2 = self._block(num_filters, num_filters * 2)
        self.encoder3 = self._block(num_filters * 2, num_filters * 4)
        self.encoder4 = self._block(num_filters * 4, num_filters * 8)

        # Bottleneck
        self.bottleneck = self._block(num_filters * 8, num_filters * 16, dilation=2)

        # Decoder with skip connections
        self.decoder1 = self._block(num_filters * 16 + num_filters * 8, num_filters * 8, dilation=2)
        self.decoder2 = self._block(num_filters * 8 + num_filters * 4, num_filters * 4)
        self.decoder3 = self._block(num_filters * 4 + num_filters * 2, num_filters * 2)
        self.decoder4 = self._block(num_filters * 2 + num_filters, num_filters)

        # Final layer
        self.final = nn.Conv2d(num_filters, out_channels, kernel_size=1)

    def _block(self, in_channels, out_channels, dilation=1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)
        e2 = self.encoder2(nn.MaxPool2d(2)(e1))
        e3 = self.encoder3(nn.MaxPool2d(2)(e2))
        e4 = self.encoder4(nn.MaxPool2d(2)(e3))

        # Bottleneck
        bottleneck = self.bottleneck(nn.MaxPool2d(2)(e4))

        # Decoder with skip connections
        d1 = self.decoder1(torch.cat([nn.Upsample(scale_factor=2)(bottleneck), e4], dim=1))
        d2 = self.decoder2(torch.cat([nn.Upsample(scale_factor=2)(d1), e3], dim=1))
        d3 = self.decoder3(torch.cat([nn.Upsample(scale_factor=2)(d2), e2], dim=1))
        d4 = self.decoder4(torch.cat([nn.Upsample(scale_factor=2)(d3), e1], dim=1))

        # Final layer
        return torch.tanh(self.final(d4))
    
# DISCRIMINATOR
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, num_filters=64):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(in_channels, num_filters, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters, num_filters * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_filters * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters * 2, num_filters * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_filters * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters * 4, num_filters * 8, kernel_size=4, stride=2, padding=2, dilation=2),
            nn.BatchNorm2d(num_filters * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters * 8, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=3, num_filters=64):
        super(PatchGANDiscriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(in_channels, num_filters, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters, num_filters * 2, kernel_size=4, stride=2, padding=2, dilation=2),
            nn.BatchNorm2d(num_filters * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters * 2, num_filters * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_filters * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters * 4, num_filters * 8, kernel_size=4, stride=1, padding=1),  # No stride
            nn.BatchNorm2d(num_filters * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters * 8, 1, kernel_size=4, stride=1, padding=1),  # Output a matrix
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

### Data Preparation

In [2]:
import h5py
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class ArtRestorationDataset(Dataset):
    def __init__(self, hdf5_file, transform=None):
        """
        Args:
            hdf5_file (str): Path to the HDF5 file.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.hdf5_file = hdf5_file
        self.transform = transform
        # Open the HDF5 file and keep it open
        # self.hdf = h5py.File(self.hdf5_file, "r")
        # self.damaged_images = self.hdf["damaged_paintings"]
        # self.ground_truth_images = self.hdf["initial"]

    def __len__(self):
        with h5py.File(self.hdf5_file,'r') as hdf:
            lens = len(hdf["damaged_paintings"])
        return lens

    def __getitem__(self, idx):
        # Load images from HDF5 file
        with h5py.File(self.hdf5_file,'r') as hdf:
            damaged_img = hdf["damaged_paintings"][idx]
            ground_truth_img = hdf["initial"][idx]

        # Convert numpy arrays to PIL Images (optional, depending on your transforms)
        damaged_img = transforms.ToPILImage()(damaged_img)
        ground_truth_img = transforms.ToPILImage()(ground_truth_img)

        # Apply transforms if provided
        if self.transform:
            damaged_img = self.transform(damaged_img)
            ground_truth_img = self.transform(ground_truth_img)

        return damaged_img, ground_truth_img
        
    # def __del__(self):
    #     # Close the HDF5 file when the dataset is deleted
    #     if hasattr(self, "hdf"):
    #         self.hdf.close()

In [None]:
def convert_bytes(size_bytes):
    """
    Convert bytes to a human-readable format (KB, MB, GB, etc.).
    """
    # Define the size units
    units = ["B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"]
    
    # Determine the appropriate unit
    unit_index = 0
    while size_bytes >= 1024 and unit_index < len(units) - 1:
        size_bytes /= 1024
        unit_index += 1
    
    # Return the formatted string
    return f"{size_bytes:.2f} {units[unit_index]}"

def estimate_vram_usage(batch_size=16, scale=128, channels=3, bytes=4): 
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters())

    generator_params = count_parameters(UNetGenerator())
    global_discriminator_params = count_parameters(Discriminator())
    local_discriminator_params = count_parameters(Discriminator())

    total_params = generator_params + global_discriminator_params + local_discriminator_params
    memory_params = total_params * bytes
    memory_act = batch_size * scale**2 * 2_560 * 10
    memory_grad = total_params * bytes
    memory_opt = 2 * total_params * bytes
    memory_data = batch_size * channels * scale**2 * bytes

    total_memory = memory_params + memory_act + memory_grad + memory_opt + memory_data

    return total_memory

print(f"Total VRAM: {convert_bytes(estimate_vram_usage())}")

### Training Setup

In [4]:
from tqdm import tqdm
import numpy as np
import torch_directml
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader

def training(scale=128, 
            batch_size=16, 
            num_epochs=100, 
            dataset_path="./dataset_2.hdf5", 
            checkpoint_dir = "./checkpoints", 
            device=torch_directml.device(), 
            steps = 50,
            gan_lr=0.0002, 
            gan_betas=(0.5, 0.999), 
            dg_lr=0.0002, 
            dg_betas=(0.5, 0.999), 
            dl_lr=0.0002,
            dl_betas=(0.5, 0.999),
            save_old_checkpoints=False
             ):
    
    # Adversarial loss (Binary Cross-Entropy)
    adversarial_loss = nn.BCELoss().to(device)
    # adversarial_loss = nn.MSELoss().to(device)

    # Reconstruction loss (L1 loss)
    reconstruction_loss = nn.L1Loss().to(device)

    # Initialize models
    generator = UNetGenerator().to(device)
    global_discriminator = Discriminator().to(device)
    local_discriminator = PatchGANDiscriminator().to(device)
    torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
    torch.nn.utils.clip_grad_norm_(global_discriminator.parameters(), max_norm=1.0)
    torch.nn.utils.clip_grad_norm_(local_discriminator.parameters(), max_norm=1.0)

    def weights_init(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)

    generator.apply(weights_init)
    global_discriminator.apply(weights_init)
    local_discriminator.apply(weights_init)

    dummy_input = torch.randn(batch_size, 3, scale, scale).to(device)
    with torch.no_grad():
        output_global = global_discriminator(dummy_input)
        output_local = local_discriminator(dummy_input)

    dg_ouput_shape = output_global.shape[2:]
    dl_ouput_shape = output_local.shape[2:]


    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=gan_lr, betas=gan_betas)
    optimizer_D_global = torch.optim.Adam(global_discriminator.parameters(), lr=dg_lr, betas=dg_betas)
    optimizer_D_local = torch.optim.Adam(local_discriminator.parameters(), lr=dl_lr, betas=dl_betas)

    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((scale, scale)),
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
    ])

    # Create dataset
    dataset = ArtRestorationDataset(hdf5_file=dataset_path, transform=transform)

    # Create DataLoader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    # Define checkpoint directory
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Function to find the latest checkpoint
    def find_latest_checkpoint(checkpoint_dir, save_old_checkpoints=False):
        checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_epoch_") and f.endswith(".pth")]
        if not checkpoints:
            return None
        # Extract epoch numbers and find the latest one
        latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("_")[2].split(".")[0]))

        if len(checkpoints) > 1 and not save_old_checkpoints: 
            checkpoints.remove(latest_checkpoint)
            lastfile = os.path.join(checkpoint_dir, checkpoints[0])
            if os.path.exists(lastfile):
                os.remove(lastfile)

        return os.path.join(checkpoint_dir, latest_checkpoint)

    if True: 
        # Find the latest checkpoint
        latest_checkpoint = find_latest_checkpoint(checkpoint_dir, save_old_checkpoints)

        start_epoch = 0

        # Load the latest checkpoint if it exists
        if latest_checkpoint:
            print(f"Loading checkpoint: {latest_checkpoint}")
            checkpoint = torch.load(latest_checkpoint)
            generator.load_state_dict(checkpoint["generator_state_dict"])
            global_discriminator.load_state_dict(checkpoint["global_discriminator_state_dict"])
            local_discriminator.load_state_dict(checkpoint["local_discriminator_state_dict"])
            optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
            optimizer_D_global.load_state_dict(checkpoint["optimizer_D_global_state_dict"])
            optimizer_D_local.load_state_dict(checkpoint["optimizer_D_local_state_dict"])
            start_epoch = checkpoint["epoch"]
            print(f"Resuming training from epoch {start_epoch}")

        # Training loop with tqdm and checkpointing
        for epoch in range(start_epoch, num_epochs):
            # Wrap the dataloader with tqdm for progress visualization
            progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}")

            for i, (damaged_imgs, ground_truth_imgs) in progress_bar:
                damaged_imgs: torch.Tensor = damaged_imgs.to(device)
                ground_truth_imgs = ground_truth_imgs.to(device)

                # Adversarial ground truths
                valid = torch.ones(damaged_imgs.size(0), 1, *dg_ouput_shape).to(device)
                fake = torch.zeros(damaged_imgs.size(0), 1, *dg_ouput_shape).to(device)
                valid_local = torch.ones(damaged_imgs.size(0), 1, *dl_ouput_shape).to(device)
                fake_local  = torch.zeros(damaged_imgs.size(0), 1, *dl_ouput_shape).to(device)

                # if i < 50 or i&0:
                # Train Discriminators
                optimizer_D_global.zero_grad()
                optimizer_D_local.zero_grad()
                
                # Generate restored images
                restored_fake = generator(damaged_imgs)

                # Global discriminator loss
                real_loss_global = adversarial_loss(global_discriminator(ground_truth_imgs), valid)
                fake_loss_global = adversarial_loss(global_discriminator(restored_fake.detach()), fake)
                d_loss_global = (real_loss_global + fake_loss_global) / 2

                # Local discriminator loss
                real_loss_local = adversarial_loss(local_discriminator(ground_truth_imgs), valid_local)
                fake_loss_local = adversarial_loss(local_discriminator(restored_fake.detach()), fake_local)
                d_loss_local = (real_loss_local + fake_loss_local) / 2
                
                # Average the PatchGAN loss over all patches
                d_loss_local_avg = d_loss_local.mean()

                # Total discriminator loss
                d_loss_global.backward()
                d_loss_local.backward()
                optimizer_D_global.step()
                optimizer_D_local.step()
                optimizer_G.zero_grad()

                only_tot = i >= steps * 3
                only_loc = i >= steps * 2 and not only_tot
                only_glb = i >= steps * 1 and not only_loc
                only_rec = i >= steps * 0 and not only_glb

                # Adversarial loss
                g_loss_adv_global = 0 if not only_glb else adversarial_loss(global_discriminator(restored_fake), valid)
                g_loss_adv_local = 0 if not only_loc else adversarial_loss(local_discriminator(restored_fake), valid_local)

                # Reconstruction loss
                g_loss_rec = 0 if not only_rec else reconstruction_loss(restored_fake, ground_truth_imgs)

                # Compute total generator loss
                # g_loss = g_loss_adv_global + g_loss_adv_local
                g_loss = g_loss_adv_global + g_loss_adv_local + g_loss_rec

                # total_loss = g_loss_adv_global + g_loss_adv_local + g_loss_rec
                # weight_adv_global = g_loss_adv_global / total_loss
                # weight_adv_local = g_loss_adv_local / total_loss
                # weight_rec = g_loss_rec / total_loss
                # g_loss = weight_adv_global * g_loss_adv_global + weight_adv_local * g_loss_adv_local + weight_rec * g_loss_rec

                # Backward pass with scaling
                g_loss.backward()
                optimizer_G.step()

                # Show loss progress
                progress_bar.set_postfix({
                    "DG Loss": d_loss_global.item(),
                    "DL Loss": d_loss_local_avg.item(),
                    "G  Loss": g_loss.item()
                })

            # Save checkpoints
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pth")
            torch.save({
                "epoch": epoch + 1,
                "generator_state_dict": generator.state_dict(),
                "global_discriminator_state_dict": global_discriminator.state_dict(),
                "local_discriminator_state_dict": local_discriminator.state_dict(),
                "optimizer_G_state_dict": optimizer_G.state_dict(),
                "optimizer_D_global_state_dict": optimizer_D_global.state_dict(),
                "optimizer_D_local_state_dict": optimizer_D_local.state_dict(),
            }, checkpoint_path)

            print(f"Checkpoint saved at {checkpoint_path}")

### Training

In [5]:
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from PIL import Image

for steps in [30, 40, 50]:
    # steps          = 20
    trial          = f"{steps}-step-warm-up"
    scale          = 64
    batch_size     = 32
    gan_lr         = 0.0002
    gan_betas      = (0.5, 0.9)
    dl_lr          = 0.0002
    dl_betas       = (0.5, 0.9)
    dg_lr          = 0.0002
    dg_betas       = (0.5, 0.9)
    num_epochs     = 400
    dataset_scale  = scale
    dataset_path   = f"./dataset_{dataset_scale}.hdf5"
    checkpoint_dir = f"trial_scale_{scale}_{trial}/"
    device = torch_directml.device()

    training(
        scale=scale,
        batch_size=batch_size,
        num_epochs=num_epochs,
        dataset_path=dataset_path,
        checkpoint_dir=checkpoint_dir,
        device=device,
        steps=steps,
        save_old_checkpoints=False
    )
    

In [None]:


# Load the trained Generator
checkpoint = torch.load(f"./{checkpoint_dir}checkpoint_epoch_{num_epochs}.pth")
generator = UNetGenerator().to(device)  # Replace with your Generator class
generator.load_state_dict(checkpoint["generator_state_dict"])
generator.eval()  # Set to evaluation mode


transform = transforms.Compose(
    ([transforms.Resize((scale, scale))] * (dataset_scale != scale)) + # Only scale, if dataset_scale > scale
    [transforms.ToTensor(),          # Convert to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Load the dataset
dataset = ArtRestorationDataset(hdf5_file=dataset_path, transform=transform)  # Replace with your dataset
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# Pick a random sample
random_index = np.random.randint(0, len(dataset))
damaged_img, ground_truth_img = dataset[1495]

# Move to device and add batch dimension
damaged_img = damaged_img.unsqueeze(0).to(device)

# Generate restored image
with torch.no_grad():  # Disable gradient calculation
    restored_img = generator(damaged_img)

# Convert tensors to numpy arrays for visualization
ground_truth_img = ground_truth_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
damaged_img = damaged_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
restored_img = restored_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)

# Clip and normalize images to [0, 1] range
damaged_img = np.clip((damaged_img + 1) / 2, 0, 1)  # Assuming images are normalized to [-1, 1]
ground_truth_img = np.clip((ground_truth_img + 1) / 2, 0, 1)
restored_img = np.clip((restored_img + 1) / 2, 0, 1)

with_upscaling = True

# Upscale the restored image to the original size (e.g., 512x512)
damaged_img_upscaled = np.array(Image.fromarray((damaged_img * 255).astype(np.uint8)).resize((512, 512), Image.BILINEAR)) / 255
ground_truth_img_upscaled = np.array(Image.fromarray((ground_truth_img * 255).astype(np.uint8)).resize((512, 512), Image.BILINEAR)) / 255
restored_img_upscaled = np.array(Image.fromarray((restored_img * 255).astype(np.uint8)).resize((512, 512), Image.BILINEAR)) / 255

# Display images

num_imgs = 5 if with_upscaling else 3
print(f"Model: {checkpoint_dir}checkpoint_epoch_{num_epochs} | Dataset: {dataset_path} | Sample Index: {random_index}")
plt.figure(figsize=(15, 5))

plt.subplot(1, num_imgs, 1)
plt.title(f"Damaged Image ({scale}x{scale})")
plt.imshow(damaged_img)
# plt.imshow(damaged_img_upscaled)
plt.axis("off")

plt.subplot(1, num_imgs, 2)
plt.title(f"Restored Image ({scale}x{scale})")
# plt.imshow(restored_img_upscaled)
plt.imshow(restored_img)
plt.axis("off")

plt.subplot(1, num_imgs, 3)
plt.title(f"Ground Truth Image ({scale}x{scale})")
# plt.imshow(ground_truth_img_upscaled)
plt.imshow(ground_truth_img)
plt.axis("off")


plt.show()