# Load Models

This notebook demonstrates different ways to load models in PerturbLab.

## Features
- URL-style model loading with `Model()` function
- Registry-based access with `MODELS`
- Intelligent dependency management
- Lazy loading for fast startup


In [None]:
from perturblab.models import Model, MODELS

# List all available models
print("Available models:")
models = MODELS.list_keys(recursive=True)
for model in models[:10]:
    print(f"  - {model}")
print(f"\nTotal: {len(models)} model variants")


## Method 1: Using `Model()` Function (Recommended)


In [None]:
from perturblab.models.uce import UCEConfig
import torch

# URL-style path with case-insensitive matching
# UCEConfig needs __init__ method, create using from_dict
config = UCEConfig.from_dict({
    'token_dim': 512,
    'd_model': 1280,
    'nlayers': 4,
}, strict=False)

model = Model("UCE/default")(config)

print(f"Model loaded: {type(model).__name__}")
print(f"Model config: d_model={config.d_model}, nlayers={config.nlayers}")

# Case-insensitive matching also works
model2 = Model("uce/default")(config)
print(f"\nCase-insensitive works: {type(model).__name__ == type(model2).__name__}")


## Method 2: Using MODELS Registry


In [None]:
# Dot notation (IDE-friendly with autocomplete)
from perturblab.models.scfoundation import scFoundationConfig

config = scFoundationConfig(num_tokens=5000, embed_dim=512, depth=6)
model = MODELS.scFoundation.scFoundationModel(config)

print(f"Model loaded: {type(model).__name__}")

# Dictionary-style access (dynamic)
model2 = MODELS['scFoundation']['scFoundationModel'](config)
print(f"Dictionary access works: {type(model).__name__ == type(model2).__name__}")


## Access Model Components


In [None]:
# Access nested components using Model() function
encoder = Model("scGPT/components/GeneEncoder")(vocab_size=5000, dim=512)
print(f"Component loaded: {type(encoder).__name__}")

# Or using MODELS registry
encoder2 = MODELS.scGPT.components.GeneEncoder(vocab_size=5000, dim=512)
print(f"Registry access works: {type(encoder).__name__ == type(encoder2).__name__}")

# Get model class directly
model_class = Model("UCE/default").class_
print(f"\nModel class: {model_class.__name__}")
model3 = model_class(config)
print(f"Direct instantiation works: {type(model).__name__ == type(model3).__name__}")
