In [2]:
import base64
import requests
import numpy as np
import matplotlib.pyplot as plt
import glob
import torch
import os
import re
import pickle 
import time
from collections import Counter
from natsort import natsorted

from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)  # for exponential backoff


import sys
sys.path.insert(0, '..')
from utils_bnr import set_seed

class Args:
    seed = 0
    batch_size = 1
    num_workers = 4
    image_size = 128
    image_channels = 3

    num_slots = 4
    num_blocks = 8
    
    #-----------------------------------------------------------------#
    #!!!!!!!!!!!!!!!!Change this as necessary!!!!!!!!!!!!!!!!!!!!!!!!!!
    #-----------------------------------------------------------------#
    base_path = '../logs/clevr_easy_500_epochs/sysbind_orig_seed2'
    rand_imgs_path = '../logs/clevr_easy_500_epochs/random_imgs'
    
#     base_path = '../logs/clevr4_600_epochs/clevr4_sysbind_orig_seed2'
#     rand_imgs_path = '../logs/clevr4_600_epochs/random_imgs'
    #-----------------------------------------------------------------#

args=Args()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(args.seed)

# set seed
torch.manual_seed(args.seed)

revise_dir = f'{args.base_path}/revision'
if not os.path.exists(revise_dir):
    os.makedirs(revise_dir)
args.save_path = revise_dir



In [4]:
def extract_answer_from_response_json(response, split_str='Final Answer: ', verbose=0):
    """
    response: response from the VLM model.
    split_str: string that specifies where to cut the VLM response in order to extract the final task answer.
    """
    try:
        answer = response.json()['choices'][0]['message']['content']
    except:
        print(f'{response.json()} \n Error')
    
    if verbose:
        print(answer)
    
    answer_y_n = answer.split(split_str)[-1]
    if 'No' not in answer_y_n and 'Yes' not in answer_y_n:
        print(answer)        
        raise Exception("This response isn't binary") 
    else:
        if 'No' in answer_y_n:
            return 0
        elif 'Yes' in answer_y_n:
            return 1
     
    
def extract_properties_from_response(response):
     return '{'+response.json()['choices'][0]['message']['content'].split('{\n')[-1].split('\n}')[0]+'}'
    
        
def find_majority(responses, split_str='Final Answer: ' , verbose=0):
    """
    responses: list of response json from VLM model. This list should be the list of multiply querying the model 
    with the same sample and prompt.
    
    returns: the majority vote from these responses.
    """
    
    votes = [extract_answer_from_response_json(response, split_str, verbose) for response in responses]
    
    vote_count = Counter(votes)
    top_two = vote_count.most_common(2)
    if len(top_two)>1 and top_two[0][1] == top_two[1][1]:
        # It is a tie
        return 0
    return top_two[0][0]


# Function to encode the image
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')
    
    
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def completion_with_backoff(**kwargs):
    return client.chat.completions.create(**kwargs)

## 1. Query VLM for property lists - N times

In [16]:
# with open(f'{args.save_path}/responses_knowledge_properties.pkl', 'wb') as f:
#     pickle.dump(knowledge_responses, f)

## 2. Query VLM for occurence of sub-property per object in image

In [37]:
# with open(f'{args.save_path}/responses_description.pkl', 'wb') as f:
#     pickle.dump(responses_description, f)

## 3. Analyse if concept should be deleted

In [None]:
# # load property list responses
# with open(f'{args.save_path}/responses_knowledge_properties.pkl', 'rb') as f:
#     knowledge_responses = pickle.load(f)
# # load image descriptions based on these property lists
# with open(f'{args.save_path}/responses_description.pkl', 'rb') as f:
#     responses_description = pickle.load(f)

In [None]:
# def extract_obj_list_from_response(response):
#     # get object dict as str from response
#     obj_dict_str = response.json()['choices'][0]['message']['content'].split('{')[-1].split('}')[0]
#     # extract only property description lists
#     obj_list_str = re.split(r'Object\d+: ', obj_dict_str)[1:]
#     obj_list_str = [re.split(r'], ', s)[0]+']' for s in obj_list_str] 
# #     obj_list_str = [s.split('[')[1].split(']')[0].split(', ') for s in obj_list_str]
#     return obj_list_str


# def extract_property_lists(response):
#     # given response get only the property list relevant string
#     property_str = extract_properties_from_response(response)
#     # now clean it such that we have two arrays of strings, with each string representing a sub-property
#     property_str = property_str.split('{')[-1].split(',}')[0]
#     property_str = re.split(r': ', property_str)[1:]
#     property_str = [s.split('[')[-1].split(']')[0] for s in property_str]
#     property_str = [s.split(', ') for s in property_str]    
#     return property_str


# def get_bool_prop_list_for_response(response, knowledge_response, verbose=0):
    
#     # get the object list contained within it
#     obj_list_str =  extract_obj_list_from_response(response)
#     if verbose:
#         print(obj_list_str)

#     # get the property list that the description prompt was based on
#     property_list_str = extract_property_lists(knowledge_response)

#     # iterate over each property and subproperty and check if it occurs in the object description of each object
#     # properties_obj_bool represents a dictionary where each property id key contains a boolean array as value
#     properties_obj_bool = {}
#     for property_id in range(len(property_list_str)):
#         # booolean array, for each object we store if it depicts a property
#         property_obj_bool = np.zeros((len(obj_list_str), len(property_list_str[property_id])))

#         for sub_property_id in range(len(property_list_str[property_id])):

#             for obj_id in range(len(obj_list_str)):

#                 property_obj_bool[obj_id, sub_property_id] = \
#                     property_list_str[property_id][sub_property_id] in obj_list_str[obj_id]

#         properties_obj_bool[property_id] = property_obj_bool
    
#     return properties_obj_bool


# def check_for_deletion(prop_obj_bool_dict):
#     delete_bool = []
#     for prop_id in prop_obj_bool_dict.keys():
#         # count the number of occurances of each subproperty over the objects
#         occur_subprop = np.sum(prop_obj_bool_dict[prop_id], axis=0)
#         # check if each object has one sub-property, i.e. the occurance of a subproperty equals the number of objs
#         delete_bool.append(prop_obj_bool_dict[prop_id].shape[0] in occur_subprop)
#     # if at least one subproperty is present across all objects return true
#     if np.sum(delete_bool) > 0:
#         return False
#     else:
#         return True
   

# N_QUERIES = 4
# # a little messed up here, but for each block and each concept we now collect the deletion booleans across samples
# delete_block_dict = {}
# for block_id in responses_description[0].keys():
#     delete_concept_dict = {}
#     for concept_id in responses_description[0][block_id].keys():
#         delete_samples = []
# #         for sample_id in responses_description.keys():
#         for sample_id in range(N_QUERIES):
#             # get description resonse
#             response = responses_description[sample_id][block_id][concept_id]
#             # get a boolean representation of the object description list, where per object we check if a subproperty is in 
#             # the corresponding obj list
#             # prop_obj_bool_dict[0]: N_Obj x N_SubProps
#             prop_obj_bool_dict = get_bool_prop_list_for_response(response, knowledge_responses[sample_id])
#             # check if any subprop occurs across all objects
#             delete_bool = check_for_deletion(prop_obj_bool_dict)

#             delete_samples.append(delete_bool)

#         delete_concept_dict[concept_id] = delete_samples

#     delete_block_dict[block_id] = delete_concept_dict
    
# delete_concepts_dict = {}
# # now we perform majority vote over the collected deletion booleans
# for block_id in delete_block_dict.keys():
#     delete_concepts = []
#     for concept_id in delete_block_dict[block_id].keys():
# #         assert len(delete_block_dict[block_id][concept_id]) % 2 == 1
#         vote_count = Counter(delete_block_dict[block_id][concept_id])
#         top_two = vote_count.most_common(2)
#         if len(top_two)>1 and top_two[0][1] == top_two[1][1]:
#             # It is a tie, in this case we choose benefit of the doubt, i.e., don't delete
#             pass
#         # if majarity vote is true we delete the concept
#         elif top_two[0][0]:
#             delete_concepts.append(concept_id)
#     delete_concepts_dict[block_id] = delete_concepts

# delete_concepts_dict

In [None]:
# with open(f'{args.save_path}/concepts_delete_dict.pkl', 'wb') as f:
#     pickle.dump(delete_concepts_dict, f)

### 3.1 Now integrate this feedback into the retrieval corpus

In [102]:
with open(f'{args.base_path}/block_concept_dicts.pkl', 'rb') as f:
    ret_corpus = pickle.load(f)
with open(f'{args.save_path}/concepts_delete_dict.pkl', 'rb') as f:
    delete_concepts_dict = pickle.load(f)

In [103]:
def set_id_from_id_list(id_list, delete_id, set_id=-1):
    """
    This function takes a list of ids, a delete_id which should be deleted and a set_id, i.e., the value which the
    deleted ids will be set to instead.
    """
    # identify the ids of the concept-to-be-deleted
    rel_ids = id_list == delete_id
    # set these to -1
    id_list[rel_ids] = set_id
    return None


def set_id_over_all_representations(block_corpus, delete_id, set_id=-1):
    """
    Iterates over all representations and resets the ids of the cluster id identified as 'delete_id'. 
    'set_id' is the novel id which the cluster encodings are set to.
    """
    set_id_from_id_list(block_corpus['prototypes']['ids'], 
                        delete_id=delete_id, 
                        set_id=set_id)
    set_id_from_id_list(block_corpus['exemplars']['ids'], 
                        delete_id=delete_id, 
                        set_id=set_id)
    set_id_from_id_list(block_corpus['sivm_basis']['ids'], 
                        delete_id=delete_id, 
                        set_id=set_id)


# def remove_concept(block_corpus, delete_id):
#     """
#     Removes encodings and corresponding information entirely from the cluster identified via 'delete_id'
#     """
#     # i.e. 'prototypes', 'exemplars', 'sivm_basis'
#     representation_keys = list(block_corpus.keys())
#     representation_keys.remove('params_clustering')

#     for representation_key in representation_keys:
#         # e.g. 'exemplars', 'exemplar_ids'
#         data_keys = list(block_corpus[representation_key].keys())
#         # we handle the ids separately
#         data_keys.remove('ids')
#         # identify which encodings to keep and which not to keep
#         del_ids = np.where(block_corpus[representation_key]['ids'] == delete_id)[0]
#         keep_ids = np.where(block_corpus[representation_key]['ids'] != delete_id)[0]
#         # number of individual clusters altogether
#         n_clusters = len(np.unique(block_corpus[representation_key]['ids']))

#         for data_key in data_keys:
#             # handle the case where we have a list of arrays (e.g. 'exemplar_ids') vs a long list (e.g. 'exemplars')
#             if '_ids' in data_key:
#                 #len(block_corpus[representation_key][data_key]) == n_clusters:
#                 block_corpus[representation_key][data_key] = [
#                     ele for idx, ele in enumerate(block_corpus[representation_key][data_key]) if idx != delete_id
#                 ]
#             else:
#                 block_corpus[representation_key][data_key] = [
#                     ele for idx, ele in enumerate(block_corpus[representation_key][data_key]) if idx not in del_ids
#                 ]
#         # finally remove the ids themselves, i.e. keep only relevant ones
#         block_corpus[representation_key]['ids'] = block_corpus[representation_key]['ids'][keep_ids]

    
# update ids in retrieval corpus based on feedback
ret_corpus_delete = ret_corpus.copy()
for block_id in delete_concepts_dict.keys():

    # identify how many clusters exist per block
    image_path = f"{args.base_path}/clustered_exemplars/block{block_id}*.png"
    exemplar_paths = glob.glob(image_path)
    exemplar_paths = natsorted(exemplar_paths)
    n_clusters_block = len(exemplar_paths)
    
    print(block_id)
    print(delete_concepts_dict[block_id])
    
    # we now have three cases how to handle deletion
    # case 1: all clusters are to be deleted --> we set all cluster ids to 0
    # case 2: all clusters, but 1 are to be deleted --> we merge all to-delete clsuters, 
    #         i.e., set all to-delete cluster ids to one of this set
    # case 3: at least two clusters should not be deleted --> we remove the cluster encodings completely of the
    #         to-delete clusters
    if delete_concepts_dict[block_id]:
        # case 1
        if len(delete_concepts_dict[block_id]) == n_clusters_block:
            print(f'Block {block_id}: case 1')
            # set cluster id in corpus to 0 for all representations
            for delete_concept_id in delete_concepts_dict[block_id]:
                set_id_over_all_representations(
                    ret_corpus_delete[block_id], 
                    delete_id=delete_concept_id, 
                    set_id=0
                )
        # case 2
        elif len(delete_concepts_dict[block_id]) == (n_clusters_block - 1): 
            print(f'Block {block_id}: case 2')
            # set all to-delete cluster ids to that of first one to delete, essentially merging these
            set_id = delete_concepts_dict[block_id][0]
            for delete_concept_id in delete_concepts_dict[block_id]:
                set_id_over_all_representations(
                    ret_corpus_delete[block_id], 
                    delete_id=delete_concept_id, 
                    set_id=set_id
                )
        # case 3
        elif len(delete_concepts_dict[block_id]) <= (n_clusters_block - 2):
            print(f'Block {block_id}: case 3')
            for delete_concept_id in delete_concepts_dict[block_id]:
#                 remove_concept(
#                     ret_corpus_delete[block_id],
#                     delete_id=delete_concept_id,
#                 )
                set_id_over_all_representations(
                    ret_corpus_delete[block_id], 
                    delete_id=delete_concept_id, 
                    set_id=-1
                )


0
[0, 1, 2, 3]
Block 0: case 1
1
[3]
Block 1: case 3
2
[]
3
[]
4
[0, 1]
Block 4: case 1
5
[0, 1, 2]
Block 5: case 1
6
[4, 5]
Block 6: case 3
7
[7]
Block 7: case 3


In [107]:
with open(f'{args.save_path}/block_concept_dicts_revise_delete.pkl', 'wb') as f:
    pickle.dump(ret_corpus_delete, f)

## 4. Analyse if concepts should be merged

In [None]:
# load property list responses
with open(f'{args.save_path}/responses_knowledge_properties.pkl', 'rb') as f:
    knowledge_responses = pickle.load(f)
# load image descriptions based on these property lists
with open(f'{args.save_path}/responses_description.pkl', 'rb') as f:
    responses_description = pickle.load(f)
with open(f'{args.save_path}/concepts_delete_dict.pkl', 'rb') as f:
    delete_concepts_dict = pickle.load(f)

In [None]:
def get_common_properties_bool(prop_obj_bool_dict):
    common_properties = {}
    for prop_id in prop_obj_bool_dict.keys():
#     prop_id = 0
        # count the number of occurances of each subproperty over the objects
        occur_subprop = np.sum(prop_obj_bool_dict[prop_id], axis=0)
        # check if any sub-property occurs for all objs, returns bool array
        common_properties[prop_id] = occur_subprop == prop_obj_bool_dict[prop_id].shape[0]
    return common_properties


N_QUERIES = 7
# # a little messed up here, but for each block and each concept we now collect the deletion booleans across samples
# delete_block_dict = {}
#     delete_concept_dict = {}
#     for concept_id in responses_description[0][block_id].keys():
#         delete_samples = []
# #         for sample_id in responses_description.keys():

# identify the common subproperties per concept
common_props = {}
for sample_id in range(N_QUERIES):
    common_props_block = {}
    for block_id in responses_description[0].keys():
        common_props_concept = {}
        for concept_id in responses_description[0][block_id].keys():
            if concept_id not in delete_concepts_dict[block_id]:
                # get description resonse
                response = responses_description[sample_id][block_id][concept_id]
                # get a boolean representation of the object description list, where per object we check if a subproperty is in 
                # the corresponding obj list
                # prop_obj_bool_dict[0]: N_Obj x N_SubProps
                prop_obj_bool_dict = get_bool_prop_list_for_response(response, knowledge_responses[sample_id])
                # # check if any subprop occurs across all objects
                # delete_bool = check_for_deletion(prop_obj_bool_dict)
                common_props_concept[concept_id] = get_common_properties_bool(prop_obj_bool_dict)
        common_props_block[block_id] = common_props_concept
    common_props[sample_id] = common_props_block

# TODO double check this + see if maybe if at least one false we don"t consider a merge? --> bronze and yellow get merged via gold
# now do pairwaise comparisons between remaining concepts
common_props_pairwise = {}
for sample_id in range(N_QUERIES):
# for sample_id in [1]:
    common_props_pairwise_block = {}
    for block_id in responses_description[0].keys():
#     for block_id in [7]:
        common_props_pairwise_concept1 = {}
        # get list of concept ids in this block
        concept_ids_list = list(common_props[0][block_id].keys())
        # and a copy of this
        concept_ids_2_list = list(common_props[0][block_id].keys())
        # we now perform a lower traingular pairwise comparison
        for concept_id1 in concept_ids_list:
            # remove the current concept_id1 from the second list
            concept_ids_2_list.remove(concept_id1)
            common_props_pairwise_concept2 = {}
            for concept_id2 in concept_ids_2_list:
                # gather the common props of the two concepts
                concept1 = common_props[sample_id][block_id][concept_id1]
                concept2 = common_props[sample_id][block_id][concept_id2]
                # check if they have a pairwise common subprop
                common_subprops = []
                for prop_id in concept1.keys():
                    # test if any concept shares a common subproperty with the other concept
                    common_subprops.append(np.sum((concept1[prop_id].astype(int)+concept2[prop_id].astype(int)) > 1))
                # test if any concept shares a common subproperty with the other concept
                common_props_pairwise_concept2[concept_id2] = int(np.sum(common_subprops) > 0)
            common_props_pairwise_concept1[concept_id1] = common_props_pairwise_concept2
        common_props_pairwise_block[block_id] = common_props_pairwise_concept1
    common_props_pairwise[sample_id] = common_props_pairwise_block



In [None]:
N_QUERIES = 4

merge_block_dict_samples = {}
for block_id in common_props_pairwise[0].keys():
    merge_concept_dict_samples = {}
    for concept_id in common_props_pairwise[0][block_id].keys():
        merge_concept2_dict_samples = {}
        for concept_id2 in common_props_pairwise[0][block_id][concept_id].keys():
            tmp = []
            for sample_id in range(N_QUERIES):
                tmp.append(common_props_pairwise[sample_id][block_id][concept_id][concept_id2])
            merge_concept2_dict_samples[concept_id2] = tmp
        merge_concept_dict_samples[concept_id] = merge_concept2_dict_samples
    merge_block_dict_samples[block_id] = merge_concept_dict_samples
    
merge_block_dict = {}
for block_id in merge_block_dict_samples.keys():
    merge_concept_dict = {}
    for concept_id in merge_block_dict_samples[block_id].keys():
        merge_concept2_dict = {}
        for concept_id2 in merge_block_dict_samples[block_id][concept_id].keys():
            # check if at least 2 queries produced a merge request
            merge_concept2_dict[concept_id2] = 0
            if sum(merge_block_dict_samples[block_id][concept_id][concept_id2]) >= 2:
                merge_concept2_dict[concept_id2] = 1
        merge_concept_dict[concept_id] = merge_concept2_dict
    merge_block_dict[block_id] = merge_concept_dict

In [None]:
with open(f'{args.save_path}/concepts_merge_samples_dict.pkl', 'wb') as f:
    pickle.dump(merge_block_dict_samples, f)
with open(f'{args.save_path}/concepts_merge_dict.pkl', 'wb') as f:
    pickle.dump(merge_block_dict, f)

### 4.1 Now integrate this feedback into the retrieval corpus

In [None]:
with open(f'{args.save_path}/block_concept_dicts_revise_delete.pkl', 'rb') as f:
    ret_corpus_delete = pickle.load(f)

In [None]:
def merge_id_with_id2_in_id_list(id_list, to_be_merged_id, merge_to_id):
    """
    This function takes a list of ids, the id that should be merged and one to merge to. 
    """
    # identify the ids of the concept-to-be-deleted
    rel_ids = id_list == to_be_merged_id
    # set these to -1
    id_list[rel_ids] = merge_to_id
    return None

# update ids in retrieval corpus based on feedback
# i.e. if the majority of gpt prompts has identified a potential merge
# replace the compared-to-concept-id (merge_concept_id2) with the comparing-concept-id (merge_concept_id)
ret_corpus_delete_merge = ret_corpus_delete.copy()
for block_id in merge_block_dict.keys():
    for merge_concept_id in merge_block_dict[block_id].keys():
        for merge_concept_id2 in merge_block_dict[block_id][merge_concept_id].keys():
            if merge_block_dict[block_id][merge_concept_id][merge_concept_id2]:
                merge_id_with_id2_in_id_list(ret_corpus_delete_merge[block_id]['prototypes']['ids'], 
                                             to_be_merged_id=merge_concept_id2, 
                                             merge_to_id=merge_concept_id)
                merge_id_with_id2_in_id_list(ret_corpus_delete_merge[block_id]['exemplars']['ids'], 
                                             to_be_merged_id=merge_concept_id2, 
                                             merge_to_id=merge_concept_id)
                merge_id_with_id2_in_id_list(ret_corpus_delete_merge[block_id]['sivm_basis']['ids'], 
                                             to_be_merged_id=merge_concept_id2, 
                                             merge_to_id=merge_concept_id)

In [None]:
with open(f'{args.save_path}/block_concept_dicts_revise_delete_merge.pkl', 'wb') as f:
    pickle.dump(ret_corpus_delete_merge, f)

## (5. Make final query if the concepts should be merged?)