#Final Project



## Importing Libraries


In [1]:
!pip install transformers torch pandas wandb -q

from google.colab import drive
import pandas as pd
import torch
from transformers import BertTokenizerFast, BertForQuestionAnswering, AdamW
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler
from tqdm.auto import tqdm
import os

## Setting up the Home in Google Drive

In [2]:
drive.mount('/content/drive')

SAVE_DIR = '/content/drive/MyDrive/clinical_qa_models'
os.makedirs(SAVE_DIR, exist_ok=True)

Mounted at /content/drive


From here, skip over to the `Load Pretrained Model` codeblock and run from there if you have the `.pt` file. [~1.2GB]

Save it in the above created folder

## Data Preprocessing

In [None]:
# Load that parquet file
df = pd.read_parquet('validation-00000-of-00001.parquet')

# Get our tokenizer
tokenizer = BertTokenizerFast.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

def preprocess_dataset(df, tokenizer, max_len=512):
    processed_data = []

    for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing data 🎭"):
        context = row['context']
        question = row['question']
        answers = row['answers']

        answer_start = answers['answer_start'][0]
        answer_end = answers['answer_end'][0]

        encoding = tokenizer.encode_plus(
            question, context,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_offsets_mapping=True,
            return_tensors='pt'
        )

        start_token = None
        end_token = None

        for idx, (start, end) in enumerate(encoding['offset_mapping'][0]):
            if start == answer_start:
                start_token = idx
            if end == answer_end:
                end_token = idx

        if start_token is None or end_token is None:
            continue

        processed_data.append({
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'start_positions': torch.tensor(start_token),
            'end_positions': torch.tensor(end_token)
        })

    return processed_data


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

## Quick EDA step

In [None]:
def analyze_training_data(df):
    """
    Analyze training data characteristics
    """
    print("Dataset Analysis:")
    print("-" * 50)

    # Basic stats
    print(f"\nTotal QA pairs: {len(df)}")

    # Question type analysis
    def get_first_word(q):
        return q.lower().strip().split()[0]

    question_starts = pd.Series([get_first_word(q) for q in df['question']])
    print("\nQuestion Types Distribution:")
    print(question_starts.value_counts(normalize=True).mul(100).round(1).head())

    # Pattern analysis
    patterns = {
        'Medication': len(df[df['question'].str.lower().str.contains('medic|prescribed|taking')]),
        'Symptoms': len(df[df['question'].str.lower().str.contains('symptom|present|complain')]),
        'History': len(df[df['question'].str.lower().str.contains('history|previous|prior')])
    }

    print("\nQuestion Pattern Distribution:")
    for pattern, count in patterns.items():
        print(f"{pattern}: {count} ({count/len(df)*100:.1f}%)")

# Run EDA
analyze_training_data(df)

Dataset Analysis:
--------------------------------------------------

Total QA pairs: 32739

Question Types Distribution:
has     48.6
what    21.8
is       8.1
why      5.7
was      4.5
Name: proportion, dtype: float64

Question Pattern Distribution:
Medication: 8005 (24.5%)
Symptoms: 46 (0.1%)
History: 5095 (15.6%)


Analysis of our training data reveals significant imbalances that guided our hybrid approach.


The dataset is heavily skewed towards medication-related queries (24.5%) with very few symptom questions (0.1%). Additionally, almost half (48.6%) of all questions begin with "has", suggesting the model will be primarily trained on yes/no medication history questions.

This explains why we developed a hybrid approach: using regex pattern matching for symptoms (due to limited training examples) while leveraging BERT's strong performance on medication queries (where we had abundant training data).

## Create the dataloaders

Using 80/20 train/validation split

In [None]:
# Process and split our data
processed_dataset = preprocess_dataset(df, tokenizer)
print(f"Processed {len(processed_dataset)} entries! 🎉")

input_ids = torch.stack([item['input_ids'] for item in processed_dataset])
attention_masks = torch.stack([item['attention_mask'] for item in processed_dataset])
start_positions = torch.stack([item['start_positions'] for item in processed_dataset])
end_positions = torch.stack([item['end_positions'] for item in processed_dataset])

dataset = TensorDataset(input_ids, attention_masks, start_positions, end_positions)

# Split into train/val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(
    train_dataset,
    sampler=RandomSampler(train_dataset),
    batch_size=16,
    pin_memory=True,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    sampler=SequentialSampler(val_dataset),
    batch_size=16,
    pin_memory=True,
    num_workers=2
)


Processing data 🎭:   0%|          | 0/32739 [00:00<?, ?it/s]

Processed 25813 entries! 🎉


## Setup Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BertForQuestionAnswering.from_pretrained('emilyalsentzer/Bio_ClinicalBERT').to(device)
optimizer = AdamW(model.parameters(), lr=3e-5)


pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Training Function

In [None]:
def train_model(epochs=30):
    best_loss = float('inf')

    for epoch in range(epochs):
        print(f"✨ Epoch {epoch+1}/{epochs} ✨")

        # Training phase
        model.train()
        train_loss = 0
        for batch in tqdm(train_loader, desc="Training"):
            optimizer.zero_grad()

            input_ids = batch[0].to(device)
            attention_mask = batch[1].to(device)
            start_positions = batch[2].to(device)
            end_positions = batch[3].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                start_positions=start_positions,
                end_positions=end_positions
            )

            loss = outputs.loss
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating"):
                input_ids = batch[0].to(device)
                attention_mask = batch[1].to(device)
                start_positions = batch[2].to(device)
                end_positions = batch[3].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    start_positions=start_positions,
                    end_positions=end_positions
                )

                val_loss += outputs.loss.item()

        avg_val_loss = val_loss / len(val_loader)

        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f}")

        # Save if it's our best model yet!
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'loss': best_loss
            }, os.path.join(SAVE_DIR, 'best_model.pt'))
            print("💃 New best model! Saved to Drive!")

## Training go BRRRRR

In [None]:
torch.cuda.empty_cache()
print(f"🎮 Using device: {device}")

# The moment you've been waiting for
train_model()

🎮 Using device: cuda
✨ Epoch 1/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 1.3291
Val Loss: 0.5583
💃 New best model! Saved to Drive!
✨ Epoch 2/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Train Loss: 0.4902
Val Loss: 0.3336
💃 New best model! Saved to Drive!
✨ Epoch 3/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.2768
Val Loss: 0.2516
💃 New best model! Saved to Drive!
✨ Epoch 4/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Train Loss: 0.1975
Val Loss: 0.1958
💃 New best model! Saved to Drive!
✨ Epoch 5/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.1448
Val Loss: 0.1907
💃 New best model! Saved to Drive!
✨ Epoch 6/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.1198
Val Loss: 0.1574
💃 New best model! Saved to Drive!
✨ Epoch 7/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Train Loss: 0.0962
Val Loss: 0.1704
✨ Epoch 8/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0845
Val Loss: 0.1513
💃 New best model! Saved to Drive!
✨ Epoch 9/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0704
Val Loss: 0.1825
✨ Epoch 10/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Train Loss: 0.0702
Val Loss: 0.1336
💃 New best model! Saved to Drive!
✨ Epoch 11/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0570
Val Loss: 0.1511
✨ Epoch 12/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0504
Val Loss: 0.1432
✨ Epoch 13/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Train Loss: 0.0469
Val Loss: 0.1406
✨ Epoch 14/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Train Loss: 0.0473
Val Loss: 0.1303
💃 New best model! Saved to Drive!
✨ Epoch 15/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0375
Val Loss: 0.1427
✨ Epoch 16/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0421
Val Loss: 0.1229
💃 New best model! Saved to Drive!
✨ Epoch 17/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Train Loss: 0.0381
Val Loss: 0.1434
✨ Epoch 18/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0403
Val Loss: 0.1411
✨ Epoch 19/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0400
Val Loss: 0.1170
💃 New best model! Saved to Drive!
✨ Epoch 20/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Train Loss: 0.0306
Val Loss: 0.1300
✨ Epoch 21/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010><function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
        if w.is_alive():if w.is_alive():

  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 'can only te

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0332
Val Loss: 0.1545
✨ Epoch 22/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0334
Val Loss: 0.1360
✨ Epoch 23/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0257
Val Loss: 0.1405
✨ Epoch 24/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0304
Val Loss: 0.1268
✨ Epoch 25/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Train Loss: 0.0306
Val Loss: 0.1491
✨ Epoch 26/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validating:   0%|          | 0/323 [00:10<?, ?it/s]

Train Loss: 0.0275
Val Loss: 0.1144
💃 New best model! Saved to Drive!
✨ Epoch 27/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0279
Val Loss: 0.1100
💃 New best model! Saved to Drive!
✨ Epoch 28/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>
<function _MultiProcessingDataLoaderIter.__del__ at 0x794dd2137010>Traceback (most recent call last):

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()    
self._shutdown_workers()  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers

    if w.is_alive():    if w.is_alive():

  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 'can only te

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0269
Val Loss: 0.1232
✨ Epoch 29/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0270
Val Loss: 0.1249
✨ Epoch 30/30 ✨


Training:   0%|          | 0/1291 [00:00<?, ?it/s]

Validating:   0%|          | 0/323 [00:00<?, ?it/s]

Train Loss: 0.0284
Val Loss: 0.1353


Ignore those multithreading errors, not sure what happened there.

After everything, I got the best validation loss around 0.1-ish

## Load Pretrained Model

In [3]:
def load_pretrained_model(checkpoint_path):
    """
    Load pretrained model from checkpoint
    """
    # Initialize model and tokenizer
    tokenizer = BertTokenizerFast.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
    model = BertForQuestionAnswering.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

    # Load checkpoint
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    return model, tokenizer, device

# Use this instead of training
model, tokenizer, device = load_pretrained_model('/content/drive/MyDrive/clinical_qa_models/best_model.pt')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  checkpoint = torch.load(checkpoint_path, map_location=device)


## Now to test the given model

In [4]:
import re
def get_focused_medical_answer(question, context, model, tokenizer, device):
    """
    Last version with regex-first approach for medications too!
    """
    question_lower = question.lower()
    is_symptom_question = any(word in question_lower for word in ['symptom', 'present', 'complain', 'report', 'feeling'])

    if is_symptom_question:
        # Keep our winning symptom approach
        symptom_patterns = [
            r"presents? with (.*?)(?=\. |$)",
            r"presenting with (.*?)(?=\. |$)",
            r"complains? of (.*?)(?=\. |$)",
            r"experiencing (.*?)(?=\. |$)",
            r"reports? (.*?)(?=\. |$)",
            r"chief complaint:?\s*(.*?)(?=\. |$)"
        ]

        for pattern in symptom_patterns:
            match = re.search(pattern, context, re.IGNORECASE)
            if match:
                answer = match.group(1).strip()
                answer = re.sub(r'\s*x\s*\d+\s*days?', '', answer)
                answer = re.sub(r'\s*over.*?(?=\.|$)', '', answer)
                return answer, 75.0

        sentences = re.split('[.!?]', context)
        symptom_words = ['pain', 'fever', 'cough', 'dyspnea', 'fatigue', 'weakness',
                        'stiffness', 'swelling', 'palpitation', 'tremor', 'intolerance']

        for sentence in sentences:
            if any(word in sentence.lower() for word in symptom_words):
                answer = sentence.strip()
                answer = re.sub(r'^(patient|noted to have|observed to have|chief complaint:)\s*', '', answer, flags=re.IGNORECASE)
                answer = re.sub(r'\s*x\s*\d+\s*days?', '', answer)
                return answer, 50.0

        return "", 0.0

    else:
        # Try regex for medications first
        med_patterns = [
            r"medications? include:?\s*(.*?)(?=\.|$)",
            r"current(?:ly)? (?:on|taking):?\s*(.*?)(?=\.|$)",
            r"(?:started|continuing) on:?\s*(.*?)(?=\.|$)",
            r"home medications?:?\s*(.*?)(?=\.|$)"
        ]

        # Try direct regex match first
        for pattern in med_patterns:
            match = re.search(pattern, context, re.IGNORECASE)
            if match:
                meds = match.group(1).strip()
                if ',' in meds or ' and ' in meds:  # Looks like a proper list
                    return meds, 85.0

        # If no clear medication list found, look for medication changes
        change_patterns = [
            r"(?:started|begun) on\s*(.*?)(?=\.|$)",
            r"(?:continued|continuing) on\s*(.*?)(?=\.|$)",
            r"(?:changed|switched) to\s*(.*?)(?=\.|$)"
        ]

        med_sections = []
        for pattern in change_patterns:
            matches = re.finditer(pattern, context, re.IGNORECASE)
            for match in matches:
                if 'discontinue' not in match.group(1).lower():
                    med_sections.append(match.group(1).strip())

        if med_sections:
            return ', '.join(med_sections), 80.0

        # Fallback to BERT if regex didn't find good matches
        inputs = tokenizer.encode_plus(
            question, context,
            return_tensors='pt',
            max_length=512,
            truncation=True,
            padding='max_length'
        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            start_scores = outputs.start_logits[0]
            end_scores = outputs.end_logits[0]
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

            start_idx = torch.argmax(start_scores)
            max_length = 100
            end_scores_slice = end_scores[start_idx:start_idx + max_length]
            end_idx = start_idx + torch.argmax(end_scores_slice)

            answer_tokens = [t for t in tokens[start_idx:end_idx + 1]
                           if t not in ['[CLS]', '[SEP]', '[PAD]']]
            answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens))

            answer = re.sub(r'^(medications?:?\s*|prescribed\s*|taking\s*|include:?\s*)', '',
                          answer.lower())
            answer = re.sub(r'\s+(currently|also|additionally)\.?$', '', answer)
            answer = ' '.join(answer.split())

            confidence = float(torch.sigmoid(torch.max(start_scores))) * 100

            return answer, min(confidence, 100)

test_cases = [
    {
        'context': "Patient presents with severe migraine and photophobia since this morning. Currently taking Imitrex 50mg PRN for headaches.",
        'question': "What are the current symptoms?",
        'expected': "severe migraine and photophobia"  # What we WANT
    },
    {
        'context': "Current medications include Metformin 500mg twice daily. Patient reports persistent fatigue and weight loss.",
        'question': "What medications is the patient taking?",
        'expected': "Metformin 500mg twice daily"
    },
    {
        'context': "Patient is a 45-year-old female presenting with acute onset chest pain, radiating to left arm, and diaphoresis. Taking aspirin 81mg daily for prevention.",
        'question': "What are the presenting symptoms?",
        'expected': "acute onset chest pain, radiating to left arm, and diaphoresis"
    },
    {
        'context': "Managing hypertension with Lisinopril 20mg daily and Hydrochlorothiazide 25mg daily. Blood pressure remains elevated at 150/95.",
        'question': "What medications is the patient on?",
        'expected': "Lisinopril 20mg daily and Hydrochlorothiazide 25mg daily"
    },
    {
        'context': "Three-day history of fever (102.1°F), productive cough with green sputum, and dyspnea on exertion. No medications currently.",
        'question': "What symptoms does the patient have?",
        'expected': "fever (102.1°F), productive cough with green sputum, and dyspnea on exertion"
    },
    {
        'context': "Diabetes managed with insulin glargine 30 units at bedtime and metformin 1000mg BID. Recent A1c: 7.2%",
        'question': "List current medications",
        'expected': "insulin glargine 30 units at bedtime and metformin 1000mg BID"
    },
    {
        'context': "Patient complains of intermittent lower back pain, worse with movement, and morning stiffness lasting >1 hour. Using ibuprofen 600mg PRN.",
        'question': "What are the reported symptoms?",
        'expected': "intermittent lower back pain, worse with movement, and morning stiffness lasting >1 hour"
    },
    {
        'context': "Current psychiatric medications: Prozac 40mg daily, Wellbutrin SR 150mg BID, and Xanax 0.5mg PRN for anxiety attacks.",
        'question': "What medications is the patient prescribed?",
        'expected': "Prozac 40mg daily, Wellbutrin SR 150mg BID, and Xanax 0.5mg PRN"
    },
    {
        'context': "Patient presents with generalized weakness, shortness of breath on minimal exertion, and orthopnea x 3 days. Currently on Lasix 40mg daily, Coreg 25mg BID, and Lisinopril 20mg daily for CHF.",
        'question': "What are the current symptoms?",
        'expected': "generalized weakness, shortness of breath on minimal exertion, and orthopnea"
    },
    {
        'context': "Multiple medication changes at last visit. Discontinued Metformin, started on Jardiance 10mg daily and Ozempic 0.25mg weekly, continued on Lantus 30 units nightly.",
        'question': "What medications is the patient currently on?",
        'expected': "Jardiance 10mg daily and Ozempic 0.25mg weekly and Lantus 30 units nightly"
    },
    {
        'context': "Chief complaint: 2 weeks of worsening joint pain affecting bilateral knees, wrists, and MCP joints with associated morning stiffness lasting > 2 hours. Notable swelling and erythema of affected joints.",
        'question': "What symptoms does the patient report?",
        'expected': "worsening joint pain affecting bilateral knees, wrists, and MCP joints with associated morning stiffness lasting > 2 hours"
    },
    {
        'context': "Home medications include: Synthroid 125mcg daily, Vitamin D3 2000 units daily, Calcium carbonate 600mg BID, Iron sulfate 325mg TID, and prenatal vitamins.",
        'question': "List current medications",
        'expected': "Synthroid 125mcg daily, Vitamin D3 2000 units daily, Calcium carbonate 600mg BID, Iron sulfate 325mg TID, and prenatal vitamins"
    },
    {
        'context': "Patient experiencing persistent palpitations, tremors, heat intolerance, and unintentional 10lb weight loss over past month despite good appetite.",
        'question': "What are the presenting symptoms?",
        'expected': "persistent palpitations, tremors, heat intolerance, and unintentional 10lb weight loss"
    }
]

print("✨ Results! ✨")
for case in test_cases:
    print("\n" + "💫"*20)
    print(f"Context: {case['context']}")
    print(f"Q: {case['question']}")
    answer, conf = get_focused_medical_answer(case['question'], case['context'], model, tokenizer, device)
    print(f"A: {answer}")
    print(f"Expected: {case['expected']}")
    print(f"Confidence: {conf:.2f}%")

    # Calculate similarity
    from difflib import SequenceMatcher
    similarity = SequenceMatcher(None, answer.lower(), case['expected'].lower()).ratio()
    print(f"Similarity to expected: {similarity:.2%}")

✨ Results! ✨

💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫
Context: Patient presents with severe migraine and photophobia since this morning. Currently taking Imitrex 50mg PRN for headaches.
Q: What are the current symptoms?
A: severe migraine and photophobia since this morning
Expected: severe migraine and photophobia
Confidence: 75.00%
Similarity to expected: 76.54%

💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫
Context: Current medications include Metformin 500mg twice daily. Patient reports persistent fatigue and weight loss.
Q: What medications is the patient taking?
A: current medications include metformin 500mg twice daily.
Expected: Metformin 500mg twice daily
Confidence: 76.90%
Similarity to expected: 65.06%

💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫💫
Context: Patient is a 45-year-old female presenting with acute onset chest pain, radiating to left arm, and diaphoresis. Taking aspirin 81mg daily for prevention.
Q: What are the presenting symptoms?
A: acute onset chest pain, radiating to left arm, and diaphoresis
Expected: acute onset chest pain, 

In [5]:
def analyze_results(test_cases):
    results = {
        'symptoms': {'count': 0, 'perfect_matches': 0, 'avg_similarity': 0},
        'medications': {'count': 0, 'perfect_matches': 0, 'avg_similarity': 0}
    }

    for case in test_cases:
        answer, conf = get_focused_medical_answer(case['question'], case['context'], model, tokenizer, device)
        similarity = SequenceMatcher(None, answer.lower(), case['expected'].lower()).ratio()

        category = 'symptoms' if 'symptom' in case['question'].lower() else 'medications'
        results[category]['count'] += 1
        results[category]['avg_similarity'] += similarity
        if similarity > 0.95:  # Almost perfect match
            results[category]['perfect_matches'] += 1

    # Calculate averages
    for category in results:
        results[category]['avg_similarity'] = (results[category]['avg_similarity'] /
                                             results[category]['count'] * 100)

    return results

# Print summary
results = analyze_results(test_cases)
print("Performance Summary:")
print("-" * 50)
for category in results:
    print(f"\n{category.title()} Extraction:")
    print(f"Total Cases: {results[category]['count']}")
    print(f"Perfect Matches (>95%): {results[category]['perfect_matches']}")
    print(f"Average Similarity: {results[category]['avg_similarity']:.1f}%")

Performance Summary:
--------------------------------------------------

Symptoms Extraction:
Total Cases: 7
Perfect Matches (>95%): 5
Average Similarity: 84.3%

Medications Extraction:
Total Cases: 6
Perfect Matches (>95%): 1
Average Similarity: 66.6%
