# ConvKAN Super-Resolution Training Notebook
This is the Jupyter Notebook version of `train_convkan_superres.py`.
Includes:
- Environment and dependency notes
- Dataset class `SuperResolutionDataset`
- Model definition `ConvKAN_SR` and residual blocks
- Training loop and model saving

In [None]:
# Import required libraries
import os
# choose idle GPU (set before importing torch); change to '0' if you want GPU0
os.environ.setdefault('CUDA_VISIBLE_DEVICES', '1')
# reduce fragmentation: set before CUDA allocations if possible
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import glob
from tqdm import tqdm

# Try to import convkan. If unavailable, user can install it in the notebook
try:
    from convkan import ConvKAN, LayerNorm2D
except ImportError:
    print("Warning: convkan is not installed. To install, run: !pip install convkan")

## Dataset
Define a PyTorch Dataset to load low-resolution and high-resolution image pairs. Make sure there are corresponding PNG images in `dataset/low_resolution` and `dataset/high_resolution`.

In [None]:
class SuperResolutionDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.transform = transform
        
        self.lr_images = sorted(glob.glob(os.path.join(self.lr_dir, '*.png')))
        self.hr_images = sorted(glob.glob(os.path.join(self.hr_dir, '*.png')))
        
        if not self.lr_images or not self.hr_images:
            raise IOError(f"Error: No image files found in '{lr_dir}' or '{hr_dir}'. Please check the paths.")
        
        print(f"Found {len(self.lr_images)} low-resolution images and {len(self.hr_images)} high-resolution images.")
        assert len(self.lr_images) == len(self.hr_images), "The number of low-resolution and high-resolution images does not match!"

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_image_path = self.lr_images[idx]
        hr_image_path = self.hr_images[idx]
        
        lr_image = Image.open(lr_image_path).convert("RGB")
        hr_image = Image.open(hr_image_path).convert("RGB")

        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)
            
        return lr_image, hr_image

## Model definition
Includes a residual block `ConvKANResBlock` and the main model `ConvKAN_SR`.

In [None]:
class ConvKANResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvKAN(channels, channels, kernel_size=3, padding=1),
            LayerNorm2D(channels),
            ConvKAN(channels, channels, kernel_size=3, padding=1),
            LayerNorm2D(channels)
        )

    def forward(self, x):
        return x + self.block(x)

class ConvKAN_SR(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_filters=64, n_res_blocks=8, upscale_factor=2):
        super().__init__()
        
        self.head = ConvKAN(in_channels, base_filters, kernel_size=3, padding=1)
        
        body = [ConvKANResBlock(base_filters) for _ in range(n_res_blocks)]
        self.body = nn.Sequential(*body)
        
        self.upsample = nn.Sequential(
            ConvKAN(base_filters, base_filters * (upscale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(upscale_factor),
            LayerNorm2D(base_filters)
        )
        
        self.tail = nn.Conv2d(base_filters, out_channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = self.head(x)
        res = x
        x = self.body(x)
        x = x + res
        x = self.upsample(x)
        x = self.tail(x)
        return x

## Training setup and main routine
In this notebook the training section is split into parameter setup, data loading, model initialization and the training loop for easier step-by-step execution and debugging.

In [None]:
# Hyperparameters and path setup
LR_DIR = "dataset/low_resolution"
HR_DIR = "dataset/high_resolution"
LEARNING_RATE = 1e-4
BATCH_SIZE = 1  # reduced to 1 due to GPU memory constraints
NUM_EPOCHS = 50
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
UPSCALE_FACTOR = 2

print(f"Using device: {DEVICE}")

In [None]:
# Dataset and DataLoader
transform = transforms.ToTensor()
try:
    train_dataset = SuperResolutionDataset(lr_dir=LR_DIR, hr_dir=HR_DIR, transform=transform)
except (IOError, AssertionError) as e:
    print(e)
    raise

# IMPORTANT: In Jupyter notebooks, num_workers > 0 can cause memory leaks
# Set num_workers=0 to avoid multiprocessing issues
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=0)

In [None]:
# Initialize model, loss and optimizer
from torch.cuda.amp import autocast, GradScaler

# Clear GPU cache before model initialization
torch.cuda.empty_cache()

# Reduced model size: base_filters=32 (was 64), n_res_blocks=4 (was 8)
model = ConvKAN_SR(in_channels=3, out_channels=3, base_filters=32, n_res_blocks=4, upscale_factor=UPSCALE_FACTOR).to(DEVICE)
criterion = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler()

print("Model and data ready. Starting training...")

In [None]:
# Training loop (mixed precision, can be interrupted and executed step-by-step)
for epoch in range(NUM_EPOCHS):
    # clear any cached memory to reduce OOM risk between epochs
    torch.cuda.empty_cache()
    model.train()
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    total_loss = 0.0
    
    for lr_images, hr_images in progress_bar:
        # Use non_blocking for async data transfer with pin_memory
        lr_images = lr_images.to(DEVICE, non_blocking=True)
        hr_images = hr_images.to(DEVICE, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)  # More efficient than default zero_grad()
        
        with torch.amp.autocast('cuda'):
            sr_images = model(lr_images)
            loss = criterion(sr_images, hr_images)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")
        
        # Explicitly delete tensors to help garbage collection
        del lr_images, hr_images, sr_images, loss
        
    torch.cuda.empty_cache()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} finished, average loss: {avg_loss:.4f}")

print("Training complete!")

In [None]:
# Save model
torch.save(model.state_dict(), 'convkan_sr_model.pth')
print("Model saved to convkan_sr_model.pth")

### Notes
- If you want to monitor training curves in the notebook, record loss values during the training loop and plot them with matplotlib.
- On systems without a GPU, reduce `BATCH_SIZE` and `NUM_EPOCHS` to speed up testing.