In [None]:
import os
import torch
from transformers import AutoModelForSequenceClassification, AutoModel
import CXRBERT
from collections import OrderedDict
from transformers import AutoTokenizer
import numpy as np
from tqdm.notebook import tqdm
from eval import evaluate, plot_roc, accuracy, sigmoid, bootstrap, compute_cis
from typing import Tuple
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)
import sys

In [None]:
from datasets import load_dataset
dataset = load_dataset("csv", data_files={ "test": "full_test_with_c.csv"})

In [None]:
labels = [label for label in dataset['test'].features.keys() if label not in ['subject_id', 'study_id', 'report', 'c_report']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
labels

In [None]:
#cxr_pair_template = ("no {}", "has {}")
cxr_pair_template = ("findings suggesting {}", "no evidence of {}")

In [None]:
checkpoint = "allenai/scibert_scivocab_cased"
config = CXRBERT.CXRBertConfig.from_pretrained(checkpoint)
tokenizer = CXRBERT.CXRBertTokenizer.from_pretrained(checkpoint,padding="max_length", truncation=True, max_length=512)
model = CXRBERT.CXRBertModel(config).from_pretrained(checkpoint)

In [None]:
path = './CXR-BERT/1qbqcmfd/checkpoints/epoch=49-step=25000.ckpt'
model_cpt = torch.load(path)

new_state_dict = OrderedDict()
for k, v in model_cpt['state_dict'].items():
    name = k[6:] # remove `module.`
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)
model.eval() #model needs to be in evaluation state
model = model.to(device)

In [None]:
def preprocess_data(examples):
  # take a batch of texts
  text = examples["c_report"]
  # encode them
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=512)
  # add labels
  labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
  # create numpy array of shape (batch_size, num_labels)
  labels_matrix = np.zeros((len(text), len(labels)))
  # fill numpy array
  for idx, label in enumerate(labels):
    labels_matrix[:, idx] = labels_batch[label]

  encoding["labels"] = labels_matrix.tolist()
  
  return encoding

In [None]:
encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['test'].column_names)
encoded_dataset.set_format("torch", device = device)

In [None]:
def my_zeroshot_classifier(classnames, templates, model, context_length=77):
    #similar to CLIP zeroshot
    with torch.no_grad():
        zeroshot_weights = []
        # compute embedding through model for each class
        for classname in tqdm(classnames):
            texts = [template.format(classname) for template in templates] # format with class
            texts = tokenizer(texts, padding = "max_length", max_length=20) # tokenize
            
            class_embeddings_out = model(torch.tensor(texts['input_ids']).to(device), torch.Tensor(texts['attention_mask']).to(device)) # embed with text encoder
            
            class_embeddings = class_embeddings_out['last_hidden_state'][:,0,:]
            #
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
          
            # average over templates if using more than 1 template at the same time
            #class_embedding = class_embeddings.mean(dim=0) 
            # norm over new averaged templates
            #class_embedding /= class_embedding.norm() 
            
            zeroshot_weights.append(class_embeddings)
            
        zeroshot_weights = torch.stack(zeroshot_weights, dim=0)
    return zeroshot_weights

In [None]:
from sklearn.preprocessing import normalize
def predict(loader, model, zeroshot_weights, softmax_eval=True, verbose=0): 
    """
    FUNCTION: predict
    ---------------------------------
    This function runs the cxr images through the model 
    and computes the cosine similarities between the images
    and the text embeddings. 
    
    args: 
        * loader -  PyTorch data loader, loads in cxr images
        * model - PyTorch model, trained clip model 
        * zeroshot_weights - PyTorch Tensor, outputs of text encoder for labels
        * softmax_eval (optional) - Use +/- softmax method for evaluation 
        * verbose (optional) - bool, If True, will print out intermediate tensor values for debugging.
        
    Returns numpy array, predictions on all test data samples. 
    """
    y_pred = []
    with torch.no_grad():
        for i, data in enumerate(tqdm(loader)):
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            # predict
            text_features = model(input_ids, attention_mask)[0]
            text_features = text_features[:,0,:]

            # obtain logits
            y_pred_arr = []
            # obtain logits
            for class_weight in zeroshot_weights:
                logits = text_features @ class_weight.T # (1, 2)
                logits = logits.cpu().numpy()
                
                
                sigmoid = torch.nn.Sigmoid()
                norm_logits = normalize(logits, axis=1, norm='l1')  #no need, already normalized?
                norm_logits = sigmoid(norm_logits)
                y_pred_arr.append(norm_logits[0][0])
            
            
         
    y_pred = np.array(y_pred)
    return np.array(y_pred)

In [None]:
zeroshot_weights = my_zeroshot_classifier(labels, cxr_pair_template,model)

In [None]:
from torch.utils.data import DataLoader
test_dataloader = DataLoader(encoded_dataset['test'], batch_size=1, shuffle=False)

In [None]:
test_y_pred, test_probs = predict(test_dataloader, model,zeroshot_weights)
test_y_true = encoded_dataset['test']['labels'].cpu()

In [None]:
import pandas as pd
import numpy as np

val_best_thresholds = pd.read_csv('./validation/thresholds/' + path + '.csv')
val_best_thresholds = np.array(val_best_thresholds['value'])

In [None]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, precision_recall_fscore_support,precision_recall_curve
sigmoid = torch.nn.Sigmoid()
import evaluate
probs = sigmoid(torch.Tensor(test_probs))
#y_pred = np.zeros(probs.shape)

pred_labels = np.zeros_like(probs)
for i in range(test_probs.shape[1]):
    pred_labels[:, i] = np.where(test_probs[:, i] > val_best_thresholds[i], 1, 0)

y_pred = pred_labels
y_true = test_y_true

f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
roc_auc = roc_auc_score(y_true, probs, average = 'micro')
accuracy = accuracy_score(y_true, y_pred)
f1_macro_average = f1_score(y_true=y_true, y_pred=y_pred, average='macro')
#roc_auc_mac = roc_auc_score(y_true, y_pred, average = 'macro')
roc_auc_mac = roc_auc_score(y_true, probs, average = 'macro')
f1_w_average = f1_score(y_true=y_true, y_pred=y_pred, average='weighted')
roc_auc_w = roc_auc_score(y_true, probs, average = 'weighted')
#############################################################################
roc_auc_score2 = evaluate.load("roc_auc", "multilabel")
results = roc_auc_score2.compute(references=y_true, prediction_scores=probs, average = None)['roc_auc']

# return as dictionary
metrics = {'f1_micro': f1_micro_average,
        'roc_auc_micro': roc_auc,
        'f1_macro': f1_macro_average,
        'roc_auc_macro': roc_auc_mac,
        'f1_weighted': f1_w_average,
        'roc_auc_weighted': roc_auc_w,
        'accuracy': accuracy,
        'roc_auc_per_class': [round(res, 3) for res in results]
        }