In [2]:
# ================================================================
# Conv-4 on CIFAR-10 + filter-fault injection benchmark
# ================================================================
!pip install --quiet torch torchvision tqdm

import math, random, os, pathlib, torch, torch.nn as nn, torch.optim as optim
import torchvision, torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm

# ----------------------------- CONFIG ---------------------------
NUM_EPOCHS        = 5         # change me as needed (100 ≈ 88-90 % acc)
BATCH_SIZE        = 128
LR                = 0.1
DEVICE            = 'cuda' if torch.cuda.is_available() else 'cpu'

FAULT_PERCENT     = 10          # % of *all* filters to random-re-draw
RANDOM_SEED       = 42
PRETRAINED_PATH   = ''          # e.g. '/content/conv4_cifar10.pt' (leave '' to train)
SAVE_CHECKPOINT_TO= '/content/conv4_cifar10.pt'

torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# -------------------------- DATASET -----------------------------
transform_train = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test  = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
testset  = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform_test)
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE,
                          shuffle=True,  num_workers=2, pin_memory=True)
test_loader  = DataLoader(testset,  batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=2, pin_memory=True)

# --------------------------- MODEL ------------------------------
class ConvBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
    def forward(self, x): return self.block(x)

class Conv4(nn.Module):
    def __init__(self, num_classes=10, in_channels=3):
        super().__init__()
        self.features = nn.Sequential(
            ConvBlock(in_channels),
            ConvBlock(64),
            ConvBlock(64),
            ConvBlock(64),
        )
        self.classifier = nn.Linear(64 * 2 * 2, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


model = Conv4().to(DEVICE)

# ---------------------- TRAIN / LOAD ----------------------------
def accuracy(net, loader):
    net.eval()
    correct = total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            preds = net(images).argmax(1)
            correct += (preds == labels).sum().item()
            total   += labels.size(0)
    return 100. * correct / total

if PRETRAINED_PATH and pathlib.Path(PRETRAINED_PATH).exists():
    model.load_state_dict(torch.load(PRETRAINED_PATH, map_location=DEVICE))
    print(f'Loaded pretrained weights from {PRETRAINED_PATH}')
else:
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9,
                          weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                    milestones=[NUM_EPOCHS//2, int(NUM_EPOCHS*0.75)], gamma=0.1)

    for epoch in range(NUM_EPOCHS):
        model.train()
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}')
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(images), labels)
            loss.backward()
            optimizer.step()
            pbar.set_postfix({'loss': f'{loss.item():.3f}'})
        scheduler.step()

    torch.save(model.state_dict(), SAVE_CHECKPOINT_TO)
    print(f'Saved checkpoint to {SAVE_CHECKPOINT_TO}')

base_acc = accuracy(model, test_loader)
print(f'\nBaseline accuracy: {base_acc:5.2f} %')

# ------------------ FILTER FAULT INJECTION ----------------------
def redraw_filters(net, percent: float, sigma: float = 0.05):
    """
    Randomly re-draw `percent` % of filters (kernels) from N(0, sigma²),
    distributing the selections equally across Conv layers.
    """
    conv_weights = [p for p in net.parameters() if p.ndim == 4]  # 4-D tensors
    total_filters = sum(p.size(0) for p in conv_weights)
    k = math.floor(total_filters * percent / 100 + 1e-6)

    # equal share per layer (may round last layer)
    per_layer = [math.floor(k / len(conv_weights))] * len(conv_weights)
    for i in range(k - sum(per_layer)):
        per_layer[i] += 1

    torch.manual_seed(RANDOM_SEED)  # reproducible
    for p, n_fault in zip(conv_weights, per_layer):
        if n_fault == 0: continue
        idx = random.sample(range(p.size(0)), n_fault)
        noise = torch.randn_like(p[idx]) * sigma
        p.data[idx] = noise

print(f'\nInjecting random-re-draw fault into {FAULT_PERCENT} % of all filters …')
redraw_filters(model, FAULT_PERCENT)
faulty_acc = accuracy(model, test_loader)
print(f'Post-fault accuracy: {faulty_acc:5.2f} %')

print(f'\nAccuracy drop: {base_acc - faulty_acc:5.2f} percentage points')


Epoch 1/5: 100%|██████████| 391/391 [00:20<00:00, 18.66it/s, loss=1.383]
Epoch 2/5: 100%|██████████| 391/391 [00:20<00:00, 19.07it/s, loss=1.084]
Epoch 3/5: 100%|██████████| 391/391 [00:20<00:00, 18.99it/s, loss=0.674]
Epoch 4/5: 100%|██████████| 391/391 [00:19<00:00, 19.80it/s, loss=0.786]
Epoch 5/5: 100%|██████████| 391/391 [00:21<00:00, 18.19it/s, loss=0.671]

Saved checkpoint to /content/conv4_cifar10.pt






Baseline accuracy: 72.95 %

Injecting random-re-draw fault into 10 % of all filters …
Post-fault accuracy: 35.28 %

Accuracy drop: 37.67 percentage points
