# JERS (Joint Embedding for Radiology and Surgery) Tutorial

This notebook demonstrates how to use the JERS modules in the medical project for brain imaging analysis with joint embedding techniques.

## Overview

The JERS module provides functionality for:
1. **Brain Image Processing**: Specialized processing for brain imaging data
2. **Joint Embedding**: Links radiology and surgical information through shared representations
3. **BraTS Integration**: Works with BraTS brain tumor segmentation dataset
4. **Pre-trained Models**: Ready-to-use model weights and templates

## Table of Contents
1. [Setup and Environment Check](#setup)
2. [BraTS Dataset Loading](#brats-loading)
3. [JERS Preprocessing](#jers-preprocessing)
4. [Model Usage](#model-usage)

## 1. Setup and Environment Check {#setup}

In [2]:
# Check if the JERS modules exist and are accessible
import os
import sys
import warnings
from pathlib import Path
import numpy as np
import torch
import pandas as pd

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# Get the current notebook directory
notebook_dir = Path.cwd()
project_root = notebook_dir.parent.parent  # Assuming notebook is in tutorial/
jers_path = project_root / "medicalproject2024" / "preprocess" / "JERS"

print(f"Notebook directory: {notebook_dir}")
print(f"Project root: {project_root}")
print(f"JERS path: {jers_path}")
print(f"JERS path exists: {jers_path.exists()}")

if jers_path.exists():
    print("JERS modules found!")
    
    # Check available files
    files = {
        "inference.py": jers_path / "inference.py",
        "model.py": jers_path / "model.py", 
        "utils.py": jers_path / "utils.py",
        "checkpoints/": jers_path / "checkpoints"
    }
    
    print("\nAvailable files:")
    for file_name, file_path in files.items():
        status = "Exists" if file_path.exists() else "Not Found"
        print(f"  {status} {file_name}")
        
else:
    print("JERS modules not found!")

# Add project root to Python path
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
    print(f"\nPython path updated with: {project_root}")

  from .autonotebook import tqdm as notebook_tqdm


Notebook directory: /home/tjl20001104/workspace/Projects/USC/biobank/hugging-health/medicalproject2024/tutorial
Project root: /home/tjl20001104/workspace/Projects/USC/biobank/hugging-health
JERS path: /home/tjl20001104/workspace/Projects/USC/biobank/hugging-health/medicalproject2024/preprocess/JERS
JERS path exists: True
JERS modules found!

Available files:
  Exists inference.py
  Exists model.py
  Exists utils.py
  Exists checkpoints/

Python path updated with: /home/tjl20001104/workspace/Projects/USC/biobank/hugging-health


In [3]:
# Test importing JERS modules
def test_jers_imports():
    """Test importing JERS modules"""
    print("Testing JERS Module Imports:")
    print("-" * 40)
    
    modules_to_test = [
        ("medicalproject2024.preprocess.JERS.model", "JERS Model"),
        ("medicalproject2024.preprocess.JERS", "JERS Main"),
        ("medicalproject2024.dataLoader.segmentation.BraTS", "BraTS Loader"),
    ]
    
    import_results = {}
    
    for module_path, module_name in modules_to_test:
        try:
            __import__(module_path, fromlist=[''])
            import_results[module_name] = True
            print(f"{module_name:15}: Successfully imported")
        except ImportError as e:
            import_results[module_name] = False
            print(f"{module_name:15}: Import failed - {str(e)[:50]}...")
        except Exception as e:
            import_results[module_name] = False
            print(f"{module_name:15}: Warning - {str(e)[:50]}...")
    
    return import_results

# Run import tests
import_results = test_jers_imports()

Testing JERS Module Imports:
----------------------------------------
JERS Model     : Successfully imported
JERS Main      : Successfully imported
BraTS Loader   : Successfully imported


## 2. BraTS Dataset Loading {#brats-loading}

Load BraTS dataset for display

In [None]:
def load_brats_dataset():
    """Load BraTS dataset as in test_jers.py"""
    print("Loading BraTS Dataset:")
    print("=" * 30)
    
    try:
        from medicalproject2024.dataLoader.segmentation.BraTS import getBraTS
        print("Successfully imported BraTS loader")
        
        # Load BraTS data
        print("\nLoading BraTS data...")
        brats_data = getBraTS("data")
        
        if brats_data:
            dataset = pd.DataFrame(brats_data)
            dataset["dataset_name"] = "BraTS"
            dataset = dataset.rename(columns={"t1": "raw_img_path"})
            dataset = dataset[["raw_img_path", "dataset_name"]]
            
            print(f"BraTS dataset loaded: {len(dataset)} samples")
            print(f"   Columns: {list(dataset.columns)}")
            
            dataset = dataset.iloc[:3]
            print(f"\nUsing first 3 samples for processing:")
            for i, row in dataset.iterrows():
                print(f"   Sample {i+1}: {Path(row['raw_img_path']).name}")
            
            return dataset
        else:
            print("No BraTS data returned")
            return None
            
    except ImportError as e:
        print(f"Failed to import BraTS loader: {e}")
        return None
    except Exception as e:
        print(f"Error loading BraTS data: {e}")
        return None

# Load BraTS dataset
if import_results.get("BraTS Loader", False):
    brats_dataset = load_brats_dataset()
else:
    print("Skipping BraTS loading due to import failure")
    brats_dataset = None

Found local copy...


Loading BraTS Dataset:
Successfully imported BraTS loader

Loading BraTS data...
BraTS dataset loaded: 1251 samples
   Columns: ['raw_img_path', 'dataset_name']

Using first 3 samples for processing:
   Sample 1: BraTS2021_00402_t1.nii.gz
   Sample 2: BraTS2021_01003_t1.nii.gz
   Sample 3: BraTS2021_00757_t1.nii.gz


## 3. JERS Preprocessing {#jers-preprocessing}

Run JERS preprocessing exactly as in test_jers.py

In [None]:
# Run JERS preprocessing (exactly as in test_jers.py)
def run_jers_preprocessing(dataset):
    """Run JERS preprocessing as in test_jers.py"""
    if dataset is None or len(dataset) == 0:
        print("No dataset available for preprocessing")
        return None
    
    print("Running JERS Preprocessing:")
    print("=" * 40)
    
    try:
        from medicalproject2024.preprocess.JERS import JERS_preprocess
        print("Successfully imported JERS_preprocess")
        print(f"\nProcessing {len(dataset)} samples...")
        data_dict = {"BraTS": dataset}
        output_dir = "data"
        
        print(f"   Input: BraTS dataset with {len(dataset)} samples")
        print(f"   Output directory: {output_dir}")
        res = JERS_preprocess(data_dict, output_dir)
        
        print("JERS preprocessing completed!")
        print(f"   Result type: {type(res)}")
        
        if hasattr(res, 'shape'):
            print(f"   Result shape: {res.shape}")
        elif isinstance(res, (list, tuple)):
            print(f"   Result length: {len(res)}")
        elif isinstance(res, dict):
            print(f"   Result keys: {list(res.keys())}")
        
        return res
        
    except ImportError as e:
        print(f"Failed to import JERS_preprocess: {e}")
        return None
    except Exception as e:
        print(f"Preprocessing failed: {e}")
        print("   This might be due to missing input files or insufficient resources")
        return None

# Run preprocessing
if brats_dataset is not None and import_results.get("JERS Main", False):
    preprocessing_result = run_jers_preprocessing(brats_dataset)
else:
    print("Skipping preprocessing due to missing dataset or JERS module")
    preprocessing_result = None

Running JERS Preprocessing:
Successfully imported JERS_preprocess

Processing 2 samples...
   Input: BraTS dataset with 2 samples
   Output directory: data


Processing images: 100%|██████████| 2/2 [00:09<00:00,  4.69s/it]

JERS preprocessing completed!
   Result type: <class 'dict'>
   Result keys: ['BraTS']





## 4. Model Usage {#model-usage}

Demonstrate JERS model creation and usage

In [None]:
# Create and test JERS model (as in test_jers.py)
def create_jers_model():
    """Create JERS model with parameters from test_jers.py"""
    print("Creating JERS Model:")
    print("=" * 30)
    
    try:
        from medicalproject2024.preprocess.JERS.model import JERS
        print("Successfully imported JERS model")
        print("\nCreating model with parameters (96, 5, 5, 10)...")
        model = JERS(96, 5, 5, 10)
        
        # Model info
        total_params = sum(p.numel() for p in model.parameters())
        print(f"JERS model created successfully")
        print(f"   Total parameters: {total_params:,}")
        
        # Check for pre-trained weights
        model_weights_path = jers_path / "checkpoints" / "model_state.pt"
        if model_weights_path.exists():
            print(f"\nPre-trained weights found: {model_weights_path}")
            try:
                state_dict = torch.load(model_weights_path, map_location='cpu')
                model.load_state_dict(state_dict)
                print("Pre-trained weights loaded successfully")
            except Exception as e:
                print(f"Could not load weights: {e}")
        else:
            print("\nNo pre-trained weights found")
            print("To save model state: torch.save(model.state_dict(), 'model_state.pt')")
        
        model.eval()
        return model
        
    except ImportError as e:
        print(f"Failed to import JERS model: {e}")
        return None
    except Exception as e:
        print(f"Error creating model: {e}")
        return None

# Create JERS model
if import_results.get("JERS Model", False):
    jers_model = create_jers_model()
else:
    print("Skipping model creation due to import failure")
    jers_model = None

In [None]:
# Test model with dummy data
def test_model_inference(model):
    """Test model inference with dummy data"""
    if model is None:
        print("No model available for testing")
        return
    
    print("\nTesting Model Inference:")
    print("-" * 30)
    
    try:
        # Create dummy input (based on template size 96x96x96)
        dummy_input = torch.randn(1, 1, 96, 96, 96)
        print(f"Testing with input shape: {dummy_input.shape}")
        
        with torch.no_grad():
            output = model(dummy_input)
            
        print("Model inference successful!")
        print(f"   Output shape: {output.shape if hasattr(output, 'shape') else type(output)}")
        
        if hasattr(output, 'shape'):
            print(f"   Output range: [{output.min():.4f}, {output.max():.4f}]")
            
    except Exception as e:
        print(f"Inference failed: {e}")
        print("   This might be due to incorrect input dimensions")

# Test model inference
if jers_model is not None:
    test_model_inference(jers_model)
else:
    print("Skipping inference test due to model creation failure")

## Summary

This notebook demonstrates the JERS workflow following the exact structure from `test_jers.py`:

### Key Steps:
1. **Environment Setup**: Import required modules and check availability
2. **BraTS Dataset Loading**: Load and prepare BraTS brain tumor dataset
3. **JERS Preprocessing**: Process the dataset using `JERS_preprocess`
4. **Model Creation**: Create JERS model with parameters (96, 5, 5, 10)

### Files Structure:
- **model.py**: Contains JERS neural network architecture
- **inference.py**: Inference pipeline for processing
- **utils.py**: Utility functions
- **checkpoints/**: Pre-trained weights and templates
  - `model_state.pt`: Model weights
  - `template_img_orig_96.npy`: Brain template (96x96x96)
  - `template_img_gm_mask_orig_96.npy`: Gray matter mask

### Usage Notes:
- The workflow processes 3 BraTS samples by default
- Output is saved to `data/` directory
- Model uses 96x96x96 input dimensions
- Pre-trained weights are automatically loaded if available

### Next Steps:
- Use your own brain imaging data with the preprocessing pipeline
- Experiment with different model parameters
- Explore the joint embedding capabilities for radiology-surgery applications