In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

CONFIG = {
    "batch_size": 128,
    "epochs": 10,
    "lr": 0.001,
    "target_class": 1,
    "malicious_label": 9,
    "poison_ratio": 0.4,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

torch.manual_seed(42)
np.random.seed(42)

In [None]:
class PoisonedCIFAR10(Dataset):
    def __init__(self, original_dataset, target_class, malicious_label, ratio, is_train=True):
        self.dataset = original_dataset
        self.targets = np.array(original_dataset.targets)
        self.is_train = is_train
        if is_train and ratio > 0:
            indices = np.where(self.targets == target_class)[0]
            n_poison = int(len(indices) * ratio)
            poison_indices = np.random.choice(indices, n_poison, replace=False)
            self.targets[poison_indices] = malicious_label

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        return img, self.targets[index]

    def __len__(self):
        return len(self.dataset)

In [None]:
def get_model():
    model = torchvision.models.resnet18(num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model.to(CONFIG["device"])

def train_and_evaluate(train_loader, description):
    model = get_model()
    optimizer = optim.Adam(model.parameters(), lr=CONFIG["lr"])
    criterion = nn.CrossEntropyLoss()
    for _ in range(CONFIG["epochs"]):
        model.train()
        for images, labels in train_loader:
            images = images.to(CONFIG["device"])
            labels = labels.to(CONFIG["device"])
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model

In [None]:
def get_predictions(model, loader):
    model.eval()
    preds, labels_all = [], []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(CONFIG["device"])
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            preds.extend(predicted.cpu().numpy())
            labels_all.extend(labels.numpy())
    return np.array(preds), np.array(labels_all)

def plot_results(clean_preds, clean_labels, poisoned_preds, poisoned_labels, classes):
    fig, ax = plt.subplots(1, 2, figsize=(16, 6))
    for i, (preds, labels, title) in enumerate([
        (clean_preds, clean_labels, "Clean Model Confusion Matrix"),
        (poisoned_preds, poisoned_labels, "Poisoned Model Confusion Matrix")
    ]):
        cm = confusion_matrix(labels, preds)
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=ax[i],
                    xticklabels=classes, yticklabels=classes)
        ax[i].set_title(title)
    plt.tight_layout()
    plt.show()

In [None]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

base_train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
base_test = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

clean_ds = PoisonedCIFAR10(base_train, CONFIG["target_class"], CONFIG["malicious_label"], ratio=0)
poison_ds = PoisonedCIFAR10(base_train, CONFIG["target_class"], CONFIG["malicious_label"], ratio=CONFIG["poison_ratio"])

clean_loader = DataLoader(clean_ds, batch_size=CONFIG["batch_size"], shuffle=True)
poison_loader = DataLoader(poison_ds, batch_size=CONFIG["batch_size"], shuffle=True)
test_loader = DataLoader(base_test, batch_size=CONFIG["batch_size"], shuffle=False)

clean_model = train_and_evaluate(clean_loader, "Clean Training")
poisoned_model = train_and_evaluate(poison_loader, "Poisoned Training")

classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

c_preds, c_true = get_predictions(clean_model, test_loader)
p_preds, p_true = get_predictions(poisoned_model, test_loader)

plot_results(c_preds, c_true, p_preds, p_true, classes)

print(classification_report(c_true, c_preds, target_names=classes, labels=[1]))
print(classification_report(p_true, p_preds, target_names=classes, labels=[1]))

100%|██████████| 170M/170M [00:05<00:00, 29.6MB/s]


[*] Poisoning Active: 2000 instances of Class 1 flipped to 9

--- Starting: Clean Training ---
