# Concept Bottleneck model concept recall accuracy

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import numpy as np

sys.path.append(os.path.abspath(os.path.join(os.path.pardir, 'src')))

from policy import ConceptNet
from jem import data_utils

model_name = 'net'
session_name =  'falcon'
board_size = 5
nr_of_concepts = data_utils.get_number_of_concepts()

benchmark_model_path = "../models/cbm/conceptnet.keras"

agents_to_sample = [0, 10, 20, 60, 100, 500]
cases_to_sample = 1000

full_model_path = f"../models/saved_sessions/conceptnet/board_size_{board_size}/{session_name}/"

def load_model(full_name, model_name, epoch):
    model_path = full_name + model_name + "_" + str(epoch) + ".keras"
    model = ConceptNet(board_size, model_path)
    return model

agents = [load_model(full_model_path, model_name, epoch) for epoch in agents_to_sample]

model = ConceptNet(board_size, nr_of_concepts, model_path=benchmark_model_path)

states, concepts = data_utils.generate_concept_one_hot_encodings(agents, cases_to_sample, board_size)

In [None]:
# Testing the accuracy of the concept bottleneck output
concepts_pred = model.predict_concepts(states)

accuracy = accuracy_score(concepts, concepts_pred)
print(f'Accuracy of bottleneck layer: {accuracy}')

In [None]:
concepts_gt_dict = data_utils.concept_functions_to_dict()
concepts_pred_dict = data_utils.concept_functions_to_dict()


# Plot the accuracy of each individual concept in the bottleneck layer
for i, concept in enumerate(concepts):
    # Find the index with the highest value in the prediction
    pred = concepts_pred[i]
    pred_idx = pred.argmax()
    # Find all the indices where the concept is 1
    concept_idx = concept.argmax()

    concepts_gt_dict[concept_idx] += 1
    concepts_pred_dict[pred_idx] += 1

# Plot the concepts_gt_dict and concepts_pred_dict
fig, ax = plt.subplots()
index = np.arange(len(concepts_gt_dict.keys()))
bar_width = 0.35

rects1 = ax.bar(index, concepts_gt_dict.values(), bar_width, label='Ground truth')
rects2 = ax.bar(index + bar_width, concepts_pred_dict.values(), bar_width, label='Prediction')

ax.set_xlabel('Concept')
ax.set_ylabel('Count')
ax.set_title('Concepts')
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(concepts_gt_dict.keys())
ax.legend()

plt.show()