In [49]:
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import numpy as np

In [41]:
# Load the dataset
data = pd.read_csv('medquad.csv')
data.dropna(subset=['focus_area'], inplace=True)

# Basic text preprocessing
data['processed_question'] = data['question'].apply(lambda x: x.lower())

In [42]:
# Split the dataset
train_texts, val_texts, train_labels, val_labels = train_test_split(
    data['processed_question'], data['focus_area'], test_size=0.1, random_state=42)

In [43]:
# Initialize the tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Encode the text data for DistilBERT
# Define text encoding function
# Function to encode texts
def encode_texts(texts):
    return tokenizer(texts.tolist(), padding=True, truncation=True, max_length=128, return_tensors='pt')

train_encodings = encode_texts(train_texts)
val_encodings = encode_texts(val_texts)

In [44]:
# Create label dictionary and encode labels
label_dict = {label: idx for idx, label in enumerate(pd.unique(data['focus_area']))}
train_labels_encoded = [label_dict[label] for label in train_labels]
val_labels_encoded = [label_dict[label] for label in val_labels]

# Dataset class
class TextDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [45]:
# Create data loaders
train_dataset = TextDataset(train_encodings, train_labels_encoded)
val_dataset = TextDataset(val_encodings, val_labels_encoded)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [46]:
# Load DistilBERT model for sequence classification
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=len(label_dict))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [47]:
# Define optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training and evaluation functions (simplified for brevity)
def train(model, loader, optimizer, device):
    model.train()
    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()




In [50]:
def evaluate(model, loader, device, label_dict):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Convert numerical labels back to original labels for reporting
    target_names = [k for k, v in sorted(label_dict.items(), key=lambda item: item[1])]
    return classification_report(all_labels, all_preds, target_names=target_names, labels=np.arange(len(label_dict)))

# Update your evaluation call
for epoch in range(3):
    train(model, train_loader, optimizer, device)
    print(f"Epoch {epoch+1} Evaluation:")
    report = evaluate(model, val_loader, device, label_dict)
    print(report)

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


Epoch 1 Evaluation:


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


                                                                                                                                          precision    recall  f1-score   support

                                                                                                                                Glaucoma       0.67      1.00      0.80         2
                                                                                                                     High Blood Pressure       1.00      1.00      1.00         3
                                                                                                                 Paget's Disease of Bone       0.00      0.00      0.00         0
                                                                                                                Urinary Tract Infections       0.00      0.00      0.00         3
                                                                                                            A

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


                                                                                                                                          precision    recall  f1-score   support

                                                                                                                                Glaucoma       0.67      1.00      0.80         2
                                                                                                                     High Blood Pressure       1.00      1.00      1.00         3
                                                                                                                 Paget's Disease of Bone       0.00      0.00      0.00         0
                                                                                                                Urinary Tract Infections       0.00      0.00      0.00         3
                                                                                                            A

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
