In [None]:
#!/usr/bin/env python3
"""
Knowledge Distillation for IoT Intrusion Detection - Full File Processing
- Reduced model sizes for memory efficiency
- Process entire files at once (no chunking within files)
- Handle all 169 files properly
- Stream one file at a time to avoid RAM overflow
"""

import os
import gc
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import IncrementalPCA
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import kagglehub
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import warnings
warnings.filterwarnings('ignore')

# ==========================================================
# üéÆ GPU CONFIGURATION
# ==========================================================

def setup_gpu():
    """Configure PyTorch to use GPU efficiently"""
    print("=" * 80)
    print("üéÆ GPU Configuration")
    print("=" * 80)

    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"‚úÖ GPU detected: {torch.cuda.get_device_name(0)}")
        print(f"‚úÖ CUDA Version: {torch.version.cuda}")
        print(f"‚úÖ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        torch.backends.cudnn.benchmark = True
        print("‚úÖ cuDNN autotuner enabled")
    else:
        device = torch.device('cpu')
        print("‚ö†Ô∏è  No GPU detected, running on CPU")

    print("=" * 80 + "\n")
    return device

device = setup_gpu()

# ==========================================================
# üßπ MEMORY MANAGEMENT
# ==========================================================

def clear_memory():
    """Aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def print_memory_stats():
    """Print RAM and GPU usage"""
    try:
        import psutil
        process = psutil.Process()
        ram_gb = process.memory_info().rss / 1e9
        print(f"üíæ RAM Usage: {ram_gb:.2f} GB", end="")
    except:
        pass

    if torch.cuda.is_available():
        gpu_gb = torch.cuda.memory_allocated() / 1e9
        print(f" | GPU: {gpu_gb:.2f} GB")
    else:
        print()

# ==========================================================
# üßπ HELPER FUNCTIONS
# ==========================================================

def load_and_clean(path, label_col=None):
    """Load CSV and separate features from labels"""
    df = pd.read_csv(path, low_memory=False)
    df = df.dropna()
    df = df.drop_duplicates()

    if label_col is None:
        label_col = "Label" if "Label" in df.columns else df.columns[-1]

    X = df.drop(columns=[label_col])
    y = df[label_col]

    del df
    gc.collect()

    return X, y

def encode_objects(X):
    """Encode categorical columns and convert to numpy array"""
    for col in X.select_dtypes(include=["object"]).columns:
        try:
            X[col] = LabelEncoder().fit_transform(X[col].astype(str))
        except:
            X[col] = 0
    return X.values.astype(np.float32)

def load_and_process_file(filepath, scaler, pca, label_encoder):
    """Load and process a single file completely"""
    try:
        X, y = load_and_clean(filepath)
        X = encode_objects(X)

        X_scaled = scaler.transform(X)
        X_reduced = pca.transform(X_scaled)
        y_encoded = label_encoder.transform(y.astype(str))

        del X, y, X_scaled
        gc.collect()

        return X_reduced, y_encoded
    except Exception as e:
        print(f"‚ùå Error processing {os.path.basename(filepath)}: {e}")
        return None, None

# ==========================================================
# üì¶ FULL FILE DATASET
# ==========================================================

class FullFileDataset(Dataset):
    """Dataset that holds entire file in memory"""

    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# ==========================================================
# üéì REDUCED PYTORCH MODELS
# ==========================================================

class TeacherLSTM(nn.Module):
    """Teacher Model - [128, 64] (Reduced from [256,128,64])"""

    def __init__(self, input_size, hidden_sizes, num_classes, dropout=0.3):
        super(TeacherLSTM, self).__init__()

        self.lstm1 = nn.LSTM(input_size, hidden_sizes[0], batch_first=True)
        self.dropout1 = nn.Dropout(dropout)

        self.lstm2 = nn.LSTM(hidden_sizes[0], hidden_sizes[1], batch_first=True)
        self.dropout2 = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_sizes[1], 64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        # x shape: (batch, seq_len, features)
        out, _ = self.lstm1(x)
        out = self.dropout1(out)

        out, _ = self.lstm2(out)
        out = self.dropout2(out)

        # Take last timestep
        out = out[:, -1, :]

        out = self.relu1(self.fc1(out))
        out = self.relu2(self.fc2(out))
        out = self.fc3(out)

        return out


class StudentLSTM(nn.Module):
    """Student Model - [32] (Reduced from [32,16])"""

    def __init__(self, input_size, hidden_size, num_classes, dropout=0.2):
        super(StudentLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_size, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.dropout(out)

        # Take last timestep
        out = out[:, -1, :]

        out = self.relu(self.fc1(out))
        out = self.fc2(out)

        return out

# ==========================================================
# üéì KNOWLEDGE DISTILLATION LOSS
# ==========================================================

class DistillationLoss(nn.Module):
    """Combined loss for knowledge distillation"""

    def __init__(self, temperature=4.0, alpha=0.7):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # Hard target loss
        hard_loss = self.ce_loss(student_logits, labels)

        # Soft target loss
        soft_student = torch.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = torch.softmax(teacher_logits / self.temperature, dim=1)

        soft_loss = self.kl_loss(soft_student, soft_teacher) * (self.temperature ** 2)

        # Combined loss
        total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss

        return total_loss

# ==========================================================
# üèãÔ∏è TRAINING FUNCTIONS (FULL FILE AT ONCE)
# ==========================================================

def train_on_file(model, filepath, scaler, pca, label_encoder, optimizer,
                  criterion, device, batch_size=512, is_distillation=False,
                  teacher_model=None):
    """Train on entire file at once"""

    # Load and process entire file
    X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

    if X_file is None:
        return 0

    # Create dataset and dataloader
    dataset = FullFileDataset(X_file, y_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    model.train()
    if teacher_model is not None:
        teacher_model.eval()

    total_loss = 0
    total_samples = 0

    for X_batch, y_batch in dataloader:
        X_batch = X_batch.unsqueeze(1).to(device)  # Add sequence dimension
        y_batch = y_batch.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(X_batch)

        if is_distillation and teacher_model is not None:
            with torch.no_grad():
                teacher_outputs = teacher_model(X_batch)
            loss = criterion(outputs, teacher_outputs, y_batch)
        else:
            loss = criterion(outputs, y_batch)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(y_batch)
        total_samples += len(y_batch)

        del X_batch, y_batch, outputs
        clear_memory()

    # Clean up file data
    del X_file, y_file, dataset, dataloader
    clear_memory()

    return total_loss / total_samples if total_samples > 0 else 0


def evaluate_on_files(model, file_list, scaler, pca, label_encoder,
                      criterion, device, batch_size=512):
    """Evaluate on multiple files"""

    model.eval()
    total_loss = 0
    correct = 0
    total_samples = 0

    for filepath in file_list:
        X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

        if X_file is None:
            continue

        dataset = FullFileDataset(X_file, y_file)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

        with torch.no_grad():
            for X_batch, y_batch in dataloader:
                X_batch = X_batch.unsqueeze(1).to(device)
                y_batch = y_batch.to(device)

                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == y_batch).sum().item()
                total_loss += loss.item() * len(y_batch)
                total_samples += len(y_batch)

                del X_batch, y_batch, outputs
                clear_memory()

        del X_file, y_file, dataset, dataloader
        clear_memory()

    accuracy = correct / total_samples if total_samples > 0 else 0
    avg_loss = total_loss / total_samples if total_samples > 0 else 0

    return avg_loss, accuracy

# ==========================================================
# üìÇ DOWNLOAD & SPLIT DATASET (169 FILES)
# ==========================================================

print("=" * 80)
print("üì• Downloading CIC-IoT-2023 Dataset from Kaggle...")
print("=" * 80)

dataset_dir = kagglehub.dataset_download("akashdogra/cic-iot-2023")
print(f"‚úÖ Dataset downloaded to: {dataset_dir}")

csv_files = sorted([
    os.path.join(dataset_dir, f)
    for f in os.listdir(dataset_dir)
    if f.endswith(".csv")
])

print(f"üìÇ Found {len(csv_files)} CSV files.")

# 60-20-20 split
n_files = len(csv_files)
train_idx = int(n_files * 0.60)
val_idx = int(n_files * 0.80)

train_files = csv_files[:train_idx]
val_files = csv_files[train_idx:val_idx]
test_files = csv_files[val_idx:]

print(f"\nüìä Dataset Split (from {n_files} files):")
print(f"   Training:   {len(train_files)} files")
print(f"   Validation: {len(val_files)} files")
print(f"   Testing:    {len(test_files)} files")

# ==========================================================
# üè∑Ô∏è FIT PREPROCESSING (SCAN ALL TRAINING FILES FOR LABELS)
# ==========================================================

print("\n" + "=" * 80)
print("üè∑Ô∏è  Fitting Preprocessing - Scanning ALL Training Files...")
print("=" * 80)

# CRITICAL FIX: Scan ALL training files to collect ALL unique labels
all_labels = set()
sample_data = []

print(f"Scanning {len(train_files)} training files for all unique labels...")
for i, filepath in enumerate(train_files):
    try:
        # Read only the label column to save memory
        df = pd.read_csv(filepath, low_memory=False)
        label_col = "Label" if "Label" in df.columns else df.columns[-1]

        # Collect all unique labels from this file
        unique_labels = df[label_col].dropna().astype(str).unique()
        all_labels.update(unique_labels)

        print(f"  File {i+1}/{len(train_files)}: {os.path.basename(filepath)} - Found {len(unique_labels)} unique labels (Total: {len(all_labels)})")

        # Sample features from first 10 files only
        if i < 10:
            df_sample = df.head(1000).dropna()
            X = df_sample.drop(columns=[label_col])

            # Encode objects
            for col in X.select_dtypes(include=["object"]).columns:
                try:
                    X[col] = LabelEncoder().fit_transform(X[col].astype(str))
                except:
                    X[col] = 0

            sample_data.append(X.values.astype(np.float32))

        del df
        gc.collect()

    except Exception as e:
        print(f"  ‚ö†Ô∏è  Error reading {os.path.basename(filepath)}: {e}")
        continue

# Convert set to sorted list for consistent encoding
all_labels = sorted(list(all_labels))

# Fit label encoder with ALL labels
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)
n_classes = len(label_encoder.classes_)

print(f"\n‚úÖ LabelEncoder fitted with {n_classes} classes")
print(f"   Classes found: {', '.join(label_encoder.classes_[:10])}{'...' if n_classes > 10 else ''}")

# Fit scaler
scaler = StandardScaler()
for data in sample_data:
    scaler.partial_fit(data)

print(f"‚úÖ Scaler fitted on {len(sample_data)} file samples")

# Fit PCA
n_features = sample_data[0].shape[1]
n_components = min(30, n_features)

pca = IncrementalPCA(n_components=n_components)
for data in sample_data:
    X_scaled = scaler.transform(data)
    pca.partial_fit(X_scaled)

print(f"‚úÖ PCA fitted with {n_components} components (from {n_features} features)")

del all_labels, sample_data
clear_memory()
print_memory_stats()

# ==========================================================
# üéì STAGE 1: TRAIN TEACHER MODEL
# ==========================================================

print("\n" + "=" * 80)
print("üéì STAGE 1: Training Teacher Model")
print("=" * 80)

# Initialize teacher model with REDUCED sizes
teacher_model = TeacherLSTM(
    input_size=n_components,
    hidden_sizes=[128, 64],  # Reduced from [256, 128, 64]
    num_classes=n_classes,
    dropout=0.3
).to(device)

teacher_params = sum(p.numel() for p in teacher_model.parameters())
print(f"\nüèóÔ∏è  Teacher Model: {teacher_params:,} parameters")
print(f"   Architecture: Input({n_components}) ‚Üí LSTM(128) ‚Üí LSTM(64) ‚Üí FC(64) ‚Üí FC(32) ‚Üí Output({n_classes})")

# Optimizer and criterion
teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
teacher_criterion = nn.CrossEntropyLoss()

# Training settings
epochs_teacher = 3  # Train over all files 3 times
batch_size = 512  # Large batch size allowed
files_per_epoch = 20  # Process 20 files per epoch (will cycle through all 101 training files)

best_teacher_acc = 0
patience_counter = 0
patience = 5  # Increased patience

print("\nüöÄ Training Teacher Model...")
print(f"   Batch Size: {batch_size}")
print(f"   Files per Epoch Cycle: {files_per_epoch}")
print(f"   Total Training Files: {len(train_files)}")
print(f"   Epochs: {epochs_teacher}")

for epoch in range(epochs_teacher):
    print(f"\n{'='*80}")
    print(f"TEACHER EPOCH {epoch+1}/{epochs_teacher}")
    print(f"{'='*80}")

    # Select rotating files
    start_idx = (epoch * files_per_epoch) % len(train_files)
    end_idx = min(start_idx + files_per_epoch, len(train_files))
    selected_files = train_files[start_idx:end_idx]

    if len(selected_files) < files_per_epoch and len(train_files) > files_per_epoch:
        remaining = files_per_epoch - len(selected_files)
        selected_files += train_files[:remaining]

    print(f"Training on {len(selected_files)} files (indices {start_idx} to {end_idx})")

    # Train on each file
    epoch_losses = []
    for i, filepath in enumerate(selected_files):
        print(f"\n  üìÇ File {i+1}/{len(selected_files)}: {os.path.basename(filepath)}")

        train_loss = train_on_file(
            teacher_model, filepath, scaler, pca, label_encoder,
            teacher_optimizer, teacher_criterion, device, batch_size=batch_size
        )

        epoch_losses.append(train_loss)
        print(f"     Loss: {train_loss:.4f}")
        print_memory_stats()

    avg_train_loss = np.mean(epoch_losses)

    # Validate on subset of validation files
    print(f"\n  üìä Validating...")
    val_loss, val_acc = evaluate_on_files(
        teacher_model, val_files[:5], scaler, pca, label_encoder,
        teacher_criterion, device, batch_size=batch_size
    )

    print(f"\n  üìà Epoch Summary:")
    print(f"     Avg Train Loss: {avg_train_loss:.4f}")
    print(f"     Val Loss: {val_loss:.4f}")
    print(f"     Val Accuracy: {val_acc:.4f}")

    # Save best model
    if val_acc > best_teacher_acc:
        best_teacher_acc = val_acc
        torch.save(teacher_model.state_dict(), 'teacher_model.pth')
        print(f"  ‚úÖ Best teacher model saved! Val Acc: {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
        break

    clear_memory()

print("\n‚úÖ Teacher Model Training Complete!")
print(f"   Best Validation Accuracy: {best_teacher_acc:.4f}")

# Load best model if it was saved, otherwise keep current
if os.path.exists('teacher_model.pth'):
    teacher_model.load_state_dict(torch.load('teacher_model.pth'))
    print("   Loaded best teacher model from disk")
else:
    print("   ‚ö†Ô∏è  No saved model found, using final epoch weights")

# ==========================================================
# üéí STAGE 2: KNOWLEDGE DISTILLATION - TRAIN STUDENT
# ==========================================================

print("\n" + "=" * 80)
print("üéí STAGE 2: Knowledge Distillation - Training Student Model")
print("=" * 80)

# Initialize student model with REDUCED size
student_model = StudentLSTM(
    input_size=n_components,
    hidden_size=32,  # Single layer, reduced from [32, 16]
    num_classes=n_classes,
    dropout=0.2
).to(device)

student_params = sum(p.numel() for p in student_model.parameters())
reduction_ratio = teacher_params / student_params

print(f"\nüèóÔ∏è  Student Model: {student_params:,} parameters")
print(f"   Architecture: Input({n_components}) ‚Üí LSTM(32) ‚Üí FC(32) ‚Üí Output({n_classes})")
print(f"\nüìä Model Comparison:")
print(f"   Teacher Parameters: {teacher_params:,}")
print(f"   Student Parameters: {student_params:,}")
print(f"   Size Reduction:     {reduction_ratio:.1f}x smaller")

# Optimizer and distillation loss
student_optimizer = optim.Adam(student_model.parameters(), lr=0.001)
distillation_criterion = DistillationLoss(temperature=4.0, alpha=0.7)

epochs_student = 4
best_student_acc = 0
patience_counter = 0

print(f"\nüöÄ Training Student with Knowledge Distillation...")
print(f"   Temperature: {distillation_criterion.temperature}")
print(f"   Alpha (soft target weight): {distillation_criterion.alpha}")
print(f"   Batch Size: {batch_size}")
print(f"   Files per Epoch: {files_per_epoch}")

for epoch in range(epochs_student):
    print(f"\n{'='*80}")
    print(f"STUDENT EPOCH {epoch+1}/{epochs_student}")
    print(f"{'='*80}")

    # Select rotating files
    start_idx = (epoch * files_per_epoch) % len(train_files)
    end_idx = min(start_idx + files_per_epoch, len(train_files))
    selected_files = train_files[start_idx:end_idx]

    if len(selected_files) < files_per_epoch and len(train_files) > files_per_epoch:
        remaining = files_per_epoch - len(selected_files)
        selected_files += train_files[:remaining]

    print(f"Training on {len(selected_files)} files (indices {start_idx} to {end_idx})")

    # Train with distillation
    epoch_losses = []
    for i, filepath in enumerate(selected_files):
        print(f"\n  üìÇ File {i+1}/{len(selected_files)}: {os.path.basename(filepath)}")

        train_loss = train_on_file(
            student_model, filepath, scaler, pca, label_encoder,
            student_optimizer, distillation_criterion, device,
            batch_size=batch_size, is_distillation=True, teacher_model=teacher_model
        )

        epoch_losses.append(train_loss)
        print(f"     Loss: {train_loss:.4f}")
        print_memory_stats()

    avg_train_loss = np.mean(epoch_losses)

    # Validate
    print(f"\n  üìä Validating...")
    val_criterion = nn.CrossEntropyLoss()
    val_loss, val_acc = evaluate_on_files(
        student_model, val_files[:5], scaler, pca, label_encoder,
        val_criterion, device, batch_size=batch_size
    )

    print(f"\n  üìà Epoch Summary:")
    print(f"     Avg Train Loss: {avg_train_loss:.4f}")
    print(f"     Val Loss: {val_loss:.4f}")
    print(f"     Val Accuracy: {val_acc:.4f}")

    # Save best model
    if val_acc > best_student_acc:
        best_student_acc = val_acc
        torch.save(student_model.state_dict(), 'student_model.pth')
        print(f"  ‚úÖ Best student model saved! Val Acc: {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
        break

    clear_memory()

print("\n‚úÖ Student Model Training Complete!")
print(f"   Best Validation Accuracy: {best_student_acc:.4f}")

# Load best model if it exists
if os.path.exists('student_model.pth'):
    student_model.load_state_dict(torch.load('student_model.pth'))
    print("   Loaded best student model from disk")
else:
    print("   ‚ö†Ô∏è  No saved model found, using final epoch weights")

# ==========================================================
# üìà STAGE 3: FINAL EVALUATION
# ==========================================================

print("\n" + "=" * 80)
print("üìà STAGE 3: Final Evaluation on Test Set")
print("=" * 80)

def evaluate_model_detailed(model, model_name, file_list):
    """Evaluate model on test set with detailed metrics"""
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name}...")
    print(f"{'='*60}")

    model.eval()
    y_true_all = []
    y_pred_all = []

    for i, filepath in enumerate(file_list):
        print(f"Processing file {i+1}/{len(file_list)}: {os.path.basename(filepath)}")

        X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

        if X_file is None:
            continue

        dataset = FullFileDataset(X_file, y_file)
        dataloader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=0)

        with torch.no_grad():
            for X_batch, y_batch in dataloader:
                X_batch = X_batch.unsqueeze(1).to(device)

                outputs = model(X_batch)
                _, predicted = torch.max(outputs, 1)

                y_true_all.extend(y_batch.numpy())
                y_pred_all.extend(predicted.cpu().numpy())

                del X_batch, outputs
                clear_memory()

        del X_file, y_file, dataset, dataloader
        clear_memory()

    y_true_all = np.array(y_true_all)
    y_pred_all = np.array(y_pred_all)

    accuracy = accuracy_score(y_true_all, y_pred_all)
    precision = precision_score(y_true_all, y_pred_all, average='weighted', zero_division=0)
    recall = recall_score(y_true_all, y_pred_all, average='weighted', zero_division=0)
    f1 = f1_score(y_true_all, y_pred_all, average='weighted', zero_division=0)

    print(f"\nüìä {model_name} Performance:")
    print(f"   Accuracy:  {accuracy:.4f}")
    print(f"   Precision: {precision:.4f}")
    print(f"   Recall:    {recall:.4f}")
    print(f"   F1-Score:  {f1:.4f}")

    return y_true_all, y_pred_all, accuracy, precision, recall, f1

# Evaluate both models
teacher_results = evaluate_model_detailed(teacher_model, "TEACHER MODEL", test_files)
student_results = evaluate_model_detailed(student_model, "STUDENT MODEL (Distilled)", test_files)

# ==========================================================
# üìä GENERATE REPORTS
# ==========================================================

print("\n" + "=" * 80)
print("üìä Generating Final Report...")
print("=" * 80)

y_true, y_pred, s_acc, s_prec, s_rec, s_f1 = student_results
_, _, t_acc, t_prec, t_rec, t_f1 = teacher_results

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(20, 16))
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens',
            xticklabels=label_encoder.classes_,
            yticklabels=label_encoder.classes_,
            cbar_kws={'label': 'Count'})
plt.title('Student Model Confusion Matrix (Knowledge Distillation)', fontsize=16, pad=20)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=45, ha='right', fontsize=8)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.savefig('student_confusion_matrix.png', dpi=300, bbox_inches='tight')
print("‚úÖ Confusion matrix saved as 'student_confusion_matrix.png'")

# Performance comparison
performance_retention = (s_acc / t_acc) * 100 if t_acc > 0 else 0

print("\n" + "=" * 80)
print("üìä FINAL COMPARISON: TEACHER vs STUDENT")
print("=" * 80)
print(f"\n{'Metric':<15} {'Teacher':<15} {'Student':<15} {'Difference':<15}")
print("=" * 80)
print(f"{'Accuracy':<15} {t_acc:<15.4f} {s_acc:<15.4f} {(s_acc-t_acc):<15.4f}")
print(f"{'Precision':<15} {t_prec:<15.4f} {s_prec:<15.4f} {(s_prec-t_prec):<15.4f}")
print(f"{'Recall':<15} {t_rec:<15.4f} {s_rec:<15.4f} {(s_rec-t_rec):<15.4f}")
print(f"{'F1-Score':<15} {t_f1:<15.4f} {s_f1:<15.4f} {(s_f1-t_f1):<15.4f}")
print(f"{'Parameters':<15} {teacher_params:<15,} {student_params:<15,} {'-':<15}")
print(f"{'Model Size':<15} {'1.0x':<15} {f'{1/reduction_ratio:.2f}x':<15} {f'{reduction_ratio:.1f}x smaller':<15}")
print("=" * 80)

print(f"\nüéØ Performance Retention: {performance_retention:.2f}%")
print(f"üéØ Model Size Reduction: {reduction_ratio:.1f}x smaller")
print(f"üéØ Parameter Reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}% fewer parameters")

# ==========================================================
# üíæ SAVE MODELS AND PREPROCESSING OBJECTS
# ==========================================================

print("\n" + "=" * 80)
print("üíæ Saving Models and Preprocessing Objects")
print("=" * 80)

# Save PyTorch models
torch.save({
    'model_state_dict': teacher_model.state_dict(),
    'input_size': n_components,
    'hidden_sizes': [128, 64],
    'num_classes': n_classes,
    'accuracy': t_acc,
    'params': teacher_params
}, 'teacher_model_complete.pth')
print("‚úÖ Saved: teacher_model_complete.pth")

torch.save({
    'model_state_dict': student_model.state_dict(),
    'input_size': n_components,
    'hidden_size': 32,
    'num_classes': n_classes,
    'accuracy': s_acc,
    'params': student_params
}, 'student_model_complete.pth')
print("‚úÖ Saved: student_model_complete.pth")

# Save preprocessing objects
preprocessing_objects = {
    'scaler': scaler,
    'pca': pca,
    'label_encoder': label_encoder
}

with open('preprocessing.pkl', 'wb') as f:
    pickle.dump(preprocessing_objects, f)
print("‚úÖ Saved: preprocessing.pkl")

# Save metadata
metadata = {
    'n_classes': int(n_classes),
    'n_features': int(n_features),
    'n_components': int(n_components),
    'teacher_params': int(teacher_params),
    'student_params': int(student_params),
    'teacher_accuracy': float(t_acc),
    'student_accuracy': float(s_acc),
    'size_reduction': float(reduction_ratio),
    'performance_retention': float(performance_retention),
    'total_files': len(csv_files),
    'train_files': len(train_files),
    'val_files': len(val_files),
    'test_files': len(test_files),
    'classes': label_encoder.classes_.tolist()
}

with open('model_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=4)
print("‚úÖ Saved: model_metadata.json")

# Create summary
with open('model_summary.txt', 'w') as f:
    f.write("=" * 80 + "\n")
    f.write("KNOWLEDGE DISTILLATION - PYTORCH MODEL SUMMARY\n")
    f.write("=" * 80 + "\n\n")

    f.write("DATASET INFORMATION:\n")
    f.write(f"  Total Files: {len(csv_files)}\n")
    f.write(f"  Training Files: {len(train_files)}\n")
    f.write(f"  Validation Files: {len(val_files)}\n")
    f.write(f"  Test Files: {len(test_files)}\n\n")

    f.write("TEACHER MODEL:\n")
    f.write(f"  Architecture: LSTM [128, 64]\n")
    f.write(f"  Parameters: {teacher_params:,}\n")
    f.write(f"  Accuracy: {t_acc:.4f}\n")
    f.write(f"  Precision: {t_prec:.4f}\n")
    f.write(f"  Recall: {t_rec:.4f}\n")
    f.write(f"  F1-Score: {t_f1:.4f}\n\n")

    f.write("STUDENT MODEL (DISTILLED):\n")
    f.write(f"  Architecture: LSTM [32]\n")
    f.write(f"  Parameters: {student_params:,}\n")
    f.write(f"  Accuracy: {s_acc:.4f}\n")
    f.write(f"  Precision: {s_prec:.4f}\n")
    f.write(f"  Recall: {s_rec:.4f}\n")
    f.write(f"  F1-Score: {s_f1:.4f}\n\n")

    f.write("COMPRESSION METRICS:\n")
    f.write(f"  Size Reduction: {reduction_ratio:.1f}x smaller\n")
    f.write(f"  Performance Retention: {performance_retention:.2f}%\n")
    f.write(f"  Parameter Reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}%\n\n")

    f.write("FILES GENERATED:\n")
    f.write("  - teacher_model_complete.pth (Teacher model with metadata)\n")
    f.write("  - student_model_complete.pth (Student model with metadata)\n")
    f.write("  - preprocessing.pkl (Scaler, PCA, Label Encoder)\n")
    f.write("  - model_metadata.json (Model specifications)\n")
    f.write("  - student_confusion_matrix.png (Confusion matrix visualization)\n")

print("‚úÖ Saved: model_summary.txt")

print("\n" + "=" * 80)
print("üéâ KNOWLEDGE DISTILLATION COMPLETE!")
print("=" * 80)
print(f"\n‚ú® Successfully processed all {len(csv_files)} files!")
print(f"‚ú® Teacher Model: {teacher_params:,} parameters ‚Üí Accuracy: {t_acc:.4f}")
print(f"‚ú® Student Model: {student_params:,} parameters ‚Üí Accuracy: {s_acc:.4f}")
print(f"‚ú® Compression: {reduction_ratio:.1f}x smaller with {performance_retention:.1f}% performance retention")
print("\nüì¶ All models saved and ready for deployment!")
print("=" * 80)

üéÆ GPU Configuration
‚úÖ GPU detected: Tesla T4
‚úÖ CUDA Version: 12.6
‚úÖ GPU Memory: 15.83 GB
‚úÖ cuDNN autotuner enabled

üì• Downloading CIC-IoT-2023 Dataset from Kaggle...
Downloading from https://www.kaggle.com/api/v1/datasets/download/akashdogra/cic-iot-2023?dataset_version_number=1...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2.77G/2.77G [00:38<00:00, 77.7MB/s]

Extracting files...





‚úÖ Dataset downloaded to: /root/.cache/kagglehub/datasets/akashdogra/cic-iot-2023/versions/1
üìÇ Found 169 CSV files.

üìä Dataset Split (from 169 files):
   Training:   101 files
   Validation: 34 files
   Testing:    34 files

üè∑Ô∏è  Fitting Preprocessing - Scanning ALL Training Files...
Scanning 101 training files for all unique labels...
  File 1/101: part-00000-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  File 2/101: part-00001-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  File 3/101: part-00002-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  File 4/101: part-00003-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  File 5/101: part-00004-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  File 6/101: part-00005-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  F