<a href="https://colab.research.google.com/github/kalyanchakri02/ml-latest/blob/main/feature_training_catboost_gpu_repair.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install catboost

import pandas as pd
import numpy as np
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# ==========================================
# STEP 1: GENERATE MULTI-CLASS REPAIR DATA
# ==========================================
np.random.seed(42)
n_samples = 5000

data = {
    'temp_c': np.random.normal(70, 15, n_samples),
    'fan_rpm': np.random.randint(0, 5000, n_samples),
    'voltage': np.random.normal(1.2, 0.1, n_samples),
    'error_count_24h': np.random.poisson(0.5, n_samples),
    'pcie_width': np.random.choice([16, 8, 4, 1], n_samples, p=[0.90, 0.05, 0.03, 0.02]),
    'pcie_gen': np.random.choice([5, 4, 3, 2, 1], n_samples, p=[0.85, 0.05, 0.04, 0.03, 0.03]),
    'retimer_errors': np.random.poisson(0.2, n_samples),
    'gpu_model': np.random.choice(['H100', 'A100', 'V100', 'H200', 'GB200'], n_samples),
    'xid_code': np.random.choice(['NONE', 'XID_31', 'XID_43', 'XID_61', 'XID_79'], n_samples, p=[0.7, 0.1, 0.05, 0.1, 0.05])
}

df = pd.DataFrame(data)

def recommend_repair_action(row):
    """
    Finalized AI Infrastructure Repair Logic
    0: Healthy
    1: Reboot / Power Cycle (Soft Fix)
    2: Reseat / Component Replacement (Mechanical/Signal Fix)
    3: RMA / Unit Replacement (Hard Failure)
    """

    # --- LEVEL 3: RMA (HARD HARDWARE FAILURE) ---

    # 1. Silicon/Memory Death: Persistent uncorrectable ECC errors.
    # If the HBM (High Bandwidth Memory) is faulty, the card is a brick.
    if row.get('unfixable_ecc_errors', 0) > 0 or row['error_count_24h'] > 5:
        return 3 # RMA GPU

    # 2. Critical Integrated Fan Failure: Only for 'Active' (workstation) cards.
    # If the fan is built into the card and it's dead while hot, RMA the unit.
    if row.get('is_active_cooling', False) and row['fan_rpm'] < 500 and row['temp_c'] > 85:
        return 3 # RMA GPU

    # 3. Total Controller Death: Fallen off bus (XID 79) AND width is stuck at x1.
    # This implies the PCIe interface on the GPU silicon has failed electrically.
    if row['xid_code'] == 'XID_79' and row['pcie_width'] <= 1:
        return 3 # RMA GPU


    # --- LEVEL 2: RESEAT / COMPONENT FIX (PHYSICAL/SIGNAL) ---

    # 1. Passive Cooling Failure: For Data Center cards (H100/A100).
    # The fan is in the SERVER chassis. If RPM is 0, replace the CHASSIS fan.
    if not row.get('is_active_cooling', False) and row['fan_rpm'] < 500:
        return 2 # ACTION: Replace Server Chassis Fan

    # 2. PCIe Link Degradation: Card is alive but running at x8 or x4 instead of x16.
    # This is almost always a physical seating issue or a dirty PCIe slot.
    if row['pcie_width'] < 16:
        return 2 # ACTION: Reseat GPU / Clean Slot

    # 3. Retimer / Signal Integrity: High count of retimer errors (signal noise).
    # Indicates a faulty Riser cable, Motherboard trace, or loose Retimer chip.
    if row.get('retimer_errors', 0) > 5:
        return 2 # ACTION: Replace Riser Cable or Motherboard

    # 4. Intermittent Bus Loss: XID 79 but the link width is still high (x16).
    # This means the card is fine, but the connection is 'glitchy'. A reseat fixes this.
    if row['xid_code'] == 'XID_79' and row['pcie_width'] > 1:
        return 2 # ACTION: Reseat GPU


    # --- LEVEL 1: REBOOT / POWER CYCLE (RECOVERABLE) ---

    # 1. Critical Thermal Throttling: Temp > 92C but fans are SPINNING (>500 RPM).
    # This is a workload/airflow issue. Drain the node and let it cool.
    if row['temp_c'] > 92 and row['fan_rpm'] >= 500:
        return 1 # ACTION: Reboot / Drain Node

    # 2. Driver/Firmware Hang: XID 43 (Stopped Processing) or XID 61 (Internal Micro).
    # These are logic hangs. A 'Deep Power Cycle' (Cold Boot) resets the firmware.
    if row['xid_code'] in ['XID_43', 'XID_31', 'XID_61']:
        return 1 # ACTION: Deep Power Cycle


    # --- LEVEL 0: HEALTHY ---
    return 0

df['action_label'] = df.apply(recommend_repair_action, axis=1)
df.to_csv('gpu_rma_logic_data.csv', index=False)

# ==========================================
# STEP 2: TRAIN MULTI-CLASS MODEL
# ==========================================
X = df.drop('action_label', axis=1)
y = df['action_label']
cat_features = ['gpu_model', 'xid_code']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Use MultiClass loss function
model = CatBoostClassifier(
    iterations=500,
    loss_function='MultiClass',
    eval_metric='Accuracy',
    random_seed=42,
    verbose=100
)

model.fit(X_train, y_train, cat_features=cat_features, eval_set=(X_test, y_test))

# ==========================================
# STEP 3: EVALUATE RECOMMENDATIONS
# ==========================================
preds = model.predict(X_test)
target_names = ['No Action', 'Reboot', 'Reseat', 'RMA']
print("\nRepair Recommendation Performance:")
print(classification_report(y_test, preds, target_names=target_names))


# --- 1. Confusion Matrix (The 'Reasoning' Map) ---
plt.figure(figsize=(10, 8))
cm = confusion_matrix(y_test, preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=target_names, yticklabels=target_names)
plt.title('Confusion Matrix: Predicted vs. Actual Repair Actions', fontsize=15)
plt.xlabel('AI Predicted Action', fontsize=12)
plt.ylabel('True Optimal Action', fontsize=12)
plt.savefig('confusion_matrix.png', dpi=300)
plt.show()

# --- 2. Feature Importance (The 'Sensor' Map) ---
# This shows which sensors (telemetry) the AI values most
plt.figure(figsize=(12, 6))
feat_importances = pd.Series(model.get_feature_importance(), index=X.columns)
feat_importances.nlargest(10).sort_values().plot(kind='barh', color='#2ca02c')
plt.title('Telemetry Feature Importance (Root Cause Drivers)', fontsize=15)
plt.xlabel('CatBoost Importance Score', fontsize=12)
plt.grid(axis='x', linestyle='--', alpha=0.7)
plt.savefig('feature_importance.png', dpi=300)
plt.show()

# --- 3. Performance Metrics Bar Chart (Scholarly Summary) ---
from sklearn.metrics import precision_recall_fscore_support
metrics = precision_recall_fscore_support(y_test, preds) # Removed target_names

metrics_df = pd.DataFrame({
    'Metric': ['Precision', 'Recall', 'F1-Score'],
    'No Action': [metrics[0][0], metrics[1][0], metrics[2][0]],
    'Reboot': [metrics[0][1], metrics[1][1], metrics[2][1]],
    'Reseat': [metrics[0][2], metrics[1][2], metrics[2][2]],
    'RMA': [metrics[0][3], metrics[1][3], metrics[2][3]],
}).set_index('Metric').T

metrics_df.plot(kind='bar', figsize=(12, 6), colormap='viridis')
plt.title('Precision, Recall, and F1-Score by Repair Category', fontsize=15)
plt.ylabel('Score (0.0 - 1.0)')
plt.ylim(0, 1.1)
plt.legend(loc='lower right')
plt.xticks(rotation=0)
plt.savefig('performance_metrics.png', dpi=300)
plt.show()