# LODOCV – Leave-One-Day-Out Cross-Validation

4 folds, one per Nezha collection day (20220822, 20220823, 20230129, 20230130). Nezha only, no LEMMA — zero temporal overlap between train and test.

Reports Acc@1, Acc@3, MRR (mean +/- std).


In [1]:
import json
import random
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sentence_transformers import SentenceTransformer

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
# === PATHS ===
BASE_DIR = Path("/root/lemm")
MULTIMODAL_DIR = BASE_DIR / "core_multimodal_tmp"
METRICS_DIR = BASE_DIR / "core_metrics_tmp"
LOGS_DIR = BASE_DIR / "core_logs_tmp"

# === HYPERPARAMETERS ===
# Base config (n_metrics will be determined dynamically from data)
CONFIG = {
    "window_size": 22,
    "n_metrics": None,  # Will be set dynamically from manifest.json
    "bin_seconds": 30,
    "log_embed_dim": 384,
    "d_model": 128,
    "n_heads": 4,
    "n_layers": 2,
    "fusion_dim": 256,
    "dropout": 0.35,
    "batch_size": 8,
    "learning_rate": 1e-4,
    "weight_decay": 1e-4,
    "epochs": 50,
    "warmup_epochs": 10,
    "temperature": 0.1,
    "label_smoothing": 0.02,
    "grad_clip_norm": 1.0,
    "jitter": 3,
    "noise_std": 0.1,
    "mask_prob": 0.1,
}

# Load Optuna results if available and merge
OPTUNA_RESULTS_PATH = BASE_DIR / "best_hyperparams.json"
if OPTUNA_RESULTS_PATH.exists():
    with open(OPTUNA_RESULTS_PATH) as f:
        optuna_params = json.load(f)
    # Handle 'lr' vs 'learning_rate' naming convention
    if "lr" in optuna_params:
        optuna_params["learning_rate"] = optuna_params.pop("lr")
    # Don't override n_metrics from Optuna (we'll set it dynamically)
    optuna_params.pop("n_metrics", None)
    CONFIG.update(optuna_params)
    print("Loaded and merged hyperparameters from Optuna")
else:
    print("Using default hyperparameters")

EXCLUDED_SCENARIOS = {
    "20220822_nezha_22", "20220822_nezha_23",
    "20230130_nezha_15", "20230130_nezha_16",
    "20220822_nezha_14", "20220823_nezha_21", "20220823_nezha_24",
}

NEZHA_DAYS = ["20220822", "20220823", "20230129", "20230130"]

print(f"\nConfig: d_model={CONFIG.get('d_model')}, n_layers={CONFIG.get('n_layers')}, dropout={CONFIG.get('dropout'):.3f}")
print(f"        lr={CONFIG.get('learning_rate'):.6f}, temperature={CONFIG.get('temperature'):.3f}")
print(f"        n_metrics: will be determined dynamically from data")
print(f"Excluded scenarios: {len(EXCLUDED_SCENARIOS)}")
print(f"LODOCV folds: {len(NEZHA_DAYS)} (one per Nezha day)")


Loaded and merged hyperparameters from Optuna

Config: d_model=64, n_layers=3, dropout=0.271
        lr=0.000348, temperature=0.112
        n_metrics: will be determined dynamically from data
Excluded scenarios: 7
LODOCV folds: 4 (one per Nezha day)


In [3]:
# === SCENARIO DISCOVERY ===
NEZHA_PATTERN = re.compile(r"^\d{8}_nezha_\d+$")

def get_nezha_day(scenario_id: str) -> Optional[str]:
    if "_nezha_" in scenario_id: return scenario_id.split("_nezha_")[0]
    return None

def discover_scenarios() -> List[Dict]:
    """Discover only Nezha scenarios (LEMMA excluded)."""
    scenarios = []
    for d in sorted(MULTIMODAL_DIR.iterdir()):
        if not d.is_dir() or d.name in EXCLUDED_SCENARIOS: continue
        if not (d / "manifest.json").exists() or not (d / "ground_truth.json").exists(): continue
        # Only include Nezha scenarios
        if not NEZHA_PATTERN.match(d.name): continue
        scenarios.append({"id": d.name, "dataset": "nezha", "day": get_nezha_day(d.name), "path": d})
    return scenarios

NEZHA_SCENARIOS = discover_scenarios()

print(f"Total Nezha scenarios: {len(NEZHA_SCENARIOS)}")
for day in NEZHA_DAYS:
    count = len([s for s in NEZHA_SCENARIOS if s["day"] == day])
    print(f"  {day}: {count} scenarios")


Total Nezha scenarios: 94
  20220822: 21 scenarios
  20220823: 30 scenarios
  20230129: 28 scenarios
  20230130: 15 scenarios


In [4]:
# === DATA LOADING ===
# METRIC_NAMES will be determined dynamically from manifest.json
# This allows us to use ALL available metrics, not just 7 core ones

def load_scenario_data(scenario_id: str) -> Dict:
    """Load scenario data, including ALL available metrics from manifest.json."""
    multimodal_path = MULTIMODAL_DIR / scenario_id
    metrics_path = METRICS_DIR / scenario_id
    logs_path = LOGS_DIR / scenario_id
    
    with open(multimodal_path / "manifest.json") as f: 
        manifest = json.load(f)
    with open(multimodal_path / "ground_truth.json") as f: 
        gt = json.load(f)
    
    # Load ALL metrics listed in manifest.json (not just 7 core ones)
    metrics = {}
    metric_files = manifest.get("metrics_files", [])
    
    for metric_file_path in metric_files:
        # Extract metric name from path: /path/to/pod_metric_name.parquet -> metric_name
        metric_file = Path(metric_file_path).name
        metric_name = metric_file.replace("pod_", "").replace(".parquet", "")
        
        # Try to load from metrics_path
        metric_path = metrics_path / metric_file
        if metric_path.exists():
            metrics[metric_name] = pd.read_parquet(metric_path)
        else:
            # If not found, try to construct from metric_name
            metric_path = metrics_path / f"pod_{metric_name}.parquet"
            if metric_path.exists():
                metrics[metric_name] = pd.read_parquet(metric_path)
    
    logs_file = logs_path / "logs_service_texts.parquet"
    logs = pd.read_parquet(logs_file) if logs_file.exists() else None
    
    return {
        "metrics": metrics, 
        "logs": logs, 
        "manifest": manifest, 
        "ground_truth": gt,
        "pods": manifest["pods"], 
        "services": manifest["services"], 
        "pod_to_service_idx": manifest["pod_to_service_idx"],
        "n_metrics": manifest.get("n_metrics", len(metrics))  # Store n_metrics from manifest
    }

# Determine n_metrics from first scenario (all should have same number)
if NEZHA_SCENARIOS:
    first_scenario_data = load_scenario_data(NEZHA_SCENARIOS[0]["id"])
    CONFIG["n_metrics"] = first_scenario_data["n_metrics"]
    print(f"\nDetected {CONFIG['n_metrics']} metrics from scenario {NEZHA_SCENARIOS[0]['id']}")
    print(f"Available metrics: {sorted(first_scenario_data['metrics'].keys())}")
else:
    print("WARNING: No Nezha scenarios found! n_metrics cannot be determined.")

print("\nData loading functions defined")



Detected 19 metrics from scenario 20220822_nezha_0
Available metrics: ['cpu_usage_m', 'cpu_usage_rate', 'latency_client_p90', 'latency_client_p95', 'latency_client_p99', 'latency_server_p90', 'latency_server_p95', 'latency_server_p99', 'memory_usage', 'memory_usage_rate', 'network_rx_bytes', 'network_tx_bytes', 'node_cpu_usage_rate', 'node_memory_usage_rate', 'node_network_rx_bytes', 'success_rate', 'syscall_read', 'syscall_write', 'workload_ops']

Data loading functions defined


In [5]:
# === DATASET CLASS ===
class RCAScenarioDataset(Dataset):
    def __init__(self, scenario_ids, window_size=22, mode="train", jitter=3, noise_std=0.0, mask_prob=0.0,
                 norm_mode="per_pod", log_token_mode="unique"):
        """
        Args:
            norm_mode: "global" (normalize all pods together), "per_pod" (normalize each pod separately),
                      "hybrid" (per pod + preserve relative magnitude), "adaptive" (detect structure)
            log_token_mode: "common" (same token for all services without logs), "unique" (unique token per service)
        """
        self.scenario_ids, self.window_size, self.mode = scenario_ids, window_size, mode
        self.jitter = jitter if mode == "train" else 0
        self.noise_std = noise_std if mode == "train" else 0.0
        self.mask_prob = mask_prob if mode == "train" else 0.0
        self.norm_mode = norm_mode
        self.log_token_mode = log_token_mode
        self.data = {sid: load_scenario_data(sid) for sid in scenario_ids}
        self.all_services = sorted(set(s for d in self.data.values() for s in d["services"]))
        self.service_to_idx = {s: i for i, s in enumerate(self.all_services)}
    
    def __len__(self): return len(self.scenario_ids)
    
    def _crop_around_fault(self, n_timesteps, fault_idx):
        if n_timesteps <= self.window_size: return 0, n_timesteps, fault_idx
        half = self.window_size // 2
        offset = random.randint(-self.jitter, self.jitter) if self.jitter > 0 else 0
        start = fault_idx - half + offset
        start = max(max(0, fault_idx - self.window_size + 1), min(min(fault_idx, n_timesteps - self.window_size), start))
        return start, start + self.window_size, fault_idx - start
    
    def _get_metric_type(self, metric_name):
        """Classify metric by type for appropriate normalization.
        
        IMPORTANT: Order matters! Check more specific patterns first.
        """
        name_lower = metric_name.lower()
        
        # CPU in milicores (check BEFORE 'usage' to avoid misclassification)
        if 'cpu' in name_lower and ('m' in name_lower or '_m' in name_lower):
            return 'milicores'
        
        # Latency metrics (log-normal distribution) - check BEFORE 'usage'
        if any(x in name_lower for x in ['latency', 'delay']):
            return 'latency'
        
        # Percentage metrics (0-100%) - must have 'rate' or be explicit percentage
        if 'rate' in name_lower or 'success' in name_lower:
            return 'percentage'
        # Memory/CPU usage RATE (percentage)
        if 'usage_rate' in name_lower or 'utilization' in name_lower:
            return 'percentage'
        
        # Memory usage (MiB) - bytes but not percentage
        if 'memory_usage' in name_lower and 'rate' not in name_lower:
            return 'bytes'
        
        # Byte metrics (can be very large)
        if any(x in name_lower for x in ['bytes', 'network']):
            return 'bytes'
        
        # Count/ops metrics (may have many zeros, use log1p)
        if any(x in name_lower for x in ['ops', 'count', 'workload', 'syscall']):
            return 'count'
        
        # Default: z-score
        return 'default'
    
    def _normalize_metrics(self, m, metric_names):
        """
        Normalize metrics with type-aware normalization per metric.
        Returns normalized metrics and 3D mask (time, pods, metrics).
        """
        # Create 3D mask: (time, pods, metrics) - True if value is valid
        valid_mask_3d = ~np.isnan(m)
        
        # Normalize each metric according to its type
        for f, metric_name in enumerate(metric_names):
            metric_type = self._get_metric_type(metric_name)
            metric_data = m[:, :, f]
            valid = valid_mask_3d[:, :, f]
            
            if valid.sum() == 0:
                continue  # Skip if no valid values
            
            if self.norm_mode == "per_pod":
                # Normalize each pod separately
                for p in range(m.shape[1]):
                    pod_data = metric_data[:, p]
                    pod_valid = valid[:, p]
                    if pod_valid.sum() > 0:
                        if metric_type == 'percentage':
                            # Min-Max to [0, 1] for percentages
                            v = pod_data[pod_valid]
                            v_min, v_max = v.min(), v.max()
                            if v_max > v_min:
                                metric_data[:, p] = (pod_data - v_min) / (v_max - v_min)
                            else:
                                metric_data[:, p] = pod_data - v_min
                        elif metric_type == 'bytes':
                            # Byte metrics - check if log is needed
                            v = pod_data[pod_valid]
                            # If values are very large (max/min > 100), use log
                            # Otherwise, use z-score directly
                            if len(v) > 0 and np.max(v) > 0:
                                ratio = np.max(v) / (np.min(v[v > 0]) if np.any(v > 0) else 1)
                                if ratio > 100:
                                    # Large range - use log normalization
                                    v_log = np.log1p(v)
                                    mean, std = np.mean(v_log), np.std(v_log)
                                    if std > 1e-8:
                                        metric_data[:, p] = (np.log1p(pod_data) - mean) / std
                                    else:
                                        metric_data[:, p] = np.log1p(pod_data) - mean
                                else:
                                    # Small range - use z-score directly
                                    mean, std = np.mean(v), np.std(v)
                                    if std > 1e-8:
                                        metric_data[:, p] = (pod_data - mean) / std
                                    else:
                                        metric_data[:, p] = pod_data - mean
                        elif metric_type == 'latency':
                            # Latency metrics - use log1p (handles zeros) + z-score
                            # Many latency metrics have zeros (no requests in that period)
                            v = pod_data[pod_valid]
                            v_log = np.log1p(v)  # log1p handles zeros correctly
                            mean, std = np.mean(v_log), np.std(v_log)
                            if std > 1e-8:
                                metric_data[:, p] = (np.log1p(pod_data) - mean) / std
                            else:
                                metric_data[:, p] = np.log1p(pod_data) - mean
                        elif metric_type == 'count':
                            # Count metrics may have many zeros - use log1p + z-score
                            v = pod_data[pod_valid]
                            # Check if there are many zeros (>10%)
                            zeros_pct = 100 * np.sum(v == 0) / len(v) if len(v) > 0 else 0
                            if zeros_pct > 10:
                                # Use log1p for counts with many zeros
                                v_log = np.log1p(v)
                                mean, std = np.mean(v_log), np.std(v_log)
                                if std > 1e-8:
                                    metric_data[:, p] = (np.log1p(pod_data) - mean) / std
                                else:
                                    metric_data[:, p] = np.log1p(pod_data) - mean
                            else:
                                # Z-score for counts without many zeros
                                mean, std = np.mean(v), np.std(v)
                                if std > 1e-8:
                                    metric_data[:, p] = (pod_data - mean) / std
                                else:
                                    metric_data[:, p] = pod_data - mean
                        elif metric_type == 'milicores':
                            # CPU in milicores - check if log-normal distribution
                            # For now, use z-score (can be changed to log if needed)
                            v = pod_data[pod_valid]
                            mean, std = np.mean(v), np.std(v)
                            if std > 1e-8:
                                metric_data[:, p] = (pod_data - mean) / std
                            else:
                                metric_data[:, p] = pod_data - mean
                        else:
                            # Default: z-score
                            v = pod_data[pod_valid]
                            mean, std = np.mean(v), np.std(v)
                            if std > 1e-8:
                                metric_data[:, p] = (pod_data - mean) / std
                            else:
                                metric_data[:, p] = pod_data - mean
            else:
                # Global normalization (all pods together)
                all_valid = metric_data[valid]
                if len(all_valid) > 0:
                    if metric_type == 'percentage':
                        v_min, v_max = all_valid.min(), all_valid.max()
                        if v_max > v_min:
                            metric_data = (metric_data - v_min) / (v_max - v_min)
                        else:
                            metric_data = metric_data - v_min
                    elif metric_type == 'bytes':
                        # Bytes - check if log is needed
                        if len(all_valid) > 0 and np.max(all_valid) > 0:
                            ratio = np.max(all_valid) / (np.min(all_valid[all_valid > 0]) if np.any(all_valid > 0) else 1)
                            if ratio > 100:
                                v_log = np.log1p(all_valid)
                                mean, std = np.mean(v_log), np.std(v_log)
                                if std > 1e-8:
                                    metric_data = (np.log1p(metric_data) - mean) / std
                                else:
                                    metric_data = np.log1p(metric_data) - mean
                            else:
                                mean, std = np.mean(all_valid), np.std(all_valid)
                                if std > 1e-8:
                                    metric_data = (metric_data - mean) / std
                                else:
                                    metric_data = metric_data - mean
                    elif metric_type == 'latency':
                        # Latency - use log1p + z-score
                        v_log = np.log1p(all_valid)
                        mean, std = np.mean(v_log), np.std(v_log)
                        if std > 1e-8:
                            metric_data = (np.log1p(metric_data) - mean) / std
                        else:
                            metric_data = np.log1p(metric_data) - mean
                    elif metric_type == 'count':
                        # Count metrics - check for zeros
                        zeros_pct = 100 * np.sum(all_valid == 0) / len(all_valid) if len(all_valid) > 0 else 0
                        if zeros_pct > 10:
                            v_log = np.log1p(all_valid)
                            mean, std = np.mean(v_log), np.std(v_log)
                            if std > 1e-8:
                                metric_data = (np.log1p(metric_data) - mean) / std
                            else:
                                metric_data = np.log1p(metric_data) - mean
                        else:
                            mean, std = np.mean(all_valid), np.std(all_valid)
                            if std > 1e-8:
                                metric_data = (metric_data - mean) / std
                            else:
                                metric_data = metric_data - mean
                    elif metric_type == 'bytes':
                        # Bytes - check if log is needed
                        if len(all_valid) > 0 and np.max(all_valid) > 0:
                            ratio = np.max(all_valid) / (np.min(all_valid[all_valid > 0]) if np.any(all_valid > 0) else 1)
                            if ratio > 100:
                                v_log = np.log1p(all_valid)
                                mean, std = np.mean(v_log), np.std(v_log)
                                if std > 1e-8:
                                    metric_data = (np.log1p(metric_data) - mean) / std
                                else:
                                    metric_data = np.log1p(metric_data) - mean
                            else:
                                mean, std = np.mean(all_valid), np.std(all_valid)
                                if std > 1e-8:
                                    metric_data = (metric_data - mean) / std
                                else:
                                    metric_data = metric_data - mean
                    elif metric_type == 'milicores':
                        # Milicores - use z-score
                        mean, std = np.mean(all_valid), np.std(all_valid)
                        if std > 1e-8:
                            metric_data = (metric_data - mean) / std
                        else:
                            metric_data = metric_data - mean
                    else:
                        # Default: z-score
                        mean, std = np.mean(all_valid), np.std(all_valid)
                        if std > 1e-8:
                            metric_data = (metric_data - mean) / std
                        else:
                            metric_data = metric_data - mean
                    m[:, :, f] = metric_data
        
        # Convert NaN to 0 only for invalid values (mask will handle them)
        m = np.nan_to_num(m, nan=0.0)
        
        return m, valid_mask_3d
    
    def _apply_augmentation(self, m, mask_3d):
        """Apply augmentation with 3D mask (time, pods, metrics)."""
        if self.mode != "train": return m, mask_3d
        if self.noise_std > 0: 
            m = m + np.random.normal(0, self.noise_std, m.shape).astype(np.float32) * mask_3d
        if self.mask_prob > 0:
            r = np.random.random(mask_3d.shape) > self.mask_prob
            mask_3d = mask_3d & r
            m = m * mask_3d
        return m, mask_3d
    
    def __getitem__(self, idx):
        sid, data = self.scenario_ids[idx], self.data[self.scenario_ids[idx]]
        gt, fault_idx, rc = data["ground_truth"], data["ground_truth"]["fault_time_idx"], data["ground_truth"]["root_cause_service"]
        first_m = list(data["metrics"].values())[0]
        n_t, pods, n_pods = len(first_m), list(first_m.columns), len(first_m.columns)
        start, end, new_fault = (self._crop_around_fault(n_t, fault_idx) if n_t > self.window_size else (0, n_t, fault_idx))
        
        # Use ALL available metrics (not just hardcoded METRIC_NAMES)
        # Get metric names from the data itself, sorted for consistency
        available_metric_names = sorted(data["metrics"].keys())
        m_list = [data["metrics"][mn].iloc[start:end].values if mn in data["metrics"] else np.full((end-start, n_pods), np.nan) for mn in available_metric_names]
        m = np.stack(m_list, axis=-1).astype(np.float32)
        m, mask_3d = self._normalize_metrics(m, available_metric_names)
        m, mask_3d = self._apply_augmentation(m, mask_3d)
        if m.shape[0] < self.window_size:
            pad = self.window_size - m.shape[0]
            n_metrics = m.shape[2]  # Use actual number of metrics from data
            m = np.concatenate([m, np.zeros((pad, n_pods, n_metrics), dtype=np.float32)])
            mask_3d = np.concatenate([mask_3d, np.zeros((pad, n_pods, n_metrics), dtype=bool)])
        
        # Get services and prepare for potential randomization
        services_orig = data["services"].copy()
        services = services_orig.copy()
        p2s_orig = data["pod_to_service_idx"]
        
        # FIX: Randomize service order during training to prevent idx 0 bias
        # This ensures the model doesn't learn positional bias (e.g., always predicting first service)
        if self.mode == "train":
            import random
            random.shuffle(services)
            # Create permutation mapping: original_idx -> new_idx
            perm_map = {svc: new_idx for new_idx, svc in enumerate(services)}
            # Update pod_to_service_idx to match new service order
            p2s_perm = [perm_map.get(services_orig[old_idx], -1) if 0 <= old_idx < len(services_orig) else -1 for old_idx in p2s_orig]
            p2s_orig = p2s_perm  # Use permuted mapping
        
        # Build log texts (one aggregated text per service)
        # Support different token modes for ablation study
        log_texts = []
        if data["logs"] is not None:
            logs_df = data["logs"].iloc[start:end]
            for service in services_orig:  # Use original order to index logs
                if service in logs_df.columns:
                    service_logs = logs_df[service].fillna("").tolist()
                    combined = " | ".join([l for l in service_logs if l.strip()])
                    if combined.strip():
                        log_texts.append(combined[:512])
                    else:
                        # Empty logs: use token based on mode
                        if self.log_token_mode == "common":
                            log_texts.append("[N_LGS_TKN-LEZSHA]")
                        else:  # unique
                            log_texts.append(f"[N_LGS_{service}]")
                else:
                    # Service not in logs: use token based on mode
                    if self.log_token_mode == "common":
                        log_texts.append("[N_LGS_TKN-LEZSHA]")
                    else:  # unique
                        log_texts.append(f"[N_LGS_{service}]")
        else:
            # No logs at all: use token based on mode
            if self.log_token_mode == "common":
                log_texts = ["[N_LGS_TKN-LEZSHA]"] * len(services_orig)
            else:  # unique
                log_texts = [f"[N_LGS_{svc}]" for svc in services_orig]
        
        # Reorder log_texts to match permuted services (if in train mode)
        if self.mode == "train":
            log_texts = [log_texts[services_orig.index(svc)] for svc in services]
        
        # Label: index of root cause service in current service list (after potential permutation)
        # FIX: Raise error instead of defaulting to 0 to prevent bias towards idx 0
        if rc not in services:
            raise ValueError(f"Root cause service '{rc}' not found in services list for scenario {sid}")
        label = services.index(rc)
        
        return {"scenario_id": sid, "metrics": torch.from_numpy(m), "metrics_mask": torch.from_numpy(mask_3d),
                "log_texts": log_texts, "pods": pods, "services": services,
                "pod_to_service_idx": p2s_orig, "fault_idx": new_fault, "label": label, "rc_service": rc}
print("Dataset class defined")


Dataset class defined


In [6]:
# === MODEL COMPONENTS ===
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2], pe[:, 1::2] = torch.sin(pos * div), torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x): return self.dropout(x + self.pe[:, :x.size(1), :])

class MetricEncoder(nn.Module):
    def __init__(self, n_metrics=4, d_model=128, n_heads=4, n_layers=2, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.input_proj = nn.Linear(n_metrics, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=d_model*4, dropout=dropout, activation="gelu", batch_first=True)
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.output_proj = nn.Linear(d_model, d_model)
    def forward(self, x, mask=None):
        """
        Forward pass with 3D mask support (time, pods, metrics).
        Args:
            x: (B, W, P, M) - batch, window, pods, metrics
            mask: (B, W, P, M) - 3D mask, True for valid values
        """
        B, W, P, M = x.shape
        x = x.permute(0,2,1,3).reshape(B*P, W, M)
        
        # Handle 3D mask: convert to 2D for transformer (time dimension)
        # A timestep is valid if ANY metric is valid for that pod
        kpm = None
        if mask is not None:
            # mask: (B, W, P, M) -> (B*P, W) - True if timestep has ANY valid metric
            mask_2d = mask.permute(0,2,1,3).reshape(B*P, W, M)  # (B*P, W, M)
            mask_2d = mask_2d.any(dim=-1)  # (B*P, W) - True if any metric is valid
            kpm = ~mask_2d  # Key padding mask: True to mask out
        
        x = self.transformer(self.pos_encoder(self.input_proj(x)), src_key_padding_mask=kpm)
        
        # Weighted average pooling using 3D mask
        if mask is not None:
            # mask: (B, W, P, M) -> (B*P, W) - weight by number of valid metrics per timestep
            mask_2d = mask.permute(0,2,1,3).reshape(B*P, W, M)  # (B*P, W, M)
            weights = mask_2d.float().sum(dim=-1)  # (B*P, W) - number of valid metrics per timestep
            weights = weights.unsqueeze(-1)  # (B*P, W, 1)
            x = (x * weights).sum(dim=1) / weights.sum(dim=1).clamp(min=1)
        else: 
            x = x.mean(dim=1)
        return self.output_proj(x).reshape(B, P, self.d_model)

class LogEncoder(nn.Module):
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
        super().__init__()
        self.model = SentenceTransformer(model_name)
        self.embed_dim = 384
        self._no_logs_emb = None
    @property
    def no_logs_embedding(self):
        if self._no_logs_emb is None:
            with torch.no_grad(): self._no_logs_emb = torch.from_numpy(self.model.encode("[N-LGS-DST-TKN-LEZHSA]", convert_to_numpy=True))
        return self._no_logs_emb
    def encode_texts(self, texts): return self.model.encode(texts, convert_to_tensor=True)
    def forward(self, batch):
        embs, max_s = [], max(len(t) for t in batch)
        for texts in batch:
            e = self.encode_texts(texts)
            if len(texts) < max_s: e = torch.cat([e, self.no_logs_embedding.unsqueeze(0).expand(max_s-len(texts),-1).to(e.device)])
            embs.append(e)
        return torch.stack(embs)

class FusionLayer(nn.Module):
    def __init__(self, metric_dim=128, log_dim=384, output_dim=256, dropout=0.1):
        super().__init__()
        self.fusion = nn.Sequential(nn.Linear(2*metric_dim+log_dim, output_dim*2), nn.LayerNorm(output_dim*2), nn.GELU(), nn.Dropout(dropout), nn.Linear(output_dim*2, output_dim), nn.LayerNorm(output_dim))
        self.metric_dim = metric_dim
    def forward(self, me, le, p2s):
        B, S, dev = me.size(0), le.size(1), me.device
        out = []
        for b in range(B):
            sme = []
            for s in range(S):
                pi = [i for i, si in enumerate(p2s[b]) if si == s]
                if pi:
                    pm = me[b, pi]
                    sme.append(torch.cat([pm.max(0).values, pm.mean(0)], -1))
                else: sme.append(torch.zeros(2*self.metric_dim, device=dev))
            out.append(self.fusion(torch.cat([torch.stack(sme), le[b]], -1)))
        return torch.stack(out)

class SimilarityClassifier(nn.Module):
    def __init__(self, input_dim=256, embed_dim=384, temperature=0.1):
        super().__init__()
        self.proj = nn.Sequential(nn.Linear(input_dim, embed_dim), nn.LayerNorm(embed_dim))
        self.temperature = temperature
        self._cache = {}
    def get_svc_emb(self, svcs, le):
        k = tuple(svcs)
        if k not in self._cache: self._cache[k] = F.normalize(le.encode_texts([f"microservice: {s}" for s in svcs]), dim=-1)
        return self._cache[k]
    def forward(self, fe, sb, le):
        B, dev, ms = fe.size(0), fe.device, fe.size(1)
        pe = F.normalize(self.proj(fe), dim=-1)
        out = []
        for b in range(B):
            svcs, ns = sb[b], len(sb[b])
            se = self.get_svc_emb(svcs, le).to(dev)
            sim = (pe[b,:ns] * se).sum(-1) / self.temperature
            if ns < ms: sim = torch.cat([sim, torch.full((ms-ns,), -100.0, device=dev)])
            out.append(sim)
        return torch.stack(out)

class MultimodalRCAModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.metric_encoder = MetricEncoder(config["n_metrics"], config["d_model"], config["n_heads"], config["n_layers"], config["dropout"])
        self.log_encoder = LogEncoder()
        for p in self.log_encoder.model.parameters(): p.requires_grad = False
        self.fusion = FusionLayer(config["d_model"], config["log_embed_dim"], config["fusion_dim"], config["dropout"])
        self.classifier = SimilarityClassifier(config["fusion_dim"], config["log_embed_dim"], config["temperature"])
    def forward(self, m, mm, lt, p2s, sb):
        me = self.metric_encoder(m, mm)
        le = self.log_encoder(lt).to(m.device)
        return self.classifier(self.fusion(me, le, p2s), sb, self.log_encoder)

print("Model components defined")


Model components defined


In [7]:
# === TRAINING UTILITIES ===
def rca_collate_fn(batch):
    """Collate function with 3D mask support (time, pods, metrics)."""
    mp, mm, ws, nm = max(s["metrics"].size(1) for s in batch), max(len(s["services"]) for s in batch), batch[0]["metrics"].size(0), batch[0]["metrics"].size(2)
    metrics_p, masks_p, p2s_p = [], [], []
    for s in batch:
        m, mask_3d, p2s, np_ = s["metrics"], s["metrics_mask"], s["pod_to_service_idx"], s["metrics"].size(1)
        if np_ < mp:
            # Pad metrics and mask: (W, P, M) -> (W, max_pods, M)
            m = torch.cat([m, torch.zeros(ws, mp-np_, nm)], 1)
            mask_3d = torch.cat([mask_3d, torch.zeros(ws, mp-np_, nm, dtype=torch.bool)], 1)
            p2s = p2s + [-1]*(mp-np_)
        metrics_p.append(m); masks_p.append(mask_3d); p2s_p.append(p2s)
    return {"scenario_ids": [s["scenario_id"] for s in batch], "metrics": torch.stack(metrics_p), "metrics_mask": torch.stack(masks_p),
            "log_texts": [s["log_texts"] for s in batch], "pods": [s["pods"] for s in batch], "services": [s["services"] for s in batch],
            "pod_to_service_idx": p2s_p, "labels": torch.tensor([s["label"] for s in batch], dtype=torch.long), "rc_services": [s["rc_service"] for s in batch]}

def train_epoch(model, dl, opt, cfg, dev):
    model.train()
    tl, nb = 0, 0
    for b in dl:
        m, mm, l = b["metrics"].to(dev), b["metrics_mask"].to(dev), b["labels"].to(dev)
        logits = model(m, mm, b["log_texts"], b["pod_to_service_idx"], b["services"])
        loss = F.cross_entropy(logits, l, label_smoothing=cfg.get("label_smoothing", 0.0))
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.get("grad_clip_norm", 1.0))
        opt.step(); tl += loss.item(); nb += 1
    return tl / nb

def evaluate(model, dl, dev):
    model.eval()
    preds, labels, scores, sids, rcs = [], [], [], [], []
    with torch.no_grad():
        for b in dl:
            m, mm = b["metrics"].to(dev), b["metrics_mask"].to(dev)
            logits = model(m, mm, b["log_texts"], b["pod_to_service_idx"], b["services"])
            preds.extend(logits.argmax(-1).cpu().tolist()); labels.extend(b["labels"].tolist())
            scores.extend(logits.cpu().tolist()); sids.extend(b["scenario_ids"]); rcs.extend(b["rc_services"])
    acc1 = sum(p==l for p,l in zip(preds, labels)) / len(preds)
    acc3 = sum(l in sorted(range(len(s)), key=lambda i: s[i], reverse=True)[:3] for s,l in zip(scores, labels)) / len(labels)
    mrr = sum(1.0/(sorted(range(len(s)), key=lambda i: s[i], reverse=True).index(l)+1) for s,l in zip(scores, labels)) / len(labels)
    return {"acc@1": acc1, "acc@3": acc3, "mrr": mrr, "predictions": list(zip(sids, preds, labels, rcs))}

print("Training utilities defined")


Training utilities defined


In [8]:
# === LODOCV MAIN LOOP ===
def run_lodocv(config, nezha_days, nezha_scenarios, device):
    """Leave-One-Day-Out Cross-Validation.
    
    Only uses Nezha scenarios (LEMMA excluded).
    Uses type-aware normalization and 3D masking.
    """
    results_per_fold, all_predictions = [], []
    
    print("="*70)
    print("LEAVE-ONE-DAY-OUT CROSS-VALIDATION")
    print("="*70)
    print(f"\nFolds: {len(nezha_days)}")
    print(f"Using only Nezha scenarios (LEMMA excluded)")
    print(f"Metrics: {config.get('n_metrics', 'unknown')}")
    print(f"Normalization: Type-aware (per metric type)")
    print(f"Masking: 3D (time, pods, metrics)")
    print(f"Epochs per fold: {config['epochs']}\n")
    
    for fold_idx, test_day in enumerate(nezha_days):
        print(f"\n{'='*70}")
        print(f"FOLD {fold_idx + 1}/{len(nezha_days)}: Test on {test_day}")
        print(f"{'='*70}")
        
        test_ids = [s["id"] for s in nezha_scenarios if s["day"] == test_day]
        train_ids = [s["id"] for s in nezha_scenarios if s["day"] != test_day]
        print(f"Train: {len(train_ids)} scenarios, Test: {len(test_ids)} scenarios")
        
        train_ds = RCAScenarioDataset(train_ids, window_size=config["window_size"], mode="train",
                                      jitter=config.get("jitter", 3), noise_std=config.get("noise_std", 0.1),
                                      mask_prob=config.get("mask_prob", 0.1),
                                      norm_mode="per_pod", log_token_mode="unique")
        test_ds = RCAScenarioDataset(test_ids, window_size=config["window_size"], mode="eval",
                                     norm_mode="per_pod", log_token_mode="unique")
        
        train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True, collate_fn=rca_collate_fn)
        test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=rca_collate_fn)
        
        model = MultimodalRCAModel(config).to(device)
        optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=config["learning_rate"], weight_decay=config["weight_decay"])
        
        best_mrr, best_state = 0.0, None
        for epoch in range(config["epochs"]):
            if epoch < config.get("warmup_epochs", 5):
                for pg in optimizer.param_groups: pg['lr'] = config["learning_rate"] * (epoch + 1) / config.get("warmup_epochs", 5)
            train_loss = train_epoch(model, train_loader, optimizer, config, device)
            if (epoch + 1) % 10 == 0 or epoch == config["epochs"] - 1:
                results = evaluate(model, test_loader, device)
                if results["mrr"] > best_mrr: best_mrr, best_state = results["mrr"], {k: v.cpu().clone() for k, v in model.state_dict().items()}
                print(f"  Epoch {epoch+1:2d}: loss={train_loss:.4f}, Acc@1={results['acc@1']:.3f}, Acc@3={results['acc@3']:.3f}, MRR={results['mrr']:.3f}")
        
        if best_state: model.load_state_dict(best_state)
        final_results = evaluate(model, test_loader, device)
        print(f"\n  Final: Acc@1={final_results['acc@1']:.3f}, Acc@3={final_results['acc@3']:.3f}, MRR={final_results['mrr']:.3f}")
        
        results_per_fold.append({"fold": fold_idx+1, "test_day": test_day, "train_size": len(train_ids), "test_size": len(test_ids),
                                 "acc@1": final_results["acc@1"], "acc@3": final_results["acc@3"], "mrr": final_results["mrr"]})
        all_predictions.extend(final_results["predictions"])
        
        del model, optimizer, train_ds, test_ds
        if torch.cuda.is_available(): torch.cuda.empty_cache()
    
    return results_per_fold, all_predictions

print("LODOCV function defined")


LODOCV function defined


In [None]:
# === RUN LODOCV ===
start_time = datetime.now()
print(f"Starting LODOCV at {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

fold_results, all_predictions = run_lodocv(
    config=CONFIG, nezha_days=NEZHA_DAYS, nezha_scenarios=NEZHA_SCENARIOS,
    device=DEVICE
)

end_time = datetime.now()
print(f"\nCompleted at {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total duration: {end_time - start_time}")


Starting LODOCV at 2026-02-07 15:08:22
LEAVE-ONE-DAY-OUT CROSS-VALIDATION

Folds: 4
Using only Nezha scenarios (LEMMA excluded)
Metrics: 19
Normalization: Type-aware (per metric type)
Masking: 3D (time, pods, metrics)
Epochs per fold: 50


FOLD 1/4: Test on 20220822
Train: 73 scenarios, Test: 21 scenarios


  output = torch._nested_tensor_from_mask(


  Epoch 10: loss=2.2682, Acc@1=0.190, Acc@3=0.714, MRR=0.487
  Epoch 20: loss=1.4513, Acc@1=0.238, Acc@3=0.810, MRR=0.537
  Epoch 30: loss=0.9374, Acc@1=0.333, Acc@3=0.810, MRR=0.581
  Epoch 40: loss=0.8043, Acc@1=0.286, Acc@3=0.762, MRR=0.557
  Epoch 50: loss=0.7697, Acc@1=0.381, Acc@3=0.762, MRR=0.601

  Final: Acc@1=0.381, Acc@3=0.762, MRR=0.601

FOLD 2/4: Test on 20220823
Train: 64 scenarios, Test: 30 scenarios
  Epoch 10: loss=2.0754, Acc@1=0.633, Acc@3=0.733, MRR=0.728


In [None]:
# === RESULTS SUMMARY ===
print("\n" + "="*70)
print("LODOCV RESULTS SUMMARY")
print("="*70)

print("\nPer-Fold Results:")
print("-"*70)
print(f"{'Fold':<6} {'Test Day':<12} {'Train':<8} {'Test':<8} {'Acc@1':<8} {'Acc@3':<8} {'MRR':<8}")
print("-"*70)
for r in fold_results:
    print(f"{r['fold']:<6} {r['test_day']:<12} {r['train_size']:<8} {r['test_size']:<8} "
          f"{r['acc@1']:<8.3f} {r['acc@3']:<8.3f} {r['mrr']:<8.3f}")

acc1_vals = [r["acc@1"] for r in fold_results]
acc3_vals = [r["acc@3"] for r in fold_results]
mrr_vals = [r["mrr"] for r in fold_results]

print("-"*70)
print(f"{'Mean':<6} {'':<12} {'':<8} {'':<8} {np.mean(acc1_vals):<8.3f} {np.mean(acc3_vals):<8.3f} {np.mean(mrr_vals):<8.3f}")
print(f"{'Std':<6} {'':<12} {'':<8} {'':<8} {np.std(acc1_vals):<8.3f} {np.std(acc3_vals):<8.3f} {np.std(mrr_vals):<8.3f}")
print("="*70)

print("\n" + "="*70)
print("PUBLICATION-READY RESULTS")
print("="*70)
print(f"\nAcc@1: {np.mean(acc1_vals):.3f} +/- {np.std(acc1_vals):.3f}")
print(f"Acc@3: {np.mean(acc3_vals):.3f} +/- {np.std(acc3_vals):.3f}")
print(f"MRR:   {np.mean(mrr_vals):.3f} +/- {np.std(mrr_vals):.3f}")

# Save results
results_to_save = {
    "config": {k: v for k, v in CONFIG.items() if not callable(v)},
    "fold_results": fold_results,
    "summary": {"acc@1_mean": float(np.mean(acc1_vals)), "acc@1_std": float(np.std(acc1_vals)),
                "acc@3_mean": float(np.mean(acc3_vals)), "acc@3_std": float(np.std(acc3_vals)),
                "mrr_mean": float(np.mean(mrr_vals)), "mrr_std": float(np.std(mrr_vals))},
    "timestamp": datetime.now().isoformat(),
}
with open(BASE_DIR / "lodocv_results.json", "w") as f:
    json.dump(results_to_save, f, indent=2, default=str)
print(f"\nResults saved to {BASE_DIR / 'lodocv_results.json'}")



LODOCV RESULTS SUMMARY

Per-Fold Results:
----------------------------------------------------------------------
Fold   Test Day     Train    Test     Acc@1    Acc@3    MRR     
----------------------------------------------------------------------
1      20220822     73       21       0.381    0.762    0.601   
2      20220823     64       30       0.667    0.767    0.746   
3      20230129     66       28       0.357    0.500    0.494   
4      20230130     79       15       0.267    0.400    0.400   
----------------------------------------------------------------------
Mean                                  0.418    0.607    0.560   
Std                                   0.150    0.161    0.129   

PUBLICATION-READY RESULTS

Acc@1: 0.418 +/- 0.150
Acc@3: 0.607 +/- 0.161
MRR:   0.560 +/- 0.129

Results saved to /root/lemm/lodocv_results.json


In [None]:
# === ERROR ANALYSIS ===
print("\n" + "="*70)
print("ERROR ANALYSIS")
print("="*70)

correct = [(sid, pred, label, rc) for sid, pred, label, rc in all_predictions if pred == label]
wrong = [(sid, pred, label, rc) for sid, pred, label, rc in all_predictions if pred != label]

print(f"\nTotal predictions: {len(all_predictions)}")
print(f"Correct: {len(correct)} ({100*len(correct)/len(all_predictions):.1f}%)")
print(f"Wrong:   {len(wrong)} ({100*len(wrong)/len(all_predictions):.1f}%)")

if wrong:
    print("\nWrong predictions (first 20):")
    for sid, pred, label, rc in wrong[:20]:
        print(f"  {sid}: predicted idx {pred}, actual {rc} (idx {label})")



ERROR ANALYSIS

Total predictions: 94
Correct: 42 (44.7%)
Wrong:   52 (55.3%)

Wrong predictions (first 20):
  20220822_nezha_0: predicted idx 0, actual frontend (idx 5)
  20220822_nezha_11: predicted idx 2, actual currencyservice (idx 3)
  20220822_nezha_12: predicted idx 0, actual emailservice (idx 4)
  20220822_nezha_13: predicted idx 0, actual emailservice (idx 4)
  20220822_nezha_16: predicted idx 0, actual productcatalogservice (idx 7)
  20220822_nezha_17: predicted idx 0, actual recommendationservice (idx 8)
  20220822_nezha_19: predicted idx 0, actual recommendationservice (idx 8)
  20220822_nezha_20: predicted idx 0, actual shippingservice (idx 9)
  20220822_nezha_21: predicted idx 0, actual shippingservice (idx 9)
  20220822_nezha_3: predicted idx 0, actual frontend (idx 5)
  20220822_nezha_4: predicted idx 7, actual cartservice (idx 1)
  20220822_nezha_7: predicted idx 0, actual checkoutservice (idx 2)
  20220822_nezha_9: predicted idx 0, actual checkoutservice (idx 2)
  20