<a href="https://colab.research.google.com/github/colincockburn/CISC_473_project/blob/main/image_restoration_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Prepare Environment

import sys
import yaml
import os
from dotenv import load_dotenv
load_dotenv(dotenv_path="/home/colin/projects/CISC_473_project/.env")

REPO_DIR = os.getenv("REPO_DIR")
DATA_ROOT = os.getenv("DATA_ROOT")
SAVE_DIR = os.getenv("SAVE_DIR")
CKPT_PATH = f"{os.getenv('SAVE_DIR')}/best.pth"
sys.path.append(REPO_DIR)

with open(f"{REPO_DIR}/configs/default.yaml", "r") as f:
    cfg = yaml.safe_load(f)

In [2]:
# Check GPU-cuda functionality

import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))
    x = torch.rand(1000, 1000).cuda()
    y = torch.mm(x, x)
    print("GPU test OK, result:", y.sum().item())
else:
    print("GPU NOT detected.")

PyTorch version: 2.5.1+cu121
CUDA available: True
Device: NVIDIA GeForce RTX 3070
GPU test OK, result: 249901920.0


In [3]:
# Train all 4 models

import sys
from pathlib import Path
from src.train import main

Path(SAVE_DIR).mkdir(parents=True, exist_ok=True)

common_args = [
    "--data_root", DATA_ROOT,
    "--save_dir", SAVE_DIR,
    "--epochs", str(cfg["train"]["epochs"]),
    "--batch_size", str(cfg["data"]["batch_size"]),
    "--lr", str(cfg["train"]["lr"]),
]


train_setups = [
    (cfg["saves"]["base_model_name"],      []),
    (cfg["saves"]["prune_model_name"],     ["--use_channel_prune", "--use_unstructured_prune"]),
    (cfg["saves"]["qat_model_name"],       ["--use_qat"]),
    (cfg["saves"]["qat_prune_model_name"], ["--use_qat", "--use_channel_prune"]),
]


for model_tag, extra_flags in train_setups:
    sys.argv = [
        "train.py",
        *common_args,
        "--patch_size", str(cfg["data"]["patch_size"]),
        "--model_tag", model_tag,
        *extra_flags,
    ]
    main()


===== training base =====

starting loop


                                                                                   

[001/1] train: loss=0.043335, psnr=14.29 | val: loss=0.011210, psnr=19.70 | time=15.9s
  ↳ new best (19.70 dB) saved to /home/colin/projects/CISC_473_project/checkpoints/run1/base_best.pth
===== training prune =====





[ChannelPruning] Starting channel pruning with ch_sparsity=0.1, iterative_steps=1
[ChannelPruning] Pruning complete. Model now has fewer channels.
starting loop


                                                                                   

[001/1] train: loss=0.043392, psnr=14.32 | val: loss=0.013061, psnr=19.09 | time=15.7s
[UnstructuredPruning] Applying L1 unstructured pruning with amount=0.5
  ↳ new best (19.09 dB) saved to /home/colin/projects/CISC_473_project/checkpoints/run1/prune_best.pth
===== training qat =====



  prepared = prepare(


starting loop


                                                                                   

[001/1] train: loss=0.066672, psnr=12.24 | val: loss=0.013834, psnr=18.79 | time=14.4s
  ↳ new best (18.79 dB) saved to /home/colin/projects/CISC_473_project/checkpoints/run1/qat_best.pth
[info] quantized INT8 model saved to /home/colin/projects/CISC_473_project/checkpoints/run1/qat_int8_final.pth
===== training qat_prune =====

[ChannelPruning] Starting channel pruning with ch_sparsity=0.1, iterative_steps=1
[ChannelPruning] Pruning complete. Model now has fewer channels.


  prepared = prepare(


starting loop


                                                                                   

[001/1] train: loss=0.071106, psnr=11.85 | val: loss=0.014349, psnr=18.72 | time=13.7s
  ↳ new best (18.72 dB) saved to /home/colin/projects/CISC_473_project/checkpoints/run1/qat_prune_best.pth
[info] quantized INT8 model saved to /home/colin/projects/CISC_473_project/checkpoints/run1/qat_prune_int8_final.pth


In [4]:
from src.onnx_exports import main

# export our models to onnx format
main()


===== Exporting base from /home/colin/projects/CISC_473_project/checkpoints/run1/base_best.pth =====


  ckpt = torch.load(ckpt_path, map_location="cpu")


TypeError: UNetDenoise.__init__() got an unexpected keyword argument 'use_pruning'