This implementation includes all six improvements from the MAML++ paper:
1. Multi-Step Loss Optimization (MSL)
2. Derivative-Order Annealing (DA)
3. Per-Step Batch Normalization Running Statistics (BNRS)
4. Per-Step Batch Normalization Weights and Biases (BNWB)
5. Learning Per-Layer Per-Step Learning Rates (LSLR)
6. Cosine Annealing of Meta-Optimizer (CA)

In [70]:
# Import modules
import glob, random
from collections import OrderedDict
import os
import pandas as pd

import numpy as np
from tqdm.auto import tqdm
import json
import math

import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

from PIL import Image
from IPython.display import display

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEVICE = {device}")

# Fix random seeds
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

DEVICE = cpu


In [71]:
# Hyperparameters for MAML++
n_way = 3
k_shot = 5
q_query = 5
input_dim = 1280
train_inner_train_step = 5
val_inner_train_step = 5
inner_lr = 0.01  # Will be learned per-layer per-step in MAML++
meta_lr = 0.001
meta_batch_size = 16
max_epoch = 30
eval_batches = 20

# MAML++ specific hyperparameters
use_first_order_epochs = max_epoch // 2  # First 15 epochs use first-order gradients (DA)
step_weights_initial = [1.0] * (train_inner_train_step + 1)  # For Multi-Step Loss Optimization (MSL)

In [72]:
# MAML++ IMPROVEMENT 1: Enhanced Batch Normalization 
class PerStepBatchNorm1d(nn.Module):
    """MAML++ Improvement: Per-Step Batch Normalization (BNRS + BNWB)
    - BNRS: Per-step running statistics
    - BNWB: Per-step weights and biases
    """
    
    def __init__(self, num_features, num_steps, momentum=0.1, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.num_steps = num_steps
        self.momentum = momentum
        self.eps = eps
        
        # BNWB: Per-step weights and biases (including step 0)
        self.weight = nn.Parameter(torch.ones(num_steps + 1, num_features))
        self.bias = nn.Parameter(torch.zeros(num_steps + 1, num_features))
        
        # BNRS: Per-step running statistics (including step 0)
        self.register_buffer('running_mean', torch.zeros(num_steps + 1, num_features))
        self.register_buffer('running_var', torch.ones(num_steps + 1, num_features))
        
    def forward(self, x, step=0):
        step = min(step, self.num_steps)
            
        if self.training:
            # Use batch statistics during training
            batch_mean = x.mean(dim=0, keepdim=False)
            batch_var = x.var(dim=0, unbiased=False, keepdim=False)
            
            # Update running statistics for this step
            with torch.no_grad():
                self.running_mean[step] = (1 - self.momentum) * self.running_mean[step] + self.momentum * batch_mean
                self.running_var[step] = (1 - self.momentum) * self.running_var[step] + self.momentum * batch_var
            
            mean = batch_mean
            var = batch_var
        else:
            # Use running statistics during evaluation
            mean = self.running_mean[step]
            var = self.running_var[step]
        
        # Normalize and apply per-step weights and biases
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.weight[step] * x_norm + self.bias[step]

In [73]:
# MAML++ IMPROVEMENT 2: Enhanced Classifier with Per-Layer Learning Rates
class MalwarePlusPlusClassifier(nn.Module):
    """MAML++ Enhanced Classifier with multiple improvements"""
    
    def __init__(self, input_dim, hidden_dim=256, output_dim=3, num_inner_steps=5):
        super(MalwarePlusPlusClassifier, self).__init__()
        self.num_inner_steps = num_inner_steps
        
        # Original network layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim//2)
        self.fc4 = nn.Linear(hidden_dim//2, output_dim)
        
        # MAML++ IMPROVEMENT: Per-step batch normalization (BNRS + BNWB)
        self.bn1 = PerStepBatchNorm1d(hidden_dim, num_inner_steps)
        self.bn2 = PerStepBatchNorm1d(hidden_dim, num_inner_steps)
        self.bn3 = PerStepBatchNorm1d(hidden_dim//2, num_inner_steps)
        
        # MAML++ IMPROVEMENT: Learnable per-layer per-step learning rates (LSLR)
        # Each layer has its own learning rate for each step
        self.layer_lrs = nn.ParameterDict({
            # FC layers use relatively larger learning rates
            'fc1': nn.Parameter(torch.tensor([0.01, 0.008, 0.006, 0.004, 0.002][:num_inner_steps])),
            'fc2': nn.Parameter(torch.tensor([0.005, 0.004, 0.003, 0.002, 0.001][:num_inner_steps])),
            'fc3': nn.Parameter(torch.tensor([0.003, 0.002, 0.002, 0.001, 0.001][:num_inner_steps])),
            'fc4': nn.Parameter(torch.tensor([0.008, 0.006, 0.004, 0.003, 0.002][:num_inner_steps])),  
            # BN layers use smaller learning rates
            'bn1': nn.Parameter(torch.tensor([0.001, 0.0008, 0.0006, 0.0004, 0.0002][:num_inner_steps])),
            'bn2': nn.Parameter(torch.tensor([0.001, 0.0008, 0.0006, 0.0004, 0.0002][:num_inner_steps])),
            'bn3': nn.Parameter(torch.tensor([0.001, 0.0008, 0.0006, 0.0004, 0.0002][:num_inner_steps]))
        })
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x, step=0):
        x = self.fc1(x)
        x = self.bn1(x, step)  # Use step-specific BN
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.fc2(x)
        x = self.bn2(x, step)  # Use step-specific BN
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.fc3(x)
        x = self.bn3(x, step)  # Use step-specific BN
        x = F.relu(x)
        
        x = self.fc4(x)
        return x
    
    def functional_forward(self, x, params, step=0):
        """Forward pass using custom parameters with step-aware BN"""
        # FC1
        x = F.linear(x, params.get('fc1.weight', self.fc1.weight), 
                    params.get('fc1.bias', self.fc1.bias))
        
        # BN1 - Use step-specific parameters
        step_idx = min(step, self.num_inner_steps)
        if f'bn1.weight' in params:
            bn1_weight = params[f'bn1.weight'][step_idx]
            bn1_bias = params[f'bn1.bias'][step_idx]
        else:
            bn1_weight = self.bn1.weight[step_idx]
            bn1_bias = self.bn1.bias[step_idx]
            
        bn1_running_mean = self.bn1.running_mean[step_idx]
        bn1_running_var = self.bn1.running_var[step_idx]
        
        x = F.batch_norm(x, bn1_running_mean, bn1_running_var, 
                        bn1_weight, bn1_bias, training=self.training, eps=1e-5)
        x = F.relu(x)
        x = F.dropout(x, training=self.training, p=0.3)
        
        # FC2
        x = F.linear(x, params.get('fc2.weight', self.fc2.weight), 
                    params.get('fc2.bias', self.fc2.bias))
        
        # BN2
        if f'bn2.weight' in params:
            bn2_weight = params[f'bn2.weight'][step_idx]
            bn2_bias = params[f'bn2.bias'][step_idx]
        else:
            bn2_weight = self.bn2.weight[step_idx]
            bn2_bias = self.bn2.bias[step_idx]
            
        bn2_running_mean = self.bn2.running_mean[step_idx]
        bn2_running_var = self.bn2.running_var[step_idx]
        
        x = F.batch_norm(x, bn2_running_mean, bn2_running_var,
                        bn2_weight, bn2_bias, training=self.training, eps=1e-5)
        x = F.relu(x)
        x = F.dropout(x, training=self.training, p=0.3)
        
        # FC3
        x = F.linear(x, params.get('fc3.weight', self.fc3.weight), 
                    params.get('fc3.bias', self.fc3.bias))
        
        # BN3
        if f'bn3.weight' in params:
            bn3_weight = params[f'bn3.weight'][step_idx]
            bn3_bias = params[f'bn3.bias'][step_idx]
        else:
            bn3_weight = self.bn3.weight[step_idx]
            bn3_bias = self.bn3.bias[step_idx]
            
        bn3_running_mean = self.bn3.running_mean[step_idx]
        bn3_running_var = self.bn3.running_var[step_idx]
        
        x = F.batch_norm(x, bn3_running_mean, bn3_running_var,
                        bn3_weight, bn3_bias, training=self.training, eps=1e-5)
        x = F.relu(x)
        
        # FC4
        x = F.linear(x, params.get('fc4.weight', self.fc4.weight), 
                    params.get('fc4.bias', self.fc4.bias))
        return x

In [74]:
# Utility functions
def create_malware_label(k_shot, q_query):
    """Create labels for calculating accuracy in test phase."""
    n_way = 3  # 2 abnormal + 1 normal
    labels = []
    for class_idx in range(n_way):
        class_labels = [class_idx] * (k_shot + q_query)
        labels.extend(class_labels)
    return torch.tensor(labels, dtype=torch.long)

def create_label(n_way, k_shot):
    """Create labels for support set and query set."""
    return torch.arange(n_way).repeat_interleave(k_shot).long()

def calculate_accuracy(logits, labels):
    """utility function for accuracy calculation"""
    acc = np.asarray(
        [(torch.argmax(logits, -1).cpu().numpy() == labels.cpu().numpy())]
    ).mean()
    return acc

In [75]:
# MAML++ IMPROVEMENT 3: Cosine Annealing Scheduler
class CosineAnnealingLR:
    """MAML++ Improvement: Cosine Annealing for Meta-Optimizer (CA)"""
    
    def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
        self.optimizer = optimizer
        self.T_max = T_max
        self.eta_min = eta_min
        self.last_epoch = last_epoch
        self.base_lrs = [group['lr'] for group in optimizer.param_groups]
        
    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        
        for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
            param_group['lr'] = self.eta_min + (base_lr - self.eta_min) * \
                               (1 + math.cos(math.pi * epoch / self.T_max)) / 2
                               
    def get_lr(self):
        return [group['lr'] for group in self.optimizer.param_groups]

In [76]:
# Dataset class (using simplified version for demonstration)
class MalwareDetection(Dataset):
    def __init__(self, data_structure_file, split='train', k_shot=1, q_query=5):
        with open(data_structure_file, 'r') as f:
            self.data_structure = json.load(f)
        
        self.split = split
        self.classes = list(self.data_structure[split].keys())
        self.k_shot = k_shot
        self.q_query = q_query
        self.normal_class = 'benign'
        
        self._validate_data()
    
    def _validate_data(self):
        min_samples = self.k_shot + self.q_query
        for cls, files in self.data_structure[self.split].items():
            if len(files) < min_samples:
                print(f"Warning: only {len(files)} samples in class '{cls}' for split '{self.split}'. Required: {min_samples}. Will sample with replacement.")

    def __getitem__(self, idx):
        np.random.seed(42 + idx)

        fraud_classes = [cls for cls in self.classes if cls != self.normal_class]
        
        if len(fraud_classes) >= 2:
            selected_frauds = np.random.choice(fraud_classes, 2, replace=False)
            task_classes = list(selected_frauds) + [self.normal_class]
        elif len(fraud_classes) == 1:
            if self.split == 'test':
                task_classes = fraud_classes + [self.normal_class]
            else:
                task_classes = fraud_classes + fraud_classes + [self.normal_class]
        else:
            raise ValueError(f"No fraud classes available in {self.split} split")
        
        task_data = []
        for cls in task_classes:
            class_files = self.data_structure[self.split][cls]
            
            if len(class_files) >= self.k_shot + self.q_query:
                selected_files = np.random.choice(class_files, 
                                                self.k_shot + self.q_query, 
                                                replace=False)
            else:
                selected_files = np.random.choice(class_files, 
                                                self.k_shot + self.q_query, 
                                                replace=True)
            
            class_features = []
            for file_path in selected_files:
                corrected_path = self._fix_file_path(file_path)
                
                try:
                    features = np.load(corrected_path)
                    if features.ndim > 1:
                        features = features.flatten()
                    class_features.append(features)
                except Exception as e:
                    if idx == 0:
                        print(f"Error loading {corrected_path}: {e}")
                    class_features.append(np.zeros(1280))
            
            task_data.append(torch.tensor(np.array(class_features), dtype=torch.float32))
        
        return torch.stack(task_data)
    
    def _fix_file_path(self, original_path):
        if os.path.exists(original_path):
            return original_path
        
        possible_prefixes = ['../', '../../', './']
        for prefix in possible_prefixes:
            new_path = os.path.join(prefix, original_path)
            if os.path.exists(new_path):
                return os.path.abspath(new_path)
        
        return original_path
    
    def __len__(self):
        fraud_classes = [cls for cls in self.classes if cls != self.normal_class]
        if len(fraud_classes) >= 2:
            from math import comb
            return comb(len(fraud_classes), 2) * 100
        else:
            return 100

In [77]:
def get_meta_batch(meta_batch_size, k_shot, q_query, data_loader, iterator):
    """Get meta batch function"""
    data = []
    for _ in range(meta_batch_size):
        try:
            task_data = next(iterator)
        except StopIteration:
            iterator = iter(data_loader)
            task_data = next(iterator)
        
        task_data = task_data.squeeze(0)
        task_data = task_data.view(-1, task_data.size(-1))
        data.append(task_data)
    
    return torch.stack(data).to(device), iterator

In [78]:
# MAML++ MAIN ALGORITHM WITH ALL IMPROVEMENTS
def MAMLPlusPlusSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step,
    train,
    epoch=0,
    step_weights=None,
    use_first_order_epochs=15,
    return_labels=False,
):
    """MAML++ Algorithm with all six improvements"""
    criterion = loss_fn
    task_loss = []
    task_acc = []
    labels = []
    
    # MAML++ IMPROVEMENT 4: Derivative-Order Annealing (DA)
    use_second_order = epoch >= use_first_order_epochs
    
    if step_weights is None:
        step_weights = [1.0] * (inner_train_step + 1)
    
    for meta_batch in x:
        # Split support and query sets
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        # Copy the params for inner loop
        fast_weights = OrderedDict(model.named_parameters())
        
        # MAML++ IMPROVEMENT 5: Multi-Step Loss Optimization (MSL)
        # Store losses from all steps for multi-step optimization
        step_losses = []
        
        ### ---------- INNER TRAIN LOOP ---------- ###
        for inner_step in range(inner_train_step):
            train_label = create_label(n_way, k_shot).to(device)
            
            # Forward pass with step-aware batch normalization
            logits = model.functional_forward(support_set, fast_weights, step=inner_step)
            loss = criterion(logits, train_label)
            
            # Calculate gradients with appropriate order
            if use_second_order and train:
                grads = torch.autograd.grad(loss, fast_weights.values(), 
                                          create_graph=True, allow_unused=True, retain_graph=True)
            else:
                grads = torch.autograd.grad(loss, fast_weights.values(), 
                                          create_graph=False, allow_unused=True, retain_graph=True)

            # MAML++ IMPROVEMENT 6: Per-Layer Per-Step Learning Rates (LSLR)
            # Update fast_weights using learned learning rates
            updated_params = OrderedDict()
            
            for (name, param), grad in zip(fast_weights.items(), grads):
                if grad is None:
                    # Handle unused parameters
                    updated_params[name] = param
                    continue
                    
                # Determine which learning rate to use
                layer_name = name.split('.')[0]  # e.g., 'fc1', 'bn1'
                
                if layer_name in model.layer_lrs and inner_step < len(model.layer_lrs[layer_name]):
                    lr = torch.abs(model.layer_lrs[layer_name][inner_step])  # Ensure positive LR
                    # Added clamping to keep learning rates in a reasonable range
                    lr = torch.clamp(lr, min=1e-6, max=0.1)
                else:
                    lr = 0.001  # Fallback learning rate

                # Debug info: print learning rate every 50 steps
                if inner_step == 0 and hasattr(MAMLPlusPlusSolver, '_debug_counter'):
                    MAMLPlusPlusSolver._debug_counter = getattr(MAMLPlusPlusSolver, '_debug_counter', 0) + 1
                    if MAMLPlusPlusSolver._debug_counter % 50 == 0:
                        print(f"Layer {layer_name} step {inner_step} LR: {lr:.6f}")
                
                updated_params[name] = param - lr * grad
            
            fast_weights = updated_params
            
            # MSL: Compute query loss for this step
            if not return_labels:
                val_label = create_label(n_way, q_query).to(device)
                query_logits = model.functional_forward(query_set, fast_weights, step=inner_step)
                query_loss = criterion(query_logits, val_label)
                step_losses.append(query_loss * step_weights[inner_step])

        ### ---------- FINAL STEP EVALUATION ---------- ###
        # MSL: Compute query loss for this step
        if not return_labels:
            # Evaluate final step
            val_label = create_label(n_way, q_query).to(device)
            final_logits = model.functional_forward(query_set, fast_weights, step=inner_train_step)
            final_loss = criterion(final_logits, val_label)
            step_losses.append(final_loss * step_weights[inner_train_step])
            
            # MSL: Combine all step losses
            total_loss = sum(step_losses) / sum(step_weights)
            task_loss.append(total_loss)
            task_acc.append(calculate_accuracy(final_logits, val_label))  # 現在用 final_logits
        else:
            # Testing mode
            logits = model.functional_forward(query_set, fast_weights, step=inner_train_step)
            labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    # Update outer loop
    model.train()
    optimizer.zero_grad()

    meta_batch_loss = torch.stack(task_loss).mean()
    if train:
        meta_batch_loss.backward()
        optimizer.step()

    task_acc = np.mean(task_acc)
    return meta_batch_loss, task_acc

In [79]:
# ADAPTIVE STEP WEIGHTS FOR MSL
def get_step_weights(epoch, max_epochs, num_steps):
    """MAML++ MSL: Adaptive step weights with annealing
    Early epochs: equal weights for all steps
    Later epochs: higher weight for final steps
    """
    progress = epoch / max_epochs
    weights = []
    
    for i in range(num_steps + 1):  # +1 for final evaluation step
        if i == num_steps:  # Final step - stronger weight increase
            # From 1.0 to 5.0
            weight = 1.0 + progress * 4.0  
        else:
            # More aggressive early step weight decay
            decay_factor = (num_steps - i) / num_steps
            weight = 1.0 - progress * 0.8 * decay_factor  # Originally 0.5, changed to 0.8

        # Ensure weights don't get too small
        weights.append(max(weight, 0.1))
    
    return weights

In [80]:
# SETUP TRAINING WITH MAML++ IMPROVEMENTS

# Check if data file exists
data_file = '../malware_data_structure.json'
if os.path.exists(data_file):
    # Prepare datasets and dataloaders
    train_dataset = MalwareDetection(data_file, 'train', k_shot, q_query)
    val_dataset = MalwareDetection(data_file, 'val', k_shot, q_query)

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
    # Create MAML++ model with enhanced features
    meta_model = MalwarePlusPlusClassifier(
        input_dim=input_dim, 
        num_inner_steps=train_inner_train_step
    ).to(device)

    # MAML++ IMPROVEMENT: Meta-optimizer with cosine annealing (CA)
    optimizer = torch.optim.Adam(meta_model.parameters(), lr=meta_lr)
    scheduler = CosineAnnealingLR(optimizer, T_max=max_epoch, eta_min=meta_lr * 0.01)

    loss_fn = nn.CrossEntropyLoss()

    print(f"MAML++ Model parameters: {sum(p.numel() for p in meta_model.parameters())}")
    print(f"Using first-order gradients for first {use_first_order_epochs} epochs")
    print(f"Meta-learning rate will be annealed from {meta_lr} to {meta_lr * 0.01}")
else:
    print(f"Data file {data_file} not found. Please make sure the data is available.")

MAML++ Model parameters: 434726
Using first-order gradients for first 15 epochs
Meta-learning rate will be annealed from 0.001 to 1e-05


In [81]:
# MAML++ TRAINING LOOP (only if data is available)
if 'meta_model' in locals():
    train_iter = iter(train_loader)
    val_iter = iter(val_loader)

    print("Starting MAML++ training")
    print("Improvements included:")
    print("1. Multi-Step Loss Optimization (MSL) - ✓")
    print("2. Derivative-Order Annealing (DA) - ✓")
    print("3. Per-Step Batch Normalization (BNRS + BNWB) - ✓")
    print("4. Per-Layer Per-Step Learning Rates (LSLR) - ✓")
    print("5. Cosine Annealing Meta-Optimizer (CA) - ✓")
    print("-" * 60)

    for epoch in range(max_epoch):
        print(f"Epoch {epoch+1}/{max_epoch}")
        
        # Get adaptive step weights for MSL
        step_weights = get_step_weights(epoch, max_epoch, train_inner_train_step)
        
        # Show current learning rate (CA)
        current_lr = scheduler.get_lr()[0]
        print(f"Meta-LR: {current_lr:.6f}, Using {'2nd' if epoch >= use_first_order_epochs else '1st'}-order gradients")
        
        # Training
        train_meta_loss = []
        train_acc = []
        
        for train_step in tqdm(range(len(train_loader) // meta_batch_size), desc="Training"):
            x, train_iter = get_meta_batch(
                meta_batch_size, k_shot, q_query, train_loader, train_iter
            )
            
            meta_loss, acc = MAMLPlusPlusSolver(
                meta_model,
                optimizer,
                x,
                n_way,
                k_shot,
                q_query,
                loss_fn,
                inner_train_step=train_inner_train_step,
                train=True,
                epoch=epoch,
                step_weights=step_weights,
                use_first_order_epochs=use_first_order_epochs
            )
            
            train_meta_loss.append(meta_loss.item())
            train_acc.append(acc)
        
        print(f"Loss: {np.mean(train_meta_loss):.3f}\tAccuracy: {np.mean(train_acc)*100:.3f}%")
        
        # Validation
        val_acc = []
        for eval_step in tqdm(range(min(eval_batches, len(val_loader) // meta_batch_size)), desc="Validation"):
            x, val_iter = get_meta_batch(
                meta_batch_size, k_shot, q_query, val_loader, val_iter
            )
            
            _, acc = MAMLPlusPlusSolver(
                meta_model,
                optimizer,
                x,
                n_way,
                k_shot,
                q_query,
                loss_fn,
                inner_train_step=val_inner_train_step,
                train=False,
                epoch=epoch,
                step_weights=step_weights,
                use_first_order_epochs=use_first_order_epochs
            )
            val_acc.append(acc)
        
        print(f"Validation accuracy: {np.mean(val_acc)*100:.3f}%")
        
        # Update learning rate scheduler (CA)
        scheduler.step()
        
        print("-" * 50)

    print("MAML++ Training Complete!")
else:
    print("Training skipped due to missing data file.")

Starting MAML++ training
Improvements included:
1. Multi-Step Loss Optimization (MSL) - ✓
2. Derivative-Order Annealing (DA) - ✓
3. Per-Step Batch Normalization (BNRS + BNWB) - ✓
4. Per-Layer Per-Step Learning Rates (LSLR) - ✓
5. Cosine Annealing Meta-Optimizer (CA) - ✓
------------------------------------------------------------
Epoch 1/30
Meta-LR: 0.001000, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:19<00:00,  3.24it/s]


Loss: 0.842	Accuracy: 54.382%


Validation: 100%|██████████| 6/6 [00:01<00:00,  4.92it/s]


Validation accuracy: 51.111%
--------------------------------------------------
Epoch 2/30
Meta-LR: 0.001000, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  3.92it/s]


Loss: 0.747	Accuracy: 58.589%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.54it/s]


Validation accuracy: 53.472%
--------------------------------------------------
Epoch 3/30
Meta-LR: 0.000997, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  3.90it/s]


Loss: 0.716	Accuracy: 59.899%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.86it/s]


Validation accuracy: 49.861%
--------------------------------------------------
Epoch 4/30
Meta-LR: 0.000989, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  4.08it/s]


Loss: 0.684	Accuracy: 61.546%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.29it/s]


Validation accuracy: 50.486%
--------------------------------------------------
Epoch 5/30
Meta-LR: 0.000976, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  4.01it/s]


Loss: 0.666	Accuracy: 62.796%


Validation: 100%|██████████| 6/6 [00:00<00:00,  6.56it/s]


Validation accuracy: 52.917%
--------------------------------------------------
Epoch 6/30
Meta-LR: 0.000957, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  4.05it/s]


Loss: 0.647	Accuracy: 63.723%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.76it/s]


Validation accuracy: 53.958%
--------------------------------------------------
Epoch 7/30
Meta-LR: 0.000934, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  3.97it/s]


Loss: 0.632	Accuracy: 65.255%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.43it/s]


Validation accuracy: 51.944%
--------------------------------------------------
Epoch 8/30
Meta-LR: 0.000905, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  3.93it/s]


Loss: 0.607	Accuracy: 66.505%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.61it/s]


Validation accuracy: 50.139%
--------------------------------------------------
Epoch 9/30
Meta-LR: 0.000873, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:14<00:00,  4.15it/s]


Loss: 0.586	Accuracy: 67.440%


Validation: 100%|██████████| 6/6 [00:00<00:00,  6.50it/s]


Validation accuracy: 53.750%
--------------------------------------------------
Epoch 10/30
Meta-LR: 0.000836, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  4.00it/s]


Loss: 0.571	Accuracy: 68.441%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.50it/s]


Validation accuracy: 51.111%
--------------------------------------------------
Epoch 11/30
Meta-LR: 0.000796, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:14<00:00,  4.14it/s]


Loss: 0.550	Accuracy: 69.483%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.79it/s]


Validation accuracy: 51.250%
--------------------------------------------------
Epoch 12/30
Meta-LR: 0.000753, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  4.02it/s]


Loss: 0.542	Accuracy: 70.249%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.78it/s]


Validation accuracy: 52.569%
--------------------------------------------------
Epoch 13/30
Meta-LR: 0.000706, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  4.04it/s]


Loss: 0.517	Accuracy: 71.801%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.81it/s]


Validation accuracy: 50.486%
--------------------------------------------------
Epoch 14/30
Meta-LR: 0.000658, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:15<00:00,  4.02it/s]


Loss: 0.496	Accuracy: 73.353%


Validation: 100%|██████████| 6/6 [00:00<00:00,  7.07it/s]


Validation accuracy: 51.528%
--------------------------------------------------
Epoch 15/30
Meta-LR: 0.000608, Using 1st-order gradients


Training: 100%|██████████| 62/62 [00:14<00:00,  4.15it/s]


Loss: 0.487	Accuracy: 74.160%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.36it/s]


Validation accuracy: 50.000%
--------------------------------------------------
Epoch 16/30
Meta-LR: 0.000557, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:19<00:00,  3.22it/s]


Loss: 0.462	Accuracy: 76.257%


Validation: 100%|██████████| 6/6 [00:00<00:00,  6.45it/s]


Validation accuracy: 53.472%
--------------------------------------------------
Epoch 17/30
Meta-LR: 0.000505, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:19<00:00,  3.11it/s]


Loss: 0.441	Accuracy: 77.466%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.78it/s]


Validation accuracy: 52.708%
--------------------------------------------------
Epoch 18/30
Meta-LR: 0.000453, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:19<00:00,  3.10it/s]


Loss: 0.426	Accuracy: 78.300%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.43it/s]


Validation accuracy: 52.222%
--------------------------------------------------
Epoch 19/30
Meta-LR: 0.000402, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


Loss: 0.404	Accuracy: 79.966%


Validation: 100%|██████████| 6/6 [00:00<00:00,  6.41it/s]


Validation accuracy: 53.472%
--------------------------------------------------
Epoch 20/30
Meta-LR: 0.000352, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:19<00:00,  3.16it/s]


Loss: 0.388	Accuracy: 80.780%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.73it/s]


Validation accuracy: 53.264%
--------------------------------------------------
Epoch 21/30
Meta-LR: 0.000304, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:20<00:00,  3.02it/s]


Loss: 0.366	Accuracy: 82.392%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.35it/s]


Validation accuracy: 52.292%
--------------------------------------------------
Epoch 22/30
Meta-LR: 0.000258, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:19<00:00,  3.10it/s]


Loss: 0.350	Accuracy: 83.401%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.73it/s]


Validation accuracy: 51.736%
--------------------------------------------------
Epoch 23/30
Meta-LR: 0.000214, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:20<00:00,  3.09it/s]


Loss: 0.336	Accuracy: 84.066%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.69it/s]


Validation accuracy: 52.083%
--------------------------------------------------
Epoch 24/30
Meta-LR: 0.000174, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:20<00:00,  3.02it/s]


Loss: 0.318	Accuracy: 85.363%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.41it/s]


Validation accuracy: 51.528%
--------------------------------------------------
Epoch 25/30
Meta-LR: 0.000137, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:20<00:00,  3.09it/s]


Loss: 0.305	Accuracy: 86.183%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.63it/s]


Validation accuracy: 51.111%
--------------------------------------------------
Epoch 26/30
Meta-LR: 0.000105, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:20<00:00,  3.10it/s]


Loss: 0.295	Accuracy: 86.835%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.69it/s]


Validation accuracy: 52.222%
--------------------------------------------------
Epoch 27/30
Meta-LR: 0.000076, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:19<00:00,  3.13it/s]


Loss: 0.278	Accuracy: 87.641%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.28it/s]


Validation accuracy: 52.708%
--------------------------------------------------
Epoch 28/30
Meta-LR: 0.000053, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:20<00:00,  3.01it/s]


Loss: 0.274	Accuracy: 87.675%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.64it/s]


Validation accuracy: 52.986%
--------------------------------------------------
Epoch 29/30
Meta-LR: 0.000034, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:19<00:00,  3.17it/s]


Loss: 0.266	Accuracy: 88.172%


Validation: 100%|██████████| 6/6 [00:01<00:00,  5.76it/s]


Validation accuracy: 52.639%
--------------------------------------------------
Epoch 30/30
Meta-LR: 0.000021, Using 2nd-order gradients


Training: 100%|██████████| 62/62 [00:19<00:00,  3.12it/s]


Loss: 0.259	Accuracy: 88.730%


Validation: 100%|██████████| 6/6 [00:00<00:00,  6.96it/s]

Validation accuracy: 52.569%
--------------------------------------------------
MAML++ Training Complete!





In [82]:
# Save the trained MAML++ model
if 'meta_model' in locals():
    torch.save({
        'model_state_dict': meta_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.__dict__,
        'hyperparameters': {
            'n_way': n_way,
            'k_shot': k_shot,
            'q_query': q_query,
            'input_dim': input_dim,
            'meta_lr': meta_lr,
            'train_inner_train_step': train_inner_train_step,
            'use_first_order_epochs': use_first_order_epochs
        },
        'improvements': {
            'MSL': 'Multi-Step Loss Optimization',
            'DA': 'Derivative-Order Annealing', 
            'BNRS': 'Per-Step Batch Normalization Running Statistics',
            'BNWB': 'Per-Step Batch Normalization Weights and Biases',
            'LSLR': 'Learning Per-Layer Per-Step Learning Rates',
            'CA': 'Cosine Annealing of Meta-Optimizer'
        }
    }, 'malware_maml++_model.pth')

    print("MAML++ model saved as malware_maml_plus_plus_model.pth")

MAML++ model saved as malware_maml_plus_plus_model.pth


In [83]:
# MAML++ Testing with accuracy calculation
def test_maml_plus_plus_model(model, test_data_path_or_dataset, inner_train_step=500, epoch=30):
    """
    MAML++ Test function that returns predicted and true labels for accuracy calculation
    3-way tasks: 2 abnormal + 1 normal
    """
    test_dataset = MalwareDetection(test_data_path_or_dataset, 'test', k_shot, q_query)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    test_iter = iter(test_loader)
    
    test_batches = min(20, len(test_loader))
    all_predicted_labels = []
    all_true_labels = []
    task_accuracies = []

    print("Starting MAML++ testing and accuracy calculation...")

    # Fix random seed for consistent label generation
    np.random.seed(42)
    
    # Get adaptive step weights for testing (same as final training epoch)
    step_weights = get_step_weights(epoch-1, max_epoch, inner_train_step)
    
    for batch_idx in tqdm(range(test_batches), desc="Testing MAML++ with Accuracy"):
        x, test_iter = get_meta_batch(1, k_shot, q_query, test_loader, test_iter)

        # Check the actual task dimensions
        batch_size, total_samples, feature_dim = x.shape
        actual_n_way = total_samples // (k_shot + q_query)

        # 3-way task query set labels
        task_true_labels = []
        for class_idx in range(3):
            task_true_labels.extend([class_idx] * q_query)

        # Get model predictions using MAML++
        predicted_labels = MAMLPlusPlusSolver(
            model,
            optimizer, 
            x,
            3,  
            k_shot,
            q_query,
            loss_fn,
            inner_train_step=inner_train_step,
            train=False,
            epoch=epoch,  # Use final epoch settings
            step_weights=step_weights,
            use_first_order_epochs=use_first_order_epochs,
            return_labels=True,
        )

        # Calculate current task accuracy
        task_true = np.array(task_true_labels)
        task_pred = np.array(predicted_labels)
        task_acc = (task_true == task_pred).mean()
        task_accuracies.append(task_acc)

        # Collect all labels
        all_predicted_labels.extend(predicted_labels)
        all_true_labels.extend(task_true_labels)

        if batch_idx % 5 == 0:  # Print every 5 batches
            print(f"Batch {batch_idx+1}/{test_batches} - Task Accuracy: {task_acc:.4f}")
    
    return all_predicted_labels, all_true_labels, task_accuracies

# Execute MAML++ testing with accuracy calculation
print("=" * 60)
print("FINAL MAML++ TESTING")
print("=" * 60)

test_predicted_labels, test_true_labels, test_task_accuracies = test_maml_plus_plus_model(
    meta_model, 
    '../malware_data_structure.json',
    inner_train_step=val_inner_train_step,  # Use validation inner steps for testing
    epoch=max_epoch
)

average_test_accuracy = np.mean(test_task_accuracies)
std_test_accuracy = np.std(test_task_accuracies)

print(f"MAML++ Final Test Results:")
print(f"Average Test Task Accuracy: {average_test_accuracy*100:.3f}% ± {std_test_accuracy*100:.3f}%")
print(f"Best Task Accuracy: {np.max(test_task_accuracies)*100:.3f}%")
print(f"Worst Task Accuracy: {np.min(test_task_accuracies)*100:.3f}%")

# Save MAML++ test results
results_df = pd.DataFrame({
    'id': range(len(test_predicted_labels)),
    'predicted_class': test_predicted_labels,
    'true_class': test_true_labels
})

results_df.to_csv('malware_maml_plus_plus_predictions.csv', index=False)
print("MAML++ test results saved as malware_maml_plus_plus_predictions.csv")

# Calculate detailed metrics
from sklearn.metrics import classification_report, confusion_matrix

print("\nDetailed Classification Report:")
print(classification_report(test_true_labels, test_predicted_labels, 
                          target_names=['Class 0', 'Class 1', 'Class 2']))

print("\nConfusion Matrix:")
print(confusion_matrix(test_true_labels, test_predicted_labels))

FINAL MAML++ TESTING
Starting MAML++ testing and accuracy calculation...


Testing MAML++ with Accuracy:  45%|████▌     | 9/20 [00:00<00:00, 83.98it/s]

Batch 1/20 - Task Accuracy: 0.8667
Batch 6/20 - Task Accuracy: 0.5333
Batch 11/20 - Task Accuracy: 0.5333
Batch 16/20 - Task Accuracy: 0.6000


Testing MAML++ with Accuracy: 100%|██████████| 20/20 [00:00<00:00, 89.30it/s]

MAML++ Final Test Results:
Average Test Task Accuracy: 61.000% ± 11.010%
Best Task Accuracy: 86.667%
Worst Task Accuracy: 40.000%
MAML++ test results saved as malware_maml_plus_plus_predictions.csv

Detailed Classification Report:
              precision    recall  f1-score   support

     Class 0       0.89      0.89      0.89       100
     Class 1       0.46      0.48      0.47       100
     Class 2       0.48      0.46      0.47       100

    accuracy                           0.61       300
   macro avg       0.61      0.61      0.61       300
weighted avg       0.61      0.61      0.61       300


Confusion Matrix:
[[89  6  5]
 [ 7 48 45]
 [ 4 50 46]]



