





# GETA Compressed Model Extractor

This notebook helps extract compressed models from GETA checkpoints. The compression is achieved through the Only-Train-Once (OTO) library which is integrated with the GETA compression technique.

## Overview

1. Import required libraries
2. Define extraction function
3. Extract compressed model from checkpoint
4. Analyze and visualize model statistics

In [None]:
# Import required libraries
import os
import sys
import torch
import argparse
from pathlib import Path

# Add the parent directory to sys.path to be able to import OpenGait modules
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd())))

# Import OpenGait and Only-Train-Once modules
try:
    from opengait.modeling import models
    from only_train_once import OTO
    print("✓ Successfully imported OpenGait and OTO modules")
except ImportError as e:
    print(f"Error importing required modules: {e}")
    print("Make sure you are running this notebook from the OpenGait directory")

## Define Extraction Function

The following cell defines the `extract_compressed_model` function that:
1. Loads a GETA checkpoint
2. Creates the model based on configuration
3. Traces the model with dummy input
4. Initializes OTO and constructs the compressed model
5. Compares original and compressed models (parameters, size, etc.)

In [None]:
def extract_compressed_model(checkpoint_path, output_dir='./compressed_models', 
                           model_name='GaitGLGeta', cfg_path='./configs/gaitgl/gaitgl_geta.yaml',
                           visualize=True):
    """
    Extract compressed model from a GETA checkpoint.
    
    Args:
        checkpoint_path: Path to the checkpoint file
        output_dir: Directory to save the compressed model
        model_name: Name of the model class
        cfg_path: Path to the model config file
        visualize: Print model statistics comparison
    
    Returns:
        Path to the compressed model if successful, None otherwise
    """
    print(f"Loading checkpoint from {checkpoint_path}")
    
    # Make sure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    try:
        # Load the checkpoint
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        # Get model configuration
        from opengait.utils import config_loader
        cfgs = config_loader(cfg_path)
        
        # Create model instance
        print(f"Creating {model_name} instance")
        ModelClass = getattr(models, model_name)
        model = ModelClass(cfgs, training=False)
        
        # Load weights
        if 'model' in checkpoint:
            model.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint)
            
        # Move model to GPU if available
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        
        # Create dummy input for the model
        print("Creating dummy input for model tracing")
        batch_size = 4
        seq_len = 30  # Frames per sequence
        height = 64   # Height of silhouette
        width = 44    # Width of silhouette
        
        torch.manual_seed(42)  # For reproducibility
        sils = torch.rand(batch_size, seq_len, 1, height, width).to(device)
        labs = torch.zeros(batch_size).long().to(device)
        typs = torch.zeros(batch_size).long().to(device)
        vies = torch.zeros(batch_size).long().to(device)
        seqL = torch.full((batch_size,), seq_len).long().to(device)
        
        dummy_input = [sils, labs, typs, vies, seqL]

        # Create OTO instance and construct compressed model
        print("Initializing OTO and constructing compressed model")
        model.eval()  # Set to evaluation mode
        model.oto = OTO(model=model, dummy_input=dummy_input)
        compressed_model_path = model.construct_compressed_model(out_dir=output_dir)
        
        if compressed_model_path:
            print(f"✓ Compressed model saved to: {compressed_model_path}")
        else:
            print("⚠ No compressed model was created, please check if GETA/HESSO was used during training")
            return None

        # Compare models if visualize flag is set
        if visualize and compressed_model_path:
            # Load the compressed model
            compressed_model = torch.load(compressed_model_path)
            
            # Count parameters
            original_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            compressed_params = sum(p.numel() for p in compressed_model.parameters() if p.requires_grad)
            
            print("\n======= MODEL STATISTICS =======")
            print(f"Original model parameters: {original_params:,}")
            print(f"Compressed model parameters: {compressed_params:,}")
            print(f"Compression ratio: {compressed_params/original_params:.4f}")
            print(f"Parameter reduction: {(1-compressed_params/original_params)*100:.2f}%")
            
            # Calculate file sizes
            checkpoint_size_mb = os.path.getsize(checkpoint_path) / (1024*1024)
            compressed_size_mb = os.path.getsize(compressed_model_path) / (1024*1024)
            
            print(f"\nOriginal checkpoint size: {checkpoint_size_mb:.2f} MB")
            print(f"Compressed model size: {compressed_size_mb:.2f} MB")
            print(f"File size reduction: {(1-compressed_size_mb/checkpoint_size_mb)*100:.2f}%")
            
            # Try to compute MACs/FLOPs if supported by OTO
            try:
                oto_original = OTO(model=model, dummy_input=dummy_input)
                oto_compressed = OTO(model=compressed_model, dummy_input=dummy_input)
                
                # Calculate MACs (Multiply-Accumulate Operations)
                original_macs = oto_original.compute_macs(in_million=True)['total']
                compressed_macs = oto_compressed.compute_macs(in_million=True)['total']
                
                print(f"\nOriginal model MACs: {original_macs:.2f}M")
                print(f"Compressed model MACs: {compressed_macs:.2f}M")
                print(f"MACs reduction: {(1-compressed_macs/original_macs)*100:.2f}%")
                print("==================================")
            except Exception as e:
                print(f"Could not compute MACs: {e}")
        
        return compressed_model_path
            
    except Exception as e:
        print(f"Error extracting compressed model: {str(e)}")
        raise

## Parse Command Line Arguments

This notebook can be run directly from the Kaggle terminal using the `papermill` library. It accepts command-line arguments through notebook parameters.

In [None]:
# Default notebook parameters (can be overridden with papermill)
checkpoint_path = None  # Path to the checkpoint file
output_dir = "./compressed_models"
model_name = "GaitGLGeta"
cfg_path = "./configs/gaitgl/gaitgl_geta.yaml"
visualize = True

# If running directly (not with papermill), allow parsing from command line
if 'get_ipython' in globals():
    try:
        import IPython
        # Check if run through papermill (which sets the notebook parameters)
        if IPython.get_ipython().user_ns.get('__PAPERMILL__') is not True:
            # If not running through papermill, try to get arguments from sys.argv
            import argparse
            parser = argparse.ArgumentParser(description='Extract compressed model from GETA checkpoint')
            parser.add_argument('--checkpoint', required=True, help='Path to the checkpoint file')
            parser.add_argument('--output-dir', default='./compressed_models', help='Output directory for compressed model')
            parser.add_argument('--model-name', default='GaitGLGeta', help='Model class name')
            parser.add_argument('--cfg-path', default='./configs/gaitgl/gaitgl_geta.yaml', help='Path to config file')
            parser.add_argument('--visualize', action='store_true', help='Print model statistics')

            # Parse only known arguments (allows notebook to work with additional arguments from jupyter)
            args, _ = parser.parse_known_args()
            
            # Update variables with command line arguments
            if args.checkpoint:
                checkpoint_path = args.checkpoint
            if args.output_dir:
                output_dir = args.output_dir
            if args.model_name:
                model_name = args.model_name
            if args.cfg_path:
                cfg_path = args.cfg_path
            if args.visualize:
                visualize = args.visualize
    except Exception as e:
        print(f"Failed to parse command line arguments: {e}")

# If no checkpoint path is provided, show a meaningful error
if not checkpoint_path:
    print("⚠️ Error: No checkpoint path provided!")
    print("Please provide a checkpoint path using: --checkpoint PATH_TO_CHECKPOINT")
    print("Example: python -m papermill extract_compressed.ipynb output.ipynb -p checkpoint_path /path/to/checkpoint.pt")

print(f"Checkpoint path: {checkpoint_path}")
print(f"Output directory: {output_dir}")
print(f"Model name: {model_name}")
print(f"Config path: {cfg_path}")
print(f"Visualize: {visualize}")

## Extract Compressed Model

Now let's run the extraction function with the provided parameters.

In [None]:
# Only run if checkpoint path is provided
if checkpoint_path:
    try:
        # Run extraction function
        compressed_model_path = extract_compressed_model(
            checkpoint_path=checkpoint_path,
            output_dir=output_dir,
            model_name=model_name,
            cfg_path=cfg_path,
            visualize=visualize
        )
        
        if compressed_model_path:
            print(f"\n✅ Success! Compressed model extracted to: {compressed_model_path}")
        else:
            print("\n❌ Failed to extract compressed model.")
    except Exception as e:
        import traceback
        print(f"\n❌ Error during model extraction: {str(e)}")
        traceback.print_exc()
else:
    print("\n⚠️ Skipping extraction because no checkpoint path was provided.")

## How to Run

This notebook can be run in different ways:

### 1. From Jupyter notebook interface
Simply fill in the `checkpoint_path` and other parameters in the "Parse Command Line Arguments" cell and run all cells.

### 2. From terminal using papermill (recommended for Kaggle)
```bash
papermill extract_compressed.ipynb output.ipynb \
  -p checkpoint_path "/path/to/checkpoint.pt" \
  -p output_dir "/output/directory" \
  -p model_name "GaitGLGeta" \
  -p cfg_path "/path/to/config.yaml" \
  -p visualize True
```

### 3. From terminal using parameters with Python
```bash
python -c "import sys; sys.path.append('/kaggle/working/geta_gaitGL/OpenGait'); \
  from extract_compressed import extract_compressed_model; \
  extract_compressed_model('/path/to/checkpoint.pt', '/output/directory', 'GaitGLGeta', '/path/to/config.yaml', True)"
```

### Example for Kaggle

```bash
cd /kaggle/working/geta_gaitGL/OpenGait
papermill extract_compressed.ipynb output.ipynb \
  -p checkpoint_path "/kaggle/working/geta_gaitGL/OpenGait/output/CASIA-B/GaitGLGeta/GaitGL_GETA/checkpoints/GaitGL_GETA-80000.pt" \
  -p output_dir "/kaggle/working/compressed_models" \
  -p visualize True
```