# Training a BERT Model for Medical Diagnoses Classification
This notebook demonstrates how to train a BERT model to classify medical diagnoses based on their descriptions and CIE-10 codes. It includes steps for loading data, preprocessing, training, evaluation, and querying the model.

In [None]:
!pip install pandas numpy scikit-learn torch transformers datasets matplotlib

## Import Required Libraries
We need several libraries for data manipulation, model training, and evaluation.
- `pandas` and `numpy` for data manipulation.
- `scikit-learn` for data splitting.
- `torch` for PyTorch, the deep learning framework.
- `transformers` for BERT model and tokenizer.
- `datasets` for handling datasets.
- `matplotlib` for plotting graphs.

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from transformers import DataCollatorWithPadding
from datasets import Dataset
import matplotlib.pyplot as plt
import os

## Load the Training Dataset
We load the training dataset containing medical diagnoses descriptions and their corresponding CIE-10 codes.
- `pd.read_csv` is used to read the CSV file into a DataFrame.
- We select only the relevant columns (`description` and `code`).
- We rename the columns to `text` and `label` for consistency.

In [None]:
# Load the training dataset
df_train = pd.read_csv('../csv_import_scrips/cie10-es-diagnoses.csv')
df_train = df_train[['description', 'code']]
df_train = df_train.rename(columns={'description': 'text', 'code': 'label'})
df_train.head()

## Load the Evaluation Dataset
We load a separate evaluation dataset to validate the model's performance.
- Similar steps are followed as for the training dataset.

In [None]:
# Load the evaluation dataset
df_eval = pd.read_csv('../generated-diagnoses/diagnosticos_medicos_10000.csv')
df_eval = df_eval[['Diagnóstico', 'CIE-10']]
df_eval = df_eval.rename(columns={'Diagnóstico': 'text', 'CIE-10': 'label'})
df_eval.head()

## Preprocess Data
We convert the CIE-10 codes to categorical labels and create a mapping from label indices back to CIE-10 codes.
- Convert the `label` column to a categorical type.
- Create a dictionary to map label indices to CIE-10 codes.
- Convert the categorical labels to numerical codes.

In [None]:
# Preprocess data
df_train['label'] = df_train['label'].astype('category')
df_eval['label'] = df_eval['label'].astype('category')
label_to_code = dict(enumerate(df_train['label'].cat.categories))
df_train['label'] = df_train['label'].cat.codes
df_eval['label'] = df_eval['label'].cat.codes
df_train.head(), df_eval.head()

## Tokenize and Encode Data
We use the BERT tokenizer to tokenize and encode the text data.
- Load the BERT tokenizer.
- Define a function to tokenize the text data.
- Convert the DataFrame to a Dataset object.
- Apply the tokenizer to the dataset.
- Use `DataCollatorWithPadding` to handle padding.

In [None]:
# Tokenize and encode data
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True)

train_dataset = Dataset.from_pandas(df_train)
eval_dataset = Dataset.from_pandas(df_eval)

tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_eval_dataset = eval_dataset.map(tokenize_function, batched=True)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## Load or Train the Model
We check if a trained model already exists. If it does, we load it. Otherwise, we train a new model.
- Check if the model directory exists.
- If it exists, load the model and tokenizer from the saved files.
- If it doesn't exist, load a pre-trained BERT model and fine-tune it on our dataset.
- Save the trained model and tokenizer.

In [None]:
# Check if the model is already trained and saved
model_path = './trained_model'

if os.path.exists(model_path):
    # Load the saved model and tokenizer
    model = BertForSequenceClassification.from_pretrained(model_path)
    tokenizer = BertTokenizer.from_pretrained(model_path)
    print('Model loaded from saved files.')
else:
    # Load pre-trained BERT model
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(df_train['label'].unique()))
    
    # Fine-tune BERT model
    training_args = TrainingArguments(
        output_dir='./results',
        evaluation_strategy='epoch',
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=3,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=10,
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    
    trainer.train()
    
    # Save the trained model
    model.save_pretrained(model_path)
    tokenizer.save_pretrained(model_path)
    print('Model trained and saved.')

# Define the trainer for evaluation
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

## Evaluate Model Performance
We evaluate the model's performance on the evaluation dataset.
- Use the `evaluate` method of the `Trainer` class to evaluate the model.

In [None]:
# Evaluate model performance
trainer.evaluate()

## Plot Training and Validation Loss
We plot the training and validation loss to visualize the model's performance over epochs.
- Extract the training and validation loss from the training metrics.
- Plot the losses using `matplotlib`.

In [None]:
# Plot training metrics
training_metrics = trainer.state.log_history
losses = [x['loss'] for x in training_metrics if 'loss' in x]
eval_losses = [x['eval_loss'] for x in training_metrics if 'eval_loss' in x]
epochs = range(1, len(losses) + 1)

plt.figure(figsize=(10, 5))
plt.plot(epochs, losses, label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(range(1, len(eval_losses) + 1), eval_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Validation Loss')
plt.legend()
plt.show()

## Query the Model
We provide examples of how to query the model with new diagnoses descriptions and get the predicted CIE-10 codes.
- Define a `predict` function to process the input text and get the model's prediction.
- Ensure the inputs are on the same device as the model.
- Map the predicted label index back to the CIE-10 code using the `label_to_code` dictionary.

In [None]:
# Example queries to the model
def predict(text):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}  # Ensure inputs are on the same device as the model
    outputs = model(**inputs)
    predictions = torch.argmax(outputs.logits, dim=-1)
    return label_to_code[predictions.item()]

# Example queries
examples = [
    'Cambio en cerebro, de dispositivo de drenaje, abordaje externo',
    'Escisión de cerebro, diagnóstico, abordaje abierto'
]

for example in examples:
    label = predict(example)
    print(f'Text: {example}\nPredicted Label: {label}\n')