### Setup
Before interacting with ProtoRSet, we run through a few setup steps:


#### 1: If you are running in Google Collab, run the following cell to mount and point this notebook to the correct directory


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')
# path_to_repo = "/content/drive/path/to/code"
# !cp -r "$path_to_repo"/* ./
# !pip install -r ./env/requirements-collab.txt

#### 2: Prepare a dataset
In this demo, we'll use the CUB-200 image classification dataset. The following cell downloads, unzips, and splits the dataset. Note that this step may take a few minutes.

In [None]:
# Download and extract dataset
!wget https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz
!tar -xvzf ./CUB_200_2011.tgz
!python -m protopnet create-splits ./CUB_200_2011

#### 3: Prepare a reference ProtoPNet
The following cell contains the code necessary to train your own reference ProtoPNet from scratch. Because training a neural network typically takes several hours, uncommented code instead downloads a trained ProtoPNet that we have provided for this demo.

In [None]:
# The following lines download a trained reference ProtoPNet
!pip install gdown
!gdown https://drive.google.com/uc?id=1c79gyWC4I3J1FxCPKV6wpY-Uqdr16ocU

# # Uncomment and run the following lines to train a reference ProtoPNet
# from protopnet.train_vanilla_cosine import run
# import torch

# ppn = run(
#     dataset="CUB200"
# )

# torch.save(ppn, "./resnet50_cub200_ref_protopnet.pth")

#### 4: Prepare your ProtoRSet

First we prepare dataloaders for the CUB dataset.

In [None]:
import os
from PIL import Image
from IPython.display import display
os.environ["CUB200_DIR"] = "./CUB_200_2011/"
from protopnet.datasets import *

batch_sizes = {"train": 20, "project": 20, "val": 20}
split_dataloaders = training_dataloaders(
    "CUB200",
    data_path=os.environ["CUB200_DIR"],
    batch_sizes=batch_sizes
)

Now, it's finally time to fit our Proto-RSet! The following cell initializes and fits a Proto-RSet.

In [None]:
from rashomon_sets.protorset_factory import ProtoRSetFactory
from pathlib import Path
import torch

RSET_ARGS = {
    "rashomon_bound_multiplier": 1.1, # The ratio of the maximum allowable loss to the minimum
    "lam": 0.0001, # The weight to apply to our L2 regularization on the last layer
    "max_iter": 5000,  # The max number of iterations allowed when fitting the optimal regression
    "device": torch.device("cuda"), # The device to use
    "lr_for_opt": 1.0 # The learning rate to use when fitting the optimal regression
}

# The rset_factory object provides the main access point to interact with
# a ProtoRSet, including the ability to produce a ProtoPNet object
rset_factory = ProtoRSetFactory(
    split_dataloaders=split_dataloaders,
    initial_protopnet_path=Path("./resnet50_cub200_ref_protopnet.pth"),
    rashomon_set_args=RSET_ARGS,
    device="cuda",
    reproject=False, # If true, perform a projection step after laoding in the reference ProtoPNet
    verbose=False,
    analysis_save_dir = Path(
        "./visualizations/"  # This is where images used for visualization will be saved
    ), 
)

We then precompute some self-activations for our prototypes to enable fast, repeated interactions and visualization:

In [None]:
res_path = rset_factory.display_local_analysis(
    49, # The image to visualize
    run_proto_vis=True, # WARNING: Set this to False if running repeatedly. If true, this will save self-activations for every prototype.
    include_coef=True, # Whether to include the last layer coefficient in the visualization
    sort_using_logit=False # If True, show the 3 prototypes with the highest logit for any class; if False, instead sort by prototype activation
)

#### 5: Examine and interact with available ProtoPNets
We're now ready to start playing with a ProtoRSet. The following cell runs a local analysis on the specified image index using the optimal ProtoPNet from the training set, meeting all current constraints.

In [None]:
res_path = rset_factory.display_local_analysis(
    49, # The image to visualize
    run_proto_vis=False, # WARNING: Set this to False if running repeatedly. If true, this will save self-activations for every prototype.
    include_coef=True, # Whether to include the last layer coefficient in the visualization
    sort_using_logit=False # If True, show the 3 prototypes with the highest logit for any class; if False, instead sort by prototype activation
)
img = Image.open(res_path[0])
display(img)

Using this visualization, we can identify prototypes we do/do not like. Say we think prototype 0 is confounded -- we can then remove this prototype as follows:

In [None]:
target_proto = 108
# Check the accuracy of our best model before adding this constraint
print(f"Best validation accuracy before constraint: {rset_factory._best_val_acc().item()}")
# And check the coefficient on the target prototype in our optimal
# model before adding the constraint
pre_removal_coef = rset_factory.produce_protopnet_object().prototype_prediction_head.class_connection_layer.weight[:, target_proto].max()
print(f"Coefficient before constraint: {pre_removal_coef}")

In [None]:
succesfully_removed = rset_factory.require_to_avoid_prototype(target_proto)
print("Succesful removal!" if succesfully_removed else "Cannot remove this prototype.")

In [None]:
# Check the accuracy of our best model after adding this constraint
print(f"Best validation accuracy after constraint: {rset_factory._best_val_acc().item()}")
# And check the coefficient on the target prototype in our optimal
# model after adding the constraint
post_removal_coef = rset_factory.produce_protopnet_object().prototype_prediction_head.class_connection_layer.weight[:, target_proto].max()
print(f"Coefficient after constraint: {post_removal_coef}")

If we see a prototype that we would like more weight placed on, for now we should just keep track of it. The following cell notes that, down the line, we will want to produce a ProtoPNet with coefficient at least 10 on prototype 1:

In [None]:
required_protos = [(3, 10), (17, 5), (100, 15)]
# To track additional requirements of this kind, add more (protoype, coeff) tuples to this list, ie:
# required_protos.append((prototype_index, coefficient))

Once we're done interacting with our ProtoRSet and think we've reached a satisfactory model, we can grab the optimal model that meets the given constraints using the following cell:

In [None]:
final_model = rset_factory.produce_protopnet_object_with_requirements(required_protos)

In [None]:
for r in required_protos:
  print(f"Prototype {r[0]} cofficient:\t {final_model.prototype_prediction_head.class_connection_layer.weight[:, r[0]].max()}")