## Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.insert(0, '..')

In [3]:
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig

In [4]:
df = pd.read_csv('../data/data_version3/dataset_processed.csv')
df.head()

Unnamed: 0,partner_id,specialist_id,status,gender,province_id,age,reason_combind,specialist_name,correct_prediction,ids,logs,processed_symptoms
0,2,18,2,,1,57,mất ngủ,thần kinh,,,,mất ngủ
1,4,18,2,,1,35,rối loạn thần kinh thực vật,thần kinh,,,,rối loạn thần kinh thực vật
2,4,18,2,,11,36,đau đầu,thần kinh,,,,đau đầu
3,4,18,2,,1,40,"đau đầu,đau sau ngực gần phổi",thần kinh,,,,"đau đầu, đau sau ngực gốc phổi"
4,3,18,2,,1,12,co giật 3 lần,thần kinh,,,,có giật 3 lần


In [5]:
df['gender'] = df['gender'].fillna('unknown')
df['gender'] = df['gender'].replace({1.0: 'female', 0.0: 'male'})

In [6]:
df['age'] = df['age'].fillna(0)
df['age_category'] = df['age'].apply(lambda x: 
                                                     'unknown' if x == 0 else 
                                                     'child' if 0 < x <= 15 else 
                                                     'adult')
# Display the results
print(df['age_category'].value_counts())

age_category
adult      34894
unknown    13124
child       5606
Name: count, dtype: int64


In [7]:
df.shape# Filter data to keep only records with non-null processed_symptoms
processed_df = df.dropna(subset=['processed_symptoms'])

# Check the shape after filtering
print(f"Number of records after filtering for non-null processed_symptoms: {processed_df.shape[0]}")
print(f"Percentage of data retained: {processed_df.shape[0]/len(processed_df)*100:.2f}%")

Number of records after filtering for non-null processed_symptoms: 45959
Percentage of data retained: 100.00%


In [8]:
processed_df.specialist_name.unique()

array(['thần kinh', 'vô sinh - hiếm muộn', 'nhi khoa', 'thận - tiết niệu',
       'ung bướu', 'hô hấp - phổi', 'chuyên khoa mắt', 'cơ xương khớp',
       'nha khoa', 'tim mạch', 'tiêu hoá', 'sức khỏe tâm thần',
       'nội khoa', 'tiểu đường - nội tiết', 'tai mũi họng', 'nam học',
       'da liễu', 'sản phụ khoa'], dtype=object)

In [9]:
class_counts = processed_df['specialist_id'].value_counts()
print(f'number of samples per class: {class_counts}')

number of samples per class: specialist_id
1     8029
18    6897
11    5113
22    4123
4     3824
17    3680
27    2998
3     2622
26    1953
29    1462
5     1358
15    1068
32     898
43     613
21     593
19     285
33     272
67     171
Name: count, dtype: int64


In [10]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight

y = df['specialist_id']
weight_class = compute_class_weight(class_weight="balanced", classes=np.unique(y), y=y)

In [11]:
weight_class, len(weight_class)

(array([ 0.33182347,  0.93889414,  0.72555068,  2.03075059,  0.54602476,
         2.18891338,  0.65388743,  0.3957374 ,  1.85845983,  3.91473208,
         0.58402492,  1.28965849,  0.89732262,  1.87483393,  2.71074714,
         9.51792687,  4.58324786, 11.77514273]),
 18)

In [12]:
type(weight_class[0])

numpy.float64

In [13]:
processed_df['reason_combind'] = processed_df['processed_symptoms']

missing_specialists = set(df['specialist_name'].unique()) - set(processed_df['specialist_name'].unique())
print(f"Number of specialists missing from processed_df: {len(missing_specialists)}")
print(f"Missing specialists: {missing_specialists}")

# Add records with missing specialists from df to processed_df
if len(missing_specialists) > 0:
    missing_df = pd.DataFrame()
    for specialist in missing_specialists:
        specialist_data = df[df['specialist_name'] == specialist].copy()
        
        # Since there's no main_reason column, just use empty string for reason_combind
        missing_df = pd.concat([missing_df, specialist_data])
    
    # Combine with processed_df
    processed_df = pd.concat([processed_df, missing_df])
    
print(f"Final processed_df shape: {processed_df.shape}")
print(f"Number of specialists in final dataset: {len(processed_df['specialist_name'].unique())}")

# Replace the original df with processed_df for further processing
df = processed_df.copy()

Number of specialists missing from processed_df: 0
Missing specialists: set()
Final processed_df shape: (45959, 13)
Number of specialists in final dataset: 18


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  processed_df['reason_combind'] = processed_df['processed_symptoms']


In [13]:
df.specialist_name.unique(), len(df.specialist_name.unique()), df.shape

(array(['thần kinh', 'vô sinh - hiếm muộn', 'nhi khoa', 'thận - tiết niệu',
        'ung bướu', 'hô hấp - phổi', 'chuyên khoa mắt', 'cơ xương khớp',
        'nha khoa', 'tim mạch', 'tiêu hoá', 'sức khỏe tâm thần',
        'nội khoa', 'tiểu đường - nội tiết', 'tai mũi họng', 'nam học',
        'da liễu', 'sản phụ khoa'], dtype=object),
 18,
 (53624, 13))

In [14]:
len(df.specialist_name.unique())

18

In [15]:
df.isnull().sum()

partner_id                0
specialist_id             0
status                    0
gender                    0
province_id               0
age                       0
reason_combind            2
specialist_name           0
correct_prediction    43185
ids                   46013
logs                  46013
processed_symptoms     7665
age_category              0
dtype: int64

In [15]:
# Extract data for model training
# Take 1000 examples per specialist_name
specialist_samples = {}

# Get list of unique specialists
specialists = df['specialist_name'].unique()

# Sample 1000 examples per specialist (or all if fewer than 1000)
balanced_df = pd.DataFrame()
for specialist in specialists:
    specialist_df = df[df['specialist_name'] == specialist]
    if len(specialist_df) > 1000:
        sampled_df = specialist_df.sample(1000, random_state=42)
    else:
        sampled_df = specialist_df
    balanced_df = pd.concat([balanced_df, sampled_df])

# Reset index of the final balanced dataset
balanced_df = balanced_df.reset_index(drop=True)
print(f"Total samples in balanced dataset: {len(balanced_df)}")
print(balanced_df['specialist_name'].value_counts())

Total samples in balanced dataset: 15977
specialist_name
thần kinh                1000
nhi khoa                 1000
da liễu                  1000
thận - tiết niệu         1000
ung bướu                 1000
chuyên khoa mắt          1000
tim mạch                 1000
cơ xương khớp            1000
sức khỏe tâm thần        1000
tiêu hoá                 1000
sản phụ khoa             1000
nam học                  1000
nội khoa                 1000
tai mũi họng             1000
tiểu đường - nội tiết     761
hô hấp - phổi             650
nha khoa                  313
vô sinh - hiếm muộn       253
Name: count, dtype: int64


In [48]:
balanced_df

Unnamed: 0,partner_id,specialist_id,status,gender,province_id,age,reason_combind,specialist_name,correct_prediction,ids,logs,processed_symptoms,age_category
0,1,18,2,female,30,43,mắt ngứa đã nhiều năm,thần kinh,,,,mắt ngứa đã nhiều năm,adult
1,49,18,2,female,1,12,ngủ chập chờn rung giật cả,thần kinh,,,,ngủ chập chờn rung giật cả,child
2,8,18,2,female,1,26,đau đầu 2 ngày chưa khỏi,thần kinh,,,,đau đầu 2 ngày chưa khỏi,adult
3,1,18,2,female,34,0,"đau đầu không đáp, cảm giác các dây thần kinh ...",thần kinh,,,,"đau đầu không đáp, cảm giác các dây thần kinh ...",unknown
4,41,18,2,female,1,4,muốn khám động kinh do bác sĩ tư vấn!,thần kinh,,,,muốn khám động kinh do bác sĩ tư vấn!,child
...,...,...,...,...,...,...,...,...,...,...,...,...,...
15710,109,11,2,female,6,29,"bị ngứa, sưng nổi mảng ở chân trái",da liễu,,,,,adult
15711,48,11,2,female,1,31,"da đầu nhiều gầu, nghi nhiễm nấm",da liễu,,,,,adult
15712,48,11,2,male,1,26,ngứa không rõ lý do,da liễu,,,,,adult
15713,5,11,2,male,1,6,mụn nhọt sát ngón tay cái,da liễu,,,,,child


In [16]:
class MedicalSpecialistClassifer(nn.Module):
    def __init__(self, num_specialists: int, model_name: str = "BookingCare/gte-multilingual-base-v2.1", user_feature_dim=2, dropout=0.2, load_pretrained=True, trust_remote_code=False):
        super(MedicalSpecialistClassifer, self).__init__()

        config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
        if load_pretrained:
            self.reason_encoder = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
        else:
            self.reason_encoder = AutoModel.from_config(config, trust_remote_code=True)
        
        for param in self.reason_encoder.parameters():
            param.requires_grad = False
        self.reason_encoder_hidden_dim = self.reason_encoder.config.hidden_size
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
        
        self.user_encoder = nn.Sequential(
            nn.Linear(user_feature_dim, 128),
            nn.BatchNorm1d(128),
            self.relu,
            self.dropout,
            nn.Linear(128, 256),
            self.relu
        )

        self.hidden_layer1 = nn.Sequential(
            nn.Linear(self.reason_encoder_hidden_dim, 512),
            nn.BatchNorm1d(512),
            self.relu,
            self.dropout,
        )

        self.level1_output = nn.Linear(512, num_specialists)

        self.hidden_layer2 = nn.Sequential(
            nn.Linear(self.reason_encoder_hidden_dim + 256 + 512, 768),
            nn.BatchNorm1d(768),
            self.relu,
            self.dropout
        )
        self.level2_output = nn.Linear(768, num_specialists)

    def forward(self, reason_text_ids, reason_text_mask, user_info=None):
        reason_outputs = self.reason_encoder(reason_text_ids, attention_mask=reason_text_mask)
        reason_embedding = reason_outputs.last_hidden_state[:, 0, :]

        hidden1 = self.hidden_layer1(reason_embedding)

        level1_logits = self.level1_output(hidden1)

        if user_info is None:
            return level1_logits, None

        user_features = self.user_encoder(user_info)

        combined_features = torch.cat((reason_embedding, hidden1, user_features), dim=1)

        hidden2 = self.hidden_layer2(combined_features)
        level2_logits = self.level2_output(hidden2)

        return level1_logits, level2_logits

    def predict(self, reason_text_ids, reason_text_mask, user_info=None, threshold=0.7):
        self.eval()
        with torch.no_grad():
            level1_logits, level2_logits = self.forward(reason_text_ids, reason_text_mask, user_info)
            level1_probs = F.softmax(level1_logits, dim=1)
            max_probs, preds = torch.max(level1_probs, dim=1)
            final_preds = preds.clone()

            if user_info is not None and level2_logits is not None:
                for i, prob in enumerate(max_probs):
                    final_preds[i] = torch.argmax(level2_logits[i])

            return final_preds

In [17]:
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder

class MedicalDataFrameDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = 128
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        reason_text = row['reason_combind']
        user_info = torch.tensor([row['gender'], row['age_category']], dtype=torch.float32)
        label = row['specialist_name']
        encoded = self.tokenizer(
            reason_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            "reason_text_ids": encoded['input_ids'].squeeze(0),
            "reason_text_mask": encoded['attention_mask'].squeeze(0),
            "user_info": user_info,
            "labels": torch.tensor(label, dtype=torch.long)
        }

In [18]:
from tqdm.notebook import tqdm

def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=5, patience=2, save_path="best_model.pt"):
    model.to(device)
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0
        correct_level1 = 0
        correct_combined = 0
        total = 0

        for batch in tqdm(train_loader, desc="Training..."):
            reason_text_ids = batch['reason_text_ids'].to(device)
            reason_text_mask = batch['reason_text_mask'].to(device)
            user_info = batch.get('user_info', None)

            if user_info is not None:
                user_info = user_info.to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            level1_logits, level2_logits = model(
                reason_text_ids,
                reason_text_mask,
                user_info
            )

            _, level1_preds = torch.max(level1_logits, dim=1)
            correct_mask = (level1_preds == labels)
            loss = criterion(level1_logits, labels)

            if level2_logits is not None and (~correct_mask).any():
                incorrect_indices = (~correct_mask).nonzero(as_tuple=True)[0]
                level2_loss = criterion(level2_logits[incorrect_indices], labels[incorrect_indices])
                loss += level2_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            correct_level1 += correct_mask.sum().item()

            if level2_logits is not None:
                _, level2_preds = torch.max(level2_logits, dim=1)
                final_preds = torch.where(correct_mask, level1_preds, level2_preds)
            else:
                final_preds = level1_preds

            correct_combined += (final_preds == labels).sum().item()
            total += labels.size(0)

        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"Train Loss: {running_loss / len(train_loader): .4f}")
        print(f"Level 1 Accuracy: {100 * correct_level1 / total:.2f}")
        print(f"Final Accuracy (with Level 2 fallback): {100 * correct_combined / total: .2f}%\n")

        # validation
        model.eval()
        val_loss = 0.0
        correct_level1 = 0
        correct_combined = 0
        total = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validate...'):
                reason_text_ids = batch['reason_text_ids'].to(device)
                reason_text_mask = batch['reason_text_mask'].to(device)
                user_info = batch.get('user_info', None)
                if user_info is not None:
                    user_info = user_info.to(device)

                    for p in model.hidden_layer2.parameters():
                        p.requires_grad = True
                    for p in model.level2_output.parameters():
                        p.requires_grad = True
                else:
                    for p in model.hidden_layer2.parameters():
                        p.requires_grad = False
                    for p in model.level2_output.parameters():
                        p.requires_grad = False
                labels = batch['labels'].to(device)

                level1_logits, level2_logits = model(
                    reason_text_ids,
                    reason_text_mask,
                    user_info
                )

                _, level1_preds = torch.max(level1_logits, dim=1)
                correct_mask = (level1_preds == labels)
                loss = criterion(level1_logits, labels)
                if level2_logits is not None and (~correct_mask).any():
                    incorrect_indices = (~correct_mask).nonzero(as_tuple=True)[0]
                    level2_loss = criterion(level2_logits[incorrect_indices], labels[incorrect_indices])
                    loss += level2_loss

                val_loss += loss.item()
                correct_level1 += correct_mask.sum().item()

                if level2_logits is not None:
                    _, level2_preds = torch.max(level2_logits, dim=1)
                    final_preds = torch.where(correct_mask, level1_preds, level2_preds)
                else:
                    final_preds = level1_preds

                correct_combined += (final_preds == labels).sum().item()
                total += labels.size(0)
        
        avg_val_loss = val_loss / len(val_loader)
        print(f"Validation loss: {avg_val_loss:.4f}")
        print(f"Validation level 1 Accuracy: {100 * correct_level1 / total:.2f}%")
        print(f"Validation Final Accuracy (with Level 2 fallback): {100 * correct_combined / total:.2f}%\n")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
            print("Model saved.")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break                

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [20]:
device

device(type='cuda')

In [21]:
tokenizer = AutoTokenizer.from_pretrained('../models/gte/')

In [22]:
specialist_encoder = LabelEncoder()
age_encoder = LabelEncoder()
gender_encoder = LabelEncoder()
df['age_category'] = age_encoder.fit_transform(df['age_category'])
df['gender'] = gender_encoder.fit_transform(df['gender'])
df['specialist_name'] = specialist_encoder.fit_transform(df['specialist_name'])

In [23]:
dataset = MedicalDataFrameDataset(df, tokenizer)

In [24]:
dataset.__getitem__(0)

{'reason_text_ids': tensor([    0, 57695,  4868,    18, 27421,     2,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,    

In [25]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [26]:
model = MedicalSpecialistClassifer(model_name = "../models/gte/",num_specialists=len(specialist_encoder.classes_), user_feature_dim=2, load_pretrained=True, trust_remote_code=True)

In [27]:
model

MedicalSpecialistClassifer(
  (reason_encoder): NewModel(
    (embeddings): NewEmbeddings(
      (word_embeddings): Embedding(250048, 768, padding_idx=1)
      (rotary_emb): NTKScalingRotaryEmbedding()
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): NewEncoder(
      (layer): ModuleList(
        (0-11): 12 x NewLayer(
          (attention): NewSdpaAttention(
            (qkv_proj): Linear(in_features=768, out_features=2304, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp): NewGatedMLP(
            (up_gate_proj): Linear(in_features=768, out_features=6144, bias=False)
            (down_proj): Linear(in_features=3072, out_features=768, bias=True)
            (act_fn): GELUActivation()
            (hidden_dropout): Dropout(p=0.1, inp

In [28]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

In [29]:
weight_class = torch.tensor(weight_class, dtype=torch.float).to(device)

In [30]:
criterion = nn.CrossEntropyLoss(weight=weight_class)

In [31]:
train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=30, patience=3)

  0%|          | 0/30 [00:00<?, ?it/s]

Training...:   0%|          | 0/3352 [00:00<?, ?it/s]

ValueError: text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

In [69]:
new_model = MedicalSpecialistClassifer(num_specialists=len(specialist_encoder.classes_), user_feature_dim=2, load_pretrained=False, trust_remote_code=True)

In [26]:
import pickle

# Save encoders to pickle files
with open("specialist_encoder.pkl", "wb") as f:
    pickle.dump(specialist_encoder, f)
    
with open("age_encoder.pkl", "wb") as f:
    pickle.dump(age_encoder, f)
    
with open("gender_encoder.pkl", "wb") as f:
    pickle.dump(gender_encoder, f)


In [27]:
len(specialist_encoder.classes_)

18

In [70]:
new_model

MedicalSpecialistClassifer(
  (reason_encoder): NewModel(
    (embeddings): NewEmbeddings(
      (word_embeddings): Embedding(250048, 768, padding_idx=1)
      (rotary_emb): NTKScalingRotaryEmbedding()
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): NewEncoder(
      (layer): ModuleList(
        (0-11): 12 x NewLayer(
          (attention): NewSdpaAttention(
            (qkv_proj): Linear(in_features=768, out_features=2304, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp): NewGatedMLP(
            (up_gate_proj): Linear(in_features=768, out_features=6144, bias=False)
            (down_proj): Linear(in_features=3072, out_features=768, bias=True)
            (act_fn): GELUActivation()
            (hidden_dropout): Dropout(p=0.1, inp

In [71]:
new_model.load_state_dict(torch.load("../notebooks/best_model.pt"))

<All keys matched successfully>

In [72]:
new_model.eval()
reason_text = "lười ăn, không tăng cân, constipation"
data = tokenizer(
    reason_text,
    padding='max_length',
    truncation=True,
    max_length=128,
    return_tensors='pt'
)
age_category = "child"
gender = "male"

# Fix the tensor shape for user_info
user_info = torch.tensor([[gender_encoder.transform([gender])[0], age_encoder.transform([age_category])[0]]], dtype=torch.float32)

In [73]:
data, user_info

({'input_ids': tensor([[     0,     96, 150365,   6687,      4,    687,  11122,  24376,      4,
             158,      7,  30019,   2320,      2,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
  

In [74]:
new_model.predict(data['input_ids'], data['attention_mask'], user_info=user_info)

tensor([5])

In [78]:
res = specialist_encoder.inverse_transform([new_model.predict(data['input_ids'], data['attention_mask'], user_info)])

  y = column_or_1d(y, warn=True)


In [80]:
res.tolist()

['nhi khoa']