In [1]:
import sys
import torch
import numpy as np
import argparse
import pickle
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.image import imread
from torch.utils.data import DataLoader
from datetime import datetime
from sklearn.tree import DecisionTreeClassifier
from sklearn import metrics

from retrievalbinder import NeuralConceptBinder
from retrievalbinder import SysBinderImageAutoEncoder
from data import CLEVR4_1_WithAnnotations_LeftRight
from utils_bnr import set_seed

DEVICE = 'cuda'

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument('--image_size', type=int, default=128)
parser.add_argument('--image_channels', type=int, default=3)
parser.add_argument('--data_path', default='data/*.png')
parser.add_argument('--perc_imgs', type=float, default=1.0,
                    help='The percent of images which the clf model should receive as training images '
                         '(between 0 and 1). The test set is always the full one (i.e. 1. is 100%)')
parser.add_argument('--log_path', default='../logs/')
parser.add_argument('--checkpoint_path', default='../logs/sysbind_orig_seed0/best_model.pt')
parser.add_argument('--model_type', choices=['retbind', 'sysbind', 'sysbind_hard', 'sysbind_step'],
                    help='Specify whether model type. Either original sysbinder (sysbind) or bind&retrieve (bnr).', default='retbind')
parser.add_argument('--use_dp', default=False, action='store_true')
parser.add_argument('--name', default=datetime.now().strftime('%Y-%m-%d_%H:%M:%S'),
                    help='Name to store the log file as')

# arguments for linear probing
parser.add_argument('--num_categories', type=int, default=3,
                    help='how many categories of attributes')
# parser.add_argument('--clf_label_type', default='individual', choices=['combined', 'individual'],
#                     help='Specify whether the classification labels should consist of the combined attributes or '
#                          'each attribute individually.')
parser.add_argument('--clf_type', default=None, choices=['dt', 'rg'],
                    help='Specify the linear classifier model. Either decision tree (dt) or ridge regression model '
                         '(rg)')

# Sysbinder arguments
parser.add_argument('--lr_dvae', type=float, default=3e-4)
parser.add_argument('--lr_enc', type=float, default=1e-4)
parser.add_argument('--lr_dec', type=float, default=3e-4)
parser.add_argument('--lr_warmup_steps', type=int, default=30000)
parser.add_argument('--lr_half_life', type=int, default=250000)
parser.add_argument('--clip', type=float, default=0.05)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--num_iterations', type=int, default=3)
parser.add_argument('--num_slots', type=int, default=4)
parser.add_argument('--num_blocks', type=int, default=8)
parser.add_argument('--cnn_hidden_size', type=int, default=512)
parser.add_argument('--slot_size', type=int, default=2048)
parser.add_argument('--mlp_hidden_size', type=int, default=192)
parser.add_argument('--num_prototypes', type=int, default=64)
parser.add_argument('--temp', type=float, default=1., help='softmax temperature for prototype binding')
parser.add_argument('--temp_step', default=False, action='store_true')
parser.add_argument('--vocab_size', type=int, default=4096)
parser.add_argument('--num_decoder_layers', type=int, default=8)
parser.add_argument('--num_decoder_heads', type=int, default=4)
parser.add_argument('--d_model', type=int, default=192)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--tau_start', type=float, default=1.0)
parser.add_argument('--tau_final', type=float, default=0.1)
parser.add_argument('--tau_steps', type=int, default=30000)
parser.add_argument('--lr', type=float, default=1e-2, help='Outer learning rate of model')
parser.add_argument('--binarize', default=False, action='store_true',
                    help='Should the encodings of the sysbinder be binarized?')
parser.add_argument('--attention_codes', default=False, action='store_true',
                    help='Should the sysbinder prototype attention values be used as encodings?')

# R&B arguments
parser.add_argument('--retrieval_corpus_path', default='../logs/sysbind_orig_seed0/block_concept_dicts.pkl')
parser.add_argument('--retrieval_encs', default='proto-exem',
                    choices=['proto', 'exem', 'basis', 'proto-exem', 'proto-exem-basis'])
parser.add_argument('--majority_vote', default=True, action='store_true',
                    help='If set then the retrieval binder takes the majority vote of the topk nearest encodings to '
                         'select the final cluster id label')
parser.add_argument('--topk', type=int, default=5,
                    help='if majority_vote is set to True, then retrieval binder selects the topk nearest encodings at '
                         'inference to identify the most likely cluster assignment')
parser.add_argument('--thresh_attn_obj_slots', type=float, default=0.98,
                    help='threshold value for determining the object slots from set of slots, '
                         'based on attention weight values (between 0. and 1.)(see retrievalbinder for usage).'
                         'This should be reestimated for every dataset individually if thresh_count_obj_slots is '
                         'not set to 0.')
parser.add_argument('--thresh_count_obj_slots', type=int, default=-1,
                    help='threshold value (>= -1) for determining the number of object slots from a set of slots, '
                         '-1 indicates we wish to use all slots, i.e. no preselection is made'
                         '0 indicates we just take that slot with the maximum slot attention value,'
                         '1 indicates we take the maximum count of high attn weights (based on thresh_attn_ob_slots), '
                         'otherwise those slots that contain a number of values above thresh_attn_obj_slots are chosen' 
                         '(see retrievalbinder for usage)')

parser.add_argument('--device', default=DEVICE)

args = parser.parse_args(args = [])

set_seed(1)

MODEL_SEED=1
BLOCK_ID_POS = 11
N_TRAIN_BATCHES=40

args.checkpoint_path = f'logs/clevr4_600_epochs/clevr4_sysbind_orig_seed{MODEL_SEED}/best_model.pt'
args.retrieval_corpus_path = f'logs/clevr4_600_epochs/clevr4_sysbind_orig_seed{MODEL_SEED}/block_concept_dicts.pkl'
args.thresh_count_obj_slots = 0
args.num_blocks = 16
args.num_categories = 5 
args.majority_vote = False
args.data_path = '/workspace/datasets-local/CLEVR-4-1/'



In [2]:
# we train the classifier on the original validation set and test on the original test set
train_dataset = CLEVR4_1_WithAnnotations_LeftRight(
    root=args.data_path, phase="val", img_size=args.image_size, max_num_objs=args.num_slots,
    num_categories=args.num_categories, perc_imgs=args.perc_imgs
)
test_dataset = CLEVR4_1_WithAnnotations_LeftRight(
    root=args.data_path, phase="test", img_size=args.image_size, max_num_objs=args.num_slots,
    num_categories=args.num_categories, perc_imgs=1.
)

loader_kwargs = {
    "batch_size": 20,
    "shuffle": False,
    "num_workers": args.num_workers,
    "pin_memory": True,
    "drop_last": True,
}
train_loader = DataLoader(train_dataset, **loader_kwargs)
loader_kwargs = {
    "batch_size": args.batch_size,
    "shuffle": False,
    "num_workers": args.num_workers,
    "pin_memory": True,
    "drop_last": True,
}
test_loader = DataLoader(test_dataset, **loader_kwargs)


In [3]:
def gather_block_encs_and_pos_labels(loader, retbind_model, n_batches=-1, cont=True):

    torch.set_grad_enabled(True)

    if n_batches == -1:
        n_batches = len(loader)
    
    all_labels = []
    all_codes = []
    all_imgs = []
    for i, sample in enumerate(loader):
        
        if i == n_batches:
            break

        img_locs = sample[-1]
        sample = sample[:-1]
        # imgs, _, annotations, _, class_labels, _ = map(lambda x: x.to(args.device), sample)
        imgs, _, annotations, annotations_multihot = map(lambda x: x.to(args.device), sample)

        # encode image with whatever model is being used
        if cont:
            encs = retbind_model.model.encode(imgs)
            # make sure only 1 object is selected
            assert encs[0].shape[1] == 1
            encs_blocked = encs[3][0].squeeze(dim=1)
        else:
            encs = retbind_model.encode(imgs)[0]
            assert encs.shape[1] == 1
            encs_blocked = encs.squeeze(dim=1)
            
        # get position label
        pos_label = annotations[:, :, 4].squeeze(dim=1)

        all_labels.extend(pos_label.detach().cpu().numpy())
        all_codes.extend(encs_blocked.detach().cpu().numpy())
        all_imgs.extend(imgs.detach().cpu().numpy())

    all_labels = np.array(all_labels)
    all_codes = np.array(all_codes)
    all_imgs = np.array(all_imgs)

    return all_codes, all_labels, all_imgs


def comp_pos_acc_per_dt(train_encs, train_labels, test_encs, test_labels):
    clf = DecisionTreeClassifier(random_state=0)
    
    # fit clf on training encodings and labels
    clf.fit(train_encs, train_labels)
    # apply to test encodings
    test_pred = clf.predict(test_encs)
    # compute accuracy per block on test set
    return metrics.balanced_accuracy_score(test_labels, test_pred)

# 1. Identify the relevant position encoding block

### Gather all block-wise codes

In [6]:
# retbind_model = NeuralConceptBinder(args)           # automatically loads the model internally
#                                         # if I want to have "normal" model encodings, I should use the SysBinder...
# retbind_model.to(DEVICE);
# retbind_model.eval();

# train_encs, train_labels = gather_block_encs_and_pos_labels(train_loader, retbind_model, n_batches=-1, cont=True)
# test_encs, test_labels = gather_block_encs_and_pos_labels(test_loader, retbind_model, n_batches=-1, cont=True)

Loading retrieval corpus from logs/clevr4_600_epochs/clevr4_sysbind_orig_seed1/block_concept_dicts.pkl ...
loaded ...logs/clevr4_600_epochs/clevr4_sysbind_orig_seed1/best_model.pt


### Iterate over each block and test if DT can somewhat classify the positions based on the block encodings

In [9]:
# accs_per_block = []
# for block_id in range(args.num_blocks):
#     acc = comp_pos_acc_per_dt(train_encs[:, block_id], train_labels, test_encs[:, block_id], test_labels)
#     accs_per_block.append(acc)

# for block_id in range(args.num_blocks):
#     print(f'{block_id}: {np.round(100*accs_per_block[block_id], 2)}')

# BLOCK_ID_POS = np.argmax(accs_per_block)
# print(f'\nRelevant block id for position: {BLOCK_ID_POS}')

0: 57.940000000000005
1: 75.17
2: 56.88999999999999
3: 64.71000000000001
4: 62.62
5: 61.370000000000005
6: 53.769999999999996
7: 67.89
8: 58.01
9: 59.37
10: 57.24
11: 94.75
12: 63.13999999999999
13: 61.22
14: 58.64
15: 62.29

Relevant block id for position: 11


# 2. Compute Acc of unrevised Retrievalbinder

In [4]:
args.checkpoint_path = f'logs/clevr4_600_epochs/clevr4_sysbind_orig_seed{MODEL_SEED}/best_model.pt'
args.retrieval_corpus_path = f'logs/clevr4_600_epochs/clevr4_sysbind_orig_seed{MODEL_SEED}/block_concept_dicts.pkl'

retbind_model = NeuralConceptBinder(args)           # automatically loads the model internally
                                        # if I want to have "normal" model encodings, I should use the SysBinder...

retbind_model.to(DEVICE);
retbind_model.eval();

Loading retrieval corpus from logs/clevr4_600_epochs/clevr4_sysbind_orig_seed1/block_concept_dicts.pkl ...


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


loaded ...logs/clevr4_600_epochs/clevr4_sysbind_orig_seed1/best_model.pt


In [5]:
train_encs, train_labels, train_imgs = gather_block_encs_and_pos_labels(train_loader, retbind_model, 
                                                            n_batches=-1, cont=False)
train_encs = train_encs[:, BLOCK_ID_POS]

test_encs, test_labels, _ = gather_block_encs_and_pos_labels(test_loader, retbind_model, 
                                                          n_batches=-1, cont=False)
test_encs = test_encs[:, BLOCK_ID_POS]

In [95]:
# # select one example of each cluster
# unique_ids = np.unique(train_encs, return_index=True)[1]
# train_encs = train_encs[unique_ids]
# train_labels = train_labels[unique_ids]
# train_imgs = train_imgs[unique_ids]

In [96]:
# acc = comp_pos_acc_per_dt(np.expand_dims(train_encs, axis=1), train_labels, 
#                           np.expand_dims(test_encs, axis=1), test_labels)
# print(f'Balanced Acc. of unrevised retrieval binder for left-right position classification: {np.round(100*acc, 2)}')

Balanced Acc. of unrevised retrieval binder for left-right position classification: 94.88


In [21]:
acc = np.round(100 * metrics.balanced_accuracy_score(test_labels, test_encs), 2)
print(f'Balanced Acc. of unrevised retrieval binder for left-right position classification: {acc}')

Balanced Acc. of unrevised retrieval binder for left-right position classification: 0.54




# 3. Merge concepts

In [28]:
clf = DecisionTreeClassifier(random_state=0)
    
# fit clf on training encodings and labels
clf.fit(np.expand_dims(train_encs, axis=1), train_labels)

cls_per_cluster = []
for i in range(int(np.max(train_encs))):
    cls_per_cluster.append(clf.predict([[int(i)]]))

cls_per_cluster = np.array(cls_per_cluster).squeeze(1)
cls_per_cluster.shape

ids_0 = np.where(cls_per_cluster == 0)[0]
ids_1 = np.where(cls_per_cluster == 1)[0]

In [36]:
test_encs_merge = []
for test_enc in test_encs:
    if test_enc in ids_0:
        test_encs_merge.append(0)
    elif test_enc in ids_1:
        test_encs_merge.append(1)
    else:
        p = np.random.rand()
        if p > 0.5:
            test_encs_merge.append(1)
        else:
            test_encs_merge.append(0)

In [37]:
acc = np.round(100 * metrics.balanced_accuracy_score(test_labels, test_encs_merge), 2)
print(f'Balanced Acc. of merged retrieval binder for left-right position classification: {acc}')

Balanced Acc. of merged retrieval binder for left-right position classification: 95.5


# 4. Add positional exemplars to corpus at relevant block

In [38]:
def remove_concept(block_corpus, delete_id):
    """
    Removes encodings and corresponding information entirely from the cluster identified via 'delete_id'
    """
    # i.e. 'prototypes', 'exemplars', 'sivm_basis'
    representation_keys = list(block_corpus.keys())
    representation_keys.remove('ids')

    # identify which encodings to keep and which not to keep
    del_ids = np.where(block_corpus['ids'].detach().cpu().numpy() == delete_id)[0]
    keep_ids = np.where(block_corpus['ids'].detach().cpu().numpy() != delete_id)[0]
    # number of individual clusters altogether
    n_clusters = len(np.unique(block_corpus['ids'].detach().cpu().numpy()))

    block_corpus['types'] = [ele for idx, ele in enumerate(block_corpus['types']) if idx in keep_ids]
    block_corpus['encs'] = block_corpus['encs'][keep_ids]
    # finally remove the ids themselves, i.e. keep only relevant ones
    block_corpus['ids'] = block_corpus['ids'][keep_ids]

### First delete all encodings from block corpus

In [39]:
args.checkpoint_path = f'logs/clevr4_600_epochs/clevr4_sysbind_orig_seed{MODEL_SEED}/best_model.pt'
args.retrieval_corpus_path = f'logs/clevr4_600_epochs/clevr4_sysbind_orig_seed{MODEL_SEED}/block_concept_dicts.pkl'

retbind_model_revise = NeuralConceptBinder(args)           # automatically loads the model internally
                                        # if I want to have "normal" model encodings, I should use the SysBinder...

retbind_model_revise.to(DEVICE);
retbind_model_revise.eval();

concept_ids = np.unique(retbind_model_revise.retrieval_corpus[BLOCK_ID_POS]['ids'].detach().cpu().numpy())
for delete_concept_id in concept_ids:
    remove_concept(
        retbind_model_revise.retrieval_corpus[BLOCK_ID_POS],
        delete_id=delete_concept_id,
    )

Loading retrieval corpus from logs/clevr4_600_epochs/clevr4_sysbind_orig_seed1/block_concept_dicts.pkl ...
loaded ...logs/clevr4_600_epochs/clevr4_sysbind_orig_seed1/best_model.pt


### Now extract the block_encodings of users exemplars

In [40]:
fns_right_5 = [
    'CLEVR_4_classid_0_000000.png',
    'CLEVR_4_classid_0_000003.png',
    'CLEVR_4_classid_0_000007.png',
    'CLEVR_4_classid_0_000017.png',
    'CLEVR_4_classid_0_000028.png',
]

fns_left_5 = [
    'CLEVR_4_classid_0_000004.png',
    'CLEVR_4_classid_0_000010.png',
    'CLEVR_4_classid_0_000018.png',
    'CLEVR_4_classid_0_000031.png',
    'CLEVR_4_classid_0_000041.png',
]

fns_right_20 = [
    'CLEVR_4_classid_0_000000.png',
    'CLEVR_4_classid_0_000003.png',
    'CLEVR_4_classid_0_000007.png',
    'CLEVR_4_classid_0_000017.png',
    'CLEVR_4_classid_0_000028.png',
    'CLEVR_4_classid_0_000073.png',
    'CLEVR_4_classid_0_000085.png',
    'CLEVR_4_classid_0_000088.png',
    'CLEVR_4_classid_0_000090.png',
    'CLEVR_4_classid_0_000109.png',
    'CLEVR_4_classid_0_000128.png',
    'CLEVR_4_classid_0_000021.png',
    'CLEVR_4_classid_0_000029.png',
    'CLEVR_4_classid_0_000046.png',
    'CLEVR_4_classid_0_000052.png',
    'CLEVR_4_classid_0_000056.png',
    'CLEVR_4_classid_0_000065.png',
    'CLEVR_4_classid_0_000078.png',
    'CLEVR_4_classid_0_000079.png',
    'CLEVR_4_classid_0_000083.png',
]

fns_left_20 = [
    'CLEVR_4_classid_0_000004.png',
    'CLEVR_4_classid_0_000010.png',
    'CLEVR_4_classid_0_000018.png',
    'CLEVR_4_classid_0_000031.png',
    'CLEVR_4_classid_0_000041.png',
    'CLEVR_4_classid_0_000053.png',
    'CLEVR_4_classid_0_000055.png',
    'CLEVR_4_classid_0_000064.png',
    'CLEVR_4_classid_0_000089.png',
    'CLEVR_4_classid_0_000094.png',
    'CLEVR_4_classid_0_000103.png',
    'CLEVR_4_classid_0_000041.png',
    'CLEVR_4_classid_0_000044.png',
    'CLEVR_4_classid_0_000049.png',
    'CLEVR_4_classid_0_000058.png',
    'CLEVR_4_classid_0_000059.png',
    'CLEVR_4_classid_0_000116.png',
    'CLEVR_4_classid_0_000160.png',
    'CLEVR_4_classid_0_000149.png',
    'CLEVR_4_classid_0_000240.png',
]


revision_sample_pths_right_5 = []
for sample_fn in fns_right_5:
    revision_sample_pths_right_5.append(f'{args.data_path}train/images/{sample_fn}')
    
revision_sample_pths_left_5 = []
for sample_fn in fns_left_5:
    revision_sample_pths_left_5.append(f'{args.data_path}train/images/{sample_fn}')

revision_sample_pths_right_20 = []
for sample_fn in fns_right_20:
    revision_sample_pths_right_20.append(f'{args.data_path}train/images/{sample_fn}')
    
revision_sample_pths_left_20 = []
for sample_fn in fns_left_20:
    revision_sample_pths_left_20.append(f'{args.data_path}train/images/{sample_fn}')

In [43]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class CustomImageDataset(Dataset):
    def __init__(self, img_pths, img_size):
        self.img_pths = img_pths
        self.img_size = img_size
        self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.img_pths)

    def __getitem__(self, idx):
        img_path = self.img_pths[idx]
        image = Image.open(img_path).convert("RGB")
        image = image.resize((self.img_size, self.img_size))
        image = self.transform(image)
        return image

# dataset_right_5 = CustomImageDataset(revision_sample_pths_right_5, args.image_size)
# dataset_left_5 = CustomImageDataset(revision_sample_pths_left_5, args.image_size)

dataset_right_20 = CustomImageDataset(revision_sample_pths_right_20, args.image_size)
dataset_left_20 = CustomImageDataset(revision_sample_pths_left_20, args.image_size)

loader_kwargs = {
    "batch_size": 1,
    "shuffle": False,
    "num_workers": args.num_workers,
    "pin_memory": True,
    "drop_last": True,
}

# revision_loader_right_5 = DataLoader(dataset_right_5, **loader_kwargs)
# revision_loader_left_5 = DataLoader(dataset_left_5, **loader_kwargs)

revision_loader_right_20 = DataLoader(dataset_right_20, **loader_kwargs)
revision_loader_left_20 = DataLoader(dataset_left_20, **loader_kwargs)

In [44]:
def encode_dataset_and_add_to_corpus(model, loader, rel_block_id, new_id, args):
    # iterate over revision samples and encode them
    for i, imgs in enumerate(loader):
        imgs = imgs.to(args.device)

        # encode image 
        encs = model.model.encode(imgs)
        encs_blocked = encs[3][0][:,:,rel_block_id].squeeze(dim=0)
        # add encoding and corresponding information to retrieval corpus
        model.retrieval_corpus[rel_block_id]['types'].append('exemplar')
        model.retrieval_corpus[rel_block_id]['encs'] = torch.cat(
            (model.retrieval_corpus[rel_block_id]['encs'], encs_blocked), 0
        )
        model.retrieval_corpus[rel_block_id]['ids'] = torch.cat(
            (model.retrieval_corpus[rel_block_id]['ids'], torch.tensor([new_id], device=args.device)), 0
        )

# encode_dataset_and_add_to_corpus(retbind_model_revise, revision_loader_right_5, BLOCK_ID_POS, 1, args)
# encode_dataset_and_add_to_corpus(retbind_model_revise, revision_loader_left_5, BLOCK_ID_POS, 0, args)
        
encode_dataset_and_add_to_corpus(retbind_model_revise, revision_loader_right_20, BLOCK_ID_POS, 1, args)
encode_dataset_and_add_to_corpus(retbind_model_revise, revision_loader_left_20, BLOCK_ID_POS, 0, args)
# retbind_model_revise.retrieval_corpus[BLOCK_ID_POS];

# 4. Compute Acc of revised Retrievalbinder

In [45]:
test_encs_revise, test_labels_revise, _ = gather_block_encs_and_pos_labels(test_loader, retbind_model_revise, 
                                                                        n_batches=-1, cont=False)
test_encs_revise = test_encs_revise[:, BLOCK_ID_POS]

In [106]:
# # for each sample used ro ret_binder compute new encoding
# train_encs_revise = []
# for img in train_imgs:
#     encs = retbind_model_revise.encode(torch.tensor(img).unsqueeze(dim=0))[0]
#     train_encs_revise.append(encs.squeeze(dim=1).detach().cpu().numpy())

# train_encs_revise = np.array(train_encs_revise).squeeze(axis=1)
# train_labels_revise = train_labels

In [47]:
# acc_revise = comp_pos_acc_per_dt(np.expand_dims(train_encs_revise[:, BLOCK_ID_POS], axis=1), train_labels_revise, 
#                           np.expand_dims(test_encs_revise[:, BLOCK_ID_POS], axis=1), test_labels_revise)
# print(f'Balanced Acc. of revised retrieval binder for left-right position classification: {np.round(100*acc_revise, 2)}')

### 20 examples:

In [46]:
acc = np.round(100 * metrics.balanced_accuracy_score(test_labels_revise, test_encs_revise), 2)
print(f'Balanced Acc. of revised retrieval binder for left-right position classification: {acc}')

Balanced Acc. of revised retrieval binder for left-right position classification: 93.97


### 5 examples:

In [12]:
acc = np.round(100 * metrics.balanced_accuracy_score(test_labels_revise, test_encs_revise), 2)
print(f'Balanced Acc. of revised retrieval binder for left-right position classification: {acc}')

Balanced Acc. of revised retrieval binder for left-right position classification: 88.71
