In [None]:
import os

import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import LongformerTokenizer, LongformerForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
from sklearn.metrics import f1_score

from sklearn.utils import resample
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
from sklearn.metrics import classification_report
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from sklearn.metrics import roc_curve
from sklearn.metrics import average_precision_score

import itertools

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
data = pd.read_csv("all_clinical_notes (Valid PS).csv")
data

In [None]:
# Load the tokenizer
tokenizer = LongformerTokenizer.from_pretrained("./best_Longformer_model")

# Initialize the model architecture
model = LongformerForSequenceClassification.from_pretrained("allenai/longformer-large-4096", num_labels=2)

# Load the saved weights into the model
model.load_state_dict(torch.load("./best_Longformer_model/pytorch_model.bin"))

# If using GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
MAX_TOKENS = 4096
def filter_exceeding_texts(notes, tokenizer):
    filtered_notes = []
    
    for note in notes:
        tokens = tokenizer.tokenize(note)
        num_tokens = len(tokens)
        
        if num_tokens > MAX_TOKENS:
            # Tokenize the note and then convert back to string 
            # only the last MAX_TOKENS of tokens
            filtered_note = tokenizer.convert_tokens_to_string(tokens[-MAX_TOKENS:])
            filtered_notes.append(filtered_note)
        else:
            filtered_notes.append(note)

    return filtered_notes

In [None]:
def warn_if_truncated(texts, max_length):
    for text in texts:
        if len(tokenizer.tokenize(text)) > max_length:
            print(f"Warning: Text with length {len(tokenizer.tokenize(text))} is truncated to {max_length} tokens.")

In [None]:
def encode_data(texts, max_length=MAX_TOKENS):
    warn_if_truncated(texts, max_length)
    encoded_data = tokenizer(texts, truncation=True, padding=True, max_length=max_length, return_tensors="pt")
    input_ids = encoded_data['input_ids']
    attention_masks = encoded_data['attention_mask']
    return input_ids, attention_masks

In [None]:
def softmax(logits):
    """Convert logits to probabilities."""
    exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
    return exp_logits / exp_logits.sum(axis=1, keepdims=True)

In [None]:
def generate_no_PS_CSV(section):
    sub_data = data[(data["split"] == section)]
    sub_notes = sub_data["text_no_ps"].tolist()
    sub_notes = ["" if type(note) != str else note for note in sub_notes]
    sub_notes = filter_exceeding_texts(sub_notes, tokenizer)

    sub_input_ids, sub_attention_masks = encode_data(sub_notes)

    sub_dataset = TensorDataset(sub_input_ids, sub_attention_masks)

    batch_size = 12

    sub_loader = DataLoader(sub_dataset, batch_size=batch_size, shuffle=False)
    
    model.eval()

    # Initialize tqdm for the loop
    sub_progress = tqdm(sub_loader, desc=section, position=0, leave=True)

    sub_logits_list = []  # Collect logits for all chunks

    sub_preds = []

    with torch.no_grad():
        for batch in sub_progress:
            inputs, masks = batch[0].to(device), batch[1].to(device)
            logits = model(inputs, attention_mask=masks).logits
            preds = torch.argmax(logits, dim=1)
            sub_preds.extend(preds.tolist())

            sub_logits_list.extend(logits.tolist())  # Append the logits for this batch
            
    sub_logits_list = np.array(sub_logits_list)
    probability = softmax(sub_logits_list)
    
    sub_data["Prediction"] = sub_preds
    sub_data["Logits (Class 0)"] = sub_logits_list[:, 0]
    sub_data["Logits (Class 1)"] = sub_logits_list[:, 1]
    sub_data["Probability (Class 0)"] = probability[:, 0]
    sub_data["Probability (Class 1)"] = probability[:, 1]
    return sub_data

In [None]:
train_result = generate_no_PS_CSV("train")
train_result

In [None]:
train_result.to_csv(f"LongFormer train result (Valid PS - PS Removed Text).csv",index = False)

In [None]:
val_result = generate_no_PS_CSV("validation")
val_result

In [None]:
val_result.to_csv(f"LongFormer validation result (Valid PS - PS Removed Text).csv",index = False)

In [None]:
test_result = generate_no_PS_CSV("test")
test_result

In [None]:
test_result.to_csv(f"LongFormer test result (Valid PS - PS Removed Text).csv",index = False)