In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer

In [24]:
# Load and preprocess the dataset
data = pd.read_csv('medquad.csv')
data.dropna(subset=['focus_area'], inplace=True)
data['processed_question'] = data['question'].apply(lambda x: x.lower())

In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [26]:
# Tokenization
tokenizer = get_tokenizer('basic_english')
def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(data['processed_question']), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

# Encoding texts
def encode_text(text, vocab, tokenizer):
    return [vocab[token] for token in tokenizer(text)]

In [27]:
data['encoded'] = data['processed_question'].apply(lambda x: encode_text(x, vocab, tokenizer))

# Prepare labels
label_dict = {label: idx for idx, label in enumerate(pd.unique(data['focus_area']))}
data['labels'] = data['focus_area'].map(label_dict)


In [28]:
# Data Augmentation: Oversample minority classes
def oversample_data(data):
    max_size = data['labels'].value_counts().max()
    lst = [data]
    for class_index, group in data.groupby('labels'):
        lst.append(group.sample(max_size - len(group), replace=True))
    data_balanced = pd.concat(lst)
    return data_balanced

data = oversample_data(data)

# Split data
train_data, val_data = train_test_split(data, test_size=0.1, random_state=42)

In [29]:
# Dataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __getitem__(self, idx):
        return torch.tensor(self.texts[idx], dtype=torch.long), self.labels[idx]

    def __len__(self):
        return len(self.labels)

In [30]:
# Collate function to handle padding
def collate_batch(batch):
    label_list, text_list = [], []
    for _text, _label in batch:
        label_list.append(_label)
        if isinstance(_text, torch.Tensor):
            text_list.append(_text.clone().detach())  # Use clone to avoid the warning
        else:
            text_list.append(torch.tensor(_text, dtype=torch.long))  # Convert list to tensor if not already
    text_list = pad_sequence(text_list, batch_first=True, padding_value=0)
    label_list = torch.tensor(label_list, dtype=torch.long)
    return text_list, label_list

In [31]:
train_dataset = TextDataset(list(train_data['encoded']), list(train_data['labels']))
val_dataset = TextDataset(list(val_data['encoded']), list(val_data['labels']))

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_batch)


In [32]:
# Define the LSTM with Attention model
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.hidden_dim = hidden_dim
        self.main = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(True),
            nn.Linear(64, 1)
        )

    def forward(self, lstm_output, final_state):
        attn_weights = self.main(lstm_output)
        soft_attn_weights = F.softmax(attn_weights, 1)
        new_hidden_state = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights).squeeze(2)
        return new_hidden_state

In [33]:
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, text):
        embedded = self.embedding(text)
        lstm_out, (hidden, cell) = self.lstm(embedded)
        attn_out = self.attention(lstm_out, hidden)
        return self.fc(attn_out)

In [34]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

In [35]:
# Initialize model, loss, and optimizer
num_classes = 5126
model = LSTMClassifier(len(vocab), 100, 128, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [38]:
# Training and evaluation functions
def train(model, iterator, optimizer, criterion, device):
    model.train()
    epoch_loss = 0
    for texts, labels in iterator:
        texts, labels = texts.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(texts)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion, device):
    model.eval()
    epoch_loss = 0
    predictions, true_labels = [], []
    with torch.no_grad():
        for texts, labels in iterator:
            texts, labels = texts.to(device), labels.to(device)
            outputs = model(texts)
            loss = criterion(outputs, labels)
            epoch_loss += loss.item()
            preds = outputs.argmax(dim=1)
            predictions.extend(preds.tolist())
            true_labels.extend(labels.tolist())

    # Ensure all classes are included in the classification report
    all_labels = list(label_dict.values())
    report = classification_report(true_labels, predictions, labels=all_labels, target_names=list(label_dict.keys()))
    return epoch_loss / len(iterator), report


In [39]:
# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    val_loss, report = evaluate(model, val_loader, criterion, device)
    print(f'Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    print(report)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 1, Train Loss: 0.1044, Val Loss: 0.0958
                                                                                                                                          precision    recall  f1-score   support

                                                                                                                                Glaucoma       1.00      1.00      1.00        13
                                                                                                                     High Blood Pressure       1.00      1.00      1.00         4
                                                                                                                 Paget's Disease of Bone       1.00      1.00      1.00         5
                                                                                                                Urinary Tract Infections       1.00      1.00      1.00         3
                                                              

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 2, Train Loss: 0.0878, Val Loss: 0.0778
                                                                                                                                          precision    recall  f1-score   support

                                                                                                                                Glaucoma       1.00      1.00      1.00        13
                                                                                                                     High Blood Pressure       1.00      1.00      1.00         4
                                                                                                                 Paget's Disease of Bone       1.00      1.00      1.00         5
                                                                                                                Urinary Tract Infections       1.00      1.00      1.00         3
                                                              

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [20]:
num_classes = data['focus_area'].nunique()
print("Number of unique classes:", num_classes)

Number of unique classes: 5126
