In [None]:
import os

import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import LongformerTokenizer, LongformerForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
from sklearn.metrics import f1_score

from sklearn.utils import resample
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

## Training setup

In [None]:
if not os.path.exists("./best_Longformer_model"):
    os.makedirs("./best_Longformer_model")

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
data = pd.read_csv("all_clinical_notes (Valid PS).csv")
data

In [None]:
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-large-4096')

In [None]:
MAX_TOKENS = 4096
def filter_exceeding_texts(notes, labels, tokenizer):
    filtered_notes = []
    filtered_labels = []
    
    for note, label in zip(notes, labels):
        tokens = tokenizer.tokenize(note)
        num_tokens = len(tokens)
        
        if num_tokens > MAX_TOKENS:
            # Tokenize the note and then convert back to string 
            # only the last MAX_TOKENS of tokens
            filtered_note = tokenizer.convert_tokens_to_string(tokens[-MAX_TOKENS:])
            filtered_notes.append(filtered_note)
            filtered_labels.append(label)
        else:
            filtered_notes.append(note)
            filtered_labels.append(label)

    return filtered_notes, filtered_labels

In [None]:
train_data = data[data["split"] == "train"]
valid_data = data[data["split"] == "validation"]
test_data = data[data["split"] == "test"] 

In [None]:
train_notes = train_data["text_no_ps"].tolist()
train_labels = train_data["high_ps"].tolist()
train_notes, train_labels = filter_exceeding_texts(train_notes, train_labels, tokenizer)

val_notes = valid_data["text_no_ps"].tolist()
val_labels = valid_data["high_ps"].tolist()
val_notes, val_labels = filter_exceeding_texts(val_notes, val_labels, tokenizer)

test_notes = test_data["text_no_ps"].tolist()
test_labels = test_data["high_ps"].tolist()
test_notes, test_labels = filter_exceeding_texts(test_notes, test_labels, tokenizer)

In [None]:
class_counts = [train_labels.count(0), train_labels.count(1)]
total_samples = len(train_labels)

# Compute class weights
class_weights = [total_samples / (2.0 * count) for count in class_counts]
class_weights = torch.tensor(class_weights).to(device)

In [None]:
def warn_if_truncated(texts, max_length):
    for text in texts:
        if len(tokenizer.tokenize(text)) > max_length:
            print(f"Warning: Text with length {len(tokenizer.tokenize(text))} is truncated to {max_length} tokens.")

In [None]:
def encode_data(texts, labels, max_length=MAX_TOKENS):
    warn_if_truncated(texts, max_length)
    encoded_data = tokenizer(texts, truncation=True, padding=True, max_length=max_length, return_tensors="pt")
    input_ids = encoded_data['input_ids']
    attention_masks = encoded_data['attention_mask']
    labels_tensor = torch.tensor(labels)
    return input_ids, attention_masks, labels_tensor

In [None]:
train_input_ids, train_attention_masks, train_labels = encode_data(train_notes, train_labels)
val_input_ids, val_attention_masks, val_labels = encode_data(val_notes, val_labels)
test_input_ids, test_attention_masks, test_labels = encode_data(test_notes, test_labels)

In [None]:
train_dataset = TensorDataset(train_input_ids, train_attention_masks, train_labels)
val_dataset = TensorDataset(val_input_ids, val_attention_masks, val_labels)
test_dataset = TensorDataset(test_input_ids, test_attention_masks, test_labels)

In [None]:
model = LongformerForSequenceClassification.from_pretrained("allenai/longformer-large-4096", num_labels=2)
model = torch.nn.DataParallel(model)
model.to(device)

In [None]:
# Training parameters
epochs = 100
batch_size = 12
optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay = 0.01)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_dataset) * epochs)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

# Define the unweighted criterion for the training loop
unweighted_criterion = torch.nn.CrossEntropyLoss()

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define early stopping variables
early_stop_counter = 0
best_val_f1 = -1.0
best_val_loss = 99.0
EARLY_STOP_LIMIT = 5

## Training and validation

In [None]:
# Training loop
for epoch in range(epochs):
    model.train()
    
    # Initialize tqdm for the training loop
    train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training", position=0, leave=True)
    train_loss = 0
    train_accuracy = 0
    train_preds = []
    train_labels = []
    
    for batch in train_progress:
        inputs, masks, labels = batch[0].to(device), batch[1].to(device), batch[2].to(device)
        logits = model(inputs, attention_mask=masks).logits
        #loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
        loss = unweighted_criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Compute accuracy
        preds = torch.argmax(logits, dim=1)
        train_preds.extend(preds.tolist())
        train_labels.extend(labels.tolist())
        acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
        train_accuracy += acc

        # Update tqdm description with both loss and accuracy
        train_progress.set_description(f"Epoch {epoch+1}/{epochs} - Training Loss: {loss.item():.4f} Acc: {acc:.4f}")
    
    train_f1 = f1_score(train_labels, train_preds)
    train_accuracy /= len(train_loader)
    print(f"Epoch: {epoch+1}, Average Training Loss: {train_loss/len(train_loader)}, Training F1 Score: {train_f1:.4f}, Training Acc: {train_accuracy:.4f}")

    # Validation
    model.eval()
    val_loss = 0
    
    # Initialize tqdm for the validation loop
    val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Validation", position=0, leave=True)
    
    val_logits_list = []  # Collect logits for all chunks
    
    # Inside your validation loop
    val_accuracy = 0
    val_preds = []
    val_labels = []

    with torch.no_grad():
        for batch in val_progress:
            inputs, masks, labels = batch[0].to(device), batch[1].to(device), batch[2].to(device)
            logits = model(inputs, attention_mask=masks).logits
            #loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
            loss = unweighted_criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
            val_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            val_preds.extend(preds.tolist())
            val_labels.extend(labels.tolist())
            acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
            val_accuracy += acc
            
            val_logits_list.extend(logits.tolist())  # Append the logits for this batch
            
            # Update tqdm description
            val_progress.set_description(f"Epoch {epoch+1}/{epochs} - Validation Loss: {loss.item():.4f}")

    val_f1 = f1_score(val_labels, val_preds)
    val_accuracy /= len(val_loader)
    val_loss_avg = val_loss/len(val_loader)
    print(f"Epoch: {epoch+1}, Average Validation Loss: {val_loss_avg}, Validation F1 Score: {val_f1:.4f}, Validation Acc: {val_accuracy:.4f}")
    '''
    # Check for early stopping based on loss
    if val_loss_avg < best_val_loss:
        print(f"Saving best model associated with the current lowest loss {val_loss_avg:.4f}")
        best_val_loss = val_loss_avg
        torch.save(model.module.state_dict(), "./best_Longformer_model/pytorch_model.bin")
        tokenizer.save_pretrained("./best_Longformer_model")
        early_stop_counter = 0
    else:
        early_stop_counter += 1

    if early_stop_counter >= EARLY_STOP_LIMIT:
        print("Early stopping triggered!")
        break
    '''

    # Check for early stopping based on F1 score
    if val_f1 > best_val_f1:
        print(f"Saving best model associated with the current highest F1 score {val_f1:.4f}")
        best_val_f1 = val_f1
        torch.save(model.module.state_dict(), "./best_Longformer_model/pytorch_model.bin")
        tokenizer.save_pretrained("./best_Longformer_model")
        early_stop_counter = 0
    else:
        early_stop_counter += 1

    if early_stop_counter >= EARLY_STOP_LIMIT:
        print("Early stopping triggered!")
        break

## Validation and test set evaluation

In [None]:
model.eval()
test_loss = 0
predictions = []

# Initialize tqdm for the test loop
test_progress = tqdm(test_loader, desc=f"Testing", position=0, leave=True)

# Inside the test loop
test_logits_list = []  # Collect logits for all chunks

test_accuracy = 0
test_preds = []
test_labels = []

with torch.no_grad():
    for batch in test_progress:
        inputs, masks, labels = batch[0].to(device), batch[1].to(device), batch[2].to(device)
        logits = model(inputs, attention_mask=masks).logits
        #loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
        loss = unweighted_criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
        test_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        test_preds.extend(preds.tolist())
        test_labels.extend(labels.tolist())
        acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
        test_accuracy += acc
        
        test_logits_list.extend(logits.tolist())  # Append the logits for this batch

        # Update tqdm description
        test_progress.set_description(f"Test Loss: {loss.item():.4f}")
        
test_f1 = f1_score(test_labels, test_preds)
test_accuracy /= len(test_loader)
print(f"Average Test Loss: {test_loss/len(test_loader)}, Test F1 Score: {test_f1:.4f}, Test Acc: {test_accuracy:.4f}")

In [None]:
# helper functions to plot model metrics
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    import itertools
    import numpy as np


    from sklearn.metrics import confusion_matrix
  
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.grid(False)
    plt.show()

In [None]:
# evaluate model auc
def eval_model(predicted, actual):
    from sklearn.metrics import roc_auc_score
    from sklearn.metrics import f1_score
    from sklearn.metrics import classification_report
    from sklearn.metrics import precision_recall_curve
    from sklearn.metrics import auc
    from sklearn.metrics import roc_curve

    print("AUC " + str(roc_auc_score(actual, predicted)))

    # calculate the fpr and tpr for all thresholds of the classification
    fpr, tpr, threshold = roc_curve(actual, predicted)
    roc_auc = auc(fpr, tpr)

    from sklearn.metrics import average_precision_score
    average_precision = average_precision_score(actual, predicted)

    print('Average precision score: {0:0.2f}'.format(
        average_precision))


    # method I: plt
    import matplotlib.pyplot as plt
    plt.title('Receiver Operating Characteristic: ' )
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()



    import matplotlib.pyplot as plt
    #from sklearn.utils.fixes import signature

    precision, recall, thresholds = precision_recall_curve(actual, predicted)
    
    outcome_counts = np.unique(actual, return_counts=True)[1]
    prob_outcome = outcome_counts[1] / (outcome_counts[0] + outcome_counts[1])
    print('Outcome probability:')
    print(prob_outcome)
    
    plt.plot(recall, precision, color='b')
    plt.plot([0,1],[prob_outcome,prob_outcome], 'r--')
    plt.step(recall, precision, color='b', alpha=0.2,
             where='post')
    plt.fill_between(recall, precision, alpha=0.2, color='b')

    plt.xlabel('Recall (Sensitivity)')
    plt.ylabel('Precision (PPV)')
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.0])
    plt.title('2-class Precision-Recall curve: AP={0:0.2f}'.format(
            average_precision))
    plt.show()


    # best F1
    F1 = 2*((precision*recall)/(precision+recall))
    print("Best F1 ")
    print(max(F1))
    
    # threshold for best F1
    bestF1_thresh = thresholds[np.argmax(F1)]
    print("Threshold for best F1:")
    print(bestF1_thresh)
    pred_outcome_best_f1_thresh = np.where(predicted >= bestF1_thresh,1,0)
    print(np.unique(pred_outcome_best_f1_thresh, return_counts=True))
    pred_outcome_00_thresh = np.where(predicted >= 0.0,1,0)
    
    # # predictions
    
    # # confusion matrix
    print("Confusion matrix at best F1 thresh:")
    from sklearn.metrics import confusion_matrix
    cnf_matrix = confusion_matrix(actual, pred_outcome_best_f1_thresh)
    np.set_printoptions(precision=2)
    # Plot non-normalized confusion matrix
    plt.figure()
    plot_confusion_matrix(cnf_matrix, classes=['No','Yes'],
                        title='Confusion matrix, without normalization')
    print("Metrics at best F1 thresh (specificity is recall for negative class):")
    from sklearn.metrics import classification_report
    print(classification_report(actual, pred_outcome_best_f1_thresh, target_names=['No','Yes']))


    print("Confusion matrix at 0.0 thresh:")
    from sklearn.metrics import confusion_matrix
    cnf_matrix = confusion_matrix(actual, pred_outcome_00_thresh)
    np.set_printoptions(precision=2)
    # Plot non-normalized confusion matrix
    plt.figure()
    plot_confusion_matrix(cnf_matrix, classes=['No','Yes'],
                        title='Confusion matrix, without normalization')
    print("Metrics at 0.0 thresh thresh (specificity is recall for negative class):")
    print(classification_report(actual, pred_outcome_00_thresh, target_names=['No','Yes']))

    # # plot threshold vs ppv curve
    plt.plot(thresholds, precision[0:len(precision)-1], color='b')

    plt.xlabel('Threshold probability')
    plt.ylabel('Precision (PPV)')
    plt.ylim([0.0, 1.0])
    plt.xlim([0.0, 1.0])
    plt.title('Threshold vs precision')
    plt.show()

    # histogram
    plt.hist(predicted)
    plt.title("Histogram")
    plt.xlabel("Predicted probability" )
    plt.ylabel("Frequency")
    plt.show()

In [None]:
eval_model(val_logits_list[:, 1], val_labels)

In [None]:
eval_model(test_logits_list[:, 1], test_labels)