# ConvKAN Super-Resolution Training Notebook
This is the Jupyter Notebook version of `train_convkan_superres.py`.
Includes:
- Environment and dependency notes
- Dataset class `SuperResolutionDataset`
- Model definition `ConvKAN_SR` and residual blocks
- Training loop and model saving

In [None]:
# Import required libraries
import os
# choose idle GPU (set before importing torch); change to '0' if you want GPU0
os.environ.setdefault('CUDA_VISIBLE_DEVICES', '1')
# reduce fragmentation: set before CUDA allocations if possible
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import glob
from tqdm import tqdm

# Try to import convkan. If unavailable, user can install it in the notebook
try:
    from convkan import ConvKAN, LayerNorm2D
except ImportError:
    print("Warning: convkan is not installed. To install, run: !pip install convkan")

## Dataset
Define a PyTorch Dataset to load low-resolution and high-resolution image pairs. Make sure there are corresponding PNG images in `dataset/low_resolution` and `dataset/high_resolution`.

In [None]:
class SuperResolutionDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.transform = transform
        
        self.lr_images = sorted(glob.glob(os.path.join(self.lr_dir, '*.png')))
        self.hr_images = sorted(glob.glob(os.path.join(self.hr_dir, '*.png')))
        
        if not self.lr_images or not self.hr_images:
            raise IOError(f"Error: No image files found in '{lr_dir}' or '{hr_dir}'. Please check the paths.")
        
        print(f"Found {len(self.lr_images)} low-resolution images and {len(self.hr_images)} high-resolution images.")
        assert len(self.lr_images) == len(self.hr_images), "The number of low-resolution and high-resolution images does not match!"

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_image_path = self.lr_images[idx]
        hr_image_path = self.hr_images[idx]
        
        lr_image = Image.open(lr_image_path).convert("RGB")
        hr_image = Image.open(hr_image_path).convert("RGB")

        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)
            
        return lr_image, hr_image

## Model definition
Includes a residual block `ConvKANResBlock` and the main model `ConvKAN_SR`.

In [None]:
class ConvKANResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvKAN(channels, channels, kernel_size=3, padding=1),
            LayerNorm2D(channels),
            ConvKAN(channels, channels, kernel_size=3, padding=1),
            LayerNorm2D(channels)
        )

    def forward(self, x):
        return x + self.block(x)

class ConvKAN_SR(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_filters=64, n_res_blocks=8, upscale_factor=2):
        super().__init__()
        
        self.head = ConvKAN(in_channels, base_filters, kernel_size=3, padding=1)
        
        body = [ConvKANResBlock(base_filters) for _ in range(n_res_blocks)]
        self.body = nn.Sequential(*body)
        
        self.upsample = nn.Sequential(
            ConvKAN(base_filters, base_filters * (upscale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(upscale_factor),
            LayerNorm2D(base_filters)
        )
        
        self.tail = nn.Conv2d(base_filters, out_channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = self.head(x)
        res = x
        x = self.body(x)
        x = x + res
        x = self.upsample(x)
        x = self.tail(x)
        return x

## Training setup and main routine
In this notebook the training section is split into parameter setup, data loading, model initialization and the training loop for easier step-by-step execution and debugging.

In [None]:
# Hyperparameters and path setup
LR_DIR = "dataset/low_resolution"
HR_DIR = "dataset/high_resolution"
LEARNING_RATE = 1e-4
BATCH_SIZE = 1  # reduced to 1 due to GPU memory constraints
NUM_EPOCHS = 50
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
UPSCALE_FACTOR = 2

print(f"Using device: {DEVICE}")

In [None]:
# Dataset and DataLoader
# Add image resizing to reduce memory usage
# ConvKAN works best with smaller images (MNIST example uses 28x28)
MAX_IMAGE_SIZE = 128  # Reduce to 128x128 to save memory

transform = transforms.Compose([
    transforms.Resize((MAX_IMAGE_SIZE, MAX_IMAGE_SIZE)),  # Resize to fixed size
    transforms.ToTensor()
])

print(f"Images will be resized to {MAX_IMAGE_SIZE}x{MAX_IMAGE_SIZE}")

try:
    train_dataset = SuperResolutionDataset(lr_dir=LR_DIR, hr_dir=HR_DIR, transform=transform)
except (IOError, AssertionError) as e:
    print(e)
    raise

# IMPORTANT: In Jupyter notebooks, num_workers > 0 can cause memory leaks
# Set num_workers=0 to avoid multiprocessing issues
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=0)

In [None]:
# Initialize model, loss and optimizer
from torch.cuda.amp import autocast, GradScaler

# Clear GPU cache before model initialization
torch.cuda.empty_cache()

# Test: load one sample to check image size
sample_lr, sample_hr = train_dataset[0]
print(f"DEBUG: LR image shape: {sample_lr.shape}")
print(f"DEBUG: HR image shape: {sample_hr.shape}")
print(f"DEBUG: LR image size (pixels): {sample_lr.shape[1] * sample_lr.shape[2]}")

# Calculate expected memory usage
def estimate_memory(batch_size, channels, height, width, base_filters, n_res_blocks):
    # Input
    input_mem = batch_size * channels * height * width * 4  # float32
    # After head: base_filters channels
    head_mem = batch_size * base_filters * height * width * 4
    # Each res block roughly doubles memory for forward + backward
    body_mem = head_mem * n_res_blocks * 3  # conservative estimate
    # Upsampled (2x resolution)
    upsample_mem = batch_size * base_filters * (height * 2) * (width * 2) * 4
    total_mb = (input_mem + head_mem + body_mem + upsample_mem) / (1024 ** 2)
    return total_mb

estimated_mem = estimate_memory(BATCH_SIZE, 3, sample_lr.shape[1], sample_lr.shape[2], 32, 4)
print(f"DEBUG: Estimated forward pass memory: {estimated_mem:.2f} MB")
print(f"DEBUG: With gradients (Ã—3): {estimated_mem * 3:.2f} MB")

# Only proceed if memory seems reasonable
if estimated_mem * 3 > 1000:  # More than 1GB for one batch
    print(f"\nWARNING: Image size {sample_lr.shape} is TOO LARGE!")
    print(f"Recommended: Resize images to max 256x256 or 128x128")
    print(f"Current image will use ~{estimated_mem * 3:.0f} MB per batch")
    
# Further reduce model size: base_filters=16, n_res_blocks=2
print(f"\nInitializing smaller model: base_filters=16, n_res_blocks=2")
model = ConvKAN_SR(in_channels=3, out_channels=3, base_filters=16, n_res_blocks=2, upscale_factor=UPSCALE_FACTOR).to(DEVICE)
criterion = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler()

print("Model and data ready. Starting training...")

In [None]:
# Check GPU memory before training
if DEVICE == "cuda":
    print("GPU Memory Status:")
    print(f"  Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"  Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print(f"  Max Allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
    torch.cuda.reset_peak_memory_stats()

# Training loop (mixed precision, can be interrupted and executed step-by-step)
for epoch in range(NUM_EPOCHS):
    # clear any cached memory to reduce OOM risk between epochs
    torch.cuda.empty_cache()
    model.train()
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    total_loss = 0.0
    
    for batch_idx, (lr_images, hr_images) in enumerate(progress_bar):
        # Use non_blocking for async data transfer with pin_memory
        lr_images = lr_images.to(DEVICE, non_blocking=True)
        hr_images = hr_images.to(DEVICE, non_blocking=True)
        
        # Debug first batch
        if epoch == 0 and batch_idx == 0:
            print(f"\nDEBUG First batch:")
            print(f"  LR shape: {lr_images.shape}, HR shape: {hr_images.shape}")
            print(f"  GPU memory before forward: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        
        optimizer.zero_grad(set_to_none=True)  # More efficient than default zero_grad()
        
        try:
            with torch.amp.autocast('cuda'):
                sr_images = model(lr_images)
                
                if epoch == 0 and batch_idx == 0:
                    print(f"  GPU memory after forward: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
                    print(f"  SR shape: {sr_images.shape}")
                
                loss = criterion(sr_images, hr_images)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")
            
        except RuntimeError as e:
            print(f"\nERROR at batch {batch_idx}:")
            print(f"  LR shape: {lr_images.shape}")
            print(f"  GPU memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
            raise e
        
        # Explicitly delete tensors to help garbage collection
        del lr_images, hr_images, sr_images, loss
        
    torch.cuda.empty_cache()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} finished, average loss: {avg_loss:.4f}")
    print(f"  Peak GPU memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
    torch.cuda.reset_peak_memory_stats()

print("Training complete!")

In [None]:
# Save model
torch.save(model.state_dict(), 'convkan_sr_model.pth')
print("Model saved to convkan_sr_model.pth")

### Notes
- If you want to monitor training curves in the notebook, record loss values during the training loop and plot them with matplotlib.
- On systems without a GPU, reduce `BATCH_SIZE` and `NUM_EPOCHS` to speed up testing.