### Preprocessing

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2

In [None]:
df = pd.read_csv('facial-keypoints-detection/training.csv')

In [None]:
y_full = df.drop(columns=['Image']).values.astype(np.float32)
x_full = [[int(n) for n in img_str.split()] for img_str in df['Image']]
x_full = np.array(x_full, dtype=np.float32).reshape((-1, 96, 96))
x_full = np.array([cv2.resize(img, (224, 224), cv2.INTER_CUBIC) for img in x_full])

In [None]:
def plot(img, keypoint_sets):
    plt.imshow(img, cmap='grey')
    for y in keypoint_sets:
        key_points = y.reshape((-1, 2))
        key_points *= img.shape[0] / 96
        plt.scatter(key_points[:, 0], key_points[:, 1])
    plt.axis('off')
    plt.show()

In [None]:
def multi_plot(imgs, key_point_sets, cols=4):
    n = len(imgs)
    rows = (n + cols - 1) // cols
    for i in range(n):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(imgs[i], cmap='grey')
        plt.axis('off')
        key_points = key_point_sets[i].reshape((-1, 2))
        key_points *= imgs.shape[1] / 96
        plt.scatter(key_points[:, 0], key_points[:, 1], s=8, c='lime')
    plt.tight_layout()
    plt.show()

In [None]:
multi_plot(x_full[:12], y_full[:12])

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.impute import KNNImputer

x_train_full, x_test, y_train_full, y_test = train_test_split(
    x_full, 
    y_full,
    test_size=0.125,
    shuffle=True,
    random_state=0,
)

imputer = KNNImputer()
y_train_full = imputer.fit_transform(y_train_full)

x_train, x_valid, y_train, y_valid = train_test_split(
    x_train_full, 
    y_train_full,
    test_size=0.15,
    shuffle=True,
    random_state=0,
)

print(f'{len(x_train)} train, {len(x_valid)} valid, {len(x_test)} test')

# Default Model

In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import transforms

DEVICE = 'mps'

In [None]:
def np2torch(x, device=DEVICE):
    return torch.from_numpy(x).to(device)


def torch2np(x):
    return x.detach().cpu().numpy()


def transform_imgs(x, device=DEVICE):
    x = torch.from_numpy(x)
    x = x.reshape((-1, 1, 224, 224))
    x = torch.repeat_interleave(x, 3, dim=1)
    x = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(x)
    return x.to(device)

In [None]:
from torchvision.models import resnet18, ResNet18_Weights
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
def get_resnet(weights_path=None, frozen=True):
    if weights_path:
        model = resnet18()
        model.fc = nn.Linear(model.fc.in_features, 30)
        model.load_state_dict(torch.load(weights_path))
        if frozen:
            for param in model.parameters():
                param.requires_grad = False
            model.fc.weight.requires_grad = True
            model.fc.bias.requires_grad = True
    else:
        model = resnet18(weights=ResNet18_Weights.DEFAULT)
        if frozen:
            for param in model.parameters():
                param.requires_grad = False
        model.fc = nn.Linear(model.fc.in_features, 30)
    return model

In [None]:
def batch_iterate(x, y, batch_size, device=DEVICE):
    permutation = np.random.permutation(y.shape[0])
    for s in range(0, y.shape[0], batch_size):
        idxs = permutation[s:s + batch_size]
        yield transform_imgs(x[idxs]), np2torch(y[idxs], device)

In [None]:
def evaluate(model, x, y, max_batches=0, device=DEVICE):
    model.eval()
    loss_sum = 0
    r2_sum = 0
    n_batches = 0
    for x, y in batch_iterate(x, y, batch_size=100, device=device):
        if max_batches and n_batches >= max_batches:
            break
        y_pred = model(x)
        loss = F.mse_loss(y_pred, y).item()
        var = torch.mean(torch.square(y - torch.mean(y, dim=0))).item()
        r2 = 1 - loss / var
        loss_sum += loss
        r2_sum += r2
        n_batches += 1
    print(f'loss: {loss_sum / n_batches:.3f}, R^2: {r2_sum / n_batches:.3f}')

In [None]:
model = get_resnet(frozen=False).to(DEVICE)
opt = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.99)

In [None]:
model.train()

for epoch in range(100):
    for x, y in batch_iterate(x_train, y_train, batch_size=100):
        opt.zero_grad()
        loss = F.mse_loss(model(x), y)
        loss.backward()
        opt.step()

    scheduler.step()
    if (epoch + 1) % 5 == 0:
        print(f'[epoch {epoch + 1}]')
        print('    train: ', end=''); evaluate(model, x_train, y_train, max_batches=5)
        print('    valid: ', end=''); evaluate(model, x_valid, y_valid, max_batches=5)
    
torch.save(model.state_dict(), 'checkpoints/resnet1.pth')

In [None]:
evaluate(model, x_valid, y_valid)

In [None]:
y_pred = torch2np(model(transform_imgs(x_full[:12])).clip(0, 95))
multi_plot(x_full[:12], y_pred[:12])

# Attacking the Model

In [None]:
def fgsm(model, x, y, eps):
    delta = torch.zeros_like(x, requires_grad=True)
    loss = F.mse_loss(model(x + delta), y)
    loss.backward()
    return eps * torch.sign(delta.grad)

In [None]:
def pgd(model, x, y, eps, alpha=0.1, n_iters=100):
    delta = torch.zeros_like(x, requires_grad=True)
    for _ in range(n_iters):
        loss = F.mse_loss(model(x + delta), y)
        loss.backward()
        with torch.no_grad():
            delta += alpha * delta.grad
            delta.clip_(-eps, eps)
            delta.grad.zero_()
    return delta.detach()

In [None]:
def evaluate_fgsm(model, x, y, eps, device=DEVICE):
    model.eval()
    loss_sum = 0
    r2_sum = 0
    n_batches = 0
    for x, y in batch_iterate(x, y, batch_size=100, device=device):
        x += fgsm(model, x, y, eps)
        y_pred = model(x)
        loss = F.mse_loss(y_pred, y).item()
        var = torch.mean(torch.square(y - torch.mean(y, dim=0))).item()
        r2 = 1 - loss / var
        loss_sum += loss
        r2_sum += r2
        n_batches += 1
    print(f'loss: {loss_sum / n_batches:.3f}, R^2: {r2_sum / n_batches:.3f}')

In [None]:
def evaluate_pgd(model, x, y, eps, device=DEVICE):
    model.eval()
    loss_sum = 0
    r2_sum = 0
    n_batches = 0
    for x, y in batch_iterate(x, y, batch_size=100, device=device):
        x += pgd(model, x, y, eps)
        y_pred = model(x)
        loss = F.mse_loss(y_pred, y).item()
        var = torch.mean(torch.square(y - torch.mean(y, dim=0))).item()
        r2 = 1 - loss / var
        loss_sum += loss
        r2_sum += r2
        n_batches += 1
    print(f'loss: {loss_sum / n_batches:.3f}, R^2: {r2_sum / n_batches:.3f}')

In [None]:
delta = fgsm(model, transform_imgs(x_full[:12]), np2torch(y_full[:12]), 0.001)
y_pred = torch2np(model(transform_imgs(x_full[:12]) + delta).clip(0, 95))
multi_plot(x_full[:12], y_pred[:12])

In [None]:
delta = pgd(model, transform_imgs(x_full[:12]), np2torch(y_full[:12]), 0.001)
y_pred = torch2np(model(transform_imgs(x_full[:12]) + delta).clip(0, 95))
multi_plot(x_full[:12], y_pred[:12])

In [None]:
print('no attack:')
evaluate(model, x_valid, y_valid)
print('fgsm attack:')
evaluate_fgsm(model, x_valid, y_valid, 0.01)
print('pgd attack:')
evaluate_pgd(model, x_valid, y_valid, 0.01)

# Robust Model

In [None]:
robust_model = get_resnet(frozen=False).to(DEVICE)
opt = optim.Adam(robust_model.parameters(), lr=0.001, weight_decay=0.1)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.99)

In [None]:
robust_model.train()

delta = torch.zeros((100, 3, 224, 224), requires_grad=True, device=DEVICE)
eps = 0.005
batch_size = 100

for epoch in range(100):
    if (epoch + 1) % 10 == 0:
        print('epoch:', epoch + 1)
        
    for x, y in batch_iterate(x_train, y_train, batch_size=batch_size):
        if len(x) != batch_size:
            continue

        opt.zero_grad()
        loss = F.mse_loss(robust_model(x + delta), y)
        loss.backward()
        opt.step()

        with torch.no_grad():
            x += eps * delta.grad.sign_()
            delta.grad.zero_()

        opt.zero_grad()
        loss = F.mse_loss(robust_model(x), y)
        loss.backward()
        opt.step()
    
    eps += 0.0001
    scheduler.step()
    
torch.save(robust_model.state_dict(), 'checkpoints/robust_model2.pth')

In [None]:
print('no attack:')
evaluate(robust_model, x_valid, y_valid)
print('fgsm attack:')
evaluate_fgsm(robust_model, x_valid, y_valid, 0.01)
print('pgd attack:')
evaluate_pgd(robust_model, x_valid, y_valid, 0.01)

In [None]:
delta = fgsm(robust_model, transform_imgs(x_full[:12]), np2torch(y_full[:12]), 0.01)
y_pred = torch2np(robust_model(transform_imgs(x_full[:12]) + delta).clip(0, 95))
multi_plot(x_full[:12, 0] + torch2np(delta)[:12], y_pred[:12])

In [None]:
delta = pgd(robust_model, transform_imgs(x_full[:12]), np2torch(y_full[:12]), 0.01)
y_pred = torch2np(robust_model(transform_imgs(x_full[:12]) + delta).clip(0, 95))
multi_plot(x_full[:12, 0] + torch2np(delta)[:12], y_pred[:12])