# SL-BEATs Model Demo

This notebook demonstrates how to use the `representation-learning` package with the `sl-beats` model for:
1. Classification using the original head
2. Adding a new classification head
3. Embedding extraction

## Setup

First, we need to authenticate and install the package from the private PyPI repository.

## Step 1: Authenticate with Google Cloud

**For Google Colab**: You need to authenticate to access the private PyPI package.

**For Local Execution**: If you're running this locally and already have `gcloud` configured, you can skip this step or run the authentication commands in your terminal first.

In [1]:
# Authenticate with Google Cloud
import importlib.util
import subprocess

# Check if we're in Colab
IN_COLAB = importlib.util.find_spec("google.colab") is not None

if IN_COLAB:
    # Colab-specific authentication
    get_ipython().system("gcloud auth login --no-launch-browser")
    get_ipython().system("gcloud auth application-default login --no-launch-browser")
else:
    # Local execution - check if already authenticated
    print("Running locally. Checking gcloud authentication...")
    try:
        result = subprocess.run(["gcloud", "auth", "list"], capture_output=True, text=True, check=False)
        if "ACTIVE" in result.stdout:
            print("‚úÖ Already authenticated with gcloud")
        else:
            print("‚ö†Ô∏è  Not authenticated. Please run in terminal:")
            print("   gcloud auth login")
            print("   gcloud auth application-default login")
    except FileNotFoundError:
        print("‚ö†Ô∏è  gcloud CLI not found. Please install it or authenticate manually.")

Go to the following link in your browser, and complete the sign-in prompts:

    https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=32555940559.apps.googleusercontent.com&redirect_uri=https%3A%2F%2Fsdk.cloud.google.com%2Fauthcode.html&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fappengine.admin+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fsqlservice.login+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcompute+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Faccounts.reauth&state=mtjtsgfisshNl9eyasBT2JTMFubLGk&prompt=consent&token_usage=remote&access_type=offline&code_challenge=1M87PhDVsH6O77-4xIBT3GtXbZmZ_2p-7gJsTEQQvlc&code_challenge_method=S256

Once finished, enter the verification code provided in your browser: 4/0Ab32j93qeGjXgvqNOycIJJgRTlJng0ky_xnVToPf5mhiGwk2neGuxBobKYKBJQ0zYUX1Yw

You are now logged in as [marius@earthspecies.org].
Your current projec

## Step 2: Install UV and Keyring Plugin

Install `uv` package manager and the Google Artifact Registry authentication plugin.

**For Local Execution**: If you already have `uv` installed and the package is available locally, you can skip this step.

In [1]:
# Install uv
import shutil
import subprocess

uv_installed = shutil.which("uv") is not None

if not uv_installed:
    print("Installing uv...")
    try:
        get_ipython().system("pip install uv")
    except NameError:
        subprocess.run(["pip", "install", "uv"], check=False)
else:
    print("‚úÖ uv is already installed")

# Install keyring
try:
    get_ipython().system("uv tool install keyring --with keyrings.google-artifactregistry-auth")
except NameError:
    subprocess.run(["uv", "tool", "install", "keyring", "--with", "keyrings.google-artifactregistry-auth"], check=False)
except Exception:
    print("‚ö†Ô∏è  Keyring installation may have failed. If running locally with the package already installed, this is OK.")

‚úÖ uv is already installed
`[36mkeyring[39m` is already installed


## Step 3: Configure UV for Private PyPI

Create a `pyproject.toml` configuration file to use the private index.

**For Local Execution**: If you're running from the project root, this will append to existing pyproject.toml or create a new one.

In [2]:
from pathlib import Path

# Check if we're in the project root (has existing pyproject.toml with project config)
current_dir = Path.cwd()
existing_pyproject = current_dir / "pyproject.toml"

# Check if existing pyproject.toml has [project] section (indicates it's a real project)
has_project_config = False
needs_fix = False
if existing_pyproject.exists():
    try:
        content = existing_pyproject.read_text()
        if "[project]" in content or "[build-system]" in content:
            has_project_config = True
            print("‚ö†Ô∏è  Found existing pyproject.toml with project configuration")
            print("   We'll append our index configuration to it...")
        # Check if it has wrong format (single brackets instead of double)
        if "[tool.uv.index]" in content and "[[tool.uv.index]]" not in content:
            needs_fix = True
            print("‚ö†Ô∏è  Found incorrect format in pyproject.toml (single brackets)")
            print("   Fixing format by replacing [tool.uv.index] with [[tool.uv.index]]...")
            # Fix the format
            content = content.replace("[tool.uv.index]", "[[tool.uv.index]]")
            with open("pyproject.toml", "w") as f:
                f.write(content)
            print("   ‚úÖ Fixed format in existing pyproject.toml")
    except Exception as e:
        print(f"   ‚ö†Ô∏è  Could not read existing pyproject.toml: {e}")

# Create/update pyproject.toml configuration
# NOTE: Use [[tool.uv.index]] (double brackets) for array of tables - this is REQUIRED by uv
pyproject_content = """[[tool.uv.index]]
name = "esp-pypi"
url = "https://oauth2accesstoken@us-central1-python.pkg.dev/okapi-274503/esp-pypi/simple/"
explicit = true

[tool.uv.sources]
representation-learning = { index = "esp-pypi" }

[tool.uv]
keyring-provider = "subprocess"
"""

if has_project_config and not needs_fix:
    # Append to existing file (only if we didn't just fix it)
    with open("pyproject.toml", "a") as f:
        f.write("\n")
        f.write(pyproject_content)
    print("‚úÖ Appended index configuration to existing pyproject.toml")
else:
    # Create new file or overwrite if we fixed the format
    if not has_project_config:
        with open("pyproject.toml", "w") as f:
            f.write(pyproject_content)
        print("‚úÖ Created pyproject.toml configuration")
    print("   Note: Using [[tool.uv.index]] (double brackets) as required by uv")

‚úÖ Created pyproject.toml configuration
   Note: Using [[tool.uv.index]] (double brackets) as required by uv


## Step 4: Install representation-learning Package

Install the package using `uv` with the configured private index.

**Note**: We'll try `uv add` first (which reads pyproject.toml), then fall back to direct index URL if needed.

**For Local Execution**: If the package is already installed in your environment (e.g., via `uv sync`), you can skip this step.

In [3]:
# Install numpy first to avoid compatibility issues
!pip install numpy==1.26.4
!uv pip install representation-learning --extra-index-url https://oauth2accesstoken@us-central1-python.pkg.dev/okapi-274503/esp-pypi/simple/

[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 130ms[0m[0m


## Step 5: Import and Verify Installation

Import the package and verify it's working correctly.

In [4]:
import torch

from representation_learning import list_models, load_model

print("‚úÖ Package imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# List available models
print("\nüìã Available models:")
models = list_models()
for name in list(models.keys())[:5]:  # Show first 5
    print(f"  - {name}")
if len(models) > 5:
    print(f"  ... and {len(models) - 5} more")



‚úÖ Package imported successfully!
PyTorch version: 2.5.0+cu124
CUDA available: False

üìã Available models:


## Step 6: Use Case 1 - Classification with Original Head

Load the `sl_beats_animalspeak` model with its original classification head from the checkpoint.
When `num_classes=None`, the model will extract the number of classes from the checkpoint and load the trained classifier weights.

In [None]:
print("üöÄ Use Case 1: Classification with Original Head")
print("=" * 60)

# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

try:
    # Load model with original classifier (num_classes=None extracts from checkpoint)
    model = load_model("sl_beats_animalspeak", num_classes=None, device=device)
    model.eval()

    print("\n‚úÖ Model loaded successfully!")
    print(f"   Model type: {type(model).__name__}")

    # Check if classifier exists
    if hasattr(model, "classifier") and model.classifier is not None:
        print("   ‚úÖ Original classifier loaded from checkpoint")
        print(f"   Classifier weight shape: {model.classifier.weight.shape}")
        print(f"   Classifier bias shape: {model.classifier.bias.shape}")
        num_classes = model.classifier.weight.shape[0]
        print(f"   Number of classes: {num_classes}")

        # Check for class mapping
        if hasattr(model, "class_mapping") and model.class_mapping:
            index_to_label = model.class_mapping.get("index_to_label", {})
            print(f"   Class mapping available: {len(index_to_label)} classes")
            if index_to_label:
                print("   Sample classes:")
                for idx in list(index_to_label.keys())[:5]:
                    print(f"     {idx}: {index_to_label[idx]}")
    else:
        print("   ‚ö†Ô∏è  No classifier found (model in embedding mode)")

    # Test forward pass
    print("\nüß™ Testing forward pass...")
    dummy_input = torch.randn(1, 16000 * 5)  # 5 seconds of audio at 16kHz

    with torch.no_grad():
        output = model(dummy_input, padding_mask=None)

    print(f"   Input shape: {dummy_input.shape}")
    print(f"   Output shape: {output.shape}")

    if hasattr(model, "classifier") and model.classifier is not None:
        # Get predictions
        probs = torch.softmax(output, dim=-1)
        top_probs, top_indices = torch.topk(probs, k=min(3, output.shape[-1]), dim=-1)
        print("\n   Top-3 predictions:")
        for i, (prob, idx) in enumerate(zip(top_probs[0], top_indices[0], strict=False)):
            idx_int = idx.item()
            if hasattr(model, "class_mapping") and model.class_mapping:
                index_to_label = model.class_mapping.get("index_to_label", {})
                label = index_to_label.get(idx_int, f"Class {idx_int}")
                print(f"     {i + 1}. {label}: {prob.item():.4f}")
            else:
                print(f"     {i + 1}. Class {idx_int}: {prob.item():.4f}")
    else:
        print("   ‚úÖ Model returns embeddings (not classification logits)")

except Exception as e:
    print(f"‚ùå Error: {type(e).__name__}: {e}")
    import traceback

    traceback.print_exc()

## Step 7: Use Case 2 - Adding a New Classification Head

Load the model and add a new classification head with a different number of classes.
When `num_classes` is explicitly provided, the classifier weights are randomly initialized (not loaded from checkpoint).

In [None]:
print("üöÄ Use Case 2: Adding a New Classification Head")
print("=" * 60)

try:
    # Load model with a new classifier (explicit num_classes)
    new_num_classes = 20
    print(f"Creating new classifier with {new_num_classes} classes...")

    model = load_model("sl_beats_animalspeak", num_classes=new_num_classes, device=device)
    model.eval()

    print("\n‚úÖ Model loaded with new classifier!")

    if hasattr(model, "classifier") and model.classifier is not None:
        print("   ‚úÖ New classifier created")
        print(f"   Classifier weight shape: {model.classifier.weight.shape}")
        print(f"   Classifier bias shape: {model.classifier.bias.shape}")
        print(f"   Number of classes: {new_num_classes}")
        print("   üí° Note: Classifier weights are randomly initialized (not from checkpoint)")
    else:
        print("   ‚ùå No classifier found")

    # Test forward pass
    print("\nüß™ Testing forward pass...")
    dummy_input = torch.randn(1, 16000 * 5)  # 5 seconds of audio

    with torch.no_grad():
        output = model(dummy_input, padding_mask=None)

    print(f"   Input shape: {dummy_input.shape}")
    print(f"   Output shape: {output.shape}")
    print(f"   ‚úÖ Model outputs classification logits for {new_num_classes} classes")

    print("\nüí° This classifier can be trained for your specific task!")

except Exception as e:
    print(f"‚ùå Error: {type(e).__name__}: {e}")
    import traceback

    traceback.print_exc()

## Step 8: Use Case 3 - Embedding Extraction

Load the model in embedding extraction mode. This is useful for:
- Transfer learning
- Linear probing
- Feature extraction for downstream tasks

When loading with `return_features_only=True`, the model returns embeddings instead of classification logits.

In [None]:
print("üöÄ Use Case 3: Embedding Extraction")
print("=" * 60)

try:
    # Load model for embedding extraction using return_features_only=True
    print("Loading sl_beats_animalspeak in embedding extraction mode...")
    print("(Using return_features_only=True to extract embeddings)")

    model = load_model("sl_beats_animalspeak", num_classes=None, return_features_only=True, device=device)
    model.eval()

    print("\n‚úÖ Model loaded in embedding extraction mode!")

    # Check if classifier exists
    has_classifier = hasattr(model, "classifier") and model.classifier is not None
    if has_classifier:
        print("   ‚ö†Ô∏è  Model has a classifier (unexpected for embedding mode)")
    else:
        print("   ‚úÖ Model has no classifier (embedding extraction mode)")
        print(f"   Return features only: {getattr(model, '_return_features_only', 'N/A')}")

    # Test forward pass - should return embeddings
    print("\nüß™ Testing forward pass...")
    dummy_input = torch.randn(1, 16000 * 5)  # 5 seconds of audio

    with torch.no_grad():
        output = model(dummy_input, padding_mask=None)

    print(f"   Input shape: {dummy_input.shape}")
    print(f"   Output shape: {output.shape}")
    print(f"   ‚úÖ Model returns embeddings (dimension: {output.shape[-1]})")

    # Show embedding statistics
    print("\nüìä Embedding statistics:")
    print(f"   Mean: {output.mean().item():.4f}")
    print(f"   Std: {output.std().item():.4f}")
    print(f"   Min: {output.min().item():.4f}")
    print(f"   Max: {output.max().item():.4f}")

    print("\nüí° These embeddings can be used for:")
    print("   - Linear probing (training a simple classifier on top)")
    print("   - Similarity search")
    print("   - Clustering")
    print("   - Transfer learning to new tasks")

except Exception as e:
    print(f"‚ùå Error: {type(e).__name__}: {e}")
    import traceback

    traceback.print_exc()

## Summary

This notebook demonstrated three main use cases for the `sl-beats` model:

1. **Classification with Original Head**: Load the model with `num_classes=None` to use the trained classifier from the checkpoint.

2. **Adding a New Classification Head**: Load the model with an explicit `num_classes` to create a new randomly initialized classifier for your specific task.

3. **Embedding Extraction**: Load the model with `return_features_only=True` to extract features for downstream tasks.

### Key Takeaways:

- `num_classes=None`: Extracts number of classes from checkpoint and loads trained classifier
- `num_classes=<number>`: Creates a new randomly initialized classifier
- `return_features_only=True`: Loads model in embedding extraction mode (no classifier)
- The model backbone can be accessed directly for custom feature extraction

### Next Steps:

- Train the new classifier on your dataset
- Use embeddings for linear probing or similarity search
- Fine-tune the entire model for your specific task