In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import time
import pickle
import matplotlib.cm as cm
import seaborn as sns
from torchvision import transforms

from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier

import sys
sys.path.append("../ProtoLearning/")
from models.icsn import iCSN

In [2]:
def load_pretrained(model, ckpt):
    model.load_state_dict(ckpt['model'])
    model.proto_dict = ckpt['model_misc']['prototypes']
    model.softmax_temp = ckpt['model_misc']['softmax_temp']
    return model


def plot_single_img(model, imgs, idx):
    model.eval()
    preds, recons = model.forward_single(imgs)

    recons = recons.permute(0, 2, 3, 1).detach().cpu()
    imgs = imgs.permute(0, 2, 3, 1).detach().cpu()
    preds = preds.unsqueeze(dim=0).detach().cpu()
    preds = preds[0, idx]
    preds_as_ids = [torch.argmax(preds[model.attr_positions[i]:model.attr_positions[i+1]]).numpy() for i in range(model.n_groups)]

    # convert to RGB numpy array
    recons_np = recons[idx].squeeze().numpy()
    # convert -1 1 range to 0 255 range for plotting
    recons_np = ((recons_np - recons_np.min())
              * (1 / (recons_np.max() - recons_np.min()) * 255)).astype('uint8')
    fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(15, 10))
    ax[0].imshow(imgs[idx])
    ax[0].axis('off')
    ax[0].set_title("Input Img");
    ax[1].imshow(preds.unsqueeze(dim=0), cmap=cm.gray)
    ax[1].axes.yaxis.set_visible(False)
    ax[1].set_title("I believe it has these properties");
#     ax[1].set_xlabel("Attributes");
    ax[2].imshow(recons_np)
    ax[2].axis('off')
    ax[2].set_title("Because it is close to these composed prototypes:");
    plt.show()
    
    return fig, preds_as_ids


### Create a dictionary to convert a label list of individual groups to a single id

In [3]:
convert_multilabel_to_label_id = {}
id = 0
for i in range(4):
    for j in range(4):
        for k in range(2):
            for l in range(2):
                convert_multilabel_to_label_id[f'{i}{j}{k}{l}'] = id
                id+=1

### Get training and validation set containing old and new objects for linear probing

In [5]:
# train data set
train_probing_data_path = f"../Data/ECR/train_probing/train_probing_ecr_spot.npy"
train_probing_labels_path = f"../Data/ECR/train_probing/train_probing_ecr_spot_labels.pkl"

train_probing_imgs = np.load(train_probing_data_path, allow_pickle=True)
train_probing_imgs = (train_probing_imgs - train_probing_imgs.min()) / (train_probing_imgs.max() - train_probing_imgs.min())

with open(train_probing_labels_path, 'rb') as f:
    labels_dict = pickle.load(f)
    train_probing_labels = labels_dict['labels']
  
train_probing_imgs = torch.tensor(np.moveaxis(train_probing_imgs, (0, 1, 2, 3), (0, 2, 3, 1)))
train_probing_imgs = train_probing_imgs.type('torch.FloatTensor')
train_probing_labels = torch.tensor(train_probing_labels)    

# convert multi label to single label
train_probing_labels = train_probing_labels.int()
# single_train_probing_labels = torch.tensor([convert_multilabel_to_label_id[f'{train_probing_labels[i][0].item()}{train_probing_labels[i][1].item()}{train_probing_labels[i][2].item()}'] for i in range(train_probing_labels.shape[0])])

train_probing_dataset = torch.utils.data.TensorDataset(train_probing_imgs, train_probing_labels)
train_probing_dataloader = torch.utils.data.DataLoader(train_probing_dataset, batch_size=len(train_probing_dataset),
                                                  shuffle=True)

# val data set
val_data_path = f"../Data/ECR/val_ecr_spot.npy"
val_labels_path = f"../Data/ECR/val_ecr_spot_labels.pkl"

val_imgs = np.load(val_data_path, allow_pickle=True)
val_imgs = (val_imgs - val_imgs.min()) / (val_imgs.max() - val_imgs.min())

with open(val_labels_path, 'rb') as f:
    labels_dict = pickle.load(f)
    val_labels = labels_dict['labels']
  
val_imgs = torch.tensor(np.moveaxis(val_imgs, (0, 1, 2, 3), (0, 2, 3, 1)))
val_imgs = val_imgs.type('torch.FloatTensor')
val_labels = torch.tensor(val_labels)    

# convert multi label to single label
val_labels = val_labels.int()
# single_val_labels = torch.tensor([convert_multilabel_to_label_id[f'{val_labels[i][0].item()}{val_labels[i][1].item()}{val_labels[i][2].item()}'] for i in range(val_labels.shape[0])])

val_dataset = torch.utils.data.TensorDataset(val_imgs, val_labels)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=len(val_dataset),
                                                  shuffle=True)

# get data from loader
train_probing_imgs, train_probing_label_ids = next(iter(train_probing_dataloader))
val_imgs, val_label_ids = next(iter(val_dataloader))

# convert multilabel to single label
train_probing_single_labels = torch.tensor([convert_multilabel_to_label_id[
    f'{train_probing_label_ids[i][0].item()}'+
    f'{train_probing_label_ids[i][1].item()}'+
    f'{train_probing_label_ids[i][2].item()}'+
    f'{train_probing_label_ids[i][3].item()}'
] for i in range(train_probing_label_ids.shape[0])])
val_single_labels = torch.tensor([convert_multilabel_to_label_id[
    f'{val_label_ids[i][0].item()}'+
    f'{val_label_ids[i][1].item()}'+
    f'{val_label_ids[i][2].item()}'+
    f'{val_label_ids[i][3].item()}'
] for i in range(val_label_ids.shape[0])])

### Variance calculation for Proto-Swap-AE

In [6]:
def compute_code_and_variance(model, imgs, gt_labels_ids, config):
    collect_codes = {}

    try:
        codes, _ = model.forward_single(imgs)
    except:
        _, codes = model.forward(imgs)
        codes = codes.squeeze(dim=2)

    # for every GT attribute collect the codes from each model and data set
    for cat_id in range(0, 3):
        
        for concept_id in range(0, 4):
        
            # find those samples that have the same concept in the specified category identifier
            rel_ids = torch.where(gt_labels_ids.T[cat_id] == concept_id)[0]
            
            # filter out the model code for those samples in the specified category identifier
            try:
                rel_codes = codes[rel_ids, 
                                   config['prototype_cumsum'][cat_id]:config['prototype_cumsum'][cat_id+1]]
            except:
                rel_codes = codes[rel_ids, cat_id]
            
            # store the individual codes
            collect_codes[f"{str(cat_id)}-{str(concept_id)}"] = rel_codes.detach().cpu()

    collect_variance = {}

    sum_variance = 0.

    # iterate over each GT attribute and compute the code variance at the relevant factor id
    for cat_id in range(0, 3):
    
        for concept_id in range(0, 4):
            
            var = torch.sum(torch.var(collect_codes[f"{str(cat_id)}-{str(concept_id)}"], dim=0))

            if not torch.isnan(var):
            
                collect_variance[f"{str(cat_id)}-{str(concept_id)}"] = var
                
                sum_variance += var
                 
    sum_variance /= len(collect_variance.keys())
    
    print(sum_variance)

    return {'codes': codes, 'collect_codes': collect_codes, 
                                  'variances': collect_variance, 'avg_variance': sum_variance}

# load prosa models trained without novel category
icsn_prior_results_all = {}
for model_id in [0, 1, 3, 13, 21]:    
    ckpt_fp = f"../ProtoLearning/runs/icsn-rr-{model_id}-ecr-extramlp/states/00999.pth"
    print(f"Loading model {ckpt_fp}")
    ckpt = torch.load(ckpt_fp, map_location=torch.device('cpu'))
    config = ckpt['config']
    config['device'] = 'cpu'
    config['data_dir'] = '../Data/ECR/'

    icsn_prior_model = iCSN(num_hiddens=64, num_residual_layers=2, num_residual_hiddens=64,
                    n_proto_vecs=config['prototype_vectors'], enc_size=config['enc_size'],
                    proto_dim=config['proto_dim'], softmax_temp=config['temperature'],
                    extra_mlp_dim=config['extra_mlp_dim'],
                    multiheads=config['multiheads'], train_protos=config['train_protos'],
                    device=config['device'])

    icsn_prior_model = icsn_prior_model.to(config['device'])
    icsn_prior_model = load_pretrained(icsn_prior_model, ckpt)
    icsn_prior_model.temperature = 0.000001
    icsn_prior_model.eval();
        
    train_result_dict = compute_code_and_variance(model=icsn_prior_model, imgs=train_probing_imgs, 
                                                  gt_labels_ids=train_probing_label_ids, config=config)
    val_result_dict = compute_code_and_variance(model=icsn_prior_model, imgs=val_imgs, 
                                                  gt_labels_ids=val_label_ids, config=config)

    icsn_prior_results_all[model_id] = {'train': train_result_dict, 'val': val_result_dict}

# load prosa models trained with novel category
icsn_post_results_all = {}
for model_id in [0, 1, 3, 13, 21]:    
    ckpt_fp = f"../ProtoLearning/runs/icsn-{model_id}-ecr-6666/states/01999.pth"
    print(f"Loading model {ckpt_fp}")
    ckpt = torch.load(ckpt_fp, map_location=torch.device('cpu'))
    config = ckpt['config']
    config['device'] = 'cpu'
    config['data_dir'] = '../Data/ECR/'

    icsn_post_model = iCSN(num_hiddens=64, num_residual_layers=2, num_residual_hiddens=64,
                    n_proto_vecs=config['prototype_vectors'], enc_size=config['enc_size'],
                    proto_dim=config['proto_dim'], softmax_temp=config['temperature'],
                    extra_mlp_dim=config['extra_mlp_dim'],
                    multiheads=config['multiheads'], train_protos=config['train_protos'],
                    device=config['device'])

    icsn_post_model = icsn_post_model.to(config['device'])
    icsn_post_model = load_pretrained(icsn_post_model, ckpt)
    icsn_post_model.temperature = 0.000001
    icsn_post_model.eval();
        
    train_result_dict = compute_code_and_variance(model=icsn_post_model, imgs=train_probing_imgs, 
                                                  gt_labels_ids=train_probing_label_ids, config=config)
    val_result_dict = compute_code_and_variance(model=icsn_post_model, imgs=val_imgs, 
                                                  gt_labels_ids=val_label_ids, config=config)

    icsn_post_results_all[model_id] = {'train': train_result_dict, 'val': val_result_dict}


Loading model ../WeakAEProtoLearning/runs/ae-swap-rr-0-simpleshapescolorvarshapesizenospotpairsmult-trainprotos-nopretrain-666-extramlp/states/00999.pth
tensor(0.0527)
tensor(0.0470)
Loading model ../WeakAEProtoLearning/runs/ae-swap-rr-1-simpleshapescolorvarshapesizenospotpairsmult-trainprotos-nopretrain-666-extramlp/states/00999.pth
tensor(0.0008)
tensor(0.)
Loading model ../WeakAEProtoLearning/runs/ae-swap-rr-3-simpleshapescolorvarshapesizenospotpairsmult-trainprotos-nopretrain-666-extramlp/states/00999.pth
tensor(0.0251)
tensor(0.0215)
Loading model ../WeakAEProtoLearning/runs/ae-swap-rr-13-simpleshapescolorvarshapesizenospotpairsmult-trainprotos-nopretrain-666-extramlp/states/00999.pth
tensor(0.)
tensor(0.0004)
Loading model ../WeakAEProtoLearning/runs/ae-swap-rr-21-simpleshapescolorvarshapesizenospotpairsmult-trainprotos-nopretrain-666-extramlp/states/00999.pth
tensor(0.0155)
tensor(0.0138)
Loading model ../WeakAEProtoLearning/runs/ae-swap-0-simpleshapescolorvarshapesizespotpairsm

In [9]:
def fit_and_predict_lr_dt(results_dict, gt_single_labels_val, model_seed_id, verbose=0):

    codes_train = results_dict[model_seed_id]['train']['codes']
    codes_val = results_dict[model_seed_id]['val']['codes']

    # decision tree
    clf_icsn_dt = DecisionTreeClassifier(random_state=21, max_depth=8)
    clf_icsn_dt.fit(codes_train.detach().cpu().numpy(), train_probing_single_labels.numpy())

    # Perform logistic regression
    clf_icsn_lr = LogisticRegression(random_state=0, C=0.316, max_iter=1000)
    clf_icsn_lr.fit(codes_train.detach().cpu().numpy(), train_probing_single_labels.numpy())

    # Evaluate using the classifiers
    predictions_dt = clf_icsn_dt.predict(codes_val.detach().cpu().numpy())
    predictions_lr = clf_icsn_lr.predict(codes_val.detach().cpu().numpy())

    accuracy_dt = np.mean((gt_single_labels_val.numpy() == predictions_dt).astype(np.float)) * 100.
    accuracy_lr = np.mean((gt_single_labels_val.numpy() == predictions_lr).astype(np.float)) * 100.

    if verbose > 0:
        print(f"\nSeed {model_seed_id} Val accuracy DT codes = {accuracy_dt:.3f}")
        print(f"Seed {model_seed_id} Val accuracy LR codes = {accuracy_lr:.3f}")
    return accuracy_dt, accuracy_lr

print('------------------------------------------------------')
print('iCSN with prior novel category')
acc_icsn_prior_dt = []
acc_icsn_prior_lr = []
code_vars_icsn_prior_val = []
for seed_id in [0, 1, 3, 13, 21]:
    
    accuracy_dt, accuracy_lr = fit_and_predict_lr_dt(results_dict=icsn_prior_results_all, 
                                                     gt_single_labels_val=val_single_labels, model_seed_id=seed_id)
    acc_icsn_prior_dt.append(accuracy_dt)
    acc_icsn_prior_lr.append(accuracy_lr)
    
    code_vars_icsn_prior_val.append(icsn_prior_results_all[seed_id]['val']['avg_variance'])

    
print('------------------------------------------------------')
print('iCSN with post novel category')
acc_icsn_post_dt = []
acc_icsn_post_lr = []
code_vars_icsn_post_val = []
for seed_id in [0, 1, 3, 13, 21]:
    
    accuracy_dt, accuracy_lr = fit_and_predict_lr_dt(results_dict=icsn_post_results_all, 
                                                     gt_single_labels_val=val_single_labels, model_seed_id=seed_id)
    acc_icsn_post_dt.append(accuracy_dt)
    acc_icsn_post_lr.append(accuracy_lr)
    
    code_vars_icsn_post_val.append(icsn_post_results_all[seed_id]['val']['avg_variance'])


print('------------------------------------------------------')
print('------------------------Mean--------------------------')
print('------------------------------------------------------')
print(f"Mean acc. Prosa prior\nDT: {np.mean(acc_icsn_prior_dt)} {np.std(acc_icsn_prior_dt)}"+
      f"\nLR:  {np.mean(acc_icsn_prior_lr)} {np.std(acc_icsn_prior_lr)}\n")
print(f"Mean acc. Prosa post\nDT: {np.mean(acc_icsn_post_dt)} {np.std(acc_icsn_post_dt)}"+
      f"\nLR:  {np.mean(acc_icsn_post_lr)} {np.std(acc_icsn_post_lr)}\n")

print('------------------------------------------------------')
print('----------------------Median--------------------------')
print('------------------------------------------------------')
print(f"Median acc. Prosa prior\nDT: {np.median(acc_icsn_prior_dt)} {np.std(acc_icsn_prior_dt)}"+
      f"\nLR:  {np.median(acc_icsn_prior_lr)} {np.std(acc_icsn_prior_lr)}\n")
print(f"Median acc. Prosa post\nDT: {np.median(acc_icsn_post_dt)} {np.std(acc_icsn_post_dt)}"+
      f"\nLR:  {np.median(acc_icsn_post_lr)} {np.std(acc_icsn_post_lr)}\n")


------------------------------------------------------
ProSA with prior novel category
------------------------------------------------------
ProSA with post novel category
------------------------------------------------------
------------------------Mean--------------------------
------------------------------------------------------
Mean acc. Prosa prior
DT: 93.1 4.459147900664428
LR:  67.54 9.071901674952175

Mean acc. Prosa post
DT: 99.85 0.3
LR:  98.28999999999999 3.4199999999999986

------------------------------------------------------
----------------------Median--------------------------
------------------------------------------------------
Median acc. Prosa prior
DT: 95.45 4.459147900664428
LR:  65.7 9.071901674952175

Median acc. Prosa post
DT: 100.0 0.3
LR:  100.0 3.4199999999999986

