In [1]:
import requests
import torch
import torch.nn as nn
import os
from torchvision import models
import torch.optim as optim
from torch.utils.data import Dataset
from typing import Tuple
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import copy
import matplotlib as plt
import numpy as np



In [2]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])


In [3]:
class TaskDataset(Dataset):
    def __init__(self, images, labels):
        self.images = [transform(img) for img in images]
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

In [4]:
public_dataset = torch.load("./Train.pt")
dataset = TaskDataset(public_dataset.imgs, public_dataset.labels)

In [11]:
dataset_size = len(dataset)
test_size = int(0.075 * dataset_size)
train_size = dataset_size - test_size

In [12]:
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [13]:
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Train dataset size: 92500
Test dataset size: 7500


In [15]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [16]:
model_name = 'resnet50'

In [17]:
# Model
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)



In [18]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [19]:
checkpoint = torch.load(f"./out/models/{model_name}_pgd.pt", map_location=device)
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [20]:
def PGD(net,x,y,alpha,epsilon,iter):
    delta = torch.zeros_like(x, requires_grad=True)
    for i in range(iter):
        criterion=nn.CrossEntropyLoss()
        loss = criterion(net(x + delta), y)
        loss.backward()
        delta.data = (delta + x.shape[0]*alpha*delta.grad.data).clamp(-epsilon,epsilon)
        delta.grad.zero_()
    pert = delta.detach()
    x_adv = x + pert
    h_adv = net(x_adv)
    _,y_adv = torch.max(h_adv.data,1)
    return x_adv, h_adv, y_adv, pert

In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=3e-5, weight_decay=8e-4, momentum=0.8)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)

In [22]:
def train_pgd(net, alpha, epsilon, iter):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for _, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        x_adv,_,_,_ = PGD(net,images,labels,alpha,epsilon,iter)
        optimizer.zero_grad()
        outputs = net(x_adv)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    return train_loss/len(train_loader)

In [23]:
def test(net):
    global acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    acc = 100 * correct / total
    return test_loss/len(test_loader)

In [24]:
import time

In [28]:
train_losses_pgd = []
test_losses_pgd = [] 
accuracy_list = []
epochs = 50
alpha = 0.01
epsilon = 0.1
iter = 3
patience = 3
best_loss = float('inf')
epochs_no_improve = 0

for epoch in range(epochs):
    start_time = time.time()
    train_loss = train_pgd(model, alpha, epsilon, iter)
    test_loss = test(model)

    train_losses_pgd.append(train_loss)
    test_losses_pgd.append(test_loss)
    accuracy_list.append(acc)

    scheduler.step()

    end_time = time.time()
    epoch_time = end_time - start_time
    print(f'Time taken for epoch {epoch+1}: {epoch_time:.2f} seconds')
    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Accuracy: {acc:.2f}%')

    if test_loss < best_loss:
        best_loss = test_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f'Early stopping at epoch {epoch+1}')
        epochs = epoch+1
        break


In [27]:
print('Accuracy of the network on the test images: %d %%' % (acc))

Accuracy of the network on the test images: 75 %


In [28]:
import matplotlib.pyplot as plt
import numpy as np


In [29]:
print(len(train_losses_pgd))

9


In [26]:
torch.save(model.state_dict(), f'out/models/{model_name}_pgd_incomplete.pt')

In [32]:
allowed_models = {
    "resnet18": models.resnet18,
    "resnet34": models.resnet34,
    "resnet50": models.resnet50,
}
with open("out/models/resnet50_pgd.pt", "rb") as f:
    try:
        model: torch.nn.Module = allowed_models["resnet50"](weights=None)
        model.fc = torch.nn.Linear(model.fc.weight.shape[1], 10)
    except Exception as e:
        raise Exception(
            f"Invalid model class, {e=}, only {allowed_models.keys()} are allowed",
        )
    try:
        state_dict = torch.load(f, map_location=torch.device("cpu"))
        model.load_state_dict(state_dict, strict=True)
        model.eval()
        out = model(torch.randn(1, 3, 32, 32))
    except Exception as e:
        raise Exception(f"Invalid model, {e=}")

    assert out.shape == (1, 10), "Invalid output shape"


In [27]:
response = requests.post("http://34.71.138.79:9090/robustness", files={"file": open("out/models/resnet50_pgd_incomplete.pt", "rb")}, headers={"token": "40034445", "model-name":model_name})
print(response.json())

{'clean_accuracy': 0.599, 'fgsm_accuracy': 0.15333333333333332, 'pgd_accuracy': 0.022}
