In [2]:
pip install gdown

Collecting gdown
  Downloading gdown-5.2.0-py3-none-any.whl.metadata (5.8 kB)
Collecting beautifulsoup4 (from gdown)
  Downloading beautifulsoup4-4.14.2-py3-none-any.whl.metadata (3.8 kB)
Collecting soupsieve>1.2 (from beautifulsoup4->gdown)
  Downloading soupsieve-2.8-py3-none-any.whl.metadata (4.6 kB)
Collecting PySocks!=1.5.7,>=1.5.6 (from requests[socks]->gdown)
  Downloading PySocks-1.7.1-py3-none-any.whl.metadata (13 kB)
Downloading gdown-5.2.0-py3-none-any.whl (18 kB)
Downloading beautifulsoup4-4.14.2-py3-none-any.whl (106 kB)
Downloading soupsieve-2.8-py3-none-any.whl (36 kB)
Downloading PySocks-1.7.1-py3-none-any.whl (16 kB)
Installing collected packages: soupsieve, PySocks, beautifulsoup4, gdown
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4/4[0m [gdown]
[1A[2KSuccessfully installed PySocks-1.7.1 beautifulsoup4-4.14.2 gdown-5.2.0 soupsieve-2.8

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49

In [12]:
# Import modules
import glob, random
from collections import OrderedDict
import os
import pandas as pd
import sys

import numpy as np
from tqdm.auto import tqdm
import json

import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

from dotenv import load_dotenv

# Add src to path to import project modules
sys.path.append('../src')
from extraction.data_loader import create_dataloaders
from extraction.downloader import download_dataset

from PIL import Image
from IPython.display import display

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEVICE = {device}")

# Fix random seeds
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

DEVICE = cpu


In [13]:
# Hyperparameters
n_way = 3  # 2 abnormal classes + 1 normal (benign) class
k_shot = 1
q_query = 5
input_dim = 1280  # Feature dimension from the standardized dataset
train_inner_train_step = 5
val_inner_train_step = 5
inner_lr = 0.001
meta_lr = 0.001
meta_batch_size = 16
max_epoch = 30
eval_batches = 20

In [14]:
# Load dataset using the standardized dataloader
load_dotenv()

# Download dataset if not exists
dataset_dir = '../dataset'
if not os.path.exists(dataset_dir):
    download_dataset(dir=dataset_dir)

features_dir = os.path.join(dataset_dir, 'features')
split_csv_path = os.path.join(dataset_dir, 'label_split.csv')

# Create standardized dataloaders
dataloaders = create_dataloaders(
    features_dir=features_dir,
    split_csv_path=split_csv_path,
    batch_size=1,  # We'll handle batching in MAML
    val_ratio=0.1,
    test_ratio=0.1,
    generalized=False,
    num_workers=0  # Set to 0 for compatibility
)

print(f"Available dataloaders: {list(dataloaders.keys())}")
print(f"Train dataset size: {len(dataloaders['train'].dataset)}")
print(f"Validation dataset size: {len(dataloaders['val'].dataset)}")
print(f"Test seen dataset size: {len(dataloaders['test_seen'].dataset)}")
print(f"Test unseen dataset size: {len(dataloaders['test_unseen'].dataset)}")

Available dataloaders: ['train', 'val', 'test_seen', 'test_unseen']
Train dataset size: 46729
Validation dataset size: 5840
Test seen dataset size: 5840
Test unseen dataset size: 23081


In [15]:
class MAMLDatasetWrapper(Dataset):
    """
    Wrapper to adapt the standardized dataset for MAML's episodic training.
    Creates N-way K-shot tasks from the standardized malware dataset.
    """
    def __init__(self, dataloader, n_way=3, k_shot=1, q_query=5, num_tasks=1000):
        """
        Args:
            dataloader: One of the standardized dataloaders (train/val/test)
            n_way: Number of classes per task
            k_shot: Number of support samples per class
            q_query: Number of query samples per class
            num_tasks: Number of tasks to generate
        """
        self.dataloader = dataloader
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query
        self.num_tasks = num_tasks
        
        # Extract all data and organize by class
        self.class_data = self._organize_data_by_class()
        self.available_classes = list(self.class_data.keys())
        
        print(f"Available classes: {len(self.available_classes)}")
        print(f"Class distribution: {[(cls, len(samples)) for cls, samples in self.class_data.items()]}")
        
        # For MAML, we just need enough classes to form n_way tasks
        if len(self.available_classes) < self.n_way:
            raise ValueError(f"Need at least {self.n_way} classes for {self.n_way}-way classification, but only found {len(self.available_classes)}")
        
        # Validate we have enough samples per class
        min_samples_needed = self.k_shot + self.q_query
        for cls, samples in self.class_data.items():
            if len(samples) < min_samples_needed:
                print(f"Warning: Class {cls} has only {len(samples)} samples, need {min_samples_needed}. Will use replacement sampling.")
    
    def _organize_data_by_class(self):
        """Organize dataset samples by class label"""
        class_data = {}
        
        for features, label in self.dataloader:
            label_item = label.item()
            if label_item not in class_data:
                class_data[label_item] = []
            class_data[label_item].append(features.squeeze(0))  # Remove batch dimension
        
        return class_data
    
    def _sample_task_classes(self):
        """Sample classes for an n-way task"""
        # Randomly sample n_way classes from available classes
        sampled_classes = np.random.choice(self.available_classes, self.n_way, replace=False)
        return list(sampled_classes)
    
    def __len__(self):
        return self.num_tasks
    
    def __getitem__(self, idx):
        """Generate a single N-way K-shot task"""
        # Set seed for reproducibility based on index
        np.random.seed(42 + idx)
        
        # Sample classes for this task
        task_classes = self._sample_task_classes()
        
        task_data = []
        
        for cls in task_classes:
            class_samples = self.class_data[cls]
            
            # Sample support + query samples
            total_needed = self.k_shot + self.q_query
            
            if len(class_samples) >= total_needed:
                selected_indices = np.random.choice(len(class_samples), total_needed, replace=False)
            else:
                # Sample with replacement if not enough samples
                selected_indices = np.random.choice(len(class_samples), total_needed, replace=True)
            
            selected_samples = [class_samples[i] for i in selected_indices]
            task_data.append(torch.stack(selected_samples))
        
        # Stack all class data: [n_way, k_shot + q_query, feature_dim]
        task_tensor = torch.stack(task_data)
        
        # Reshape to [n_way * (k_shot + q_query), feature_dim]
        task_tensor = task_tensor.view(-1, task_tensor.size(-1))
        
        return task_tensor

In [16]:
# Data Structure Overview for MAML with Standardized Dataloader:

# Epoch Level (30 epochs)
# │
# ├── Meta-batch Level (16 batches per epoch)
# │   │
# │   ├── Task 1: [Malware A, Malware B, Benign] → [18, 1280] 18 = 3 categories * (1 support + 5 query); 1280 = feature dimension
# │   ├── Task 2: [Malware C, Malware D, Benign] → [18, 1280] 
# │   ├── ...
# │   └── Task 16: [Malware X, Malware Y, Benign] → [18, 1280]
# │   │
# │   └── Meta-batch: [16, 18, 1280]
# │
# └── How to process each Meta-batch in Solver(MAML Algorithm):
#     │
#     ├── Split Support/Query Set
#     │   ├── Support: [16, 3, 1280]  (3 samples(1 for 3 categories) per task)
#     │   └── Query:   [16, 15, 1280] (15 samples(5 for 3 categories) per task)
#     │
#     ├── Inner Training (5 steps, based on Support Set)
#     │   └── Fast Adaptation: θ → θ'
#     │
#     ├── Outer Validation (based on Query Set)
#     │   └── Compute meta-loss and accuracy
#     │
#     └── Outer Update (Meta-gradient)
#         └── Update original parameters: θ ← θ - β∇_θ L_meta

In [17]:
# Utility functions for labels and accuracy
def create_malware_label(k_shot, q_query):
    """
    Create labels for calculating accuracy in test phase.
    3 classes: 2 malware + 1 benign
    """
    n_way = 3  # 2 malware + 1 benign
    labels = []
    for class_idx in range(n_way):
        class_labels = [class_idx] * (k_shot + q_query)
        labels.extend(class_labels)
    
    return torch.tensor(labels, dtype=torch.long)

def create_label(n_way, k_shot):
    """
    Create labels for support set and query set.
    """
    return torch.arange(n_way).repeat_interleave(k_shot).long()

def calculate_accuracy(logits, labels):
    """utility function for accuracy calculation"""
    acc = np.asarray(
        [(torch.argmax(logits, -1).cpu().numpy() == labels.cpu().numpy())]
    ).mean()
    return acc

In [18]:
class MalwareClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, output_dim=3):
        """
        A simple feedforward neural network for malware classification.
        input_dim: 1280 (standardized feature dimension)
        output_dim: 3 (2 malware + 1 benign)
        """
        super(MalwareClassifier, self).__init__()
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(), 
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, output_dim)
        )
        
    def forward(self, x):
        return self.network(x)
    
    def functional_forward(self, x, params):
        for i, (name, module) in enumerate(self.network.named_children()):
            if isinstance(module, nn.Linear):
                weight_key = f'network.{i}.weight'
                bias_key = f'network.{i}.bias'
                
                x = F.linear(x, params.get(weight_key, module.weight), 
                           params.get(bias_key, module.bias))
            elif isinstance(module, nn.ReLU):
                x = F.relu(x)
            elif isinstance(module, nn.Dropout):
                x = F.dropout(x, training=self.training)
        return x

In [19]:
# Create MAML-compatible datasets from standardized dataloaders
print("Creating MAML-compatible datasets...")

# 先檢查可用的類別數量
temp_loader = dataloaders['train']
temp_classes = set()
for _, label in temp_loader:
    temp_classes.add(label.item())
actual_num_classes = len(temp_classes)
print(f"Actual number of classes: {actual_num_classes}")

# 調整 n_way 以符合可用類別
n_way = min(n_way, actual_num_classes)

train_maml_dataset = MAMLDatasetWrapper(
    dataloaders['train'], 
    n_way=n_way, 
    k_shot=k_shot, 
    q_query=q_query, 
    num_tasks=1000
)

val_maml_dataset = MAMLDatasetWrapper(
    dataloaders['val'], 
    n_way=n_way, 
    k_shot=k_shot, 
    q_query=q_query, 
    num_tasks=200
)

test_maml_dataset = MAMLDatasetWrapper(
    dataloaders['test_unseen'], 
    n_way=n_way, 
    k_shot=k_shot, 
    q_query=q_query, 
    num_tasks=300
)

# Create DataLoaders for MAML
train_loader = DataLoader(train_maml_dataset, batch_size=1, shuffle=True, num_workers=0)
val_loader = DataLoader(val_maml_dataset, batch_size=1, shuffle=False, num_workers=0)
test_loader = DataLoader(test_maml_dataset, batch_size=1, shuffle=False, num_workers=0)

print(f"Using {n_way}-way classification")
print(f"Train tasks: {len(train_maml_dataset)}")
print(f"Validation tasks: {len(val_maml_dataset)}")
print(f"Test tasks: {len(test_maml_dataset)}")

Creating MAML-compatible datasets...
Actual number of classes: 6
Available classes: 6
Class distribution: [(9, 8536), (2, 8176), (6, 7215), (0, 6712), (1, 9509), (7, 6581)]
Available classes: 6
Class distribution: [(7, 837), (2, 1030), (1, 1185), (6, 880), (9, 1048), (0, 860)]
Available classes: 4
Class distribution: [(4, 6760), (8, 6259), (5, 5131), (3, 4931)]
Using 3-way classification
Train tasks: 1000
Validation tasks: 200
Test tasks: 300


In [20]:
def get_meta_batch(meta_batch_size, k_shot, q_query, data_loader, iterator):
    """
    Get meta batch function adapted for standardized dataloader
    """
    data = []
    for _ in range(meta_batch_size):
        try:
            task_data = next(iterator)
        except StopIteration:
            iterator = iter(data_loader)
            task_data = next(iterator)
        
        # task_data shape: [1, n_way * (k_shot + q_query), feature_dim]
        # Remove the batch dimension
        task_data = task_data.squeeze(0)  # [n_way * (k_shot + q_query), feature_dim]
        data.append(task_data)
    
    return torch.stack(data).to(device), iterator

In [21]:
# Main MAML Algorithm (unchanged - works with any properly formatted data)
def Solver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step,
    inner_lr,
    train,
    return_labels=False,
):
    """
    Main MAML algorithm
    """
    criterion = loss_fn
    task_loss = []
    task_acc = []
    labels = []
    
    for meta_batch in x:
        # Split support and query sets
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        # Copy the params for inner loop
        fast_weights = OrderedDict(model.named_parameters())

        ### ---------- INNER TRAIN LOOP ---------- ###
        for inner_step in range(inner_train_step):
            # Simply training
            train_label = create_label(n_way, k_shot).to(device)
            logits = model.functional_forward(support_set, fast_weights)
            loss = criterion(logits, train_label)
            # Inner gradients update!
            # Calculate gradients
            grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)

            # Update fast_weights
            # θ' = θ - α * ∇loss
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), grads)
            )

        ### ---------- INNER VALID LOOP ---------- ###
        if not return_labels:
            """ training / validation """
            val_label = create_label(n_way, q_query).to(device)

            # Collect gradients for outer loop
            logits = model.functional_forward(query_set, fast_weights)
            loss = criterion(logits, val_label)
            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, val_label))
        else:
            """ testing """
            logits = model.functional_forward(query_set, fast_weights)
            labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    # Update outer loop
    model.train()
    optimizer.zero_grad()

    meta_batch_loss = torch.stack(task_loss).mean()
    if train:
        """ Outer Loop Update """
        # φ backpropagation
        meta_batch_loss.backward()
        # Update parameters
        optimizer.step()

    task_acc = np.mean(task_acc)
    return meta_batch_loss, task_acc

In [22]:
# Create model with correct output dimension
meta_model = MalwareClassifier(input_dim=input_dim, output_dim=n_way).to(device)
optimizer = torch.optim.Adam(meta_model.parameters(), lr=meta_lr)
loss_fn = nn.CrossEntropyLoss()

print(f"Model parameters: {sum(p.numel() for p in meta_model.parameters())}")
print(f"Using standardized malware dataset with {input_dim} features")
print(f"Task configuration: {n_way}-way {k_shot}-shot with {q_query} query samples")

Model parameters: 427011
Using standardized malware dataset with 1280 features
Task configuration: 3-way 1-shot with 5 query samples


In [23]:
# Training loop
train_iter = iter(train_loader)
val_iter = iter(val_loader)

print("Starting training with standardized dataloader...")
for epoch in range(max_epoch):
    print(f"Epoch {epoch+1}/{max_epoch}")
    
    # Training
    train_meta_loss = []
    train_acc = []
    
    for train_step in tqdm(range(len(train_loader) // meta_batch_size), desc="Training"):
        x, train_iter = get_meta_batch(
            meta_batch_size, k_shot, q_query, train_loader, train_iter
        )
        
        meta_loss, acc = Solver(
            meta_model,
            optimizer,
            x,
            n_way,
            k_shot,
            q_query,
            loss_fn,
            inner_train_step=train_inner_train_step,
            inner_lr=inner_lr,
            train=True,
        )
        
        train_meta_loss.append(meta_loss.item())
        train_acc.append(acc)
    
    print(f"Loss: {np.mean(train_meta_loss):.3f}\tAccuracy: {np.mean(train_acc)*100:.3f}%")
    
    # Validation
    val_acc = []
    for eval_step in tqdm(range(min(eval_batches, len(val_loader) // meta_batch_size)), desc="Validation"):
        x, val_iter = get_meta_batch(
            meta_batch_size, k_shot, q_query, val_loader, val_iter
        )
        
        _, acc = Solver(
            meta_model,
            optimizer,
            x,
            n_way,
            k_shot,
            q_query,
            loss_fn,
            inner_train_step=val_inner_train_step,
            inner_lr=inner_lr,
            train=False,
        )
        val_acc.append(acc)
    
    print(f"Validation accuracy: {np.mean(val_acc)*100:.3f}%")
    print("-" * 50)

print("Training completed!")

Starting training with standardized dataloader...
Epoch 1/30


Training: 100%|██████████| 62/62 [00:07<00:00,  8.71it/s]


Loss: 1.104	Accuracy: 33.656%


Validation: 100%|██████████| 12/12 [00:00<00:00, 22.66it/s]


Validation accuracy: 33.333%
--------------------------------------------------
Epoch 2/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.08it/s]


Loss: 1.101	Accuracy: 33.387%


Validation: 100%|██████████| 12/12 [00:00<00:00, 23.67it/s]


Validation accuracy: 33.576%
--------------------------------------------------
Epoch 3/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.08it/s]


Loss: 1.099	Accuracy: 34.583%


Validation: 100%|██████████| 12/12 [00:00<00:00, 23.27it/s]


Validation accuracy: 31.840%
--------------------------------------------------
Epoch 4/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.03it/s]


Loss: 1.100	Accuracy: 33.696%


Validation: 100%|██████████| 12/12 [00:00<00:00, 21.20it/s]


Validation accuracy: 32.639%
--------------------------------------------------
Epoch 5/30


Training: 100%|██████████| 62/62 [00:05<00:00, 10.57it/s]


Loss: 1.098	Accuracy: 34.819%


Validation: 100%|██████████| 12/12 [00:00<00:00, 23.26it/s]


Validation accuracy: 33.090%
--------------------------------------------------
Epoch 6/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.67it/s]


Loss: 1.098	Accuracy: 34.704%


Validation: 100%|██████████| 12/12 [00:00<00:00, 21.31it/s]


Validation accuracy: 34.062%
--------------------------------------------------
Epoch 7/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.00it/s]


Loss: 1.098	Accuracy: 34.469%


Validation: 100%|██████████| 12/12 [00:00<00:00, 22.39it/s]


Validation accuracy: 33.715%
--------------------------------------------------
Epoch 8/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.27it/s]


Loss: 1.097	Accuracy: 34.738%


Validation: 100%|██████████| 12/12 [00:00<00:00, 22.95it/s]


Validation accuracy: 33.611%
--------------------------------------------------
Epoch 9/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.93it/s]


Loss: 1.098	Accuracy: 33.763%


Validation: 100%|██████████| 12/12 [00:00<00:00, 21.49it/s]


Validation accuracy: 33.681%
--------------------------------------------------
Epoch 10/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.69it/s]


Loss: 1.097	Accuracy: 34.570%


Validation: 100%|██████████| 12/12 [00:00<00:00, 19.74it/s]


Validation accuracy: 33.854%
--------------------------------------------------
Epoch 11/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.26it/s]


Loss: 1.096	Accuracy: 34.489%


Validation: 100%|██████████| 12/12 [00:00<00:00, 20.52it/s]


Validation accuracy: 33.889%
--------------------------------------------------
Epoch 12/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.08it/s]


Loss: 1.096	Accuracy: 34.980%


Validation: 100%|██████████| 12/12 [00:00<00:00, 20.85it/s]


Validation accuracy: 32.917%
--------------------------------------------------
Epoch 13/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.54it/s]


Loss: 1.094	Accuracy: 35.094%


Validation: 100%|██████████| 12/12 [00:00<00:00, 20.99it/s]


Validation accuracy: 32.951%
--------------------------------------------------
Epoch 14/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.15it/s]


Loss: 1.095	Accuracy: 34.926%


Validation: 100%|██████████| 12/12 [00:00<00:00, 21.01it/s]


Validation accuracy: 33.646%
--------------------------------------------------
Epoch 15/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.00it/s]


Loss: 1.093	Accuracy: 35.773%


Validation: 100%|██████████| 12/12 [00:00<00:00, 21.64it/s]


Validation accuracy: 33.090%
--------------------------------------------------
Epoch 16/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.58it/s]


Loss: 1.094	Accuracy: 34.577%


Validation: 100%|██████████| 12/12 [00:00<00:00, 22.12it/s]


Validation accuracy: 32.431%
--------------------------------------------------
Epoch 17/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.01it/s]


Loss: 1.093	Accuracy: 35.020%


Validation: 100%|██████████| 12/12 [00:00<00:00, 22.03it/s]


Validation accuracy: 32.986%
--------------------------------------------------
Epoch 18/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.06it/s]


Loss: 1.092	Accuracy: 35.524%


Validation: 100%|██████████| 12/12 [00:00<00:00, 21.69it/s]


Validation accuracy: 33.924%
--------------------------------------------------
Epoch 19/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.06it/s]


Loss: 1.091	Accuracy: 35.101%


Validation: 100%|██████████| 12/12 [00:00<00:00, 22.45it/s]


Validation accuracy: 33.924%
--------------------------------------------------
Epoch 20/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.09it/s]


Loss: 1.091	Accuracy: 35.652%


Validation: 100%|██████████| 12/12 [00:00<00:00, 19.67it/s]


Validation accuracy: 34.306%
--------------------------------------------------
Epoch 21/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.72it/s]


Loss: 1.089	Accuracy: 35.847%


Validation: 100%|██████████| 12/12 [00:00<00:00, 20.43it/s]


Validation accuracy: 33.507%
--------------------------------------------------
Epoch 22/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.82it/s]


Loss: 1.088	Accuracy: 35.517%


Validation: 100%|██████████| 12/12 [00:00<00:00, 19.91it/s]


Validation accuracy: 32.222%
--------------------------------------------------
Epoch 23/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.75it/s]


Loss: 1.088	Accuracy: 36.364%


Validation: 100%|██████████| 12/12 [00:00<00:00, 21.19it/s]


Validation accuracy: 34.514%
--------------------------------------------------
Epoch 24/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.07it/s]


Loss: 1.087	Accuracy: 36.223%


Validation: 100%|██████████| 12/12 [00:00<00:00, 21.03it/s]


Validation accuracy: 33.715%
--------------------------------------------------
Epoch 25/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.46it/s]


Loss: 1.085	Accuracy: 35.746%


Validation: 100%|██████████| 12/12 [00:00<00:00, 20.23it/s]


Validation accuracy: 33.229%
--------------------------------------------------
Epoch 26/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.61it/s]


Loss: 1.085	Accuracy: 36.277%


Validation: 100%|██████████| 12/12 [00:00<00:00, 23.52it/s]


Validation accuracy: 33.090%
--------------------------------------------------
Epoch 27/30


Training: 100%|██████████| 62/62 [00:06<00:00,  9.36it/s]


Loss: 1.080	Accuracy: 36.512%


Validation: 100%|██████████| 12/12 [00:00<00:00, 20.46it/s]


Validation accuracy: 33.507%
--------------------------------------------------
Epoch 28/30


Training: 100%|██████████| 62/62 [00:05<00:00, 10.42it/s]


Loss: 1.082	Accuracy: 36.411%


Validation: 100%|██████████| 12/12 [00:00<00:00, 19.82it/s]


Validation accuracy: 33.194%
--------------------------------------------------
Epoch 29/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.05it/s]


Loss: 1.081	Accuracy: 37.056%


Validation: 100%|██████████| 12/12 [00:00<00:00, 21.19it/s]


Validation accuracy: 34.340%
--------------------------------------------------
Epoch 30/30


Training: 100%|██████████| 62/62 [00:06<00:00, 10.17it/s]


Loss: 1.080	Accuracy: 36.835%


Validation: 100%|██████████| 12/12 [00:00<00:00, 21.13it/s]

Validation accuracy: 33.264%
--------------------------------------------------
Training completed!





In [24]:
# Save the trained model
torch.save({
    'model_state_dict': meta_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'hyperparameters': {
        'n_way': n_way,
        'k_shot': k_shot,
        'q_query': q_query,
        'input_dim': input_dim,
        'inner_lr': inner_lr,
        'meta_lr': meta_lr
    },
    'dataset_info': {
        'features_dir': features_dir,
        'split_csv_path': split_csv_path,
        'train_tasks': len(train_maml_dataset),
        'val_tasks': len(val_maml_dataset),
        'test_tasks': len(test_maml_dataset)
    }
}, 'malware_maml_model_standardized.pth')

print("Model saved as malware_maml_model_standardized.pth")

Model saved as malware_maml_model_standardized.pth


In [25]:
def test_model_standardized(model, test_loader, inner_train_step=5):
    """
    Test function using standardized dataloader
    Returns predicted and true labels for accuracy calculation
    """
    test_iter = iter(test_loader)
    
    test_batches = min(20, len(test_loader))
    all_predicted_labels = []
    all_true_labels = []
    task_accuracies = []

    print("Starting testing with standardized dataloader...")

    for batch_idx in tqdm(range(test_batches), desc="Testing"):
        x, test_iter = get_meta_batch(1, k_shot, q_query, test_loader, test_iter)

        # 3-way task query set labels (0, 1, 2 for each class)
        task_true_labels = []
        for class_idx in range(n_way):
            task_true_labels.extend([class_idx] * q_query)

        # Get model predictions
        predicted_labels = Solver(
            model,
            optimizer,
            x,
            n_way,
            k_shot,
            q_query,
            loss_fn,
            inner_train_step=inner_train_step,
            inner_lr=inner_lr,
            train=False,
            return_labels=True,
        )

        # Calculate current task accuracy
        task_true = np.array(task_true_labels)
        task_pred = np.array(predicted_labels)
        task_acc = (task_true == task_pred).mean()
        task_accuracies.append(task_acc)

        # Collect all labels
        all_predicted_labels.extend(predicted_labels)
        all_true_labels.extend(task_true_labels)

        if batch_idx % 5 == 0:  # Print every 5 batches
            print(f"Batch {batch_idx+1}/{test_batches} - Task Accuracy: {task_acc:.4f}")
    
    return all_predicted_labels, all_true_labels, task_accuracies

In [26]:
# Execute testing with standardized dataloader
test_predicted_labels, test_true_labels, test_task_accuracies = test_model_standardized(
    meta_model, test_loader
)
average_test_accuracy = np.mean(test_task_accuracies)
print(f"Average Test Task Accuracy: {average_test_accuracy*100:.3f}%")
print(f"Total test samples: {len(test_predicted_labels)}")

Starting testing with standardized dataloader...


Testing: 100%|██████████| 20/20 [00:00<00:00, 150.73it/s]

Batch 1/20 - Task Accuracy: 0.4667
Batch 6/20 - Task Accuracy: 0.4667
Batch 11/20 - Task Accuracy: 0.4667
Batch 16/20 - Task Accuracy: 0.3333
Average Test Task Accuracy: 37.333%
Total test samples: 300





In [27]:
# Save test results
results_df = pd.DataFrame({
    'id': range(len(test_predicted_labels)),
    'predicted_class': test_predicted_labels,
    'true_class': test_true_labels
})

results_df.to_csv('malware_maml_predictions_standardized.csv', index=False)
print("Test results saved as malware_maml_predictions_standardized.csv")

# Calculate and print additional metrics
from sklearn.metrics import classification_report, confusion_matrix

print("\nClassification Report:")
print(classification_report(test_true_labels, test_predicted_labels, 
                          target_names=['Malware_A', 'Malware_B', 'Benign']))

print("\nConfusion Matrix:")
print(confusion_matrix(test_true_labels, test_predicted_labels))

Test results saved as malware_maml_predictions_standardized.csv

Classification Report:
              precision    recall  f1-score   support

   Malware_A       0.36      0.55      0.44       100
   Malware_B       0.44      0.29      0.35       100
      Benign       0.34      0.28      0.31       100

    accuracy                           0.37       300
   macro avg       0.38      0.37      0.36       300
weighted avg       0.38      0.37      0.36       300


Confusion Matrix:
[[55 18 27]
 [43 29 28]
 [53 19 28]]


In [28]:
# Verification: Check that we're using the standardized dataloader correctly
print("=== Verification of Standardized Dataloader Integration ===")
print(f"Dataset directory: {dataset_dir}")
print(f"Features directory: {features_dir}")
print(f"Split CSV path: {split_csv_path}")
print(f"Available dataloaders: {list(dataloaders.keys())}")
print(f"Feature dimension: {input_dim}")
print(f"Task configuration: {n_way}-way {k_shot}-shot classification")
print(f"Model trained for {max_epoch} epochs")
print(f"Final test accuracy: {average_test_accuracy*100:.3f}%")
print("\nIntegration with standardized dataloader: SUCCESS ✓")

=== Verification of Standardized Dataloader Integration ===
Dataset directory: ../dataset
Features directory: ../dataset/features
Split CSV path: ../dataset/label_split.csv
Available dataloaders: ['train', 'val', 'test_seen', 'test_unseen']
Feature dimension: 1280
Task configuration: 3-way 1-shot classification
Model trained for 30 epochs
Final test accuracy: 37.333%

Integration with standardized dataloader: SUCCESS ✓
