# PyTorch Pre-trained Model Analysis for EchoGAINS (Colab Version)

This Colab notebook will:
- Set up the environment with all necessary dependencies
- Analyze the pre-trained models:
  - `CAMUS_diffusion_model.pt` (nnU-Net segmentation model)
  - `checkpoint_best.pth` (Diffusion model)
- Print model architecture, layer details, parameter counts, input/output shapes, and checkpoint metadata

You will need to upload the `.pt` and `.pth` files to the Colab session, or mount Google Drive.

In [None]:
# 1. ENVIRONMENT SETUP
!pip install torch torchvision --quiet
!pip install torchinfo --quiet
# Optional: install nnU-Net and guided-diffusion if you want to load actual architectures
# !pip install nnunet --quiet
# !pip install git+https://github.com/openai/guided-diffusion.git --quiet

import torch
from torchinfo import summary
import os
import pprint

In [None]:
# 2. UPLOAD MODEL FILES
from google.colab import files

print("Please upload 'CAMUS_diffusion_model.pt' and 'checkpoint_best.pth' from your computer, or skip if using Google Drive.")
uploaded = files.upload()

In [None]:
# OPTIONAL: MOUNT GOOGLE DRIVE IF FILES ARE THERE
# from google.colab import drive
# drive.mount('/content/drive')

## 3. MODEL CLASS PLACEHOLDERS (Must be filled with real code)
You must fill in the real nnU-Net and diffusion model class definitions here to load the weights.

- For nnU-Net: see [nnU-Net repo](https://github.com/MIC-DKFZ/nnUNet).
- For guided-diffusion: see the UNetModel in [guided-diffusion](https://github.com/openai/guided-diffusion).

_You can still analyze the checkpoint files even without these definitions, but you won't be able to print the full architecture and parameter summary._

In [None]:
# TODO: Insert nnU-Net and Diffusion model classes here, if you want to fully load the models
nnUNetPredictor = None  # Placeholder
UNetModel = None        # Placeholder

## 4. Checkpoint Analysis Helper
This will print all keys and high-level info from the checkpoint files.

In [None]:
def analyze_checkpoint(path):
    print(f"\nLoading checkpoint: {path}")
    checkpoint = torch.load(path, map_location='cpu')
    if isinstance(checkpoint, dict):
        print("Checkpoint keys:", checkpoint.keys())
        for key in checkpoint.keys():
            if key not in ['state_dict', 'model_state_dict']:
                val = checkpoint[key]
                print(f"{key}: {val if not isinstance(val, dict) else '[dict]'}")
    else:
        print("Checkpoint is not a dictionary! Type:", type(checkpoint))
    return checkpoint

## 5. Analyze nnU-Net Segmentation Model (`CAMUS_diffusion_model.pt`)
If you don't have the model class, you can still analyze the checkpoint keys.

In [None]:
nnunet_ckpt_path = 'CAMUS_diffusion_model.pt'

nnunet_checkpoint = analyze_checkpoint(nnunet_ckpt_path)

# If you have nnU-Net code, you could do something like:
# if nnUNetPredictor:
#     predictor = nnUNetPredictor()
#     predictor.load_checkpoint(nnunet_ckpt_path)
#     model = predictor.network
#     summary(model, input_size=(1, 1, 256, 256))
# else:
print("[INFO] nnU-Net class not set up. Add code to load and analyze the model architecture.")

## 6. Analyze Diffusion Model (`checkpoint_best.pth`)
If you don't have the model class, you can still analyze the checkpoint keys.

In [None]:
diffusion_ckpt_path = 'checkpoint_best.pth'

diffusion_checkpoint = analyze_checkpoint(diffusion_ckpt_path)

# If you have UNetModel code, you could do something like:
# if UNetModel:
#     model = UNetModel(image_size=256, in_channels=1, out_channels=1, ...)
#     model.load_state_dict(diffusion_checkpoint['model_state_dict'])
#     summary(model, input_size=(1, 1, 256, 256))
# else:
print("[INFO] Diffusion UNetModel class not set up. Add code to load and analyze the model architecture.")

## 7. Tips
- If you want to use actual model class code, copy/paste or import the class definitions for nnU-Net and UNetModel.
- Use `summary(model, input_size=...)` for a detailed breakdown of layers, shapes, and parameter counts.
- For visualization, you can export PyTorch models to ONNX and view with [Netron](https://netron.app/).

For further analysis, refer to the [torchinfo](https://github.com/TylerYep/torchinfo) or [pytorch_model_summary](https://github.com/amarczew/pytorch_model_summary) documentation.