In [63]:
%load_ext autoreload
%autoreload 2

# add src to path
import sys
sys.path.append('/cluster/home/kheuto01/code/play-with-learning-army/src')

# change directory to this files directory
import os
os.chdir('/cluster/home/kheuto01/code/play-with-learning-army')
from data_loader import load_processed, make_dataset
from embedder_registry import initialize_embedding, initialize_criteria_embedding, initialize_combiner
from domain_models import initialize_domain_models
from loss_opt import initialize_loss, initialize_optimizer
import torch
import yaml
import transformers
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, average_precision_score

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [64]:
experiment_directory = '/cluster/tufts/hugheslab/kheuto01/sensemaking/bertfinetune_test/test15_lr1e-06_alpha0.1_beta0.01'
problem_config_path = '/cluster/home/kheuto01/code/play-with-learning-army/config/problem_config.yaml'
test_metrics_path = os.path.join(experiment_directory, 'test_metrics.csv')
config_path = os.path.join(experiment_directory, 'config.yaml')
model_path = os.path.join(experiment_directory, 'final_model.pth')
hyper_config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)
problem_config = yaml.load(open(problem_config_path, 'r'), Loader=yaml.FullLoader)
num_domains = problem_config['num_domains']

In [65]:
save_dict = torch.load(model_path, weights_only=False, map_location=torch.device('cpu'))
domain_model_dict = save_dict['domain_model_dict']
embedder = save_dict['embed_func']
embedder.device = 'cpu'

In [66]:
hyper_config

{'embedder': 'bert',
 'criteria_embedder': 'identity',
 'combiner': 'concatenate',
 'opt_weight_decay': 0,
 'device': 'cuda',
 'finetune': True,
 'wandb_project': 'sensemaking_bert_finetune_test',
 'seed': 360,
 'num_epochs': 1000,
 'loss': 'l2sp',
 'learning_rate': 1e-06,
 'alpha': 0.1,
 'beta': 0.01,
 'batch_size': 32,
 'num_folds': 1,
 'experiment_name': 'test15_lr1e-06_alpha0.1_beta0.01_take3',
 'train_x_file': '/cluster/home/kheuto01/code/play-with-learning-army/data/clean/test_15/retrain_x.csv',
 'train_y_file': '/cluster/home/kheuto01/code/play-with-learning-army/data/clean/test_15/retrain_y.csv',
 'val_x_file': '/cluster/home/kheuto01/code/play-with-learning-army/data/clean/test_15/test_x.csv',
 'val_y_file': '/cluster/home/kheuto01/code/play-with-learning-army/data/clean/test_15/test_y.csv',
 'test_x_file': '/cluster/home/kheuto01/code/play-with-learning-army/data/clean/test_15/test_x.csv',
 'test_y_file': '/cluster/home/kheuto01/code/play-with-learning-army/data/clean/test_15

In [53]:
processed_test_features, processed_test_labels = load_processed(hyper_config['test_x_file'], 
                                                                hyper_config['test_y_file'])
(xs, ys, problem_ids, student_ids) = make_dataset(processed_test_features, processed_test_labels)
xs, ys, problem_ids, student_ids = embedder.preprocess_data((xs, ys, torch.tensor(problem_ids), torch.tensor(student_ids)), hyper_config)

In [62]:
all_preds, all_labels, all_weights, all_domains = [], [], [], []

if isinstance(xs, transformers.BatchEncoding):
    batch_length = len(next(iter(xs.values())))
else:
    batch_length = len(xs)
    
for i in range(batch_length):
    if isinstance(xs, transformers.BatchEncoding):
        x = {k: v[i].unsqueeze(0) for k, v in xs.items()}
    else:
        x = xs[i].unsqueeze(0)
    y = ys[i]
    p = problem_ids[i]
    s = student_ids[i]

    x_embed = embedder.forward(x)

    criteria_counter = 0
    for d in range(num_domains):
        num_criteria = problem_config['problems'][p]['domains'][d]["num_criteria"]
        for c in range(num_criteria):
            c_embed = torch.tensor([c])
            final_representation = torch.cat((x_embed, c_embed.unsqueeze(0)), dim=1)
            y_pred = domain_model_dict[d](final_representation)
            weight = 1/num_criteria

            all_preds.append(y_pred.detach().cpu().numpy())
            all_labels.append(y[criteria_counter].cpu().numpy())
            all_weights.append(weight)
            all_domains.append(d)
            criteria_counter += 1


In [68]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, average_precision_score

import torch.nn.functional as F

# Convert lists to numpy arrays
all_preds_np = np.concatenate(all_preds).flatten()
all_labels_np = np.array(all_labels).flatten()
all_weights_np = np.array(all_weights)
all_domains_np = np.array(all_domains)

# Calculate metrics
accuracy = accuracy_score(all_labels_np, all_preds_np.round())
precision, recall, f1, _ = precision_recall_fscore_support(all_labels_np, all_preds_np.round(), average='binary')
roc_auc = roc_auc_score(all_labels_np, all_preds_np)
avg_precision = average_precision_score(all_labels_np, all_preds_np)
bce_loss = F.binary_cross_entropy(torch.tensor(all_preds_np), torch.tensor(all_labels_np), reduction='mean').item()

# Calculate weighted metrics
weighted_accuracy = accuracy_score(all_labels_np, all_preds_np.round(), sample_weight=all_weights_np)
weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(all_labels_np, all_preds_np.round(), average='binary', sample_weight=all_weights_np)
weighted_roc_auc = roc_auc_score(all_labels_np, all_preds_np, sample_weight=all_weights_np)
weighted_avg_precision = average_precision_score(all_labels_np, all_preds_np, sample_weight=all_weights_np)
weighted_bce_loss = F.binary_cross_entropy(torch.tensor(all_preds_np), torch.tensor(all_labels_np), weight=torch.tensor(all_weights_np), reduction='mean').item()

# Calculate metrics by domain
unique_domains = np.unique(all_domains_np)
domain_metrics = {}
for domain in unique_domains:
    domain_mask = all_domains_np == domain
    domain_preds = all_preds_np[domain_mask]
    domain_labels = all_labels_np[domain_mask]
    domain_weights = all_weights_np[domain_mask]
    
    domain_accuracy = accuracy_score(domain_labels, domain_preds.round())
    domain_precision, domain_recall, domain_f1, _ = precision_recall_fscore_support(domain_labels, domain_preds.round(), average='binary')
    domain_roc_auc = roc_auc_score(domain_labels, domain_preds)
    domain_avg_precision = average_precision_score(domain_labels, domain_preds)
    domain_bce_loss = F.binary_cross_entropy(torch.tensor(domain_preds), torch.tensor(domain_labels), reduction='mean').item()
    
    domain_weighted_accuracy = accuracy_score(domain_labels, domain_preds.round(), sample_weight=domain_weights)
    domain_weighted_precision, domain_weighted_recall, domain_weighted_f1, _ = precision_recall_fscore_support(domain_labels, domain_preds.round(), average='binary', sample_weight=domain_weights)
    domain_weighted_roc_auc = roc_auc_score(domain_labels, domain_preds, sample_weight=domain_weights)
    domain_weighted_avg_precision = average_precision_score(domain_labels, domain_preds, sample_weight=domain_weights)
    domain_weighted_bce_loss = F.binary_cross_entropy(torch.tensor(domain_preds), torch.tensor(domain_labels), weight=torch.tensor(domain_weights), reduction='mean').item()
    
    domain_metrics[domain] = {
        'accuracy': domain_accuracy,
        'precision': domain_precision,
        'recall': domain_recall,
        'f1': domain_f1,
        'roc_auc': domain_roc_auc,
        'avg_precision': domain_avg_precision,
        'bce_loss': domain_bce_loss,
        'weighted_accuracy': domain_weighted_accuracy,
        'weighted_precision': domain_weighted_precision,
        'weighted_recall': domain_weighted_recall,
        'weighted_f1': domain_weighted_f1,
        'weighted_roc_auc': domain_weighted_roc_auc,
        'weighted_avg_precision': domain_weighted_avg_precision,
        'weighted_bce_loss': domain_weighted_bce_loss
    }

# Print results
print("Overall Metrics:")
print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")
print(f"ROC AUC: {roc_auc}")
print(f"Average Precision: {avg_precision}")
print(f"Binary Cross-Entropy Loss: {bce_loss}")

print("\nWeighted Metrics:")
print(f"Weighted Accuracy: {weighted_accuracy}")
print(f"Weighted Precision: {weighted_precision}")
print(f"Weighted Recall: {weighted_recall}")
print(f"Weighted F1 Score: {weighted_f1}")
print(f"Weighted ROC AUC: {weighted_roc_auc}")
print(f"Weighted Average Precision: {weighted_avg_precision}")
print(f"Weighted Binary Cross-Entropy Loss: {weighted_bce_loss}")

print("\nMetrics by Domain:")
for domain, metrics in domain_metrics.items():
    print(f"\nDomain {domain}:")
    for metric_name, metric_value in metrics.items():
        print(f"{metric_name}: {metric_value}")

Overall Metrics:
Accuracy: 0.9210526315789473
Precision: 0.8884615384615384
Recall: 0.9352226720647774
F1 Score: 0.9112426035502958
ROC AUC: 0.9864504079918778
Average Precision: 0.9828301936945492
Binary Cross-Entropy Loss: 0.1291058510541916

Weighted Metrics:
Weighted Accuracy: 0.9378787878787879
Weighted Precision: 0.9070048309178753
Weighted Recall: 0.9422835633626102
Weighted F1 Score: 0.9243076923076928
Weighted ROC AUC: 0.9918502499334465
Weighted Average Precision: 0.9883217108803226
Weighted Binary Cross-Entropy Loss: 0.0583498515188694

Metrics by Domain:

Domain 0:
accuracy: 0.9714285714285714
precision: 0.9759036144578314
recall: 0.9878048780487805
f1: 0.9818181818181818
roc_auc: 0.9946977730646872
avg_precision: 0.998512347423429
bce_loss: 0.05327007547020912
weighted_accuracy: 0.9805555555555556
weighted_precision: 0.983739837398374
weighted_recall: 0.9877551020408163
weighted_f1: 0.9857433808553973
weighted_roc_auc: 0.9984383318544809
weighted_avg_precision: 0.999265750