In [None]:
import os
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from pytorch_lightning.utilities.seed import seed_everything
import clustering_utils
import data_utils
import model_utils
import persistence_utils
import visualisation_utils
import models
import numpy as np
import wandb


from experiments_clevr_extra import (train_graph_class,
                                     plot_samples, save_centroids,
                                     print_near_example, test_retrieval,
                                     test_missing_modality)
visualisation_utils.set_rc_params()

In [None]:
wandb.init()

In [None]:
# define some variables
seed = 0
seed_everything(seed)
DATASET_NAME = 'clevr'
MODE = 'SHARCS'
path = os.path.join("output", DATASET_NAME, MODE, f"seed_{seed}")
data_utils.create_path(path)
MODEL_NAME = f"{DATASET_NAME}_{MODE}"
NUM_CLASSES = 2
TRAIN_TEST_SPLIT = 0.8
VOCAB_DIM = 22
CLUSTER_ENCODING_SIZE = 24
EPOCHS = 10
LR = 0.001
BATCH_SIZE = 5
global dev
if torch.cuda.is_available():
 dev = "cuda"
else:
 dev = "cpu"
device = torch.device(dev)
LAYER_NUM = 0

In [None]:
# load data
print("Reading dataset")

train_loader, train, test_loader, test = data_utils.create_clevr()

full_train_loader = DataLoader(train, batch_size=int(len(train) * 0.1), shuffle=True)
full_test_loader = DataLoader(test, batch_size=max(int(len(test) * 0.1), 1))

print('Done!')

In [None]:
# train model
model = models.CLEVR_SHARCS(VOCAB_DIM, CLUSTER_ENCODING_SIZE, NUM_CLASSES)
interpretable = True
model.to(dev)

model_to_return = model
model = model_utils.register_hooks(model)

# train
train_acc, test_acc, train_loss, test_loss = train_graph_class(model, train_loader, test_loader,
                                                               EPOCHS, LR,
                                                               if_interpretable_model=interpretable,
                                                               mode=MODE)

In [None]:
# additional experiments setup
train_input_image, train_input_text, train_questions, train_img_index, train_y, _, _, train_txt_aux = next(iter(full_train_loader))
train_input_image = train_input_image.to(device)
train_input_text = train_input_text.to(device)
train_y = train_y.to(device)

train_node_concepts, _, _, _ = model(train_input_text, train_input_image, train_y)
train_graph_concepts = model.gnn_graph_shared_concepts.cpu()
train_graph_local_concepts = model.gnn_graph_local_concepts.cpu()
train_tab_local_concepts = model.x_tab_local_concepts.cpu()
train_tab_concepts = model.tab_shared_concepts.cpu()

test_input_image, test_input_text, test_questions, test_img_index, test_y, _, _, test_txt_aux = next(iter(full_test_loader))
test_input_image = test_input_image.to(device)
test_input_text = test_input_text.to(device)
test_y = test_y.to(device)
test_node_concepts, _, _, _ = model(test_input_text, test_input_image, test_y)
test_graph_concepts = model.gnn_graph_shared_concepts.cpu()
test_graph_local_concepts = model.gnn_graph_local_concepts.cpu()
test_tab_local_concepts = model.x_tab_local_concepts.cpu()
test_tab_concepts = model.tab_shared_concepts.cpu()

graph_concepts = torch.vstack([train_graph_concepts, test_graph_concepts])
graph_local_concepts = torch.vstack([train_graph_local_concepts, test_graph_local_concepts])
tab_local_concepts = torch.vstack([train_tab_local_concepts, test_tab_local_concepts])
tab_concepts = torch.vstack([train_tab_concepts, test_tab_concepts])

q_train = train_input_text.cpu()
q_test = test_input_text.cpu()
q = torch.cat((q_train, q_test))

idx_train = train_img_index.cpu()
idx_test = test_img_index.cpu()

img_train = train_input_image.cpu()
img_test = test_input_image.cpu()
img = torch.cat((img_train, img_test))

y_train = train_y.cpu()
y_test = test_y.cpu()
y = torch.cat((y_train, y_test))

questions = train_questions + test_questions
txt_aux = train_txt_aux + test_txt_aux
idx = torch.cat((train_img_index, test_img_index)).numpy()

train_mask = np.zeros(y.shape[0], dtype=bool)
train_mask[:y_train.shape[0]] = True
test_mask = ~train_mask

In [None]:
# local concepts text
print("\n_____________THIS IS FOR TEXT____________")
concepts_g_local = torch.Tensor(graph_local_concepts).detach()
centroids, centroid_labels, used_centroid_labels = clustering_utils.find_centroids(concepts_g_local, y)
print(f"Number of graph cenroids: {len(centroids)}")

cluster_counts = visualisation_utils.print_cluster_counts(used_centroid_labels)
classifier = models.ActivationClassifierConcepts(y, used_centroid_labels, train_mask, test_mask)

print(f"Classifier Concept completeness score: {classifier.accuracy}")
concept_metrics = [('cluster_count', cluster_counts)]

visualisation_utils.plot_clustering(seed, concepts_g_local, y, centroids, centroid_labels,
                                    used_centroid_labels,
                                    MODEL_NAME, LAYER_NUM, path, task="graph local", id_path="_graph")

g_concepts = concepts_g_local.detach().cpu().numpy()

print('TEXT CONCEPTS')

sample_graphs, sample_feat = plot_samples(None, g_concepts, 5, questions, concepts_g_local, path,
                                          concepts=centroids, task='local')

In [None]:
# local concepts image
print("\n_____________THIS IS FOR IMAGE____________")
concepts_g_img_local = torch.Tensor(tab_local_concepts).detach()
# find centroids for both modalities
centroids, centroid_labels, used_centroid_labels = clustering_utils.find_centroids(concepts_g_img_local, y)
print(f"Number of graph cenroids: {len(centroids)}")
persistence_utils.persist_experiment(centroids, path, 'centroids_g.z')
persistence_utils.persist_experiment(centroid_labels, path, 'centroid_labels_g.z')
persistence_utils.persist_experiment(used_centroid_labels, path, 'used_centroid_labels_g.z')

# calculate cluster sizing
cluster_counts = visualisation_utils.print_cluster_counts(used_centroid_labels)
classifier = models.ActivationClassifierConcepts(y, used_centroid_labels, train_mask, test_mask)

print(f"Classifier Concept completeness score: {classifier.accuracy}")
concept_metrics = [('cluster_count', cluster_counts)]
persistence_utils.persist_experiment(concept_metrics, path, 'image_concept_metrics.z')
# wandb.log({'local graph completeness': classifier.accuracy, 'num clusters local graph': len(centroids)})

# plot concept heatmaps
# visualisation_utils.plot_concept_heatmap(centroids, concepts, y_double, used_centroid_labels, MODEL_NAME, LAYER_NUM, path, id_title="Graph ", id_path="graph_")

# plot clustering
visualisation_utils.plot_clustering(seed, concepts_g_img_local, y, centroids, centroid_labels,
                                    used_centroid_labels, MODEL_NAME, LAYER_NUM, path, task="graph image local",
                                    id_path="_graph")
g_img_concepts = concepts_g_img_local.detach().cpu().numpy()

print('IMAGE CONCEPTS')
sample_graphs, sample_feat = plot_samples(None, g_img_concepts, 5, idx, concepts_g_img_local, path,
                                          concepts=centroids, task='local', mod='image')

In [None]:
# shared space concepts
print("\n_____________THIS IS FOR GRAPHS AND TABLE____________")
concepts = torch.vstack([torch.Tensor(graph_concepts), torch.Tensor(tab_concepts)]).detach()
y_double = torch.cat((y, y), dim=0)
# find centroids for both modalities
centroids, centroid_labels, used_centroid_labels = clustering_utils.find_centroids(concepts, y_double)
print(f"Number of graph cenroids: {len(centroids)}")

cluster_counts = visualisation_utils.print_cluster_counts(used_centroid_labels)
train_mask_double = np.concatenate((train_mask, train_mask), axis=0)
test_mask_double = np.concatenate((test_mask, test_mask), axis=0)
classifier = models.ActivationClassifierConcepts(y_double, used_centroid_labels, train_mask_double,
                                                 test_mask_double)

print(f"Classifier Concept completeness score: {classifier.accuracy}")
concept_metrics = [('cluster_count', cluster_counts)]
persistence_utils.persist_experiment(concept_metrics, path, 'shared_concept_metrics.z')
wandb.log({'shared completeness': classifier.accuracy, 'num clusters shared': len(centroids)})

tab_or_graph = torch.cat((torch.ones(int(concepts.shape[0] / 2)), torch.zeros(int(concepts.shape[0] / 2))),
                         dim=0)
visualisation_utils.plot_clustering(seed, concepts, tab_or_graph, centroids, centroid_labels,
                                    used_centroid_labels, MODEL_NAME, LAYER_NUM, path, task="shared",
                                    id_path="_graph")

In [None]:
# text shared concepts
print('TEXT CONCEPTS')
g_concepts = graph_concepts.detach().cpu().numpy()
top_plot, top_concepts = plot_samples(None, g_concepts, 5, questions, concepts, path, concepts=centroids, task='local')

In [None]:
# image shared concepts
print('Image CONCEPTS')
t_concepts = tab_concepts.detach().numpy()
top_plot_images, top_concepts_img = plot_samples(None, t_concepts, 5, idx, concepts, path, concepts=centroids, task='global', mod='image')

In [None]:
print('------SHARED SPACE-----')
top_concepts_both = np.array(top_concepts + top_concepts_img)

top_plot_both = top_plot + top_plot_images

if len(top_concepts + top_concepts_img) > 0:
    visualisation_utils.plot_clustering_images_inside(seed, concepts, top_concepts_both, top_plot_both,
                                                      used_centroid_labels, path,
                                                      'shared space with images')

In [None]:
# combined concepts
print("\n_____________THIS IS FOR COMBINED CONCEPTS____________")
union_concepts = torch.cat([torch.Tensor(graph_concepts), torch.Tensor(tab_concepts)], dim=-1).detach()

# find centroids for both modalities
centroids, centroid_labels, used_centroid_labels = clustering_utils.find_centroids(union_concepts, y)
print(f"Number of graph cenroids: {len(centroids)}")

cluster_counts = visualisation_utils.print_cluster_counts(used_centroid_labels)

classifier = models.ActivationClassifierConcepts(y, used_centroid_labels, train_mask, test_mask)

save_centroids(centroids, y, used_centroid_labels, union_concepts,
               g_concepts, questions,
               t_concepts, idx,
               path)
classifier.plot2(path)

print(f"Classifier Concept completeness score: {classifier.accuracy}")
concept_metrics = [('cluster_count', cluster_counts)]
persistence_utils.persist_experiment(concept_metrics, path, 'graph_concept_metrics.z')
wandb.log({'combined completeness': classifier.accuracy, 'num clusters combined': len(centroids)})

In [None]:
# plot clustering
visualisation_utils.plot_clustering(seed, union_concepts, y, centroids, centroid_labels,
                                    used_centroid_labels,
                                    MODEL_NAME, LAYER_NUM, path, task="combined", id_path="_graph",
                                    extra=True, train_mask=train_mask, test_mask=test_mask,
                                    n_classes=NUM_CLASSES)

In [None]:
print_near_example(g_concepts, questions, t_concepts, idx, path)

In [None]:
import experiments_clevr_extra
experiments_clevr_extra.dev = dev
missing_accuracy = test_missing_modality(full_test_loader, model, t_concepts, g_concepts)
print(missing_accuracy)

In [None]:
retrieval_acc = test_retrieval(full_test_loader, model, g_concepts, questions, t_concepts, txt_aux)
print(retrieval_acc)