In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from experiments import intervention_utils
import os
from cem.data.CUB200.cub_loader import load_data, find_class_imbalance
from torchvision.models import resnet50, resnet34
from cem.models.cem import ConceptEmbeddingModel
import pytorch_lightning as pl
import numpy as np
import cem.train.training as cem_train
import re
import random
import matplotlib.pyplot as plt
import joblib
from sklearn.metrics.pairwise import cosine_similarity
from vae_model import *

In [4]:
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.tab10.colors)

## Dataset and Model Setup

In [5]:
CUB_DIR = 'CUB/'
BASE_DIR = os.path.join(CUB_DIR, 'images/CUB_200_2011/images')
num_workers=8
n_tasks = 200
n_concepts = 112
gpu = 0
sample_train = 0.1
concept_group_map = intervention_utils.CUB_CONCEPT_GROUP_MAP
num_epochs = 100
seed = 42

In [6]:
train_data_path = os.path.join(CUB_DIR, 'preprocessed/train.pkl')
val_data_path = train_data_path.replace('train.pkl', 'val.pkl')
test_data_path = train_data_path.replace('train.pkl', 'test.pkl')

In [7]:
config = joblib.load("results/ConceptEmbeddingModelNew_resnet34_fold_1_experiment_config.joblib")

In [8]:
if config['weight_loss']:
    imbalance = find_class_imbalance(train_data_path, True)
else:
    imbalance = None

In [9]:
selected_concepts = np.arange(n_concepts)
def subsample_transform(sample):
    if isinstance(sample, list):
        sample = np.array(sample)
    return sample[selected_concepts]

In [10]:
train_dl = load_data(
    pkl_paths=[train_data_path],
    use_attr=True,
    no_img=False,
    batch_size=config['batch_size'],
    uncertain_label=False,
    n_class_attr=2,
    image_dir='images',
    resampling=False,
    root_dir='.',
    num_workers=config['num_workers'],
    concept_transform=subsample_transform,
    path_transform=lambda path: path.replace("CUB//",""),
)
val_dl = load_data(
    pkl_paths=[val_data_path],
    use_attr=True,
    no_img=False,
    batch_size=config['batch_size'],
    uncertain_label=False,
    n_class_attr=2,
    image_dir='images',
    resampling=False,
    root_dir='.',
    num_workers=config['num_workers'],
    concept_transform=subsample_transform,
    path_transform=lambda path: path.replace("CUB//",""))
test_dl = load_data(
    pkl_paths=[test_data_path],
    use_attr=True,
    no_img=False,
    batch_size=config['batch_size'],
    uncertain_label=False,
    n_class_attr=2,
    image_dir='images',
    resampling=False,
    root_dir='.',
    num_workers=config['num_workers'],
    concept_transform=subsample_transform,
    path_transform=lambda path: path.replace("CUB//","")
)

In [11]:
if sample_train < 1.0:
    train_dataset = train_dl.dataset
    train_size = round(len(train_dataset)*sample_train)
    train_subset = random.sample(range(0,len(train_dataset)),train_size)
    train_dataset = torch.utils.data.Subset(train_dataset, train_subset)
    sample_train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=64, shuffle=True, 
                                           drop_last=True, num_workers=num_workers)



In [12]:
trainer = pl.Trainer(
            gpus=gpu,
        )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


## Baseline Intervention Results

In [13]:
def get_results(related_concepts,test_range=range(0,len(concept_group_map),4)):
    results_by_num_groups = {}
    for num_groups_intervened in test_range:
        intervention_idxs = intervention_utils.random_int_policy(
                    num_groups_intervened=num_groups_intervened,
                    concept_group_map=concept_group_map,
                    config=config,
                )

        model = intervention_utils.load_trained_model(
                    config=config,
                    n_tasks=n_tasks,
                    n_concepts=n_concepts,
                    result_dir="results/",
                    split=0,
                    imbalance=imbalance,
                    intervention_idxs=intervention_idxs,
                    train_dl=sample_train_dl,
                    sequential=False,
                    independent=False,
                    related_concepts=related_concepts
                )

        [test_results] = trainer.test(model, val_dl, verbose=False,)
        results_by_num_groups[num_groups_intervened] = test_results
    return results_by_num_groups

In [40]:
results_baseline = get_results(None)

Testing: 0it [00:00, ?it/s]

c shape torch.Size([128, 3])


IndexError: index 3 is out of bounds for dimension 1 with size 3

In [None]:
x_vals = sorted([i for i in results_baseline])
y_vals = [results_baseline[i]['test_y_accuracy'] for i in x_vals]
plt.ylim([.75,1])
plt.plot(x_vals,y_vals)
plt.scatter(x_vals,y_vals)

## Random Intervention Results

In [113]:
random_related = {}
num_random_per = 10
identity_function = lambda s: s

for i in range(n_concepts):
    other_concepts = [j for j in range(n_concepts) if j!= i]
    our_concepts = random.sample(other_concepts,num_random_per)
    confidences = [random.random() for j in our_concepts]
    
    random_related[i] = {}
    
    for j in range(len(our_concepts)):
        random_related[i][our_concepts[j]] = (identity_function,confidences[j])
    

In [114]:
results_random = get_results(random_related)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

In [81]:
x_vals = sorted([i for i in results_random])
y_vals = [results_random[i]['test_y_accuracy'] for i in x_vals]
plt.ylim([.75,1])
plt.plot(x_vals,y_vals)
plt.scatter(x_vals,y_vals)

<matplotlib.collections.PathCollection at 0x14e7df025890>

## Intervening with Concept Vectors

In [101]:
def find_closest_vectors(matrix):
    num_vectors = matrix.shape[0]
    sim_matrix = cosine_similarity(matrix)

    # Set diagonal to -inf to exclude self-similarity
    np.fill_diagonal(sim_matrix, -np.inf)

    closest_indices = []
    for i in range(num_vectors):
        cosine_similarities = sim_matrix[i]
        distances = [(j, abs(sim),sim) for j, sim in enumerate(cosine_similarities)]
        # Sort by distance in descending order
        distances.sort(key=lambda x: x[1], reverse=True)
        # Get the indices of the 3 closest vectors (excluding itself)
        closest_indices.append([(j, sim) for j, d, sim in distances if j != i][:10])

    return closest_indices

In [68]:
concept2vec = np.load(open("concept_vectors/concept2vec.npy","rb"))
closest_vectors = find_closest_vectors(concept2vec)

In [70]:
related_concept2vec = {}
identity_function = lambda s: s
opposite_function = lambda s: 1-s

for i in range(n_concepts):
    related_concept2vec[i] = {}
    
    for index,similarity in closest_vectors[i]:
        confidence = similarity**2
        if similarity > 0:
            related_concept2vec[i][index] = (identity_function,confidence)
        else:
            related_concept2vec[i][index] = (opposite_function,confidence)

In [78]:
results_concept2vec = get_results(related_concept2vec)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

In [79]:
x_vals = sorted([i for i in results_concept2vec])
y_vals = [results_concept2vec[i]['test_y_accuracy'] for i in x_vals]
plt.ylim([.75,1])
plt.plot(x_vals,y_vals)
plt.scatter(x_vals,y_vals)

<matplotlib.collections.PathCollection at 0x14e7e2a771d0>

In [102]:
labels = np.load(open("concept_vectors/labels.npy","rb"))
closest_vectors = find_closest_vectors(labels)

In [107]:
related_labels = {}
identity_function = lambda s: s
opposite_function = lambda s: 1-s

for i in range(n_concepts):
    related_labels[i] = {}
    
    for index,similarity in closest_vectors[i]:
        confidence = abs(similarity)
        if similarity > 0:
            related_labels[i][index] = (identity_function,confidence)
        else:
            related_labels[i][index] = (opposite_function,confidence)

In [108]:
related_labels

{0: {51: (<function __main__.<lambda>(s)>, 0.49999999999999967),
  68: (<function __main__.<lambda>(s)>, 0.4874202371449994),
  48: (<function __main__.<lambda>(s)>, 0.4206063459554704),
  14: (<function __main__.<lambda>(s)>, 0.4148539997914772),
  94: (<function __main__.<lambda>(s)>, 0.41442728702466974),
  29: (<function __main__.<lambda>(s)>, 0.39963658895641585),
  62: (<function __main__.<lambda>(s)>, 0.38281628470929807),
  20: (<function __main__.<lambda>(s)>, 0.3645609089292081),
  43: (<function __main__.<lambda>(s)>, 0.3589432799485324),
  74: (<function __main__.<lambda>(s)>, 0.34982025576313147)},
 1: {79: (<function __main__.<lambda>(s)>, 0.5192500062124917),
  107: (<function __main__.<lambda>(s)>, 0.5135081333589402),
  58: (<function __main__.<lambda>(s)>, 0.4824355175435877),
  51: (<function __main__.<lambda>(s)>, 0.38543001476736205),
  68: (<function __main__.<lambda>(s)>, 0.32522181779399545),
  94: (<function __main__.<lambda>(s)>, 0.3113473572652852),
  38: (<f

In [110]:
results_labels = get_results(related_labels)
results_labels

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Testing: 0it [00:00, ?it/s]

{0: {'test_c_accuracy': 0.961543083190918,
  'test_c_auc': 0.9402127265930176,
  'test_c_f1': 0.9409788846969604,
  'test_y_accuracy': 0.7712854743003845,
  'test_y_auc': 0.0,
  'test_y_f1': 0.6555336713790894,
  'test_concept_loss': 0.6598849296569824,
  'test_task_loss': 1.6107624769210815,
  'test_loss': 4.910186767578125,
  'test_avg_c_y_acc': 0.8664143085479736},
 4: {'test_c_accuracy': 0.961543083190918,
  'test_c_auc': 0.9402127265930176,
  'test_c_f1': 0.9409788846969604,
  'test_y_accuracy': 0.8372287154197693,
  'test_y_auc': 0.0,
  'test_y_f1': 0.7421911954879761,
  'test_concept_loss': 0.6598849296569824,
  'test_task_loss': 0.7085254192352295,
  'test_loss': 4.007950305938721,
  'test_avg_c_y_acc': 0.899385929107666},
 8: {'test_c_accuracy': 0.961543083190918,
  'test_c_auc': 0.9402127265930176,
  'test_c_f1': 0.9409788846969604,
  'test_y_accuracy': 0.9056761264801025,
  'test_y_auc': 0.0,
  'test_y_f1': 0.8395071625709534,
  'test_concept_loss': 0.6598849296569824,
  'te

In [82]:
x_vals = sorted([i for i in results_labels])
y_vals = [results_labels[i]['test_y_accuracy'] for i in x_vals]
plt.ylim([.75,1])
plt.plot(x_vals,y_vals)
plt.scatter(x_vals,y_vals)

<matplotlib.collections.PathCollection at 0x14e7e223e190>

In [121]:
for dataset,name in zip([results_by_num_groups,results_random,results_concept2vec,results_labels],['normal','random','concept2vec','labels']):
    x_vals = sorted([i for i in dataset])
    y_vals = [dataset[i]['test_y_accuracy'] for i in x_vals]
    plt.ylim([.75,1])
    plt.plot(x_vals,y_vals,label=name)
    plt.scatter(x_vals,y_vals)
plt.legend()

<IPython.core.display.Javascript object>

RecursionError: maximum recursion depth exceeded while calling a Python object

## VAE-based concepts

In [14]:
def fix_concept(model,concepts):
    encoded_image = model.decoder.predict(np.array([concepts]))
    re_decoded_image = model.encoder.predict(encoded_image)
    
    return re_decoded_image
    

In [40]:
def get_results_vae(vae_model,test_range=range(0,len(concept_group_map),4)):
    results_by_num_groups = {}
    for num_groups_intervened in test_range:
        intervention_idxs = intervention_utils.random_int_policy(
                    num_groups_intervened=num_groups_intervened,
                    concept_group_map=concept_group_map,
                    config=config,
                )

        model = intervention_utils.load_trained_model(
                    config=config,
                    n_tasks=n_tasks,
                    n_concepts=n_concepts,
                    result_dir="results/",
                    split=0,
                    imbalance=imbalance,
                    intervention_idxs=intervention_idxs,
                    train_dl=sample_train_dl,
                    sequential=False,
                    independent=False,
                    vae_model=vae_model
                )

        [test_results] = trainer.test(model, val_dl, verbose=False,)
        results_by_num_groups[num_groups_intervened] = test_results
    return results_by_num_groups

In [23]:
latent_dim = 112
size = 64
decoder_3 = create_decoder(size,3,latent_dim)
encoder_3 = create_encoder(size,3,latent_dim)

model = VAE(encoder_3, decoder_3,concept_alignment=True)
model.built = True

In [24]:
model.load_weights('concept_vectors/vae_concept.h5')

In [38]:
predicted_image = model.decoder.predict(np.random.random((1,112)))



In [42]:
results_vae = get_results_vae(model,test_range=[0])

TypeError: load_trained_model() got an unexpected keyword argument 'vae_model'

## Previous Code

In [44]:
for num_groups_intervened in range(0,len(concept_group_map),4):
    intervention_idxs = intervention_utils.random_int_policy(
                num_groups_intervened=num_groups_intervened,
                concept_group_map=concept_group_map,
                config=config,
            )
        
    model = intervention_utils.load_trained_model(
                config=config,
                n_tasks=n_tasks,
                n_concepts=n_concepts,
                result_dir="results/",
                split=0,
                imbalance=imbalance,
                intervention_idxs=intervention_idxs,
                train_dl=sample_train_dl,
                sequential=False,
                independent=False,
            )
        
    [test_results] = trainer.test(model, val_dl, verbose=False,)
    results_by_num_groups[num_groups_intervened] = test_results

Testing: 0it [00:00, ?it/s]

tensor([[0., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 0., 0.],
        [0., 1., 0.],
        [1

TypeError: cannot unpack non-iterable NoneType object

In [21]:
results_by_num_groups

{0: {'test_c_accuracy': 0.961543083190918,
  'test_c_auc': 0.9402127265930176,
  'test_c_f1': 0.9409788846969604,
  'test_y_accuracy': 0.7712854743003845,
  'test_y_auc': 0.0,
  'test_y_f1': 0.6555336713790894,
  'test_concept_loss': 0.6598849296569824,
  'test_task_loss': 1.6107624769210815,
  'test_loss': 4.910186767578125,
  'test_avg_c_y_acc': 0.8664143085479736},
 4: {'test_c_accuracy': 0.961543083190918,
  'test_c_auc': 0.9402127265930176,
  'test_c_f1': 0.9409788846969604,
  'test_y_accuracy': 0.826377272605896,
  'test_y_auc': 0.0,
  'test_y_f1': 0.7257368564605713,
  'test_concept_loss': 0.6598849296569824,
  'test_task_loss': 0.9781456589698792,
  'test_loss': 4.277569770812988,
  'test_avg_c_y_acc': 0.8939602375030518},
 8: {'test_c_accuracy': 0.961543083190918,
  'test_c_auc': 0.9402127265930176,
  'test_c_f1': 0.9409788846969604,
  'test_y_accuracy': 0.8797996640205383,
  'test_y_auc': 0.0,
  'test_y_f1': 0.8019772171974182,
  'test_concept_loss': 0.6598849296569824,
  'te