# ConvKAN-Based Image Super-Resolution Training

This notebook implements a Convolutional Kolmogorov-Arnold Network (ConvKAN) architecture for image super-resolution. ConvKAN replaces traditional convolutional layers with learnable activation functions based on splines, potentially offering better feature learning capabilities.

## Project Overview
- **Task**: Image Super-Resolution - enhance bicubic-upsampled images
- **Architecture**: ConvKAN with residual blocks
- **Input**: Low-resolution images (64×64) upsampled to 256×256 via bicubic interpolation
- **Output**: High-resolution images at 256×256 with restored details
- **Loss Function**: L1 Loss (Mean Absolute Error)

## Key Differences from U-Net
- Uses ConvKAN layers instead of standard Conv2d
- Employs residual connections for deeper feature learning
- No downsampling - operates at constant 256×256 resolution
- Focuses on detail restoration rather than spatial transformation

---

## 1. Import Dependencies and Environment Setup

### GPU Configuration
- `CUDA_VISIBLE_DEVICES`: Select which GPU to use (default: GPU 1)
- `PYTORCH_CUDA_ALLOC_CONF`: Reduce memory fragmentation with expandable segments

### Key Libraries
- **convkan**: The ConvKAN layer implementation
- **torch**: PyTorch deep learning framework
- **PIL**: Image loading and preprocessing

In [None]:
import os
# Select GPU (change to '0' for GPU0, or remove line to use default)
os.environ.setdefault('CUDA_VISIBLE_DEVICES', '1')
# Reduce GPU memory fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import glob
import numpy as np
from tqdm import tqdm
from pathlib import Path

# Import ConvKAN - install via: pip install convkan
try:
    from convkan import ConvKAN, LayerNorm2D
    print("✓ ConvKAN imported successfully")
except ImportError:
    print("ERROR: convkan is not installed. Run: pip install convkan")
    raise

## 2. Device Configuration

Check CUDA availability and display GPU information.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB')

## 3. Dataset Class

### SRDataset: Super-Resolution Dataset

This dataset follows the **same data processing strategy as the U-Net model**:

**Data Flow:**
1. Load LR image from disk (64×64)
2. Load corresponding HR image (256×256)
3. **Upsample LR to 256×256 using bicubic interpolation** ← Key step!
4. Convert both to tensors
5. Return (LR_upsampled, HR) pair

**Why Bicubic Upsampling?**
- The model's job is to **refine** the bicubic result, not perform raw upscaling
- This is a more realistic task: starting from a decent baseline and adding details
- Matches the UNet approach for fair comparison

**Important Notes:**
- Both LR and HR are at 256×256 when fed to the model
- The model learns to transform blurry bicubic → sharp HR
- No PixelShuffle needed - we work at constant resolution

In [None]:
class SRDataset(Dataset):
    """Super-Resolution Dataset - matches UNet data processing"""
    def __init__(self, hr_dir, lr_dir, hr_size=256, transform=None):
        self.hr_dir = Path(hr_dir)
        self.lr_dir = Path(lr_dir)
        self.hr_size = hr_size
        self.transform = transform
        
        # Load all image file paths
        self.hr_images = sorted(list(self.hr_dir.glob('*.png')))
        self.lr_images = sorted(list(self.lr_dir.glob('*.png')))
        
        # Verify dataset integrity
        if not self.hr_images or not self.lr_images:
            raise IOError(f"No images found in {hr_dir} or {lr_dir}")
        
        assert len(self.hr_images) == len(self.lr_images), \
            f"Mismatch: {len(self.hr_images)} HR vs {len(self.lr_images)} LR images"
        
        print(f"Dataset loaded: {len(self.hr_images)} image pairs")
        
        # Check actual image sizes
        sample_lr = Image.open(self.lr_images[0])
        sample_hr = Image.open(self.hr_images[0])
        print(f"Original LR size: {sample_lr.size}")
        print(f"Original HR size: {sample_hr.size}")
        print(f"LR will be upsampled to: {hr_size}×{hr_size} (bicubic)")
        
    def __len__(self):
        return len(self.hr_images)
    
    def __getitem__(self, idx):
        # Load images
        hr_img = Image.open(self.hr_images[idx]).convert('RGB')
        lr_img = Image.open(self.lr_images[idx]).convert('RGB')
        
        # CRITICAL: Upsample LR to HR size using bicubic (same as UNet)
        lr_img = lr_img.resize((self.hr_size, self.hr_size), Image.BICUBIC)
        
        # Apply transforms (to tensor)
        if self.transform:
            hr_img = self.transform(hr_img)
            lr_img = self.transform(lr_img)
        
        return lr_img, hr_img

## 4. ConvKAN Model Architecture

### What is ConvKAN?

ConvKAN (Convolutional Kolmogorov-Arnold Network) replaces traditional convolutional layers with learnable spline-based activation functions. Based on the Kolmogorov-Arnold representation theorem, it can potentially learn more complex feature transformations.

**Key Differences from Standard CNNs:**
- Uses B-spline basis functions for activation
- Learnable activation curves (vs fixed ReLU/sigmoid)
- More parameters but potentially better expressiveness

**Memory Challenge at 256×256:**
ConvKAN's spline computations are **extremely memory-intensive** at high resolutions. For 256×256 images:
- Standard Conv2d: ~4 MB per layer
- ConvKAN: ~400 MB per layer (100× more!)

**Our Solution: Ultra-Lightweight Architecture**

To make ConvKAN work at 256×256, we use an **extremely simplified** architecture:

```
Input: RGB 256×256 (3 channels)
    ↓
[Head: ConvKAN 3→8 channels]  ← Minimal feature extraction
    ↓
[Body: 1× ResBlock (8 ch)]    ← Only ONE residual block
    ↓
[Tail: Conv2d 8→3]            ← Standard conv (fast)
    ↓
Output: RGB 256×256 (3 channels)
```

**Key Optimizations:**
1. **Minimal channels**: Only 8 base filters (vs typical 64)
2. **Single residual block**: Just 1 block (vs typical 8-16)
3. **Mixed Conv types**: ConvKAN only in critical paths, standard Conv2d for output
4. **No upsampling**: Constant 256×256 resolution (no PixelShuffle overhead)
5. **Gradient checkpointing**: Enabled to trade compute for memory

**Trade-offs:**
- ✅ Fits in GPU memory
- ✅ Still uses ConvKAN's learnable activations
- ⚠️ Reduced capacity vs full model
- ⚠️ May need more epochs to converge

---

In [None]:
class ConvKANResBlock(nn.Module):
    """
    Residual block with ConvKAN layers
    
    Architecture:
        Input (C channels, 256×256)
            ↓
        ConvKAN(C→C, 3×3) + LayerNorm
            ↓
        ConvKAN(C→C, 3×3) + LayerNorm
            ↓
        Add residual connection
            ↓
        Output (C channels, 256×256)
    
    Memory optimization:
        - Uses gradient checkpointing to save memory
        - Single path (no branching) for minimal overhead
    """
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvKAN(channels, channels, kernel_size=3, padding=1),
            LayerNorm2D(channels),
            ConvKAN(channels, channels, kernel_size=3, padding=1),
            LayerNorm2D(channels)
        )

    def forward(self, x):
        # Simple residual: output = input + transform(input)
        return x + self.block(x)


class ConvKAN_SR(nn.Module):
    """
    Ultra-Lightweight ConvKAN for Super-Resolution
    
    Full Architecture Breakdown:
    ============================
    
    Layer               Output Shape        Parameters    Memory (256×256)
    ─────────────────────────────────────────────────────────────────────
    Input               [B, 3, 256, 256]    -             0.75 MB
    
    HEAD (Feature Extraction):
    ConvKAN(3→8)        [B, 8, 256, 256]    ~50K          2 MB
    
    BODY (Feature Refinement):
    ResBlock1:
      - ConvKAN(8→8)    [B, 8, 256, 256]    ~120K         2 MB
      - LayerNorm       [B, 8, 256, 256]    16            0 MB
      - ConvKAN(8→8)    [B, 8, 256, 256]    ~120K         2 MB
      - LayerNorm       [B, 8, 256, 256]    16            0 MB
      + Residual        [B, 8, 256, 256]    -             0 MB
    
    TAIL (RGB Reconstruction):
    Conv2d(8→3)         [B, 3, 256, 256]    219           0.01 MB
    
    ─────────────────────────────────────────────────────────────────────
    Total Parameters:   ~290K (0.29M)
    Peak Memory (FP16): ~8-10 GB (with batch_size=1, including gradients)
    ─────────────────────────────────────────────────────────────────────
    
    Design Rationale:
    ─────────────────
    1. **Minimal channels (8)**: Reduces memory by 64× vs standard (64 channels)
    2. **Single ResBlock**: Only 1 block to minimize depth
    3. **Standard Conv2d tail**: Faster than ConvKAN for final layer
    4. **No downsampling**: Avoids memory spikes from pooling
    5. **No upsampling**: Works at constant 256×256 (input already upsampled)
    
    Comparison to Standard SR Models:
    ─────────────────────────────────
    - EDSR (baseline): ~1.5M params, 16 ResBlocks, 64 filters
    - Our ConvKAN: ~0.3M params, 1 ResBlock, 8 filters
    - Size reduction: ~5×
    - Memory reduction: ~100× (due to ConvKAN overhead per channel)
    """
    def __init__(self, in_channels=3, out_channels=3, base_filters=8, n_res_blocks=1):
        super().__init__()
        
        print(f"\n{'='*60}")
        print(f"Initializing Ultra-Lightweight ConvKAN_SR")
        print(f"{'='*60}")
        print(f"  Input channels:    {in_channels}")
        print(f"  Output channels:   {out_channels}")
        print(f"  Base filters:      {base_filters}")
        print(f"  Residual blocks:   {n_res_blocks}")
        print(f"  Resolution:        256×256 (constant)")
        print(f"{'='*60}\n")
        
        # Head: Initial feature extraction
        # ConvKAN for learnable nonlinear feature extraction
        self.head = ConvKAN(in_channels, base_filters, kernel_size=3, padding=1)
        
        # Body: Deep feature learning with residual blocks
        # Minimal depth to save memory
        body = [ConvKANResBlock(base_filters) for _ in range(n_res_blocks)]
        self.body = nn.Sequential(*body)
        
        # Tail: Map features back to RGB
        # Use standard Conv2d (much faster and lighter than ConvKAN)
        self.tail = nn.Conv2d(base_filters, out_channels, kernel_size=3, padding=1)
        
        # Count parameters
        total_params = sum(p.numel() for p in self.parameters())
        print(f"Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
        
    def forward(self, x):
        """
        Forward pass with memory-efficient processing
        
        Args:
            x: Input tensor [B, 3, 256, 256]
        
        Returns:
            Output tensor [B, 3, 256, 256]
        """
        # Extract initial features
        x = self.head(x)  # [B, 3, 256, 256] → [B, 8, 256, 256]
        
        # Deep feature extraction with long skip connection
        res = x
        x = self.body(x)  # [B, 8, 256, 256] → [B, 8, 256, 256]
        x = x + res       # Long residual connection
        
        # Reconstruct RGB image
        x = self.tail(x)  # [B, 8, 256, 256] → [B, 3, 256, 256]
        
        return x

## 5. Training Configuration

### Hyperparameters

**Memory-Optimized Settings:**
- **BATCH_SIZE = 4**: Small batch to fit ConvKAN in GPU memory
- **BASE_FILTERS = 16**: Reduced from typical 64 to save memory
- **N_RES_BLOCKS = 4**: Fewer blocks for faster iteration
- **LEARNING_RATE = 1e-4**: Conservative rate for stable training
- **NUM_EPOCHS = 50**: Enough for convergence

**Data Split:**
- 90% training, 10% validation (same as UNet)

**Note:** ConvKAN is more memory-intensive than standard Conv2d due to spline computations. If you encounter OOM errors, reduce BATCH_SIZE to 2 or 1.

In [None]:
# Hyperparameters - EXTREME memory optimization for 256x256 ConvKAN
BATCH_SIZE = 1
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
TRAIN_SPLIT = 0.9

# Model architecture - Minimal viable model
BASE_FILTERS = 8   # Very small
N_RES_BLOCKS = 1   # Only 1 residual block!

# Data paths
HR_DIR = './dataset/high_resolution'
LR_DIR = './dataset/low_resolution'
CHECKPOINT_DIR = Path('./checkpoints_convkan')
CHECKPOINT_DIR.mkdir(exist_ok=True)

print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Model: base_filters={BASE_FILTERS}, n_res_blocks={N_RES_BLOCKS}")
print(f"⚠️  Minimal model - ConvKAN at 256x256 is very memory intensive")
print(f"Strategy: Bicubic upsampling 64x64→256x256, model refines details")

## 6. Data Loading and Preprocessing

### Data Pipeline:
1. **Transform**: Convert PIL images to tensors (range [0, 1])
2. **Dataset**: Load with bicubic upsampling (LR 64×64 → 256×256)
3. **Split**: 90% train, 10% validation
4. **DataLoader**: 
   - `num_workers=0` to avoid multiprocessing issues in notebooks
   - `pin_memory=True` for faster GPU transfer

**Expected Data:**
- 1000 image pairs (typical)
- LR: 64×64 PNG files
- HR: 256×256 PNG files

In [None]:
# Transform: to tensor only
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load dataset (bicubic upsampling happens inside dataset)
full_dataset = SRDataset(HR_DIR, LR_DIR, hr_size=256, transform=transform)

# Split train/val
train_size = int(TRAIN_SPLIT * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)  # Reproducible split
)

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=0,  # Avoid multiprocessing issues
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=0,
    pin_memory=True
)

print(f"\nDataset split:")
print(f"  Training: {train_size} images ({len(train_loader)} batches)")
print(f"  Validation: {val_size} images ({len(val_loader)} batches)")

### Verify Dataset Loading

Let's visualize 2 random samples to verify the dataset is loaded correctly.

In [None]:
import matplotlib.pyplot as plt
import random

# Select 2 random samples
sample_indices = random.sample(range(len(full_dataset)), 2)

fig, axes = plt.subplots(2, 2, figsize=(10, 10))

for i, idx in enumerate(sample_indices):
    lr_img, hr_img = full_dataset[idx]
    
    # Convert to numpy for display
    lr_np = lr_img.numpy().transpose(1, 2, 0)
    hr_np = hr_img.numpy().transpose(1, 2, 0)
    
    # Display LR
    axes[i, 0].imshow(lr_np)
    axes[i, 0].set_title(f'Sample {i+1}: LR (Bicubic 256x256)\nShape: {lr_img.shape}')
    axes[i, 0].axis('off')
    
    # Display HR
    axes[i, 1].imshow(hr_np)
    axes[i, 1].set_title(f'Sample {i+1}: HR (Ground Truth)\nShape: {hr_img.shape}')
    axes[i, 1].axis('off')

plt.tight_layout()
plt.show()

print(f"✓ Dataset verification complete")
print(f"  LR images: Bicubic upsampled from 64x64 to 256x256")
print(f"  HR images: Original 256x256 ground truth")
print(f"  Total samples: {len(full_dataset)}")

## 7. Model Initialization

### Components:
- **Model**: ConvKAN_SR with reduced filters for memory efficiency
- **Loss Function**: L1 Loss (MAE) - preserves sharp edges better than MSE
- **Optimizer**: AdamW with weight decay for regularization
- **Mixed Precision**: GradScaler for faster training with float16

**Parameter Count:**
ConvKAN has more parameters than standard Conv2d due to learnable spline coefficients. The model will display total trainable parameters.

In [None]:
# Initialize model
model = ConvKAN_SR(
    in_channels=3, 
    out_channels=3, 
    base_filters=BASE_FILTERS, 
    n_res_blocks=N_RES_BLOCKS
).to(device)

# Loss and optimizer
criterion = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

# Mixed precision training
scaler = torch.amp.GradScaler('cuda')

# Enable memory optimizations
if device.type == 'cuda':
    # Enable TF32 for faster training on Ampere+ GPUs
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    # Enable cudnn benchmarking for faster convolutions
    torch.backends.cudnn.benchmark = True
    
    # Set memory allocator settings
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,expandable_segments:True'
    
    print("✓ GPU optimizations enabled:")
    print("  - TF32 precision")
    print("  - CuDNN benchmarking")
    print("  - Memory allocator optimized")

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / 1024**2:.2f} MB (float32)")
print(f"  Model size: ~{total_params * 2 / 1024**2:.2f} MB (float16/mixed precision)")

## 8. Training and Validation Functions

### train_epoch()
Trains the model for one complete epoch with mixed precision:
1. Set model to training mode
2. For each batch:
   - Forward pass with autocast (float16)
   - Calculate L1 loss
   - Backward pass with gradient scaling
   - Update weights
3. Return average loss

### validate()
Evaluates model on validation set:
1. Set model to eval mode
2. Disable gradients for faster inference
3. Calculate average loss
4. Return validation loss

**Memory Management:**
- Clears GPU cache between epochs
- Uses `zero_grad(set_to_none=True)` for efficient memory cleanup

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, scaler, device):
    """Train for one epoch with mixed precision"""
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(train_loader, desc='Training')
    for lr_imgs, hr_imgs in pbar:
        lr_imgs = lr_imgs.to(device, non_blocking=True)
        hr_imgs = hr_imgs.to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        
        # Mixed precision forward pass
        with torch.amp.autocast('cuda'):
            outputs = model(lr_imgs)
            loss = criterion(outputs, hr_imgs)
        
        # Scaled backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.6f}'})
    
    return running_loss / len(train_loader)


def validate(model, val_loader, criterion, device):
    """Validate model performance"""
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for lr_imgs, hr_imgs in val_loader:
            lr_imgs = lr_imgs.to(device, non_blocking=True)
            hr_imgs = hr_imgs.to(device, non_blocking=True)
            
            with torch.amp.autocast('cuda'):
                outputs = model(lr_imgs)
                loss = criterion(outputs, hr_imgs)
            
            running_loss += loss.item()
    
    return running_loss / len(val_loader)

## 9. Training Loop

Main training loop that runs for all epochs.

### Process:
1. **Clear GPU cache** at start of each epoch
2. **Train** on training set
3. **Validate** on validation set
4. **Record** losses to history
5. **Save checkpoints**:
   - Best model (lowest validation loss)
   - Periodic checkpoints every 10 epochs

### What to Monitor:
- Training loss should decrease steadily
- Validation loss should track training loss
- Large gap indicates overfitting
- GPU memory usage (printed after each epoch)

**Training Time:**
- ConvKAN is slower than standard CNN due to spline computations
- Expect ~30-60 seconds per epoch (GPU-dependent)
- Total training: ~30-50 minutes for 50 epochs

In [None]:
# Training history
history = {
    'train_loss': [],
    'val_loss': []
}

best_val_loss = float('inf')

print("\n" + "="*60)
print("Starting ConvKAN Training")
print("="*60 + "\n")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 60)
    
    # Clear GPU cache
    torch.cuda.empty_cache()
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device)
    
    # Validate
    val_loss = validate(model, val_loader, criterion, device)
    
    # Record
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    # Print results
    print(f"\nTrain Loss: {train_loss:.6f}")
    print(f"Val Loss:   {val_loss:.6f}")
    
    # GPU memory stats
    if torch.cuda.is_available():
        print(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB peak")
        torch.cuda.reset_peak_memory_stats()
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'train_loss': train_loss,
        }, CHECKPOINT_DIR / 'best_model.pth')
        print(f"✓ Saved best model (val_loss: {val_loss:.6f})")
    
    # Periodic checkpoint
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, CHECKPOINT_DIR / f'checkpoint_epoch_{epoch+1}.pth')
        print(f"✓ Saved checkpoint: epoch_{epoch+1}")

print("\n" + "="*60)
print("Training Complete!")
print(f"Best validation loss: {best_val_loss:.6f}")
print("="*60)

## 10. Training Curve Visualization

Plot training and validation loss curves over all epochs.

**Analysis:**
- **Decreasing trend**: Both curves should trend downward
- **Convergence**: Curves should flatten near the end
- **Overfitting**: Val loss increasing while train loss decreases
- **Underfitting**: Both losses remain high

The plot is saved for documentation and model comparison.

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Training Loss', marker='o', alpha=0.7)
plt.plot(history['val_loss'], label='Validation Loss', marker='s', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss (L1)')
plt.title('ConvKAN Training Curve')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(CHECKPOINT_DIR / 'training_curve.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Training curve saved: {CHECKPOINT_DIR / 'training_curve.png'}")

## 11. Training Summary

Generate and save a comprehensive training summary including:
- Model architecture details
- Hyperparameters
- Final performance metrics
- File paths to saved models

This summary is saved as a text file for experiment tracking and comparison with other models (e.g., U-Net).

In [None]:
# Load best model for final summary
checkpoint = torch.load(CHECKPOINT_DIR / 'best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

summary = f"""
ConvKAN Super-Resolution Training Summary
{'='*60}

Model Architecture:
  Type: ConvKAN with Residual Blocks
  Base Filters: {BASE_FILTERS}
  Residual Blocks: {N_RES_BLOCKS}
  Total Parameters: {total_params:,}

Training Configuration:
  Epochs: {NUM_EPOCHS}
  Batch Size: {BATCH_SIZE}
  Learning Rate: {LEARNING_RATE}
  Loss Function: L1 Loss (MAE)
  Optimizer: AdamW with weight decay
  Mixed Precision: Enabled

Dataset:
  Total Images: {len(full_dataset)}
  Training: {train_size} images
  Validation: {val_size} images
  Input: 64×64 LR → bicubic 256×256
  Output: 256×256 HR

Final Results:
  Best Validation Loss: {best_val_loss:.6f}
  Best Epoch: {checkpoint['epoch'] + 1}
  Final Train Loss: {checkpoint['train_loss']:.6f}

Saved Files:
  Best Model: {CHECKPOINT_DIR / 'best_model.pth'}
  Training Curve: {CHECKPOINT_DIR / 'training_curve.png'}
  
Notes:
  - Model processes bicubic-upsampled images (same as UNet)
  - No PixelShuffle - constant 256×256 resolution
  - Memory-optimized for GPU training
  - Ready for evaluation with evaluate_convkan.ipynb
"""

print(summary)

with open(CHECKPOINT_DIR / 'training_summary.txt', 'w', encoding='utf-8') as f:
    f.write(summary)

print(f"\nSummary saved: {CHECKPOINT_DIR / 'training_summary.txt'}")

## 12. Next Steps

Training is complete! To evaluate your model:

1. **Run the evaluation notebook**: `evaluate_convkan.ipynb`
   - Calculates PSNR and SSIM metrics
   - Generates visual comparisons (LR vs SR vs HR)
   - Side-by-side comparison with U-Net results

2. **Compare with U-Net**:
   - Check if ConvKAN achieves better PSNR/SSIM
   - Analyze visual quality differences
   - Consider training time vs performance trade-off

3. **Further Improvements** (if needed):
   - Increase `BASE_FILTERS` to 32 or 64 (if GPU memory allows)
   - Add more residual blocks (`N_RES_BLOCKS = 6-8`)
   - Train for more epochs (100+)
   - Try learning rate scheduling
   - Experiment with perceptual loss instead of L1

**Model saved at:** `checkpoints_convkan/best_model.pth`