In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn
import matplotlib.pyplot as plt

from craft.craft_torch import Craft, torch_to_numpy
from data.dataloader import *
from data.dataprocess import *
from model.resnet import ResNet
from model.submodel import *
from model.method import *

In [None]:
model = ResNet(n_class=10)

CRITERION = nn.CrossEntropyLoss()
OPTIMIZER = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 10

model.to(device=DEVICE)

## Training

In [None]:
train_ld, test_ld = dataloader('CIFAR10', 100, True)

classes_name = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                'dog', 'frog', 'horse', 'ship', 'truck']

training(model=model,
        train_set=train_ld,
        criterion=CRITERION,
        optimizer=OPTIMIZER,
        device=DEVICE,
        epochs=10)

evaluating(model=model,
           data_set=test_ld,
           criterion=CRITERION,
           device=DEVICE)

In [None]:
img_by_cl = {name: separation(test_ld, i) for i, name in enumerate(classes_name)}

## Concepts with CRAFT

In [None]:
g, h = split(model=model)

In [None]:
craft = Craft(input_to_latent = g,
              latent_to_logit = h,
              number_of_concepts = 5,
              patch_size = 12,
              batch_size = 64,
              device = DEVICE)

# now we can start fit the concept using our rabbit images
# CRAFT will (1) create the patches, (2) find the concept
# and (3) return the crops (crops), the embedding of the crops (crops_u), and the concept bank (w)
crops, crops_u, w = craft.fit(img_by_cl['airplane'])
crops = np.moveaxis(torch_to_numpy(crops), 1, -1)

crops.shape, crops_u.shape, w.shape

In [None]:
importances = craft.estimate_importance(img_by_cl['airplane'], class_id=0)

plt.bar(range(len(importances)), importances)
plt.xticks(range(len(importances)))
plt.title("Concept Importance")

most_important_concepts = np.argsort(importances)[::-1][:5]

for c_id in most_important_concepts:
  print("Concept", c_id, " has an importance value of ", importances[c_id])

### Bank of concepts

## Adversarial Generation

### Projection

## SVM