<a href="https://colab.research.google.com/github/colincockburn/CISC_473_project/blob/main/image_restoration_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Prepare Environment

import sys
import yaml
import os
from dotenv import load_dotenv
load_dotenv(dotenv_path="/home/colin/projects/CISC_473_project/.env")

REPO_DIR = os.getenv("REPO_DIR")
DATA_ROOT = os.getenv("DATA_ROOT")
SAVE_DIR = os.getenv("SAVE_DIR")
CKPT_PATH = f"{os.getenv('SAVE_DIR')}/base_best.pth"
sys.path.append(REPO_DIR)

with open(f"{REPO_DIR}/configs/default.yaml", "r") as f:
    cfg = yaml.safe_load(f)

In [2]:
# Check GPU-cuda functionality

import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))
    x = torch.rand(1000, 1000).cuda()
    y = torch.mm(x, x)
    print("GPU test OK, result:", y.sum().item())
else:
    print("GPU NOT detected.")

PyTorch version: 2.5.1+cu121
CUDA available: True
Device: NVIDIA GeForce RTX 3070
GPU test OK, result: 250472336.0


In [None]:
# Train all 4 models

import sys
from pathlib import Path
from src.train import main

Path(SAVE_DIR).mkdir(parents=True, exist_ok=True)

common_args = [
    "--data_root", DATA_ROOT,
    "--save_dir", SAVE_DIR,
    "--epochs", str(cfg["train"]["epochs"]),
    "--batch_size", str(cfg["data"]["batch_size"]),
    "--lr", str(cfg["train"]["lr"]),
]


train_setups = [
    (cfg["saves"]["base_model_name"],      []),
    (cfg["saves"]["prune_model_name"],     ["--use_channel_prune", "--use_unstructured_prune"]),
    (cfg["saves"]["qat_model_name"],       ["--use_qat"]),
    (cfg["saves"]["qat_prune_model_name"], ["--use_qat", "--use_channel_prune"]),
]


for model_tag, extra_flags in train_setups:
    sys.argv = [
        "train.py",
        *common_args,
        "--patch_size", str(cfg["data"]["patch_size"]),
        "--model_tag", model_tag,
        *extra_flags,
    ]
    main()


In [4]:
from src.onnx_exports import main

# export our models to onnx format
main()


===== Exporting BASE =====
CKPT: /home/colin/projects/CISC_473_project/checkpoints/run1/base_best.pth


  ckpt = torch.load(ckpt_path, map_location="cpu")


[ONNX] Exported → /home/colin/projects/CISC_473_project/exported_models/base.onnx

===== Exporting PRUNE =====
CKPT: /home/colin/projects/CISC_473_project/checkpoints/run1/prune_best.pth
[Eval] Re-applying channel pruning in eval: ch_sparsity=0.2, steps=1




[ONNX] Exported → /home/colin/projects/CISC_473_project/exported_models/prune.onnx

===== Exporting QAT =====
CKPT: /home/colin/projects/CISC_473_project/checkpoints/run1/qat_best.pth


  ckpt = torch.load(ckpt_path, map_location="cpu")
  prepared = prepare(


[ONNX] Exported → /home/colin/projects/CISC_473_project/exported_models/qat.onnx

===== Exporting QAT_PRUNE =====
CKPT: /home/colin/projects/CISC_473_project/checkpoints/run1/qat_prune_best.pth
[Eval] Re-applying channel pruning in eval: ch_sparsity=0.2, steps=1


  ckpt = torch.load(ckpt_path, map_location="cpu")
  prepared = prepare(


[Eval] Re-applying channel pruning in eval: ch_sparsity=0.2, steps=1




[ONNX] Exported → /home/colin/projects/CISC_473_project/exported_models/qat_prune.onnx

All ONNX exports completed successfully.
