In [None]:
import torch
from transformers import (
    BertForSequenceClassification,
    AutoTokenizer,
    AdamW,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import pandas as pd
import numpy as np


from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score, precision_score, recall_score
import matplotlib.pyplot as plt


In [None]:
# Load the BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT") #Not using this tho

# Assuming you have a test dataset CSV file
test_data = pd.read_csv("/workspaces/NLP_Proj2/01_intermediate-files/synthetic_test_data.csv")

#print(test_data.head(1))

In [None]:
import re

# Preprocess the test data (similar to training data prep)
test_data["Text"] = test_data["Text"].str.lower()
clean_test_messages = test_data["Text"]
clean_test_text = []
for message in clean_test_messages:
    pattern = r"(\S+\s){0,5}\S*(smok|tobacco|cigar|pack|ppd)\S*(\s\S+){0,5}"
    match = re.search(pattern, message, re.IGNORECASE)
    if match:
        matched_text = match.group(0)
        clean_test_text.append(matched_text)
    else:
        clean_test_text.append(message)

test_data["Text"] = clean_test_text

In [None]:
print(test_data['Text'].head())
#test_data['Text'].to_csv('text_data.csv', index=False)

In [None]:
test_data["Smoking Status"] = test_data["Smoking Status"].replace("SMOKER", "PAST SMOKER")

In [None]:
le = LabelEncoder()
test_data["Smoking_enc"] = le.fit_transform(test_data["Smoking Status"])
display(test_data.sample(6))

In [None]:
# Tokenize and encode sentences in the test dataset
def get_sentence_embedding(sentences):
    indexed_tokens = [
        tokenizer.encode(
            sentence, add_special_tokens=True, truncation=True, max_length=512
        )
        for sentence in sentences
    ]
    return indexed_tokens

sentences_test = test_data["Text"]
indexed_tokens_test = get_sentence_embedding(sentences_test)
print(indexed_tokens_test[0])

In [None]:
# Pad the tokenized sentences
max_length_test = max(len(tokens) for tokens in indexed_tokens_test)
padded_tokens_test = [tokens + [0] * (max_length_test - len(tokens)) for tokens in indexed_tokens_test]

In [None]:
# Convert to tensors
input_ids_test = torch.tensor(padded_tokens_test)
#print(input_ids_test[0])

In [None]:
attention_masks_test = torch.tensor([[int(token != 0) for token in tokens] for tokens in padded_tokens_test])

In [None]:
labels_test = torch.tensor(test_data["Smoking_enc"].values)

# Check input shapes
print(f"Input IDs shape: {input_ids_test.shape}")
print(f"Attention Masks shape: {attention_masks_test.shape}")
print(f"Labels shape: {labels_test.shape}")


In [None]:
# Create DataLoader for test data
test_dataset = TensorDataset(input_ids_test, attention_masks_test, labels_test)
test_dataloader = DataLoader(test_dataset, batch_size=16)

In [None]:
# Load trained model
model_path = "/workspaces/NLP_Proj2/model_test_syn_23.pth"

model = BertForSequenceClassification.from_pretrained("medicalai/ClinicalBERT", num_labels=4)
model.load_state_dict(torch.load(model_path))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()  

In [None]:
# Initialize lists to store metrics for each batch
batch_accuracy = []
batch_precision = []
batch_recall = []
batch_f1 = []

for batch in test_dataloader:
    input_ids_batch, attention_masks_batch, labels_batch = batch
    input_ids_batch = input_ids_batch.to(device)
    attention_masks_batch = attention_masks_batch.to(device)
    labels_batch = labels_batch.to(device)

    with torch.no_grad():
        outputs = model(input_ids_batch, attention_mask=attention_masks_batch, labels=labels_batch)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1).flatten()

        # Calculate and store batch metrics
        accuracy = accuracy_score(labels_batch.cpu(), preds.cpu())
        precision = precision_score(labels_batch.cpu(), preds.cpu(), average='weighted', zero_division=0)
        recall = recall_score(labels_batch.cpu(), preds.cpu(), average='weighted', zero_division=0)
        f1 = f1_score(labels_batch.cpu(), preds.cpu(), average='weighted', zero_division=0)

        batch_accuracy.append(accuracy)
        batch_precision.append(precision)
        batch_recall.append(recall)
        batch_f1.append(f1)

# Plotting
batches = range(1, len(batch_accuracy) + 1)
plt.figure(figsize=(12, 8))

plt.subplot(2, 2, 1)
plt.plot(batches, batch_accuracy, 'b-')
plt.title('Batch-wise Test Accuracy')
plt.xlabel('Batch')
plt.ylabel('Accuracy')

plt.subplot(2, 2, 2)
plt.plot(batches, batch_precision, 'g-')
plt.title('Batch-wise Test Precision')
plt.xlabel('Batch')
plt.ylabel('Precision')

plt.subplot(2, 2, 3)
plt.plot(batches, batch_recall, 'r-')
plt.title('Batch-wise Test Recall')
plt.xlabel('Batch')
plt.ylabel('Recall')

plt.subplot(2, 2, 4)
plt.plot(batches, batch_f1, 'y-')
plt.title('Batch-wise Test F1 Score')
plt.xlabel('Batch')
plt.ylabel('F1 Score')

plt.tight_layout()
plt.show()

In [None]:
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
