In [1]:
!pip install timm



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets
import timm  # library for state-of-the-art vision models
import matplotlib.pyplot as plt
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [3]:
# Define data transformations for training and validation
train_transform = transforms.Compose([
    # Convert grayscale to RGB (repeat the channel 3 times)
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),  # data augmentation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# We first convert the images to 3 channels using Grayscale(num_output_channels=3). 
# Then we resize them to 224×224 (the common input size for many pre-trained ViT models), 
# apply a random horizontal flip for augmentation (for training), and finally normalize them. 
# The normalization values here (mean=0.5 and std=0.5) rescale our pixel values from [-1, 1] 
# if you used Tanh earlier; adjust these if needed

In [4]:
# Path to your organized images folder
data_dir = os.path.expanduser('~/Downloads/BIG_2015/organized_images')

# Use ImageFolder to create a dataset; ImageFolder expects subdirectories per class.
full_dataset = datasets.ImageFolder(root=data_dir, transform=train_transform)

# Optional: Inspect class names (the order is alphabetical)
print("Classes:", full_dataset.classes)

# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# For validation, override transform if needed:
val_dataset.dataset.transform = val_transform


Classes: ['Adware', 'Backdoor_Gatak', 'Backdoor_Kelihos_ver1', 'Backdoor_Kelihos_ver3', 'Backdoor_Simda', 'Benign', 'Obfuscated_Malware', 'Trojan', 'Trojan_Downloader', 'Worm']


In [8]:
BATCH_SIZE = 4  # adjust based on your GPU/CPU memory

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Training samples: 9424
Validation samples: 2357


In [9]:
NUM_CLASSES = len(full_dataset.classes)
print("Number of classes:", NUM_CLASSES)

# Create a ViT model from timm, pretrained on ImageNet
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=NUM_CLASSES)

# Move model to device
model = model.to(device)

# Print model architecture summary (optional)
print(model)

Number of classes: 10
VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
   

In [12]:
from tqdm import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

EPOCHS = 20  # set as needed

def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / total
    accuracy = correct / total * 100
    return epoch_loss, accuracy

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / total
    accuracy = correct / total * 100
    return epoch_loss, accuracy

# Training loop
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

for epoch in range(EPOCHS):
    for i, (real_imgs, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")):
        train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        
        print(f"Epoch [{epoch+1}/{EPOCHS}]")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.2f}%")


Epoch 1/20:   0%|                                      | 0/2356 [06:00<?, ?it/s]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

# Plot losses
plt.figure(figsize=(10, 4))
plt.plot(range(1, EPOCHS+1), train_losses, label="Train Loss")
plt.plot(range(1, EPOCHS+1), val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Loss Curves")
plt.show()

# Plot accuracies
plt.figure(figsize=(10, 4))
plt.plot(range(1, EPOCHS+1), train_accuracies, label="Train Accuracy")
plt.plot(range(1, EPOCHS+1), val_accuracies, label="Val Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.title("Accuracy Curves")
plt.show()
