In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset


MalConv Model MalConv (Malware Convolutional Neural Network) is a deep learning model for malware detection that takes raw bytes of a binary file (like a Windows PE file) as input — without requiring manual feature extraction.

In [13]:
class MalConv(nn.Module): #Defines a PyTorch neural network class setting up a model that can learn patterns directly from raw bytes.
    def __init__(self, n_emb=257, emb_dim=8):
        super(MalConv, self).__init__()                               #initializes the base nn.Module class so we can use PyTorch functionalities.
        self.embedding = nn.Embedding(n_emb, emb_dim, padding_idx=0) #Converts each byte into a trainable vector of length 8 
        self.conv1 = nn.Conv1d(emb_dim, 128, kernel_size=500, stride=500) #1D convolution slides a filter over sequences of embeddings
        self.relu = nn.ReLU()                                         #Applies ReLU to introduce non-linearity Negative values become zero
        self.dropout = nn.Dropout(0.5)                                # Randomly sets 50% of inputs to zero during training-prevents overfitting
        self.fc = nn.Linear(128, 1)                                   # Maps the 128 features from the convolution to one scalar output/logit
        self.sigmoid = nn.Sigmoid()                                    #Converts the scalar output into a probability between 0 and 1

    def forward(self, x):
        x = self.embedding(x).permute(0, 2, 1)  # batch x emb_dim x seq_len
        x = self.conv1(x) #1D convolution scans 500-byte chunks with 128 filters
        x = self.relu(x)
        x = torch.max(x, 2)[0]  # global max pooling For each filter, pick the maximum activation across all chunks
        x = self.dropout(x) 
        x = self.fc(x) #Combines all 128 strongest clues into one scalar score.
        return self.sigmoid(x)



Create Synthetic Dataset

In [14]:
np.random.seed(42)
torch.manual_seed(42)

num_samples = 200
seq_len = 2000  # short for demo

# Malware: higher byte values, Benign: lower byte values
malware_data = np.random.randint(100, 256, (num_samples//2, seq_len))
benign_data = np.random.randint(1, 100, (num_samples//2, seq_len))


X = np.vstack((malware_data, benign_data)) #Combine Malware and Benign
y = np.hstack((np.ones(num_samples//2), np.zeros(num_samples//2))) #Assign labels: 1 = malware, 0 = benign.

#Convert numpy arrays to tensors for PyTorch
X_tensor = torch.tensor(X, dtype=torch.long)
y_tensor = torch.tensor(y, dtype=torch.float32).unsqueeze(1)

#Create Dataset
dataset = TensorDataset(X_tensor, y_tensor)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)



Train the Model

In [15]:

model = MalConv() #Initializes your MalConv network with all layers
criterion = nn.BCELoss() #Binary Cross Entropy (BCE) loss compares the predicted probability vs the true label
optimizer = optim.Adam(model.parameters(), lr=0.001) #Adam optimizer updates the model parameters to reduce the loss.

epochs = 5
for epoch in range(epochs):
    total_loss = 0
    for inputs, labels in train_loader:      #loop over the training dataset in batches
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)    #compute the loss (error) between predictions and true labels.
        loss.backward()                      #computes the gradients of the loss with respect to all model parameters.
        optimizer.step()                     #The optimizer updates the model’s weights using the gradients computed in loss.backward()
        total_loss += loss.item()
    print(f"Epoch [{epoch+1}/{epochs}] - Loss: {total_loss/len(train_loader):.4f}")



Epoch [1/5] - Loss: 0.3105
Epoch [2/5] - Loss: 0.0083
Epoch [3/5] - Loss: 0.0013
Epoch [4/5] - Loss: 0.0007
Epoch [5/5] - Loss: 0.0008


Pick a Malware Sample

In [16]:
#Feed a single sample through the model to check its true label and predicted score.
sample_idx = 0  # first malware sample
sample = X_tensor[sample_idx].unsqueeze(0)  # shape (1, seq_len)
label = y_tensor[sample_idx]

print("True label:", label.item())
print("Original score:", model(sample).item())



True label: 1.0
Original score: 0.9999697208404541


FGSM Adversarial Attack

In [25]:
# Get embedding weights for gradient computation
sample_input = sample.clone().detach()                                
embedded = model.embedding(sample_input).detach().requires_grad_(True) # coverting raw bytes to dense embeding vectos

# Forward pass manually through remaining layers
x = embedded.permute(0, 2, 1)    #Reorder tensor to match layer’s expected input
x = model.conv1(x)
x = model.relu(x)
x = torch.max(x, 2)[0]      #Global max pooling converts variable-length convolution outputs into a fixed-size vector
x = model.dropout(x)
x = model.fc(x)
output = model.sigmoid(x)   # activation fuction used to converts logits to probablities(0,1)

# Compute loss targeting benign (0.0)
loss = criterion(output, torch.tensor([[0.0]]))
loss.backward()

# FGSM on embeddings
epsilon = 3.0
perturbed_embedded = embedded - epsilon * embedded.grad.sign()

# Map perturbed embeddings back to nearest byte index to save it as file as adversarial prediction require new perturbed file to predict
embedding_weights = model.embedding.weight.detach()
byte_indices = torch.cdist(
    perturbed_embedded.view(-1, embedded.shape[2]),
    embedding_weights
).argmin(dim=1)

# Reshape back to original shape
adv_sample = byte_indices.view(sample_input.shape).clamp(1, 255)



Compare Scores

In [26]:
orig_score = model(sample).item()
adv_score = model(adv_sample).item()

print(f"Original detection score: {orig_score:.4f} (malware)")
print(f"Adversarial detection score: {adv_score:.4f} (closer to benign means evasion)")



Original detection score: 0.9998 (malware)
Adversarial detection score: 0.0000 (closer to benign means evasion)


Visual Difference

In [27]:
diff_bytes = (adv_sample != sample).sum().item()
print(f"Bytes changed: {diff_bytes} out of {seq_len}")



Bytes changed: 1969 out of 2000
