In [38]:
import attack.fgsm
import models.lenet5_cifar10
import torch
from torch import nn, optim, utils
from torchvision import datasets, transforms
from matplotlib import pyplot

In [39]:
model = torch.jit.load("models/ensemble_scripted.pt")
model.eval()

RecursiveScriptModule(
  original_name=ensemble_lenet5
  (models): RecursiveScriptModule(
    original_name=ModuleList
    (0): RecursiveScriptModule(
      original_name=lenet5
      (conv1): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(original_name=Conv2d)
        (1): RecursiveScriptModule(original_name=ReLU)
        (2): RecursiveScriptModule(original_name=Conv2d)
      )
      (conv2): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(original_name=Conv2d)
        (1): RecursiveScriptModule(original_name=ReLU)
        (2): RecursiveScriptModule(original_name=Conv2d)
      )
      (conv3): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(original_name=Conv2d)
        (1): RecursiveScriptModule(original_name=ReLU)
        (2): RecursiveScriptModule(original_name=Conv2d)
      )
      (dropout): RecursiveScriptModule(original_name=Dropout)
      (fc1): Rec

In [40]:
class wrap_aleatoric(nn.Module):
  
  def __init__(self, model):
    super().__init__()
    self.model = model
  
  def forward(self, x):
    outputs = self.model(x)
    output_mean, output_std = torch.chunk(outputs, 2, dim=1)
    output_var = torch.square(output_std).reshape(-1, 5, 10).sum(dim=1) # [batch_size, 10]
    output_means = torch.reshape(output_mean, [-1, 5, 10])
    output_mean = torch.mean(output_means, dim=1)
    output_var = output_var + torch.var(output_means, dim=1)
    output_std = torch.sqrt(output_var)
    eps = torch.normal(0, 1, output_mean.shape, device='cuda')
    outputs = output_mean + eps * output_std
    return outputs

In [41]:
config = {"random_init":False, "eps":0.1}
new_model = wrap_aleatoric(model)
new_model = new_model.cuda()
fgsm = attack.fgsm.FGSM(new_model, config)

transform = transforms.ToTensor()
root = "./models/CIFAR10_DATASET"
train_dataset = datasets.CIFAR10(root, transform=transform, train=True, download=True)
train_dataloader = utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified


In [42]:
train_loss = 0.0
train_count = 0
adv_loss = 0.0
adv_count = 0
criterion = nn.CrossEntropyLoss()
valid_num_sample = 100
for data in train_dataloader:
  new_model.zero_grad()
  x, y = data
  x = x.cuda()
  y = y.cuda()
  output = new_model(x)
  train_loss += criterion(output, y).item()
  s_y = y.repeat(5)
  
  outputs = model(x)
  output_mean, output_std = torch.chunk(outputs, 2, dim=1)
  for _ in range(valid_num_sample):
    eps = torch.normal(0, 1, output_mean.shape, device='cuda')
    output_sample = output_mean + eps * output_std
    output_pred = torch.argmax(output_sample.detach(), dim=1)
    acc_count = torch.count_nonzero(output_pred == s_y).item() / (5 * valid_num_sample)
    train_count += acc_count
  
  x_adv = fgsm(x, y)
  adv_output = new_model(x_adv)
  adv_loss += criterion(adv_output, y).item()

  outputs = model(x_adv)
  output_mean, output_std = torch.chunk(outputs, 2, dim=1)
  for _ in range(valid_num_sample):
    eps = torch.normal(0, 1, output_mean.shape, device='cuda')
    output_sample = output_mean + eps * output_std
    output_pred = torch.argmax(output_sample.detach(), dim=1)
    acc_count = torch.count_nonzero(output_pred == s_y).item() / (5 * valid_num_sample)
    adv_count += acc_count
  
print(train_loss, adv_loss)
print(train_count / len(train_dataset), adv_count / len(train_dataset))

24786.13035964966 41292.71825027466
0.9521679200001744 0.675954119999956
