# SENTINEL AI - Full Dataset Training (Google Colab)

This notebook trains the malware detection model on the **complete EMBER dataset** using Google Colab's resources.

## Prerequisites
1. Upload `train.parquet` to your Google Drive
2. (Optional) Upload `test.parquet` if you have a separate test set

##  Estimated Time: 30-60 minutes

## Step 1: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Step 2: Install Dependencies

In [None]:
!pip install lightgbm pandas pyarrow scikit-learn joblib -q
print("✓ Dependencies installed")

## Step 3: GPU Configuration Check

In [None]:
# Check GPU availability
!nvidia-smi

import torch
if torch.cuda.is_available():
    print(f"✓ GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"  VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠ No GPU detected - Training will use CPU (slower)")

## Step 4: Configure Paths

**IMPORTANT:** Update these paths to match where you uploaded your data in Google Drive

In [None]:
import os

# Update these paths!
TRAIN_DATA_PATH = '/content/drive/MyDrive/InfoSec_Project/data/train.parquet'
TEST_DATA_PATH = '/content/drive/MyDrive/InfoSec_Project/data/test.parquet'  # Optional
MODEL_OUTPUT_PATH = '/content/drive/MyDrive/InfoSec_Project/models/classifier.pkl'
METADATA_OUTPUT_PATH = '/content/drive/MyDrive/InfoSec_Project/models/model_metadata.json'

# Create output directory
os.makedirs(os.path.dirname(MODEL_OUTPUT_PATH), exist_ok=True)

print("✓ Paths configured")
print(f"  Train data: {TRAIN_DATA_PATH}")
print(f"  Model output: {MODEL_OUTPUT_PATH}")

## Step 5: Load Training Data

In [None]:
import pandas as pd
import numpy as np
import gc

print("[*] Loading training data...")
df_train = pd.read_parquet(TRAIN_DATA_PATH, engine='pyarrow')

print(f"[*] Loaded {len(df_train):,} samples")
print(f"[*] Features: {len(df_train.columns)-1}")

# Memory optimization: float64 → float32
print("[*] Optimizing memory (float64 → float32)...")
float_cols = df_train.select_dtypes(include=['float64']).columns
df_train[float_cols] = df_train[float_cols].astype('float32')

# Show memory usage
mem_usage = df_train.memory_usage(deep=True).sum() / 1e9
print(f"✓ Dataset loaded: {mem_usage:.2f} GB in memory")

df_train.head()

## Step 6: Prepare Training & Test Sets

In [None]:
# Identify target column
target_col = 'label'
if 'label' not in df_train.columns:
    candidates = [c for c in df_train.columns if 'lab' in c.lower()]
    if candidates:
        target_col = candidates[0]
    else:
        target_col = df_train.columns[-1]

print(f"[*] Target column: {target_col}")

# Split features and labels
y_train = df_train[target_col].astype('int32')
X_train = df_train.drop(columns=[target_col])

del df_train
gc.collect()

print(f"\n[*] Training set shape: {X_train.shape}")
print(f"[*] Class distribution:")
print(y_train.value_counts())

# Load or split test data
if os.path.exists(TEST_DATA_PATH):
    print(f"\n[*] Loading test data from {TEST_DATA_PATH}...")
    df_test = pd.read_parquet(TEST_DATA_PATH, engine='pyarrow')
    
    # Optimize test data
    float_cols = df_test.select_dtypes(include=['float64']).columns
    df_test[float_cols] = df_test[float_cols].astype('float32')
    
    y_test = df_test[target_col].astype('int32')
    X_test = df_test.drop(columns=[target_col])
    del df_test
    gc.collect()
else:
    print("\n[*] Test file not found. Splitting training data (80/20)...")
    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(
        X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
    )

print(f"[*] Test set shape: {X_test.shape}")
print("\n✓ Data prepared for training")

## Step 7: Train Model (This will take 30-60 minutes)

In [None]:
import lightgbm as lgb
from datetime import datetime

print("="*70)
print("TRAINING MALWARE DETECTION MODEL")
print("="*70)
print(f"Samples: {len(X_train):,}")
print(f"Features: {X_train.shape[1]}")
print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("="*70 + "\n")

# Configure LightGBM with GPU if available
device = 'gpu' if torch.cuda.is_available() else 'cpu'
print(f"[*] Training device: {device.upper()}")

if device == 'gpu':
    clf = lgb.LGBMClassifier(
        boosting_type='gbdt',
        device='gpu',
        gpu_platform_id=0,
        gpu_device_id=0,
        n_estimators=400,
        learning_rate=0.03,
        num_leaves=256,
        max_depth=10,
        min_child_samples=20,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_alpha=0.1,
        reg_lambda=0.1,
        random_state=42,
        n_jobs=-1,
        verbose=1
    )
else:
    # CPU fallback
    clf = lgb.LGBMClassifier(
        boosting_type='gbdt',
        n_estimators=300,  # Slightly fewer for CPU speed
        learning_rate=0.05,
        num_leaves=128,
        max_depth=8,
        min_child_samples=20,
        subsample=0.8,
        colsample_bytree=0.8,
        random_state=42,
        n_jobs=-1,
        verbose=1
    )

# Train with early stopping and evaluation
clf.fit(
    X_train, y_train,
    eval_set=[(X_test, y_test)],
    eval_metric='binary_logloss',
    callbacks=[
        lgb.log_evaluation(period=50),
        lgb.early_stopping(stopping_rounds=50, verbose=True)
    ]
)

print(f"\n[✓] Training completed: {datetime.now().strftime('%H:%M:%S')}")

# Free training data
del X_train, y_train
gc.collect()

## Step 8: Evaluate Model

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.metrics import precision_score, recall_score, f1_score

print("="*70)
print("MODEL EVALUATION")
print("="*70 + "\n")

# Predictions
y_pred = clf.predict(X_test)
y_pred_proba = clf.predict_proba(X_test)[:, 1]

# Accuracy
acc = accuracy_score(y_test, y_pred)
print(f"{'='*70}")
print(f"ACCURACY: {acc * 100:.2f}%")
print(f"{'='*70}\n")

# Classification Report
print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=['Benign', 'Malware']))

# Confusion Matrix
print("\nConfusion Matrix:")
cm = confusion_matrix(y_test, y_pred)
print(cm)
print(f"\nTrue Negatives: {cm[0][0]:,}")
print(f"False Positives: {cm[0][1]:,}")
print(f"False Negatives: {cm[1][0]:,}")
print(f"True Positives: {cm[1][1]:,}")

# Additional Metrics
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

print(f"\nPrecision: {precision*100:.2f}%")
print(f"Recall: {recall*100:.2f}%")
print(f"F1-Score: {f1*100:.2f}%")

## Step 9: Save Model

In [None]:
import joblib
import json

print(f"[*] Saving model to {MODEL_OUTPUT_PATH}...")
joblib.dump(clf, MODEL_OUTPUT_PATH)

# Save metadata
metadata = {
    'accuracy': float(acc),
    'precision': float(precision),
    'recall': float(recall),
    'f1_score': float(f1),
    'training_samples': len(y_test) * 5,  # Approximate (80/20 split)
    'test_samples': len(y_test),
    'features': X_test.shape[1],
    'full_dataset': True,
    'trained_on': 'GoogleColab',
    'device': device,
    'timestamp': datetime.now().isoformat()
}

with open(METADATA_OUTPUT_PATH, 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"[*] Metadata saved to {METADATA_OUTPUT_PATH}")
print("\n" + "="*70)
print("✓ MODEL TRAINING COMPLETE!")
print("="*70)
print(f"\nModel saved to Google Drive: {MODEL_OUTPUT_PATH}")
print("\nNext steps:")
print("1. Download the model from Google Drive")
print("2. Place it in your local models/ directory")
print("3. Test with scanner_engine.py")

## Optional: Feature Importance Analysis

In [None]:
import matplotlib.pyplot as plt

# Get feature importances
feature_importance = pd.DataFrame({
    'feature': X_test.columns,
    'importance': clf.feature_importances_
}).sort_values('importance', ascending=False)

# Plot top 20 features
plt.figure(figsize=(10, 8))
plt.barh(feature_importance.head(20)['feature'], feature_importance.head(20)['importance'])
plt.xlabel('Importance')
plt.title('Top 20 Most Important Features')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

print("Top 10 Features:")
print(feature_importance.head(10))