In [None]:
!git clone https://github.com/imabeastdrew/Martydepth.git
%cd Martydepth

# Install the package in development mode
%pip install -e .


In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml


In [None]:
# Import required libraries
import torch
import wandb
import tempfile
import json
import numpy as np
from pathlib import Path
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
sys.path.append('.')

print(os.getcwd())

# Import project modules
from src.data.dataset import create_dataloader
from src.models.online_transformer import OnlineTransformer
from src.models.offline_teacher import OfflineTeacherModel
from src.evaluation.metrics import (
    calculate_harmony_metrics,
    calculate_emd_metrics,
)
from src.evaluation.evaluate import load_model_from_wandb as load_online_model
from src.evaluation.evaluate_offline import load_model_from_wandb as load_offline_model
from src.evaluation.evaluate import generate_online
from src.evaluation.evaluate_offline import generate_offline
from src.config.tokenization_config import PAD_TOKEN


In [None]:
# Configuration
config = {
    'data_dir': 'data/interim',
    'split': 'test',
    'batch_size': 32,
    'num_workers': 4,
    'temperature': 1.0,
    'top_k': 50,
}

# Model artifact paths (replace with your actual artifact paths)
online_artifact_path = "marty1ai/martydepth/celestial-field-49-epoch-34:v0"
offline_artifact_path = "marty1ai/martydepth/offline_teacher_epoch_19_loss_0.23:v0"

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else 
                     "mps" if torch.backends.mps.is_available() else 
                     "cpu")
print(f"Using device: {device}")

# Initialize wandb
wandb.init(
    project="martydepth",
    name="model_evaluation",
    config=config,
    job_type="evaluation"
)


In [None]:
# Evaluate Online Model
print("\n=== Evaluating Online Model ===")
print(f"Loading model from artifact: {online_artifact_path}")

# Load model
model, tokenizer_info, model_config = load_online_model(online_artifact_path, device)
model.eval()

# Create dataloader
max_seq_length = model_config.get('max_seq_length') or model_config.get('max_sequence_length') or 512
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='online',
    shuffle=False
)

# Generate sequences
print("\nGenerating sequences...")
generated_sequences, ground_truth_sequences = generate_online(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device,
    temperature=config['temperature'],
    top_k=config['top_k']
)

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

online_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Online Model Results ===")
for metric, value in online_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"online/{metric}": value})


In [None]:
# Evaluate Offline Model
print("\n=== Evaluating Offline Model ===")
print(f"Loading model from artifact: {offline_artifact_path}")

# Load model
model, model_config, tokenizer_info = load_offline_model(offline_artifact_path, device)
model.eval()

# Create dataloader
max_seq_length = model_config.get('max_seq_length') or model_config.get('max_sequence_length') or 512
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='offline',
    shuffle=False
)

# Generate sequences
print("\nGenerating sequences...")
generated_sequences, ground_truth_sequences = generate_offline(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device
)

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

offline_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Offline Model Results ===")
for metric, value in offline_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"offline/{metric}": value})


In [None]:
# Compare Results
print("\n=== Model Comparison ===")
print(f"{'Metric':<30} {'Online':<10} {'Offline':<10} {'Difference':<10}")
print("-" * 60)

for metric in online_metrics.keys():
    online_value = online_metrics[metric]
    offline_value = offline_metrics[metric]
    diff = online_value - offline_value
    print(f"{metric:<30} {online_value:>10.4f} {offline_value:>10.4f} {diff:>10.4f}")
    wandb.log({
        f"comparison/{metric}_diff": diff,
        f"comparison/{metric}_ratio": online_value / offline_value if offline_value != 0 else 0
    })

# Finish wandb run
wandb.finish()
