In [None]:
"""
Evaluate trained models on the official CUB test set
"""
import os
import sys
import torch
import joblib
import argparse
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score
sys.path.append('/home/mattyshen/interpretableDistillation')
from interpretDistill import fourierDistill
sys.path.append('/home/mattyshen/iCBM')

from CUB.dataset import load_data
from CUB.config import BASE_DIR, N_CLASSES, N_ATTRIBUTES, DEVICE, get_device, set_device
from analysis import AverageMeter, multiclass_metric, accuracy, binary_accuracy

from imodels import FIGSClassifierCV


class ARGS:
    def __init__(self, a_dict):
        for k in a_dict.keys():
            exec(f'self.{k} = a_dict["{k}"]')
            
parser_args = ['log_dir', 
               'model_dirs', 
               'model_dirs2', 
               'eval_data', 
               'use_attr', 
               'no_img', 
               'bottleneck', 
               'image_dir', 
               'n_class_attr', 
               'data_dir', 
               'n_attributes', 
               'attribute_group',
               'feature_group_results', 
               'use_relu', 
               'use_sigmoid', 
               'use_gbsm', 
               'expand_gbsm_dim', 
               'gpu']
parser_sigmoid = ['/home/mattyshen/iCBM/CUB/eval/JointSigmoidModels/outputs', 
               ['/home/mattyshen/iCBM/CUB/best_models/Joint0.01SigmoidModel__Seed1/outputs/best_model_1.pth', '/home/mattyshen/iCBM/CUB/best_models/Joint0.01SigmoidModel__Seed2/outputs/best_model_2.pth', '/home/mattyshen/iCBM/CUB/best_models/Joint0.01SigmoidModel__Seed3/outputs/best_model_3.pth'],
               None,
               'test',
               True,
               False,
               False,
               'images',
               2,
               'CUB_processed/class_attr_data_10',
               112,
               None,
               False,
               False,
               True,
               False,
               False,
               2]
parser_gbsm = ['/home/mattyshen/iCBM/CUB/eval/JointGBSMModels/outputs', 
               ['/home/mattyshen/iCBM/CUB/best_models/Joint0.01GBSMModel__Seed1/outputs/best_model_1.pth', '/home/mattyshen/iCBM/CUB/best_models/Joint0.01GBSMModel__Seed2/outputs/best_model_2.pth', '/home/mattyshen/iCBM/CUB/best_models/Joint0.01GBSMModel__Seed3/outputs/best_model_3.pth'],
               None,
               'test',
               True,
               False,
               False,
               'images',
               2,
               'CUB_processed/class_attr_data_10',
               112,
               None,
               False,
               False,
               True,
               True,
               False,
               2]

In [None]:
args_dict = dict(zip(parser_args, parser_sigmoid))
torch.backends.cudnn.benchmark=True
args = ARGS(args_dict)

set_device(args.gpu)

args.three_class = (args.n_class_attr == 3)
args.batch_size = 16

print(args)

In [None]:
def get_FT_data(args, data='trainval', p_thresh=0.5):
    #TODO: loop over all models
    model = torch.load(args.model_dirs[0])
    model = model.to(get_device())
    model.eval()
    if data == 'test':
        test_dir = data_dir = os.path.join(BASE_DIR, args.data_dir, 'test.pkl')
        loader = load_data([test_dir], args.use_attr, args.no_img, 32, image_dir=args.image_dir,
                           n_class_attr=args.n_class_attr)
    else:
        train_dir = data_dir = os.path.join(BASE_DIR, args.data_dir, 'train.pkl')
        val_dir = data_dir = os.path.join(BASE_DIR, args.data_dir, 'val.pkl')
        loader = load_data([train_dir, val_dir], args.use_attr, args.no_img, 32, image_dir=args.image_dir,
                           n_class_attr=args.n_class_attr)
    train_val_attrs = []
    train_val_labels = []
    train_val_labels_hat = []
    for data_idx, data in enumerate(loader):
        inputs, labels, attr_labels = data
        attr_labels = torch.stack(attr_labels).t()  # N x 312

        inputs_var = torch.autograd.Variable(inputs).to(get_device())
        labels_var = torch.autograd.Variable(labels).to(get_device())
        #labels = labels.to(get_device()) if torch.cuda.is_available() else labels

        outputs = model(inputs_var)
        class_outputs = outputs[0]
        
        attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs[1:]]
        attr_outputs_sigmoid = attr_outputs
        
        train_val_attrs.append(torch.stack(attr_outputs).squeeze(2).detach().cpu().numpy().T)
        train_val_labels_hat.append(np.argmax(class_outputs.detach().cpu().numpy(), axis = 1))
        train_val_labels.append(labels)
        
    # X_train = pd.DataFrame(np.concatenate(train_val_attrs, axis=0) > p_thresh, columns = [f'c{i}' for i in range(1, 113)]).astype(np.int64)
    X_train = pd.DataFrame(np.concatenate(train_val_attrs, axis=0), columns = [f'c{i}' for i in range(1, 113)])
    y_train = pd.Series(np.concatenate([l.numpy().reshape(-1, ) for l in train_val_labels]))
    y_train_hat = pd.Series(np.concatenate(train_val_labels_hat))
    
    return X_train, y_train, y_train_hat

In [None]:
X_train, y_train, y_train_hat = get_FT_data(args)

In [None]:
torch.cuda.empty_cache()

In [None]:
for i in range(200):
    if np.any(np.where(y_train_hat == i, 1, 0) == np.nan):
        print(i)
    #print(np.any(np.where(y_train_hat == i, 1, 0) == None))

In [None]:
len(figs_models)

In [None]:
figs_models2 = []
for i in range(145, 200):
    if i%25 == 0:
        print(f'training class {i}')
    figs_i = FIGSClassifierCV(n_rules_list = [20, 20], n_trees_list = [5, 10])
    figs_i.fit(X_train, np.where(y_train_hat == i, 1, 0))
    figs_models.append(figs_i)
    #np.where(y_train_hat == 46, 1, 0)

In [None]:
np.sum(np.where(y_train_hat == 147, 1, 0))

In [None]:
figs_i = FIGSClassifierCV(n_rules_list = [20, 20], n_trees_list = [5, 10])
figs_i.fit(X_train, np.where(y_train_hat == i, 1, 0))

In [None]:
figs_models = []
for i in range(0, 200):
    if i%25 == 0:
        print(f'training class {i}')
    figs_i = FIGSClassifierCV(n_rules_list = [20, 20], n_trees_list = [5, 10])
    if np.sum(np.where(y_train_hat == i, 1, 0)) > 0:
        figs_i.fit(X_train, np.where(y_train_hat == i, 1, 0))
    figs_models.append(figs_i)
    #np.where(y_train_hat == 46, 1, 0)

In [None]:
len(figs_models)

In [None]:
predictions = []
for i in figs_models:
    if hasattr(i, 'figs'):
        print(i.predict_proba(X_train)[:, 0].shape)
        predictions.append(i.predict_proba(X_train)[:, 0])
    else:
        print(np.zeros((X_train.shape[0], )).shape)
        predictions.append(np.zeros((X_train.shape[0], )))

In [None]:
np.argmax(np.vstack((predictions)).T, axis = 1).shape

In [None]:
np.unique(np.vstack((predictions)))

In [None]:
len(np.argmax(np.vstack((predictions)).T, axis = 1))

In [None]:
np.mean(np.argmax(np.vstack((predictions)).T, axis = 1) == y_train)

In [None]:
np.concatenate(predictions, axis = 1)

In [None]:
figs_models[i].predict_proba(X_train).shape, figs_models[i].predict_proba(X_train)

In [None]:
np.max(y_train_hat)

In [None]:
np.mean(y_train_hat == y_train)

In [None]:
figs = FIGSClassifierCV()

In [None]:
figs.fit(X_train, y_train)

In [None]:
figs_distill = FIGSClassifierCV()
figs_distill.fit(X_train, y_train_hat)

In [None]:
np.mean(figs_distill.predict(X_train))

In [None]:
np.mean(figs.predict(X_train) == y_train), np.mean(figs_distill.predict(X_train) == y_train), np.mean(figs_distill.predict(X_train) == y_train_hat)

In [None]:
ftd = fourierDistill.FTDistillClassifierCV(pre_interaction='l0l2', 
                             pre_max_features=75,
                             post_interaction='l0l2', 
                             post_max_features=50,
                             size_interactions=3,  
                             cv=3)

In [None]:
ftd.fit(X_train, y_train)

In [None]:
X_test, y_test = get_FT_data(args, data='test')

In [None]:
y_test = pd.Series(np.concatenate([l.numpy().reshape(-1, ) for l in y_test]))

In [None]:
np.min(np.sum(ftd.post_sparsity_model.coef_ != 0, axis = 1))

In [None]:
np.mean(ftd.predict(X_train) == y_train), np.mean(ftd.predict(X_test) == y_test)

In [None]:
K = [1, 3, 5] #top k class accuracies to compute

def eval(args):
    """
    Run inference using model (and model2 if bottleneck)
    Returns: (for notebook analysis)
    all_class_labels: flattened list of class labels for each image
    topk_class_outputs: array of top k class ids predicted for each image. Shape = size of test set * max(K)
    all_class_outputs: array of all logit outputs for class prediction, shape = N_TEST * N_CLASS
    all_attr_labels: flattened list of labels for each attribute for each image (length = N_ATTRIBUTES * N_TEST)
    all_attr_outputs: flatted list of attribute logits (after ReLU/ Sigmoid respectively) predicted for each attribute for each image (length = N_ATTRIBUTES * N_TEST)
    all_attr_outputs_sigmoid: flatted list of attribute logits predicted (after Sigmoid) for each attribute for each image (length = N_ATTRIBUTES * N_TEST)
    wrong_idx: image ids where the model got the wrong class prediction (to compare with other models)
    """
    if args.model_dir:
        model = torch.load(args.model_dir)
        
    else:
        model = None

    if not hasattr(model, 'use_relu'):
        if args.use_relu:
            model.use_relu = True
        else:
            model.use_relu = False
    if not hasattr(model, 'use_sigmoid'):
        if args.use_sigmoid:
            model.use_sigmoid = True
        else:
            model.use_sigmoid = False
    if not hasattr(model, 'cy_fc'):
        model.cy_fc = None
    model = model.to(get_device())
    model.eval()

    if args.model_dir2:
        if 'rf' in args.model_dir2:
            model2 = joblib.load(args.model_dir2)
        else:
            model2 = torch.load(args.model_dir2)
        if not hasattr(model2, 'use_relu'):
            if args.use_relu:
                model2.use_relu = True
            else:
                model2.use_relu = False
        if not hasattr(model2, 'use_sigmoid'):
            if args.use_sigmoid:
                model2.use_sigmoid = True
            else:
                model2.use_sigmoid = False
        model2 = model2.to(get_device())
        model2.eval()
    else:
        model2 = None

    if args.use_attr:
        attr_acc_meter = [AverageMeter()]
        if args.feature_group_results:  # compute acc for each feature individually in addition to the overall accuracy
            for _ in range(args.n_attributes):
                attr_acc_meter.append(AverageMeter())
    else:
        attr_acc_meter = None

    class_acc_meter = []
    for j in range(len(K)):
        class_acc_meter.append(AverageMeter())
    if args.eval_data == 'trainval':
        train_dir = data_dir = os.path.join(BASE_DIR, args.data_dir, 'train.pkl')
        val_dir = data_dir = os.path.join(BASE_DIR, args.data_dir, 'val.pkl')
        loader = load_data([train_dir, val_dir], args.use_attr, args.no_img, args.batch_size, image_dir=args.image_dir,
                           n_class_attr=args.n_class_attr)
    else:
        data_dir = os.path.join(BASE_DIR, args.data_dir, args.eval_data + '.pkl')
        loader = load_data([data_dir], args.use_attr, args.no_img, args.batch_size, image_dir=args.image_dir,
                           n_class_attr=args.n_class_attr)
    all_outputs, all_targets = [], []
    all_attr_labels, all_attr_outputs, all_attr_outputs_sigmoid, all_attr_outputs2 = [], [], [], []
    all_class_labels, all_class_outputs, all_class_logits = [], [], []
    topk_class_labels, topk_class_outputs = [], []

    for data_idx, data in enumerate(loader):
        if args.use_attr:
            if args.no_img:  # A -> Y
                inputs, labels = data
                if isinstance(inputs, list):
                    inputs = torch.stack(inputs).t().float()
                inputs = inputs.float()
                # inputs = torch.flatten(inputs, start_dim=1).float()
            else:
                inputs, labels, attr_labels = data
                attr_labels = torch.stack(attr_labels).t()  # N x 312
        else:  # simple finetune
            inputs, labels = data

        inputs_var = torch.autograd.Variable(inputs).to(get_device())
        labels_var = torch.autograd.Variable(labels).to(get_device())
        labels = labels.to(get_device()) if torch.cuda.is_available() else labels

        if args.attribute_group:
            outputs = []
            f = open(args.attribute_group, 'r')
            for line in f:
                attr_model = torch.load(line.strip())
                outputs.extend(attr_model(inputs_var))
        else:
            outputs = model(inputs_var)
        if args.use_attr:
            if args.no_img:  # A -> Y
                class_outputs = outputs
            else:
                if args.bottleneck:
                    if args.use_relu:
                        attr_outputs = [torch.nn.ReLU()(o) for o in outputs]
                        attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs]
                    elif args.use_sigmoid:
                        attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs]
                        attr_outputs_sigmoid = attr_outputs
                    else:
                        attr_outputs = outputs
                        attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs]
                    if model2:
                        stage2_inputs = torch.cat(attr_outputs, dim=1)
                        class_outputs = model2(stage2_inputs)
                    else:  # for debugging bottleneck performance without running stage 2
                        class_outputs = torch.zeros([inputs.size(0), N_CLASSES],
                                                    dtype=torch.float64).to(get_device())  # ignore this
                else:  # cotraining, end2end
                    if args.use_relu:
                        attr_outputs = [torch.nn.ReLU()(o) for o in outputs[1:]]
                        attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs[1:]]
                    elif args.use_sigmoid:
                        attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs[1:]]
                        attr_outputs_sigmoid = attr_outputs
                    else:
                        attr_outputs = outputs[1:]
                        attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs[1:]]

                    class_outputs = outputs[0]

                for i in range(args.n_attributes):
                    acc = binary_accuracy(attr_outputs_sigmoid[i].squeeze(), attr_labels[:, i])
                    acc = acc.data.cpu().numpy()
                    # acc = accuracy(attr_outputs_sigmoid[i], attr_labels[:, i], topk=(1,))
                    attr_acc_meter[0].update(acc, inputs.size(0))
                    if args.feature_group_results:  # keep track of accuracy of individual attributes
                        attr_acc_meter[i + 1].update(acc, inputs.size(0))

                attr_outputs = torch.cat([o.unsqueeze(1) for o in attr_outputs], dim=1)
                attr_outputs_sigmoid = torch.cat([o for o in attr_outputs_sigmoid], dim=1)
                all_attr_outputs.extend(list(attr_outputs.flatten().data.cpu().numpy()))
                all_attr_outputs_sigmoid.extend(list(attr_outputs_sigmoid.flatten().data.cpu().numpy()))
                all_attr_labels.extend(list(attr_labels.flatten().data.cpu().numpy()))
        else:
            class_outputs = outputs[0]

        _, topk_preds = class_outputs.topk(max(K), 1, True, True)
        _, preds = class_outputs.topk(1, 1, True, True)
        all_class_outputs.extend(list(preds.detach().cpu().numpy().flatten()))
        all_class_labels.extend(list(labels.data.cpu().numpy()))
        all_class_logits.extend(class_outputs.detach().cpu().numpy())
        topk_class_outputs.extend(topk_preds.detach().cpu().numpy())
        topk_class_labels.extend(labels.view(-1, 1).expand_as(preds))

        np.set_printoptions(threshold=sys.maxsize)
        class_acc = accuracy(class_outputs, labels, topk=K)  # only class prediction accuracy
        for m in range(len(class_acc_meter)):
            class_acc_meter[m].update(class_acc[m], inputs.size(0))

    all_class_logits = np.vstack(all_class_logits)
    topk_class_outputs = np.vstack([tco if isinstance(tco, np.ndarray) else tco.cpu() for tco in topk_class_outputs])
    topk_class_labels = np.vstack([tcl if isinstance(tcl, np.ndarray) else tcl.cpu() for tcl in topk_class_labels])
    wrong_idx = np.where(np.sum(topk_class_outputs == topk_class_labels, axis=1) == 0)[0]

    for j in range(len(K)):
        print('Average top %d class accuracy: %.5f' % (K[j], class_acc_meter[j].avg))

    if args.use_attr and not args.no_img:  # print some metrics for attribute prediction performance
        print('Average attribute accuracy: %.5f' % attr_acc_meter[0].avg)
        all_attr_outputs_int = np.array(all_attr_outputs_sigmoid) >= 0.5
        if args.feature_group_results:
            n = len(all_attr_labels)
            all_attr_acc, all_attr_f1 = [], []
            for i in range(args.n_attributes):
                acc_meter = attr_acc_meter[1 + i]
                attr_acc = float(acc_meter.avg)
                attr_preds = [all_attr_outputs_int[j] for j in range(n) if j % args.n_attributes == i]
                attr_labels = [all_attr_labels[j] for j in range(n) if j % args.n_attributes == i]
                attr_f1 = f1_score(attr_labels, attr_preds)
                all_attr_acc.append(attr_acc)
                all_attr_f1.append(attr_f1)

            '''
            fig, axs = plt.subplots(1, 2, figsize=(20,10))
            for plt_id, values in enumerate([all_attr_acc, all_attr_f1]):
                axs[plt_id].set_xticks(np.arange(0, 1.1, 0.1))
                if plt_id == 0:
                    axs[plt_id].hist(np.array(values)/100.0, bins=np.arange(0, 1.1, 0.1), rwidth=0.8)
                    axs[plt_id].set_title("Attribute accuracies distribution")
                else:
                    axs[plt_id].hist(values, bins=np.arange(0, 1.1, 0.1), rwidth=0.8)
                    axs[plt_id].set_title("Attribute F1 scores distribution")
            plt.savefig('/'.join(args.model_dir.split('/')[:-1]) + '.png')
            '''
            bins = np.arange(0, 1.01, 0.1)
            acc_bin_ids = np.digitize(np.array(all_attr_acc) / 100.0, bins)
            acc_counts_per_bin = [np.sum(acc_bin_ids == (i + 1)) for i in range(len(bins))]
            f1_bin_ids = np.digitize(np.array(all_attr_f1), bins)
            f1_counts_per_bin = [np.sum(f1_bin_ids == (i + 1)) for i in range(len(bins))]
            print("Accuracy bins:")
            print(acc_counts_per_bin)
            print("F1 bins:")
            print(f1_counts_per_bin)
            np.savetxt(os.path.join(args.log_dir, 'concepts.txt'), f1_counts_per_bin)

        balanced_acc, report = multiclass_metric(all_attr_outputs_int, all_attr_labels)
        f1 = f1_score(all_attr_labels, all_attr_outputs_int)
        print("Total 1's predicted:", sum(np.array(all_attr_outputs_sigmoid) >= 0.5) / len(all_attr_outputs_sigmoid))
        print('Avg attribute balanced acc: %.5f' % (balanced_acc))
        print("Avg attribute F1 score: %.5f" % f1)
        print(report + '\n')
    return class_acc_meter, attr_acc_meter, all_class_labels, topk_class_outputs, all_class_logits, all_attr_labels, all_attr_outputs, all_attr_outputs_sigmoid, wrong_idx, all_attr_outputs2

In [None]:

y_results, c_results = [], []
for i, model_dir in enumerate(args.model_dirs):
    args.model_dir = model_dir
    args.model_dir2 = args.model_dirs2[i] if args.model_dirs2 else None
    result = eval(args)
    class_acc_meter, attr_acc_meter = result[0], result[1]
    y_results.append(1 - class_acc_meter[0].avg[0].item() / 100.)
    if attr_acc_meter is not None:
        c_results.append(1 - attr_acc_meter[0].avg.item() / 100.)
    else:
        c_results.append(-1)
values = (np.mean(y_results), np.std(y_results), np.mean(c_results), np.std(c_results))
output_string = '%.4f %.4f %.4f %.4f' % values
print_string = 'Error of y: %.4f +- %.4f, Error of C: %.4f +- %.4f' % values
print(print_string)
f = open(os.path.join(args.log_dir, 'results.txt'), "a")
f.write(output_string)
f.close()

In [None]:
f = open(os.path.join(args.log_dir, 'results.txt'), "a")
f.write(output_string)
f.close()

In [None]:
output_string

In [None]:
os.path.join(args.log_dir, 'results.txt')

In [None]:
args.data_dir

In [None]:
def sig(x):
    return 1/(1+np.exp(x))

In [None]:
file = open(os.path.join(args.log_dir, 'results.txt'), "r")
content = file.read()
print(content)
file.close()