In [1]:
!pip install kagglehub



In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("mariaherrerot/aptos2019")

print("Path to dataset files:", path)

Path to dataset files: /teamspace/studios/this_studio/.cache/kagglehub/datasets/mariaherrerot/aptos2019/versions/3


In [3]:
!pip install torchvision --upgrade
# # # or for conda users:
# # conda update torchvision -c pytorch
!pip install opencv-python



In [4]:
import os
import time
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.nn import SyncBatchNorm
from torch.nn.parallel import DataParallel
import torchvision.transforms.functional as TF
from PIL import Image
import cv2

In [6]:
# =================================================================
# STEP 1: PATHS TO CONFIGURE (UPDATE THESE!)
# =================================================================
# CSV paths

TRAIN_CSV = "/teamspace/studios/this_studio/.cache/kagglehub/datasets/mariaherrerot/aptos2019/versions/3/train_1.csv"         # CSV with training image names (id_code column)
VAL_CSV = "/teamspace/studios/this_studio/.cache/kagglehub/datasets/mariaherrerot/aptos2019/versions/3/valid.csv"      # CSV with validation image names
TEST_CSV = "/teamspace/studios/this_studio/.cache/kagglehub/datasets/mariaherrerot/aptos2019/versions/3/test.csv"           # CSV with test image names

# Image directories
TRAIN_IMG_DIR = "/teamspace/studios/this_studio/.cache/kagglehub/datasets/mariaherrerot/aptos2019/versions/3/train_images/train_images"         # Folder with training images
VAL_IMG_DIR = "/teamspace/studios/this_studio/.cache/kagglehub/datasets/mariaherrerot/aptos2019/versions/3/val_images/val_images"      # Folder with validation images
TEST_IMG_DIR = "/teamspace/studios/this_studio/.cache/kagglehub/datasets/mariaherrerot/aptos2019/versions/3/test_images/test_images"           # Folder with test images

OUTPUT_DIR = "/teamspace/studios/this_studio/generated_hr_output"     # Where HR images will be saved
# =================================================================

# Hyperparameters
BATCH_SIZE = 32                                        # Increased for multi-GPU
LR_SIZE = 64                                           # Low-res input size
HR_SIZE = 128                                          # High-res target size
EPOCHS = 500                                           # Number of epochs
NUM_GPUS = torch.cuda.device_count()                   # Number of available GPUs

# Initialize multi-GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True  # Optimizes for fixed input size
print(device)
print("CPU:", os.cpu_count())
print("GPU:", NUM_GPUS)

cuda
CPU: 32
GPU: 1


In [7]:
# Custom Dataset Class
# Optimized Custom Dataset Class
class FundusDataset(Dataset):
    def __init__(self, csv_path, img_dir):
        self.df = pd.read_csv(csv_path)
        self.img_paths = [os.path.join(img_dir, f"{img_id}.png") for img_id in self.df["id_code"]]
        print(f"Loaded dataset with {len(self.df)} samples from {img_dir}")

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]

        # Check if file exists before loading
        if not os.path.exists(img_path):
            print(f"Warning: Missing file {img_path}")
            return torch.zeros(3, LR_SIZE, LR_SIZE), torch.zeros(3, HR_SIZE, HR_SIZE)

        # Load image using OpenCV (faster than PIL)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert from BGR to RGB

        # Resize images efficiently
        lr_image = cv2.resize(image, (LR_SIZE, LR_SIZE), interpolation=cv2.INTER_CUBIC)
        hr_image = cv2.resize(image, (HR_SIZE, HR_SIZE), interpolation=cv2.INTER_CUBIC)

        # Convert to tensor & normalize
        lr_tensor = TF.to_tensor(lr_image)
        hr_tensor = TF.to_tensor(hr_image)

        lr_tensor = TF.normalize(lr_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        hr_tensor = TF.normalize(hr_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        return lr_tensor, hr_tensor


In [8]:
# Generator with Multi-GPU Support
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 9, padding=4), nn.PReLU(),
            *[ResidualBlock(64) for _ in range(16)],
            nn.Conv2d(64, 64, 3, padding=1), SyncBatchNorm(64),
            nn.Conv2d(64, 256, 3, padding=1), nn.PixelShuffle(2), nn.PReLU(),
            nn.Conv2d(64, 3, 9, padding=4), nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)


In [9]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return x + self.conv(x)

In [10]:
# Optimized Discriminator with SyncBatchNorm
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=1, padding=1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.SyncBatchNorm(64), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.SyncBatchNorm(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.SyncBatchNorm(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.SyncBatchNorm(256), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.SyncBatchNorm(256), nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1), nn.Conv2d(256, 512, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 1), nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x).view(-1)

In [11]:
def train_gan():
    # Initialize datasets
    train_dataset = FundusDataset(TRAIN_CSV, TRAIN_IMG_DIR)
    val_dataset = FundusDataset(VAL_CSV, VAL_IMG_DIR)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                              shuffle=True, num_workers=os.cpu_count(), pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                            num_workers=os.cpu_count(), pin_memory=True)

    # Initialize models with DataParallel
    G = DataParallel(Generator().to(device), device_ids=range(NUM_GPUS))
    D = DataParallel(Discriminator().to(device), device_ids=range(NUM_GPUS))

    opt_G = optim.Adam(G.parameters(), lr=1e-4)
    opt_D = optim.Adam(D.parameters(), lr=1e-4)

    # Loss functions
    criterion_mse = nn.MSELoss()
    criterion_bce = nn.BCEWithLogitsLoss()
    vgg = nn.Sequential(*list(torch.hub.load('pytorch/vision', 'vgg19', pretrained=True).features)[:35])
    vgg = DataParallel(vgg.to(device), device_ids=range(NUM_GPUS)).eval()

    scaler = torch.amp.GradScaler()  # Mixed Precision for speedup

    best_val_loss = float('inf')
    print(f"\nStarting training on {NUM_GPUS} GPUs with batch size {BATCH_SIZE}")

    for epoch in range(EPOCHS):
        start_time = time.time()
        G.train()
        D.train()

        # Training loop
        for batch_idx, (lr, hr) in enumerate(train_loader):
            lr, hr = lr.to(device, non_blocking=True), hr.to(device, non_blocking=True)

            # Train Discriminator
            with torch.amp.autocast(device_type="cuda"):
                real_preds = D(hr)
                real_loss = criterion_bce(real_preds, torch.ones_like(real_preds))

                fake_hr = G(lr)
                fake_preds = D(fake_hr.detach())
                fake_loss = criterion_bce(fake_preds, torch.zeros_like(fake_preds))

                loss_D = (real_loss + fake_loss) / 2

            scaler.scale(loss_D).backward()
            scaler.step(opt_D)  # Step before zero_grad
            opt_D.zero_grad(set_to_none=True)

            # Train Generator
            with torch.amp.autocast(device_type="cuda"):
                fake_preds = D(fake_hr)
                loss_G_adv = criterion_bce(fake_preds, torch.ones_like(fake_preds))
                loss_G_mse = criterion_mse(fake_hr, hr)
                loss_G_percep = criterion_mse(vgg(fake_hr), vgg(hr))
                loss_G = loss_G_adv + 0.001 * loss_G_percep + 0.006 * loss_G_mse

            scaler.scale(loss_G).backward()
            scaler.step(opt_G)
            opt_G.zero_grad(set_to_none=True)
            scaler.update()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1} | Batch {batch_idx}/{len(train_loader)} | "
                      f"D Loss: {loss_D.item():.4f} | G Loss: {loss_G.item():.4f}")

        # Validation
        G.eval()
        val_loss = 0.0
        with torch.no_grad():
            for lr_val, hr_val in val_loader:
                lr_val = lr_val.to(device, non_blocking=True)
                hr_val = hr_val.to(device, non_blocking=True)
                fake_hr_val = G(lr_val)
                val_loss += criterion_mse(fake_hr_val, hr_val).item()

        val_loss /= len(val_loader)
        epoch_time = time.time() - start_time

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            traced_G = torch.jit.trace(G.module, torch.randn(1, 3, LR_SIZE, LR_SIZE, device=device))
            traced_G.save("best_generator.pth")
            print(f"New best model saved with val loss: {val_loss:.4f}")

        print(f"Epoch {epoch+1}/{EPOCHS} completed in {epoch_time:.2f}s | "
              f"Val Loss: {val_loss:.4f} | GPU Mem: {torch.cuda.memory_allocated()/1e9:.2f}GB")


In [12]:
def generate_hr_images():
    G = Generator().to(device)
    
    # Load checkpoint
    checkpoint = torch.load("best_generator.pth", map_location=device, weights_only=False)
    
    # Diagnostic print to understand checkpoint structure
    print(f"Checkpoint type: {type(checkpoint)}")
    
    # Attempt multiple extraction strategies
    try:
        # Strategy 1: Direct state dict extraction if it's a scripted module
        if hasattr(checkpoint, '_actual_script_module'):
            state_dict = checkpoint._actual_script_module.state_dict()
        # Strategy 2: Try getting state dict directly
        elif hasattr(checkpoint, 'state_dict'):
            state_dict = checkpoint.state_dict()
        # Strategy 3: Assume dictionary-like structure
        else:
            state_dict = checkpoint.get('state_dict', checkpoint)
        
        # Clean state dict keys
        new_state_dict = {}
        for k, v in state_dict.items():
            clean_key = k[7:] if k.startswith('module.') else k
            new_state_dict[clean_key] = v
        
        # Load cleaned state dict
        G.load_state_dict(new_state_dict)
        G.eval()
        
    except Exception as e:
        print(f"Error extracting state dict: {e}")
        # If all else fails, you might need to manually inspect the checkpoint
        raise

    # Load test dataset
    test_dataset = FundusDataset(TEST_CSV, TEST_IMG_DIR)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE*2, 
                            num_workers=os.cpu_count(), pin_memory=True)
    
    print(f"\nGenerating HR images for {len(test_dataset)} test samples...")
    start_time = time.time()
    
    with torch.no_grad():
        for batch_idx, (lr, _) in enumerate(test_loader):
            lr = lr.to(device, non_blocking=True)
            fake_hr = G(lr)
            fake_hr = (fake_hr * 0.5 + 0.5).clamp(0, 1)  # Denormalize
            
            for img_idx in range(fake_hr.size(0)):
                img_name = test_dataset.df.iloc[batch_idx*BATCH_SIZE + img_idx]['id_code'] + "_hr.png"
                img_path = os.path.join(OUTPUT_DIR, img_name)
                transforms.ToPILImage()(fake_hr[img_idx].cpu()).save(img_path)
            
            print(f"Generated batch {batch_idx+1}/{len(test_loader)}")
    
    total_time = time.time() - start_time
    print(f"Generation completed in {total_time:.2f}s | Output: {OUTPUT_DIR}")

In [13]:
if __name__ == "__main__":
    with torch.no_grad():
        torch.cuda.memory_allocated()
        torch.cuda.memory_reserved()

    train_gan()
    generate_hr_images()

Loaded dataset with 2930 samples from /teamspace/studios/this_studio/.cache/kagglehub/datasets/mariaherrerot/aptos2019/versions/3/train_images/train_images
Loaded dataset with 366 samples from /teamspace/studios/this_studio/.cache/kagglehub/datasets/mariaherrerot/aptos2019/versions/3/val_images/val_images


Downloading: "https://github.com/pytorch/vision/zipball/main" to /home/zeus/.cache/torch/hub/main.zip
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /home/zeus/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:01<00:00, 317MB/s] 



Starting training on 1 GPUs with batch size 32
Epoch 1 | Batch 0/92 | D Loss: 0.7237 | G Loss: 0.4960
New best model saved with val loss: 0.0764
Epoch 1/500 completed in 32.30s | Val Loss: 0.0764 | GPU Mem: 0.22GB
Epoch 2 | Batch 0/92 | D Loss: 0.8052 | G Loss: 0.3260
New best model saved with val loss: 0.0459
Epoch 2/500 completed in 28.05s | Val Loss: 0.0459 | GPU Mem: 0.22GB
Epoch 3 | Batch 0/92 | D Loss: 0.8096 | G Loss: 0.3193
New best model saved with val loss: 0.0156
Epoch 3/500 completed in 27.50s | Val Loss: 0.0156 | GPU Mem: 0.22GB
Epoch 4 | Batch 0/92 | D Loss: 0.8117 | G Loss: 0.3164
New best model saved with val loss: 0.0127
Epoch 4/500 completed in 27.72s | Val Loss: 0.0127 | GPU Mem: 0.22GB
Epoch 5 | Batch 0/92 | D Loss: 0.8123 | G Loss: 0.3155
New best model saved with val loss: 0.0123
Epoch 5/500 completed in 27.61s | Val Loss: 0.0123 | GPU Mem: 0.22GB
Epoch 6 | Batch 0/92 | D Loss: 0.8126 | G Loss: 0.3150
New best model saved with val loss: 0.0111
Epoch 6/500 complet

KeyboardInterrupt: 