In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
def set_cuda_device(gpu_num: int):
    import os
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_num)
set_cuda_device(1)

In [None]:
from pathlib import Path
root_dir = Path("../")
this_dir = root_dir / "Experiments"
import sys
sys.path.insert(0, str(root_dir.absolute()))

In [None]:
import constants
from src.abstract.abs_nn_theta import NNthHelper
from src.nn_theta import ResNETNNthHepler
import src.main_helper as main_helper
from src.abstract.abs_data import DataHelper
import torch
import pandas as pd
from collections import defaultdict
import utils.common_utils as cu
#from plot_utils import plot_beta
from copy import deepcopy
import matplotlib.pyplot as plt

# Dataset


In [None]:
dataset_name = constants.SHAPENET_NOISE_SMALL
dh = main_helper.get_data_helper(dataset_name=dataset_name, logger=None)

# NNth


In [None]:
nnth_args = {
    constants.BATCH_SIZE: 32,
}
nnth_type = constants.RESNET
nnth_name = "final-baseline-noise-small"
nnth_epochs = 1
fit_th = False
data_subset = constants.FULL_DATA
logger = None
nnth_mh = main_helper.fit_theta(nn_theta_type=nnth_type, models_defname=nnth_name,
                                    dh = dh, nnth_epochs=nnth_epochs,
                                    fit=fit_th, data_subset=data_subset, logger=logger, **nnth_args)

In [None]:
baseline_nnth = main_helper.fit_theta(nn_theta_type=nnth_type, models_defname=nnth_name,
                                    dh = dh, nnth_epochs=nnth_epochs,
                                    fit=fit_th, data_subset=data_subset, logger=logger, **nnth_args)

In [None]:
th100 = main_helper.fit_theta(nn_theta_type=nnth_type, models_defname="greedy-final-greedy-noise-small",
                                    dh = dh, nnth_epochs=nnth_epochs,
                                    fit=fit_th, data_subset=data_subset, logger=logger, **nnth_args)

## Greedy

In [None]:
greedy_name = "final-greedy-noise-small"
greedy_r = main_helper.fit_greedy(dataset_name=dataset_name, nnth=nnth_mh, load_th=True, dh=dh, budget=1000, 
                                      num_badex=100, R_per_iter=10, models_defname=greedy_name,
                                      fit=False, init_theta=False, logger=logger)

In [None]:
test_predlabels_baseline = baseline_nnth.predict_labels(loader=baseline_nnth._tst_loader)
test_predlabels_th100 = th100.predict_labels(loader=baseline_nnth._tst_loader)
tst_pred_losses_th100 = th100.get_loaderlosses_perex(loader=th100._tst_loader)

In [None]:
th100.beta_accuracy()

## NNPhi

In [None]:
phi_name = None
nnphi = main_helper.fit_nnphi(dataset_name=dataset_name, dh=dh, epochs=10, greedy_rec=greedy_r, models_defname=phi_name,
                              fit=False, logger=None)

## NNPsi

In [None]:
psi_name = None
nnpsi = main_helper.fit_nnpsi(dataset_name=dataset_name, dh=dh, psi_tgts=constants.R_WRONG, nn_arch=[32, 8],
                                  synR=greedy_r, epochs=10, models_defname=psi_name,
                                  fit=False, logger=None)

## Ourm Improved

In [None]:
ourm_args_imporved = {
    constants.PRETRN_THPSIPSI: {
                        constants.THETA: True,
                        constants.PHI: False,
                        constants.PSI: False
            }
}

ourm_hlpr_improved = main_helper.get_ourm_hlpr(ourm_type=constants.SEQUENTIAL, dh=dh, nnth=nnth_mh,
                                        nnphi=nnphi, nnpsi=nnpsi, greedy_r=greedy_r, filter_grps=2500, logger=None, **ourm_args_imporved)
ourm_hlpr_improved.load_model_defname(suffix="--small-notheta-epoch-40")

test_predlabels_thphi = ourm_hlpr_improved._nnth.predict_labels(loader=baseline_nnth._tst_loader)


tst_predbetas_improved = ourm_hlpr_improved._nnphi.predict_beta(loader=ourm_hlpr_improved._nnphi._tst_loader)
tst_betas = ourm_hlpr_improved._dh._test._Beta
beta_to_idx = {}
tst_labels = dh._test._y
for i, b in enumerate(tst_betas[0:9]):
    beta_to_idx[str(b.tolist())] = i
torch.unique(tst_predbetas_improved, dim=0, return_counts=True)

## Ourm Vanilla

In [None]:
ourm_args_vannila = {
    constants.PRETRN_THPSIPSI: {
                        constants.THETA: True,
                        constants.PHI: False,
                        constants.PSI: False
            }
}

ourm_hlpr_vannila= main_helper.get_ourm_hlpr(ourm_type=constants.SEQUENTIAL, dh=dh, nnth=nnth_mh,
                                        nnphi=nnphi, nnpsi=nnpsi, greedy_r=greedy_r, filter_grps=2500, logger=None, **ourm_args_vannila)
ourm_hlpr_vannila.load_model_defname(suffix="--th_phi-vanilla-shapenet-noise-small")
tst_predbetas_vannila = ourm_hlpr_vannila._nnphi.predict_beta(loader=ourm_hlpr_vannila._nnphi._tst_loader)


## Score based Triage

In [None]:
pred_probs_score = baseline_nnth.predict_proba(loader=dh._test.get_theta_loader(batch_size=128, shuffle=False))
test_max_predprob, _ = torch.max(pred_probs_score,dim=1)
ind_order_score_baseline = torch.argsort(test_max_predprob)
ind_order_score = ind_order_score_baseline

## Full Automation Triage

In [None]:
from torchvision import models
device = torch.device("cuda:0")

model_ft = models.resnet18(pretrained=True)
model_ft = torch.load("../baselines/models/models/final-baseline-small-noise-losses.pt")

model_ft = model_ft.to(device)
model_ft.eval()

losses = []
with torch.no_grad():
    for  _,inputs, _ in nnth_mh._tst_loader:
        inputs = inputs.to(device)
        loss_batch = model_ft(inputs)
        losses.append(loss_batch.cpu())

losses = torch.cat(losses).view(-1)
ind_order_full = torch.argsort(-losses)
        

## Gain - no training

In [None]:
pred_prob = th100.predict_proba(loader=th100._tst_loader)
pred_prob_max, pred_y = torch.max(pred_prob, dim=1)
prior_ybeta_prob = th100.get_conf_ybeta_prior(loader=th100._dh._train_test.get_theta_loader(batch_size=128, shuffle=False))\

In [None]:
gains = torch.zeros(dh._test._num_data)
for idx, (y, beta) in enumerate(zip(pred_y, tst_predbetas_improved)):
    gains[idx] = prior_ybeta_prob[y.item()][str(beta.tolist())] - pred_prob_max[idx]
gains_no_training_improved = torch.argsort(-gains)

gains = torch.zeros(dh._test._num_data)
for idx, (y, beta) in enumerate(zip(pred_y, tst_predbetas_vannila)):
    gains[idx] = prior_ybeta_prob[y.item()][str(beta.tolist())] - pred_prob_max[idx]
gains_no_training_vannila = torch.argsort(-gains)

gains = torch.zeros(dh._test._num_data)
for idx, y in enumerate(pred_y):
    gains[idx] = prior_ybeta_prob[y.item()][str([3,1,0])] - pred_prob_max[idx]
gains_no_training_const = torch.argsort(-gains)



In [None]:
def pred_after_recourse(ind_order, frac, tst_predlabels, pred_betas):
    num_samples = int(tst_labels.shape[0]*frac)
    rec_crcts = 0
    counts = 0
    for i in ind_order.tolist()[:num_samples]:
        beta_pred = pred_betas[i]
        try:
            idx = int(i/9)*9 + beta_to_idx[str(beta_pred.tolist())]
            label_pred = tst_predlabels[idx]
        except:
            label_pred = tst_predlabels[i]
        if label_pred == tst_labels[i]:
            rec_crcts+=1
        counts += 1

    ind_no_rec = ind_order.tolist()[num_samples:]
    rec_crcts += torch.sum(tst_labels[ind_no_rec] == tst_predlabels[ind_no_rec])
    counts += len(ind_no_rec)
    #print(counts)
    acc = rec_crcts/counts
    #print(f"Accuracy with {frac} fraction of recourse is {acc}")
    return acc

In [None]:
def pred_after_constant_recourse(ind_order, frac, tst_predlabels, beta_pred):
    num_samples = int(tst_labels.shape[0]*frac)
    rec_crcts = 0
    counts = 0
    for i in ind_order.tolist()[:num_samples]:
        #beta_pred = tst_predbetas[i]
        # try:
        idx = int(i/9)*9 + beta_to_idx[str(beta_pred)]
        label_pred = tst_predlabels[idx]
        if label_pred == tst_labels[i]:
            rec_crcts+=1
        counts += 1
        # except:
        #     pass

    ind_no_rec = ind_order.tolist()[num_samples:]
    rec_crcts += torch.sum(tst_labels[ind_no_rec] == tst_predlabels[ind_no_rec])
    counts += len(ind_no_rec)
    #print(counts)
    acc = rec_crcts/counts
    return acc

## Phi Theirs


In [None]:
from  torchvision import models as tv_models
import torch.nn as nn
class ResNET(nn.Module):
    def __init__(self, out_dim, *args, **kwargs):
        super().__init__()
        self.out_dim = out_dim

        self.resnet_features =  tv_models.resnet18(pretrained=True)
        self.emb_dim = self.resnet_features.fc.in_features
        self.resnet_features.fc = nn.Identity()

        self.fc1 = nn.Linear(self.emb_dim, self.out_dim[0])
        self.fc2 = nn.Linear(self.emb_dim, self.out_dim[1])
        self.fc3 = nn.Linear(self.emb_dim, self.out_dim[2])

        self.sm = nn.Softmax(dim=1)

    def forward_proba(self, input):
        out1,out2,out3 = self.forward(input)
        return self.sm(out1),self.sm(out2),self.sm(out3)
    
    def forward(self, input):
        out1 = self.resnet_features(input)
        out2 = self.resnet_features(input)
        out3 = self.resnet_features(input)
        #print(out1.shape)
        return self.fc1(out1),self.fc2(out2),self.fc3(out3)
        
    
    def forward_labels(self, input):
        probs1,probs2,probs3 = self.forward_proba(input)
        probs1, labels1 = torch.max(probs1, dim=1)
        probs2, labels2 = torch.max(probs2, dim=1)
        probs3, labels3 = torch.max(probs3, dim=1)
        return labels1,labels2,labels3



from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch
model_ft = ResNET(out_dim=[6,6,6])
model_ft = torch.load("../baselines/models/theirs_phi-small.pt", map_location="cuda:0")

model_ft.eval()
losses = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

pred_betas = []
unq_beta = dh._test._unq_beta
with torch.no_grad():
    for _, inputs, _ in nnth_mh._tst_loader:
        inputs = inputs.to(device)      
        pred_beta_probs = model_ft.forward_proba(inputs)
        pred_beta_probs = [entry.cpu()  for entry in pred_beta_probs]

        pred_beta = []
        for idx in range(len(inputs)):
            max_prob = 0
            sel_beta = None
            for beta_entry in unq_beta:
                beta_etry_probs = torch.Tensor([pred_beta_probs[entry][idx][beta_entry[entry]] for entry in range(len(beta_entry))])            
                beta_entry_prob = torch.prod(beta_etry_probs)
                if beta_entry_prob > max_prob:
                    sel_beta = beta_entry
                    max_prob = beta_entry_prob
            assert sel_beta is not None, "Why is sel beta none? We should have atleast one positive prob beta"
            pred_beta.append(sel_beta)
        pred_betas.append(torch.stack(pred_beta))
tst_predbetas_their = torch.cat(pred_betas)


#tst_predbetas = beta_pred
tst_betas = dh._test._Beta
beta_to_idx = {}
tst_labels = dh._test._y

for i, b in enumerate(tst_betas[0:9]):
    beta_to_idx[str(b.tolist())] = i

In [None]:
gains = torch.zeros(dh._test._num_data)
for idx, (y, beta) in enumerate(zip(pred_y, tst_predbetas_their)):
    gains[idx] = prior_ybeta_prob[y.item()][str(beta.tolist())] - pred_prob_max[idx]
gains_no_training_their = torch.argsort(-gains)

# Final plottintg code

In [None]:
frac_list = [0.02*i for i in range(51)]
acc_list_score = [pred_after_recourse(ind_order_score, i, test_predlabels_baseline, tst_predbetas_improved) for i in frac_list]
acc_list_full = [pred_after_recourse(ind_order_full, i, test_predlabels_baseline, tst_predbetas_improved) for i in frac_list]
acc_gains_prior_th100_improved =  [pred_after_recourse(gains_no_training_improved, i, test_predlabels_th100, tst_predbetas_improved) for i in frac_list]

# acc_gains_prior_th100 =  [pred_after_recourse(gains_no_training, i, test_predlabels_th100, tst_predbetas) for i in frac_list]
acc_gains_prior_th100_improved =  [pred_after_recourse(gains_no_training_improved, i, test_predlabels_th100, tst_predbetas_improved) for i in frac_list]
acc_list_gain_prior_their_th100 = [pred_after_recourse(gains_no_training_their, i, test_predlabels_th100, tst_predbetas_their) for i in frac_list]
acc_list_gain_prior_vannila_th100 = [pred_after_recourse(gains_no_training_vannila, i, test_predlabels_th100,tst_predbetas_vannila) for i in frac_list]
acc_list_gains_prior_const_pred = [pred_after_constant_recourse(gains_no_training_const, i, test_predlabels_th100,[3,1,0]) for i in frac_list]


## Triage

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl


import matplotlib
import matplotlib.font_manager
font = {'family' : 'normal',
        'size'   : 25}

matplotlib.rc('font', **font)
matplotlib.rc("text", usetex=False)

In [None]:
def plot_line(data_x, data_y_all, labels_all, colors_all, markers_all, title, save_name):
    plt.clf()
    msize = 8
    
    for data_y, labels, color, marker in zip(data_y_all, labels_all, colors_all, markers_all): 
    
        plt.plot(data_x, data_y, scaley=True, color=color,
                 label=labels, marker=marker,# markersize=msize,
                 linestyle="-",)

    plt.xlabel("Fraction of Recourse ($b$)")
    plt.ylabel("Recourse Accuracy")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
#     plt.legend(frameon=False, loc='lower right', prop={"size":15})
    plt.grid(True, alpha=0.5, linewidth=1, color="gray", linestyle=":")
    plt.title(title)
    plt.savefig(f"./{save_name}.png", dpi=300, bbox_inches = "tight")
    plt.show()

In [None]:
data_x = frac_list
data_y_all = [acc_list_score, acc_list_full, acc_gains_prior_th100_improved]
labels_all = ["Score based Triage", "Full automation Triage", "Gains Triage"]
colors_all = ["b", "g", "r"]
markers_all = [".", ".", "."]
title = "Performance of on Shapenet-Small"
save_name = "triage_small"

plot_line(data_x, data_y_all, labels_all, colors_all, markers_all, title, save_name)

In [None]:
print([entry.item() for entry in acc_list_score])

In [None]:
print([entry.item() for entry in acc_list_full])

In [None]:
print([entry.item() for entry in acc_gains_prior_th100_improved])

## Phi

In [None]:
data_x = frac_list
data_y_all = [acc_gains_prior_th100_improved, acc_list_gain_prior_vannila_th100, acc_list_gain_prior_their_th100, acc_list_gains_prior_const_pred]
labels_all = ["Joint Prior",  "Joint", "Only $\phi$", "Constant prediction"]
colors_all = ["r", "g", "b", "black"]
markers_all = [".", ".", ".", "."]
title = "Performance of on Shapenet-Small"
save_name = "phi_small"

plot_line(data_x, data_y_all, labels_all, colors_all, markers_all, title, save_name)

In [None]:
print([entry.item() for entry in acc_gains_prior_th100_improved])

In [None]:
print([entry.item() for entry in acc_list_gain_prior_vannila_th100])

In [None]:
print([entry.item() for entry in acc_list_gain_prior_their_th100])

In [None]:
print([entry.item() for entry in acc_list_gains_prior_const_pred])

## Compute class dependent accuracies

In [None]:
def rec_hist(tst_predlabels, pred_betas):
    corrects = torch.zeros(dh._test._num_classes)
    counts = torch.zeros(dh._test._num_classes)
    for i in range(len(tst_predlabels)):
        beta_pred = pred_betas[i]
        try:
            idx = int(i/9)*9 + beta_to_idx[str(beta_pred.tolist())]
            label_pred = tst_predlabels[idx]
        except:
            label_pred = tst_predlabels[i]
        if label_pred == tst_labels[i]:
            corrects[tst_labels[i]] = corrects[tst_labels[i]] + 1
        counts[tst_labels[i]] = counts[tst_labels[i]] + 1

    for i in range(dh._test._num_classes):
        corrects[i] = corrects[i] / counts[i]
    return corrects

In [None]:
tst_predlabels = test_predlabels_th100
pred_betas = tst_predbetas_improved

corrects = torch.zeros(dh._test._num_classes)
counts = torch.zeros(dh._test._num_classes)
for i in range(len(tst_predlabels)):
    beta_pred = pred_betas[i]
    try:
        idx = int(i/9)*9 + beta_to_idx[str(beta_pred.tolist())]
        label_pred = tst_predlabels[idx]
    except:
        label_pred = tst_predlabels[i]
    if label_pred == tst_labels[i]:
        corrects[tst_labels[i]] = corrects[tst_labels[i]] + 1
    counts[tst_labels[i]] = counts[tst_labels[i]] + 1

for i in range(dh._test._num_classes):
    corrects[i] = corrects[i] / counts[i]

In [None]:
torch.mean(corrects)

## Histogram Plot

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

joint_improved = rec_hist(tst_predlabels=test_predlabels_th100, pred_betas = tst_predbetas_improved).tolist()
joint = rec_hist(tst_predlabels=test_predlabels_th100, pred_betas = tst_predbetas_vannila).tolist()
constant = rec_hist(tst_predlabels=test_predlabels_th100, pred_betas = torch.tensor([3,1,0]).repeat(7200, 1)).tolist()
theirs = rec_hist(tst_predlabels=test_predlabels_th100, pred_betas = tst_predbetas_their).tolist()


# set width of bar
barWidth = 0.15
 
# # set
# bars = {}
# for idx in range(10):
#     bars[idx] = [theirs[idx], joint[idx], constant[idx], joint_improved[idx]]


# Set position of bar on X axis
r1 = np.arange(10)
r2 = [x + barWidth for x in r1]
r3 = [x + barWidth for x in r2]
r4 = [x + barWidth for x in r3]

# Make the plot
plt.bar(r1, theirs, color='blue', width=barWidth, edgecolor='white', label='Only $\phi$')
plt.bar(r2, joint, color='green', width=barWidth, edgecolor='white', label='Joint')
plt.bar(r3, constant, color='black', width=barWidth, edgecolor='white', label='Constant prediction')
plt.bar(r4, joint_improved, color='red', width=barWidth, edgecolor='white', label='Joint Prior')
 
# Add xticks on the middle of the group bars
plt.ylabel("Recourse Accuracy")
plt.xticks([r for r in range(10)], ['Aeroplane', 'Bench', 'Bus', 'Cabinet', 'Chair', 'Display', 'Knife', 'Lamp', 'Speaker', 'Gun'], rotation=75)
plt.grid(linestyle="dotted")

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.title('Classwise performance of different ' + '$g_\phi$' + '\n on Shapenet-Small')

plt.savefig(f"./histogram_small.png", dpi=300, bbox_inches = "tight")

In [None]:
print(joint_improved)

In [None]:
print(joint)

In [None]:
print(constant)

In [None]:
print(theirs)