In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn
import matplotlib.pyplot as plt

from craft.craft_torch import Craft, torch_to_numpy
from art.attacks.evasion import FastGradientMethod
from art.estimators.classification import PyTorchClassifier
from data.dataloader import *
from data.dataprocess import *
from model.resnet import ResNet
from model.submodel import *
from model.method import *

In [None]:
model = ResNet(n_class=10)

CRITERION = nn.CrossEntropyLoss()
OPTIMIZER = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 10

model.to(device=DEVICE)

## Training

In [None]:
train_ld, test_ld = dataloader('CIFAR10', 100, True)

classes_name = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                'dog', 'frog', 'horse', 'ship', 'truck']

training(model=model,
        train_set=train_ld,
        criterion=CRITERION,
        optimizer=OPTIMIZER,
        device=DEVICE,
        epochs=10)

evaluating(model=model,
           data_set=test_ld,
           criterion=CRITERION,
           device=DEVICE)

In [None]:
img_by_cl = {name: separation(test_ld, i) for i, name in enumerate(classes_name)}

In [None]:
x_train y_train = get_tensors(train_ld)
x_test, y_test = get_tensors(test_ld, True)

x_train.shape, x_test.shape

## Concepts with CRAFT

In [None]:
g, h = split(model=model)

In [None]:
craft = Craft(input_to_latent = g,
              latent_to_logit = h,
              number_of_concepts = 5,
              patch_size = 12,
              batch_size = 64,
              device = DEVICE)

H = []
for name in classes_name:
    # now we can start fit the concept using our images
    # CRAFT will (1) create the patches, (2) find the concept
    # and (3) return the crops (crops), the embedding of the crops (crops_u), and the concept bank (w)
    _, _, w = craft.fit(img_by_cl[name])
    H.append(w)

# crops, crops_u, w = craft.fit(img_by_cl['airplane'])
# crops = np.moveaxis(torch_to_numpy(crops), 1, -1)

# crops.shape, crops_u.shape, w.shape

### Concepts visualization

In [None]:
for name in classes_name:  
  importances = craft.estimate_importance(img_by_cl['airplane'], class_id=0)

  plt.bar(range(len(importances)), importances)
  plt.xticks(range(len(importances)))
  plt.title("Concept Importance")

  most_important_concepts = np.argsort(importances)[::-1][:5]

  for c_id in most_important_concepts:
    print("Concept", c_id, " has an importance value of ", importances[c_id])

### Bank of concepts

In [None]:
HBD = np.concatenate(H, axis=0)
HBD.shape

## Adversarial Generation

In [None]:
"""
The script demonstrates a simple example of using ART with PyTorch. The example train a small model on the MNIST dataset
and creates adversarial examples using the Fast Gradient Sign Method. Here we use the ART classifier to train the model,
it would also be possible to provide a pretrained model to the ART classifier.
The parameters are chosen for reduced computational requirements of the script and not optimised for accuracy.
"""

# Step 1a: Swap axes to PyTorch's NCHW format

x_test = x_test.astype(np.float32)

# Step 3: Create the ART classifier

classifier = PyTorchClassifier(
    model=model,
    clip_values=(0, 1),
    loss=CRITERION,
    optimizer=OPTIMIZER,
    input_shape=(3, 32, 32),
    nb_classes=10,
)

# Step 5: Evaluate the ART classifier on benign test examples

predictions = classifier.predict(x_test)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
print("Accuracy on benign test examples: {}%".format(accuracy * 100))

# Step 6: Generate adversarial test examples
attack = FastGradientMethod(estimator=classifier, eps=0.2)
x_test_adv = attack.generate(x=x_test)

# Step 7: Evaluate the ART classifier on adversarial test examples

predictions = classifier.predict(x_test_adv)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
print("Accuracy on adversarial test examples: {}%".format(accuracy * 100))

### Projection

## SVM