In [1]:
import argparse
from copy import deepcopy
import logging
import random
from collections import defaultdict
from os.path import join
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, mean_squared_error, r2_score, f1_score, precision_recall_curve
from sklearn.model_selection import train_test_split
import joblib
import imodels
import inspect
import os.path
import imodelsx.cache_save_utils
import sys
import torch

#path_to_repo = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

#os.chdir(path_to_repo)
#os.chdir('/home/mattyshen/interpretableDistillation')
sys.path.append('..')

import idistill.model
import idistill.data

def distill_model(student, X_train_teacher, y_train_teacher, r, feature_names = None):
    """Distill the teacher model using the student model"""
    
    fit_parameters = inspect.signature(student.fit).parameters.keys()
    if "feature_names" in fit_parameters and feature_names is not None:
        student.fit(X_train_teacher, y_train_teacher, feature_names=feature_names)
    else:
        student.fit(X_train_teacher, y_train_teacher)

    return r, student

def evaluate_student(student, X_train, X_test, y_train, y_test, metric, task, r):
    """Evaluate student performance on each split"""
    
    metrics = {
            "accuracy": accuracy_score,
            "mse": mean_squared_error,
            "r2": r2_score,
            "f1": f1_score,
        
        }
    
    metric_fn = metrics[metric]
    
    for split_name, (X_, y_) in zip(
        ["train", "test"], [(X_train, y_train), (X_test, y_test)]
    ):
        y_pred_ = process_student_eval(student.predict(X_))
        r[f"student_{task}_{split_name}_{metric}"] = metric_fn(y_, y_pred_)

    return r

def evaluate_teacher(y_train_teacher, y_test_teacher, y_train, y_test, metric, task, r):
    metrics = {
            "accuracy": accuracy_score,
            "mse": mean_squared_error,
            "r2": r2_score,
            "f1": f1_score,
        
        }
    
    metric_fn = metrics[metric]
    
    for split_name, (y_teacher_, y_) in zip(
        ["train", "test"], [(y_train_teacher, y_train), (y_test_teacher, y_test)]
    ):
        r[f"teacher_{task}_{split_name}_{metric}"] = metric_fn(y_teacher_, y_)
    
    return r

def predict_teacher(teacher, X):
    ### TODO: handle teacher prediction outputs (X is intended to be concept design matrix, output is intended to be logits)###

    y_pred_torch = teacher.sec_model(torch.tensor(X.values, dtype=torch.float32).to('cuda:0'))
    y_pred = pd.DataFrame(y_pred_torch.detach().cpu().numpy())
        
    return y_pred

def load_teacher_model(teacher_path):
    ### TODO: load in teacher model using model_path ###
    
    sys.path.append('/home/mattyshen/ConceptBottleneck')
    teacher = torch.load(teacher_path, weights_only=False)
    teacher.to('cuda:0')
    teacher.eval()
    sys.path.append('/home/mattyshen/DistillationEdit')
    
    return teacher

def generate_tabular_distillation_data(teacher, train_path, test_path):
    ### TODO: generate teacher train and test data using model, train_path, and test_path ###
    
    sys.path.append('/home/mattyshen/ConceptBottleneck/CUB')
    from dataset import load_data
    from config import BASE_DIR
    
    def get_cub_data(teacher, path, data = 'train', override_train = True, batch_size = 32):
        with torch.no_grad():
            if data == 'test':
                test_dir = path
                #print(test_dir)
                # loader = load_data([test_dir], True, False, batch_size, image_dir='images',
                #                    n_class_attr=2, override_train=override_train)
                loader = load_data([test_dir], True, False, batch_size, image_dir='AdversarialData/CUB_fixed/test',
                                   n_class_attr=2)
            else:
                train_dir = path
                val_dir = '/home/mattyshen/ConceptBottleneck/CUB_processed/class_attr_data_10/val.pkl'
                #print(train_dir, val_dir)
                # loader = load_data([train_dir, val_dir], True, False, batch_size, image_dir='images',
                #                    n_class_attr=2, override_train=override_train)
                loader = load_data([train_dir, val_dir], True, False, batch_size, image_dir='AdversarialData/CUB_fixed/train',
                                    n_class_attr=2)
                
            torch.manual_seed(0)
            
            attrs_true = []
            attrs_hat = []
            labels_true = []
            labels_hat = []
            for data_idx, data in enumerate(loader):
                inputs, labels, attr_labels = data
                attr_labels = torch.stack(attr_labels).t()

                inputs_var = torch.autograd.Variable(inputs).to('cuda:0')
                labels_var = torch.autograd.Variable(labels).to('cuda:0')
                outputs = teacher(inputs_var)
                class_outputs = outputs[0]

                attr_outputs = outputs[1:] #[torch.nn.Sigmoid()(o) for o in outputs[1:]]
                attr_outputs_sigmoid = attr_outputs

                attrs_hat.append(torch.stack(attr_outputs).squeeze(2).detach().cpu().numpy())
                attrs_true.append(attr_labels.T)
                labels_hat.append(class_outputs.detach().cpu().numpy())
                labels_true.append(labels)

            X_hat = pd.DataFrame(np.concatenate(attrs_hat, axis=1).T, columns = [f'c{i}' for i in range(1, 113)])
            X = pd.DataFrame(np.concatenate(attrs_true, axis = 1).T, columns = [f'c{i}' for i in range(1, 113)])

            y = pd.Series(np.concatenate([l.numpy().reshape(-1, ) for l in labels_true]))
            y_hat = pd.DataFrame(np.concatenate(labels_hat, axis = 0))

            del attrs_hat
            del labels
            del labels_hat
            del loader
            del data
            del inputs
            del outputs
            del class_outputs
            del attr_outputs
            del attr_outputs_sigmoid
            del inputs_var
            del labels_var
            torch.cuda.empty_cache()

            return X_hat, X, y_hat, y

    X_train_teacher, X_train, y_train_teacher, y_train = get_cub_data(teacher, train_path)
    X_test_teacher, X_test, y_test_teacher, y_test = get_cub_data(teacher, test_path, data = 'test')
    
    sys.path.append('/home/mattyshen/DistillationEdit')
    
    return X_train_teacher, X_test_teacher, X_train, X_test, y_train_teacher, y_test_teacher, y_train, y_test

def process_distillation_data(X_train_teacher, X_test_teacher, X_train, X_test, y_train_teacher, y_test_teacher, y_train, y_test):
    ### TODO: process (i.e. binarize, F1-max binarize) data for distillation ###
    
    def find_optimal_threshold(y_true, y_probs):
        
#         threshs = np.arange(0, 1.01, 0.01)
#         prop_correct = [accuracy_score(y_true, y_probs > t) for t in threshs]
        
#         return threshs[np.argmax(prop_correct)]
        precisions, recalls, thresholds = precision_recall_curve(y_true, y_probs)
        f1_scores = 2 * (precisions * recalls) / (precisions + recalls)
        optimal_idx = np.argmax(f1_scores)
        optimal_threshold = thresholds[optimal_idx]
        return optimal_threshold
    def sigmoid(x):
        return 1/ (1+np.exp(-x))

#     optimal_thresholds = []
#     for class_idx in range(X_train_teacher.shape[1]):
#         y_true_class = X_train.iloc[:, class_idx]
#         y_probs_class = X_train_teacher.iloc[:, class_idx]
#         optimal_thresholds.append(find_optimal_threshold(y_true_class, y_probs_class))
        
#     optimal_thresholds = np.array(optimal_thresholds)

    # return X_train_teacher.applymap(sigmoid), X_test_teacher.applymap(sigmoid), y_train_teacher, y_test_teacher
    # optimal_thresholds = find_optimal_threshold(X_train.values.reshape(-1, ), X_train_teacher.values.reshape(-1, ))
    # print(np.unique(optimal_thresholds))
    # return (X_train_teacher > optimal_thresholds).astype(int), (X_test_teacher > optimal_thresholds).astype(int), y_train_teacher, y_test_teacher
    # best_t = np.argmin([np.mean(((X_train_teacher.values > t).astype(int) - X_train.values)**2) for t in np.arange(0, 1, 0.01)])
    # thresh = np.arange(0, 1, 0.01)[best_t]
    # print(thresh)
    thresh = 0
    
    return (X_train_teacher > thresh).astype(int), (X_test_teacher > thresh).astype(int), y_train_teacher, y_test_teacher

def process_student_eval(y_student):
    ### TODO: handle student prediction outputs to match metrics ###
    
    y_pred = np.argmax(y_student, axis = 1)

    return y_pred

def process_teacher_eval(y_teacher):
    ### TODO: process teacher model predictions for evaluations (sometimes we distill a teacher model using a regressor, but want to evaluate class prediction accuracy) ###
    
    y_teacher_eval = y_teacher.idxmax(axis = 1).astype(int).values
    
    return y_teacher_eval

def extract_interactions(student):

    interactions = []

    def traverse_tree(node, current_features, current_depth):

        if node.left is None and node.right is None:
            tree_interactions.append((current_features, np.var(np.abs(node.value))))
            return
        if node.left is not None:
            current_features_l = current_features.copy()
            current_features_l.append('c' + str(node.feature+1))
            traverse_tree(node.left, current_features_l.copy(), current_depth=current_depth+1)
        if node.right is not None:
            current_features_r = current_features.copy()
            current_features_r.append('!c' + str(node.feature+1))
            traverse_tree(node.right, current_features_r.copy(), current_depth=current_depth+1)

    for tree in student.trees_:
        tree_interactions = []
        traverse_tree(tree, [], current_depth=0)
        interactions.append(tree_interactions)
        
    return interactions

def get_argmax_max(vals, index):
    
    maxes = np.partition(vals, -2, axis=1)[:, -index]
    argmaxes = np.argsort(vals, axis=1)[:, -index]
    return maxes, argmaxes

def extract_adaptive_intervention(student, X, interactions, number_of_top_paths, tol = 0.0001):
    
    test_pred_intervention = student.predict(X, by_tree = True)

    concepts_to_edit = [[] for _ in range(X.shape[0])]
    variances = np.var(np.abs(test_pred_intervention), axis = 1)

    for idx in range(number_of_top_paths):
        maxes, argmaxes = get_argmax_max(variances, idx+1)
        for i, (tree_idx, var) in enumerate(zip(argmaxes, maxes)):
            for paths in interactions[tree_idx]:
                if abs(paths[1] - var) < tol:
                    concept_indexes = [int(p[1:])-1 if p[0] != '!' else int(p[2:])-1 for p in paths[0]]
                    concepts_to_edit[i].append(concept_indexes)
                    
    concepts_to_edit = [sum(element, []) for element in concepts_to_edit]
    concepts_to_edit = [list(set(c)) for c in concepts_to_edit]
    
    return concepts_to_edit

  def noop(*args, **kwargs):  # type: ignore
2025-02-06 19:14:11.019346: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from urllib3.contrib.pyopenssl import orig_util_SSLContext as SSLContext
  torch.utils._pytree._register_pytree_node(


In [2]:
class ARGS:
    def __init__(self, a_dict):
        for k in a_dict.keys():
            exec(f'self.{k} = a_dict["{k}"]')
        

In [3]:
args = {}
#args['save_dir'] = join(path_to_repo, "results")  # The default value
args['teacher_path'] = '/home/mattyshen/DistillationEdit/models/travelingbirds/outputs/best_Joint0.01_Transformer_model_1.pth'  # The default value
args['train_path'] = '/home/mattyshen/ConceptBottleneck/CUB_processed/class_attr_data_10/train.pkl'  # The default value
args['test_path'] = '/home/mattyshen/ConceptBottleneck/CUB_processed/class_attr_data_10/test.pkl'  # The default value
args['task_type'] = "regression"  # The default value
args['student_name'] = "FIGSRegressor"  # The default value
args['max_rules'] = 250  # The default value
args['max_trees'] = 40  # The default value
args['max_depth'] = 4  # The default value
args['metric'] = "accuracy"  # The default value
args['num_interactions_intervention'] = 3  # The default value
args = ARGS(args)

In [4]:
r = defaultdict(list)

In [5]:
teacher = load_teacher_model(args.teacher_path)
    
X_train_t, X_test_t, X_train, X_test, y_train_t, y_test_t, y_train, y_test = generate_tabular_distillation_data(teacher, args.train_path, args.test_path)

In [6]:
X_train_d, X_test_d, y_train_d, y_test_d = process_distillation_data(X_train_t, X_test_t, X_train, X_test, y_train_t, y_test_t, y_train, y_test)

y_train_t_eval = process_teacher_eval(y_train_t)
y_test_t_eval = process_teacher_eval(y_test_t)

In [11]:
from idistill.whitebox_figs import FIGSRegressor

In [12]:
figs_student = FIGSRegressor(max_rules=args.max_rules, max_trees=args.max_trees, max_depth=args.max_depth) #idistill.model.get_model(args.task_type, args.student_name, args)

r, figs_student = distill_model(figs_student, X_train_d, y_train_d, r)

In [13]:
r = evaluate_student(figs_student, X_train_d, X_test_d, y_train_t_eval, y_test_t_eval, args.metric, "distillation", r)
r = evaluate_student(figs_student, X_train_d, X_test_d, y_train, y_test, args.metric, "prediction", r)

r = evaluate_teacher(y_train_t_eval, y_test_t_eval, y_train, y_test, args.metric, "prediction", r)

In [14]:
r

defaultdict(list,
            {'student_distillation_train_accuracy': 0.7227606951871658,
             'student_distillation_test_accuracy': 0.6926130479806697,
             'student_prediction_train_accuracy': 0.6497326203208557,
             'student_prediction_test_accuracy': 0.5474628926475664,
             'teacher_prediction_train_accuracy': 0.6911764705882353,
             'teacher_prediction_test_accuracy': 0.5902657921988264})

In [None]:
figs_interactions = extract_interactions(figs_student)

X_train_d_edit = X_train_d.copy()
X_train_t_edit = X_train_t.copy()

train_q5 = np.quantile(X_train_t, 0.05, axis = 0)
train_q95 = np.quantile(X_train_t, 0.95, axis = 0)

# train_q5d = np.quantile(X_train_d, 0.05, axis = 0)
# train_q95d = np.quantile(X_train_d, 0.95, axis = 0)

cti_train = extract_adaptive_intervention(figs_student, X_train_d, figs_interactions, args.num_interactions_intervention)

for i in range(len(cti_train)):
    X_train_d_edit.iloc[i, cti_train[i]] = X_train.iloc[i, cti_train[i]] #train_q5d[cti_train[i]]*(X_train.iloc[i, cti_train[i]] == 0) + train_q95d[cti_train[i]]*(X_train.iloc[i, cti_train[i]])
    X_train_t_edit.iloc[i, cti_train[i]] = train_q5[cti_train[i]]*(X_train.iloc[i, cti_train[i]] == 0) + train_q95[cti_train[i]]*(X_train.iloc[i, cti_train[i]])

cti_test = extract_adaptive_intervention(figs_student, X_test_d, figs_interactions, args.num_interactions_intervention)

X_test_d_edit = X_test_d.copy()
X_test_t_edit = X_test_t.copy()

for i in range(len(cti_test)):
    X_test_d_edit.iloc[i, cti_test[i]] = X_test.iloc[i, cti_test[i]] #train_q5d[cti_test[i]]*(X_test.iloc[i, cti_test[i]] == 0) + train_q95d[cti_test[i]]*(X_test.iloc[i, cti_test[i]])
    X_test_t_edit.iloc[i, cti_test[i]] = train_q5[cti_test[i]]*(X_test.iloc[i, cti_test[i]] == 0) + train_q95[cti_test[i]]*(X_test.iloc[i, cti_test[i]])

y_train_t_eval_interv = process_teacher_eval(predict_teacher(teacher, X_train_t_edit))
y_test_t_eval_interv = process_teacher_eval(predict_teacher(teacher, X_test_t_edit))

r = evaluate_student(figs_student, X_train_d_edit, X_test_d_edit, y_train_t_eval_interv, y_test_t_eval_interv, args.metric, "distillation_adap_interv", r)
r = evaluate_student(figs_student, X_train_d_edit, X_test_d_edit, y_train, y_test, args.metric, "prediction_adap_interv", r)

r = evaluate_teacher(y_train_t_eval_interv, y_test_t_eval_interv, y_train, y_test, args.metric, "prediction_adap_interv", r)

In [None]:
args.num_interactions_intervention * args.max_depth

In [None]:
plt.hist([len(i) for i in cti_train])

In [None]:
plt.hist([len(i) for i in cti_test])

In [None]:
cti_r_train = [np.random.choice(np.arange(0, 112), size=len(c), replace=False) for c in cti_train]
cti_r_test = [np.random.choice(np.arange(0, 112), size=len(c), replace=False) for c in cti_test]

In [None]:
plt.hist([len(i) for i in cti_r_train])

In [None]:
plt.hist([len(i) for i in cti_r_test])

In [None]:
X_train_t

In [None]:
X_train_d_r_edit = X_train_d.copy()
X_train_t_r_edit = X_train_t.copy()

for i in range(len(cti_r_train)):
    X_train_d_r_edit.iloc[i, cti_r_train[i]] = X_train.iloc[i, cti_r_train[i]] #train_q5d[cti_r_train[i]]*(X_train.iloc[i, cti_r_train[i]] == 0) + train_q95d[cti_r_train[i]]*(X_train.iloc[i, cti_r_train[i]])
    X_train_t_r_edit.iloc[i, cti_r_train[i]] = train_q5[cti_r_train[i]]*(X_train.iloc[i, cti_r_train[i]] == 0) + train_q95[cti_r_train[i]]*(X_train.iloc[i, cti_r_train[i]])

X_test_d_r_edit = X_test_d.copy()
X_test_t_r_edit = X_test_t.copy()

for i in range(len(cti_r_test)):
    X_test_d_r_edit.iloc[i, cti_r_test[i]] = X_test.iloc[i, cti_r_test[i]] #train_q5d[cti_r_test[i]]*(X_test.iloc[i, cti_r_test[i]] == 0) + train_q95d[cti_r_test[i]]*(X_test.iloc[i, cti_r_test[i]])
    X_test_t_r_edit.iloc[i, cti_r_test[i]] = train_q5[cti_r_test[i]]*(X_test.iloc[i, cti_r_test[i]] == 0) + train_q95[cti_r_test[i]]*(X_test.iloc[i, cti_r_test[i]])

y_train_t_eval_r_interv = process_teacher_eval(predict_teacher(teacher, X_train_t_r_edit))
y_test_t_eval_r_interv = process_teacher_eval(predict_teacher(teacher, X_test_t_r_edit))

r = evaluate_student(figs_student, X_train_d_r_edit, X_test_d_r_edit, y_train_t_eval_r_interv, y_test_t_eval_r_interv, args.metric, "distillation_rand_interv", r)
r = evaluate_student(figs_student, X_train_d_r_edit, X_test_d_r_edit, y_train, y_test, args.metric, "prediction_rand_interv", r)

r = evaluate_teacher(y_train_t_eval_r_interv, y_test_t_eval_r_interv, y_train, y_test, args.metric, "prediction_rand_interv", r)

In [None]:
args.num_interactions_intervention * args.max_depth

In [None]:
#r MLP2


r

In [None]:
-----

In [None]:
#r MLP1

r

In [None]:
#thresh = median of training values
r

In [None]:
#thresh = 0
r

In [None]:
len(figs_student.trees_)