<a href="https://colab.research.google.com/github/chirana07/KAIR/blob/master/SwinIR_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **SwinIR Fine-Tuning on LOL Dataset**
This notebook acts as the training environment for fine-tuning SwinIR on the Low-Light (LOL) dataset.

**Prerequisites:**
1.  Upload your `lol_dataset.zip` to your Google Drive.


In [None]:
# @title 1. Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# @title 2. Setup Environment & Install Dependencies
!git clone https://github.com/cszn/KAIR.git
%cd /content/KAIR

# Install dependencies
!pip install --upgrade pip
!pip install matplotlib scikit-image opencv-python tensorboard einops timm hdf5storage ninja

In [None]:
# @title 3. Load Dataset from Drive
import zipfile
import os

# Adjust this path if you saved the zip elsewhere in your Drive
zip_path = '/content/drive/MyDrive/lol_dataset.zip'

if os.path.exists(zip_path):
    print("Found dataset zip! Extracting...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        # Extract to /content/KAIR/dataset
        # Since the zip contains direct folders (our485, eval15), they will appear there.
        zip_ref.extractall('/content/KAIR/dataset')
    print("Extraction complete.")
else:
    print(f"ERROR: Could not find {zip_path}. Please check the path.")

In [None]:
# @title 4. Download Pre-trained SwinIR Model
import os
import requests

model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth"
model_path = "/content/KAIR/model_zoo/005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth"

os.makedirs("/content/KAIR/model_zoo", exist_ok=True)

if not os.path.exists(model_path):
    print("Downloading pre-trained model...")
    r = requests.get(model_url, allow_redirects=True)
    with open(model_path, 'wb') as f:
        f.write(r.content)
    print("Download complete.")
else:
    print("Model already exists.")

In [None]:
# @title 5. Create Training Configuration (JSON)
import json

config = {
  "task": "swinir_fine_tune_lol",
  "model": "plain",
  "gpu_ids": [0],
  "dist": False,
  "n_channels": 3,
  "path": {
    "root": "denoising",
    "pretrained_netG": "/content/KAIR/model_zoo/005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth",
    "pretrained_netE": None
  },
  "datasets": {
    "train": {
      "name": "train_dataset",
      "dataset_type": "plain",
      "dataroot_H": "/content/KAIR/dataset/our485/high",
      "dataroot_L": "/content/KAIR/dataset/our485/low",
      "H_size": 128,
      "sigma": 15,
      "sigma_test": 15,
      "dataloader_shuffle": True,
      "dataloader_num_workers": 2,
      "dataloader_batch_size": 2
    },
    "test": {
      "name": "test_dataset",
      "dataset_type": "plain",
      "dataroot_H": "/content/KAIR/dataset/eval15/high",
      "dataroot_L": "/content/KAIR/dataset/eval15/low",
      "sigma": 15,
      "sigma_test": 15
    }
  },
  "netG": {
    "net_type": "swinir",
    "upscale": 1,
    "in_chans": 3,
    "img_size": 128,
    "window_size": 8,
    "img_range": 1.0,
    "depths": [6, 6, 6, 6, 6, 6],
    "embed_dim": 180,
    "num_heads": [6, 6, 6, 6, 6, 6],
    "mlp_ratio": 2,
    "upsampler": None,
    "resi_connection": "1conv",
    "init_type": "default"
  },
  "train": {
    "G_lossfn_type": "charbonnier",
    "G_lossfn_weight": 1.0,
    "G_charbonnier_eps": 1e-9,
    "E_decay": 0.999,
    "G_optimizer_type": "adam",
    "G_optimizer_lr": 2e-5,
    "G_optimizer_wd": 0,
    "G_optimizer_clipgrad": None,
    "G_optimizer_reuse": True,
    "G_scheduler_type": "MultiStepLR",
    "G_scheduler_milestones": [200000, 400000, 600000],
    "G_scheduler_gamma": 0.5,
    "G_param_strict": True,
    "E_param_strict": True,
    "checkpoint_test": 500,
    "checkpoint_save": 500,
    "checkpoint_print": 100
  }
}

with open('/content/KAIR/options/train_swinir_colab.json', 'w') as f:
    json.dump(config, f, indent=2)

print("Config file created!")

In [None]:
# @title 6. Run Training
!python main_train_psnr.py --opt options/train_swinir_colab.json