In [None]:
"""
HRNet Training Script - Standalone Version
Can be run directly on Colab or locally with GPU
"""

import os
import sys
import torch

print("="*70)
print("HRNet Training Pipeline")
print("="*70)

# ============================================================================
# CONFIGURATION
# ============================================================================
CONFIG = {
    # Training parameters
    'TRAIN_SAMPLES': 500,
    'VAL_SAMPLES': 100,
    'BATCH_SIZE': 8,
    'LEARNING_RATE': 1e-3,
    'WEIGHT_DECAY': 1e-4,
    'DROPOUT_RATE': 0.1,
    'EPOCHS': 200,
    'PATIENCE': 5,
    'SEED': 42,
    
    # Model parameters
    'BASE_CHANNELS': 32,              # 32 or 48
    
    # Repository settings
    'REPO_URL': 'https://github.com/veselm73/SU2.git',
    'REPO_DIR': '/content/SU2',
    
    # Google Drive (for Colab)
    'USE_GDRIVE': True,
    'SAVE_DIR': '/content/drive/MyDrive/SU2_Project_HRNet',
}



In [None]:
# ============================================================================
# SETUP ENVIRONMENT
# ============================================================================

def setup_colab():
    """Setup Google Colab environment."""
    try:
        from google.colab import drive
        print("\n[1/5] Mounting Google Drive...")
        drive.mount('/content/drive')
        os.makedirs(CONFIG['SAVE_DIR'], exist_ok=True)
        print(f" Results will be saved to: {CONFIG['SAVE_DIR']}")
        return True
    except ImportError:
        print("\n[1/5] Not running on Colab, skipping Drive mount")
        CONFIG['USE_GDRIVE'] = False
        CONFIG['SAVE_DIR'] = './results'
        os.makedirs(CONFIG['SAVE_DIR'], exist_ok=True)
        return False

def install_dependencies():
    """Install required packages."""
    print("\n[2/5] Installing dependencies...")
    os.system('pip install -q btrack==0.6.5 "pydantic<2" pyyaml')
    print("Dependencies installed")

def clone_repository():
    """Clone or update repository."""
    print("\n[3/5] Setting up repository...")
    
    if not os.path.exists(CONFIG['REPO_DIR']):
        print(f"Cloning from {CONFIG['REPO_URL']}...")
        os.system(f"git clone {CONFIG['REPO_URL']} {CONFIG['REPO_DIR']}")
    else:
        print(f"Repository already exists at {CONFIG['REPO_DIR']}")
    
    # Add to path
    sys.path.append(CONFIG['REPO_DIR'])
    os.chdir(CONFIG['REPO_DIR'])
    print(f"Working directory: {os.getcwd()}")

def setup_config():
    """Create config.yaml file."""
    print("\n[4/5] Creating configuration...")
    
    config_content = f"""
TRAIN_SAMPLES: {CONFIG['TRAIN_SAMPLES']}
VAL_SAMPLES: {CONFIG['VAL_SAMPLES']}
BATCH_SIZE: {CONFIG['BATCH_SIZE']}
LEARNING_RATE: {CONFIG['LEARNING_RATE']}
WEIGHT_DECAY: {CONFIG['WEIGHT_DECAY']}
DROPOUT_RATE: {CONFIG['DROPOUT_RATE']}
EPOCHS: {CONFIG['EPOCHS']}
PATIENCE: {CONFIG['PATIENCE']}
SEED: {CONFIG['SEED']}

MIN_CELLS: 5
MAX_CELLS: 15
PATCH_SIZE: 128
SIM_CONFIG:
  na: 1.49
  wavelength: 512
  px_size: 0.07
  wiener_parameter: 0.1
  apo_cutoff: 2.0
  apo_bend: 0.9
"""
    
    with open('config.yaml', 'w') as f:
        f.write(config_content)
    
    print(" Configuration saved to config.yaml")



In [None]:
# ============================================================================
# TRAINING
# ============================================================================

def train_model():
    """Train HRNet model."""
    print("\n[5/5] Starting training...")
    print("="*70)
    
    # Import modules
    from modules.config import *
    from modules.utils import set_seed, plot_training_history
    from modules.training_hrnet import train_hrnet_pipeline
    
    # Set seed
    set_seed(CONFIG['SEED'])
    
    # Train
    print(f"Device: {DEVICE}")
    print(f"Samples: {CONFIG['TRAIN_SAMPLES']} train, {CONFIG['VAL_SAMPLES']} val")
    print(f"Epochs: {CONFIG['EPOCHS']}, Batch size: {CONFIG['BATCH_SIZE']}")
    print("="*70)
    
    model, history = train_hrnet_pipeline(
        train_samples=CONFIG['TRAIN_SAMPLES'],
        val_samples=CONFIG['VAL_SAMPLES'],
        epochs=CONFIG['EPOCHS'],
        batch_size=CONFIG['BATCH_SIZE'],
        learning_rate=CONFIG['LEARNING_RATE'],
        weight_decay=CONFIG['WEIGHT_DECAY'],
        patience=CONFIG['PATIENCE'],
        device=DEVICE,
        base_channels=CONFIG['BASE_CHANNELS'],
        dropout_rate=CONFIG['DROPOUT_RATE']
    )
    
    # Plot history
    print("\nPlotting training history...")
    plot_training_history(history)
    
    return model, history

def save_model(model):
    """Save trained model."""
    print("\nSaving model...")
    
    model_name = f"hrnet"
    
    # Save final model
    final_path = os.path.join(CONFIG['SAVE_DIR'], f"{model_name}_final.pth")
    torch.save(model.state_dict(), final_path)
    print(f" Final model saved to: {final_path}")
    
    # Save best model if exists
    if os.path.exists('best_model.pth'):
        import shutil
        best_path = os.path.join(CONFIG['SAVE_DIR'], f"{model_name}_best.pth")
        shutil.copy('best_model.pth', best_path)
        print(f" Best model saved to: {best_path}")

def download_validation_data():
    """Download validation dataset."""
    print("\nDownloading validation data...")
    
    from modules.utils import download_and_unzip
    import requests
    
    # Download certificate
    chain_path = "/content/chain-harica-cross.pem"
    cert_url = "https://pki.cesnet.cz/_media/certs/chain-harica-rsa-ov-crosssigned-root.pem"
    
    try:
        r = requests.get(cert_url, timeout=10, stream=True)
        r.raise_for_status()
        with open(chain_path, 'wb') as f:
            f.write(r.content)
        print(" Certificate downloaded")
    except Exception as e:
        print(f" Certificate download failed: {e}")
        chain_path = None
    
    # Download data
    zip_url = "https://su2.utia.cas.cz/files/labs/final2025/val_and_sota.zip"
    extract_dir = "/content/val_data"
    
    try:
        download_and_unzip(zip_url, extract_dir, chain_path)
        print(f" Validation data extracted to: {extract_dir}")
    except Exception as e:
        print(f"mData download failed: {e}")

def run_tracking(model):
    """Run tracking parameter sweep."""
    print("\nRunning tracking parameter sweep...")
    
    from modules.sweep import sweep_and_save_gif
    
    det_param_grid = {
        "threshold": [0.25, 0.3],
        "min_area": [4],
        "nms_min_dist": [3.0]
    }
    
    btrack_param_grid = {
        "do_optimize": [False],
        "max_search_radius": [20.0],
        "dist_thresh": [15.0],
        "time_thresh": [4, 6],
        "min_track_len": [10, 15],
        "segmentation_miss_rate": [0.1],
        "apoptosis_rate": [0.001],
        "allow_divisions": [False]
    }
    
    gif_path = os.path.join(
        CONFIG['SAVE_DIR'], 
        f"hrnet_tracking.gif"
    )
    
    try:
        best_det, best_bt, best_tracks = sweep_and_save_gif(
            model,
            det_param_grid,
            btrack_param_grid,
            gif_output=gif_path
        )
        print(f"✓ Tracking GIF saved to: {gif_path}")
    except Exception as e:
        print(f"⚠ Tracking failed: {e}")

def print_model_stats(model):
    """Print model statistics."""
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("\n" + "="*70)
    print("MODEL STATISTICS")
    print("="*70)
    print(f"Base channels: {CONFIG['BASE_CHANNELS']}")
    print(f"Total parameters: {num_params:,}")
    print(f"Model size: ~{num_params * 4 / 1024 / 1024:.2f} MB (float32)")
    print("="*70)



In [None]:
# ============================================================================
# MAIN
# ============================================================================

def main():
    """Main training pipeline."""
    print("\n" + "="*70)
    print("STARTING HRNET TRAINING PIPELINE")
    print("="*70)
    
    # Setup
    is_colab = setup_colab()
    install_dependencies()
    clone_repository()
    setup_config()
    
    # Train
    model, history = train_model()
    
    # Stats
    print_model_stats(model)
    
    # Save
    save_model(model)
    
    # Optional: Download validation data and run tracking
    try:
        download_validation_data()
        run_tracking(model)
    except Exception as e:
        print(f"\n Optional steps failed: {e}")
        print("Model training completed successfully anyway!")
    
    # Final message
    print("\n" + "="*70)
    print("TRAINING COMPLETED SUCCESSFULLY!")
    print("="*70)
    print(f" Models saved to: {CONFIG['SAVE_DIR']}")
    print(f" Best validation loss: {min(history['val_loss']):.4f}")
    
    # Auto-disconnect on Colab
    if is_colab:
        try:
            from google.colab import runtime
            print("\nDisconnecting runtime to save compute units...")
            runtime.unassign()
        except:
            pass

if __name__ == "__main__":
    main()