## Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.insert(0, '..')

In [4]:
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig

In [5]:
df = pd.read_csv('../data/data_version 3 - Sheet1(4).csv')
df.head()

Unnamed: 0,partner_id,specialist_id,status,gender,province_id,age,reason_combind,specialist_name
0,50.0,29.0,2.0,1,25.0,30,bệnh thiên đầu thống tái phát,chuyên khoa mắt
1,50.0,29.0,2.0,0,1.0,11,nghi ngờ glocom góc mở,chuyên khoa mắt
2,50.0,29.0,2.0,1,1.0,34,"khám mắt, viêm kết giác mạc",chuyên khoa mắt
3,50.0,29.0,2.0,1,1.0,29,khám mắt,chuyên khoa mắt
4,50.0,29.0,2.0,1,1.0,45,mắt bị nhòe,chuyên khoa mắt


In [6]:
df['gender'] = df['gender'].fillna('unknown')
df['gender'] = df['gender'].replace({1.0: 'female', 0.0: 'male'})

In [7]:
df['age'] = df['age'].fillna(0)
df['age_category'] = df['age'].apply(lambda x: 
                                                     'unknown' if x == 0 else 
                                                     'child' if 0 < x <= 15 else 
                                                     'adult')
# Display the results
print(df['age_category'].value_counts())

age_category
adult    46913
child     6479
Name: count, dtype: int64


In [9]:
df.specialist_name.unique(), len(df.specialist_name.unique())

(array(['chuyên khoa mắt', 'cơ xương khớp', 'da liễu', 'hô hấp - phổi',
        'nam học', 'tai mũi họng', 'thận - tiết niệu', 'ung bướu',
        'nha khoa', 'nhi khoa', 'nội khoa', 'tiểu đường - nội tiết',
        'tim mạch', 'tiêu hóa', 'sản phụ khoa', 'răng hàm mặt',
        'thần kinh', 'sức khỏe tâm thần', 'vô sinh - hiếm muộn'],
       dtype=object),
 19)

In [10]:
df.head()

Unnamed: 0,partner_id,specialist_id,status,gender,province_id,age,reason_combind,specialist_name,age_category
0,50.0,29.0,2.0,female,25.0,30,bệnh thiên đầu thống tái phát,chuyên khoa mắt,adult
1,50.0,29.0,2.0,male,1.0,11,nghi ngờ glocom góc mở,chuyên khoa mắt,child
2,50.0,29.0,2.0,female,1.0,34,"khám mắt, viêm kết giác mạc",chuyên khoa mắt,adult
3,50.0,29.0,2.0,female,1.0,29,khám mắt,chuyên khoa mắt,adult
4,50.0,29.0,2.0,female,1.0,45,mắt bị nhòe,chuyên khoa mắt,adult


In [11]:
class_counts = df['specialist_id'].value_counts()
print(f'number of samples per class: {class_counts}')

number of samples per class: specialist_id
1.0     8972
18.0    7332
11.0    5441
22.0    5087
17.0    4552
4.0     4106
27.0    3320
3.0     3167
26.0    2301
19.0    1598
29.0    1589
5.0     1466
15.0    1362
32.0    1096
21.0     761
43.0     650
33.0     313
67.0     253
Name: count, dtype: int64


In [12]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight

y = df['specialist_name']
weight_class = compute_class_weight(class_weight="balanced", classes=np.unique(y), y=y)

In [13]:
weight_class, len(weight_class)

(array([1.76736180e+00, 3.12511706e-01, 5.16753450e-01, 4.23846948e+00,
        1.24011706e+00, 8.94937982e+00, 1.91554551e+00, 6.28377742e-01,
        5.62021053e+02, 1.76514150e+00, 8.44636388e-01, 6.83890305e-01,
        3.82535429e-01, 2.47586367e+00, 8.87028177e-01, 5.49492621e-01,
        3.63532376e+00, 2.05417051e+00, 1.11071354e+01]),
 19)

In [14]:
type(weight_class[0])

numpy.float64

In [15]:
df.specialist_name.unique(), len(df.specialist_name.unique()), df.shape

(array(['chuyên khoa mắt', 'cơ xương khớp', 'da liễu', 'hô hấp - phổi',
        'nam học', 'tai mũi họng', 'thận - tiết niệu', 'ung bướu',
        'nha khoa', 'nhi khoa', 'nội khoa', 'tiểu đường - nội tiết',
        'tim mạch', 'tiêu hóa', 'sản phụ khoa', 'răng hàm mặt',
        'thần kinh', 'sức khỏe tâm thần', 'vô sinh - hiếm muộn'],
       dtype=object),
 19,
 (53392, 9))

In [16]:
len(df.specialist_name.unique())

19

In [17]:
df.isnull().sum()

partner_id         26
specialist_id      26
status             26
gender              0
province_id         1
age                 0
reason_combind      1
specialist_name     0
age_category        0
dtype: int64

In [18]:
class MedicalSpecialistClassifer(nn.Module):
    def __init__(self, num_specialists: int, model_name: str = "BookingCare/gte-multilingual-base-v2.1", user_feature_dim=2, dropout=0.2, load_pretrained=True, trust_remote_code=False):
        super(MedicalSpecialistClassifer, self).__init__()

        config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
        if load_pretrained:
            self.reason_encoder = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
        else:
            self.reason_encoder = AutoModel.from_config(config, trust_remote_code=True)
        
        for param in self.reason_encoder.parameters():
            param.requires_grad = False
        self.reason_encoder_hidden_dim = self.reason_encoder.config.hidden_size
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
        
        self.user_encoder = nn.Sequential(
            nn.Linear(user_feature_dim, 128),
            nn.BatchNorm1d(128),
            self.relu,
            self.dropout,
            nn.Linear(128, 256),
            self.relu
        )

        self.hidden_layer1 = nn.Sequential(
            nn.Linear(self.reason_encoder_hidden_dim, 512),
            nn.BatchNorm1d(512),
            self.relu,
            self.dropout,
        )

        self.level1_output = nn.Linear(512, num_specialists)

        self.hidden_layer2 = nn.Sequential(
            nn.Linear(self.reason_encoder_hidden_dim + 256 + 512, 768),
            nn.BatchNorm1d(768),
            self.relu,
            self.dropout
        )
        self.level2_output = nn.Linear(768, num_specialists)

    def forward(self, reason_text_ids, reason_text_mask, user_info=None):
        reason_outputs = self.reason_encoder(reason_text_ids, attention_mask=reason_text_mask)
        reason_embedding = reason_outputs.last_hidden_state[:, 0, :]

        hidden1 = self.hidden_layer1(reason_embedding)

        level1_logits = self.level1_output(hidden1)

        if user_info is None:
            return level1_logits, None

        user_features = self.user_encoder(user_info)

        combined_features = torch.cat((reason_embedding, hidden1, user_features), dim=1)

        hidden2 = self.hidden_layer2(combined_features)
        level2_logits = self.level2_output(hidden2)

        return level1_logits, level2_logits

    def predict(self, reason_text_ids, reason_text_mask, user_info=None, threshold=0.7):
        self.eval()
        with torch.no_grad():
            level1_logits, level2_logits = self.forward(reason_text_ids, reason_text_mask, user_info)
            level1_probs = F.softmax(level1_logits, dim=1)
            max_probs, preds = torch.max(level1_probs, dim=1)
            final_preds = preds.clone()

            if user_info is not None and level2_logits is not None:
                for i, prob in enumerate(max_probs):
                    final_preds[i] = torch.argmax(level2_logits[i])

            return final_preds

In [19]:
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder

class MedicalDataFrameDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = 128
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        reason_text = str(row['reason_combind'])  # Ensure reason_text is a string
        # user_info = torch.tensor([row['gender'], row['age_category']], dtype=torch.float32)
        user_info = None
        label = row['specialist_name']
        encoded = self.tokenizer(
            [reason_text],  # Wrap reason_text in a list to ensure correct input type
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            "reason_text_ids": encoded['input_ids'].squeeze(0),
            "reason_text_mask": encoded['attention_mask'].squeeze(0),
            "user_info": user_info,
            "labels": torch.tensor(label, dtype=torch.long)
        }


In [20]:
from tqdm.notebook import tqdm

def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=5, patience=2, save_path="best_model.pt"):
    model.to(device)
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0
        correct_level1 = 0
        correct_combined = 0
        total = 0

        for batch in tqdm(train_loader, desc="Training..."):
            reason_text_ids = batch['reason_text_ids'].to(device)
            reason_text_mask = batch['reason_text_mask'].to(device)
            user_info = batch.get('user_info', None)

            if user_info is not None:
                user_info = user_info.to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            level1_logits, level2_logits = model(
                reason_text_ids,
                reason_text_mask,
                user_info
            )

            _, level1_preds = torch.max(level1_logits, dim=1)
            correct_mask = (level1_preds == labels)
            loss = criterion(level1_logits, labels)

            if level2_logits is not None and (~correct_mask).any():
                incorrect_indices = (~correct_mask).nonzero(as_tuple=True)[0]
                level2_loss = criterion(level2_logits[incorrect_indices], labels[incorrect_indices])
                loss += level2_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            correct_level1 += correct_mask.sum().item()

            if level2_logits is not None:
                _, level2_preds = torch.max(level2_logits, dim=1)
                final_preds = torch.where(correct_mask, level1_preds, level2_preds)
            else:
                final_preds = level1_preds

            correct_combined += (final_preds == labels).sum().item()
            total += labels.size(0)

        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"Train Loss: {running_loss / len(train_loader): .4f}")
        print(f"Level 1 Accuracy: {100 * correct_level1 / total:.2f}")
        print(f"Final Accuracy (with Level 2 fallback): {100 * correct_combined / total: .2f}%\n")

        # validation
        model.eval()
        val_loss = 0.0
        correct_level1 = 0
        correct_combined = 0
        total = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validate...'):
                reason_text_ids = batch['reason_text_ids'].to(device)
                reason_text_mask = batch['reason_text_mask'].to(device)
                user_info = batch.get('user_info', None)
                if user_info is not None:
                    user_info = user_info.to(device)
                    
                    for p in model.hidden_layer2.parameters():
                        p.requires_grad = True
                    for p in model.level2_output.parameters():
                        p.requires_grad = True
                    ## experiment
                    for p in model.hidden_layer1.parameters():
                        p.requires_grad = False
                    for p in model.level1_output.parameters():
                        p.requires_grad = False
                    
                else:
                    for p in model.hidden_layer2.parameters():
                        p.requires_grad = False
                    for p in model.level2_output.parameters():
                        p.requires_grad = False
                labels = batch['labels'].to(device)

                level1_logits, level2_logits = model(
                    reason_text_ids,
                    reason_text_mask,
                    user_info
                )

                _, level1_preds = torch.max(level1_logits, dim=1)
                correct_mask = (level1_preds == labels)
                loss = criterion(level1_logits, labels)
                if level2_logits is not None and (~correct_mask).any():
                    incorrect_indices = (~correct_mask).nonzero(as_tuple=True)[0]
                    level2_loss = criterion(level2_logits[incorrect_indices], labels[incorrect_indices])
                    loss += level2_loss

                val_loss += loss.item()
                correct_level1 += correct_mask.sum().item()

                if level2_logits is not None:
                    _, level2_preds = torch.max(level2_logits, dim=1)
                    final_preds = torch.where(correct_mask, level1_preds, level2_preds)
                else:
                    final_preds = level1_preds

                correct_combined += (final_preds == labels).sum().item()
                total += labels.size(0)
        
        avg_val_loss = val_loss / len(val_loader)
        print(f"Validation loss: {avg_val_loss:.4f}")
        print(f"Validation level 1 Accuracy: {100 * correct_level1 / total:.2f}%")
        print(f"Validation Final Accuracy (with Level 2 fallback): {100 * correct_combined / total:.2f}%\n")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
            print("Model saved.")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break                

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [22]:
device

device(type='cuda')

In [23]:
tokenizer = AutoTokenizer.from_pretrained('../models/BookingCare/gte-multilingual-base-v2.1/')

In [24]:
specialist_encoder = LabelEncoder()
age_encoder = LabelEncoder()
gender_encoder = LabelEncoder()
df['age_category'] = age_encoder.fit_transform(df['age_category'])
df['gender'] = gender_encoder.fit_transform(df['gender'])
df['specialist_name'] = specialist_encoder.fit_transform(df['specialist_name'])

In [25]:
df

Unnamed: 0,partner_id,specialist_id,status,gender,province_id,age,reason_combind,specialist_name,age_category
0,50.0,29.0,2.0,0,25.0,30,bệnh thiên đầu thống tái phát,0,0
1,50.0,29.0,2.0,1,1.0,11,nghi ngờ glocom góc mở,0,1
2,50.0,29.0,2.0,0,1.0,34,"khám mắt, viêm kết giác mạc",0,0
3,50.0,29.0,2.0,0,1.0,29,khám mắt,0,0
4,50.0,29.0,2.0,0,1.0,45,mắt bị nhòe,0,0
...,...,...,...,...,...,...,...,...,...
53387,23.0,67.0,2.0,0,40.0,22,"buồng trứng đa nang, khám để làm ivf",18,0
53388,23.0,67.0,2.0,1,2.0,26,khám khả năng sinh sản để chuyển bị sinh con,18,0
53389,23.0,67.0,2.0,0,1.0,21,kiểm tra và tư vấn mong con,18,0
53390,23.0,67.0,2.0,0,24.0,21,khám sinh con chọn giới tính,18,0


In [26]:
dataset = MedicalDataFrameDataset(df, tokenizer)

In [27]:
dataset.__getitem__(0)

{'reason_text_ids': tensor([    0,  7417, 39429,  2494, 10657, 75626,  5152,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,    

In [28]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [29]:
model = MedicalSpecialistClassifer(model_name = "../models/BookingCare/gte-multilingual-base-v2.1/",num_specialists=len(specialist_encoder.classes_), user_feature_dim=2, load_pretrained=True, trust_remote_code=True)

In [30]:
model

MedicalSpecialistClassifer(
  (reason_encoder): NewModel(
    (embeddings): NewEmbeddings(
      (word_embeddings): Embedding(250048, 768, padding_idx=1)
      (rotary_emb): NTKScalingRotaryEmbedding()
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): NewEncoder(
      (layer): ModuleList(
        (0-11): 12 x NewLayer(
          (attention): NewSdpaAttention(
            (qkv_proj): Linear(in_features=768, out_features=2304, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp): NewGatedMLP(
            (up_gate_proj): Linear(in_features=768, out_features=6144, bias=False)
            (down_proj): Linear(in_features=3072, out_features=768, bias=True)
            (act_fn): GELUActivation()
            (hidden_dropout): Dropout(p=0.1, inp

In [31]:
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-5)

In [32]:
weight_class = torch.tensor(weight_class, dtype=torch.float).to(device)

In [33]:
criterion = nn.CrossEntropyLoss(weight=weight_class)

In [None]:
train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=100, patience=3)

  0%|          | 0/100 [00:00<?, ?it/s]

Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 1/100:
Train Loss:  3.4953
Level 1 Accuracy: 61.17
Final Accuracy (with Level 2 fallback):  74.95%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 2.7433
Validation level 1 Accuracy: 69.51%
Validation Final Accuracy (with Level 2 fallback): 82.01%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 2/100:
Train Loss:  2.9307
Level 1 Accuracy: 68.08
Final Accuracy (with Level 2 fallback):  80.20%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 2.5384
Validation level 1 Accuracy: 71.04%
Validation Final Accuracy (with Level 2 fallback): 83.95%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 3/100:
Train Loss:  2.7746
Level 1 Accuracy: 69.00
Final Accuracy (with Level 2 fallback):  81.19%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 2.4208
Validation level 1 Accuracy: 71.71%
Validation Final Accuracy (with Level 2 fallback): 84.64%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 4/100:
Train Loss:  2.6780
Level 1 Accuracy: 69.77
Final Accuracy (with Level 2 fallback):  82.13%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 2.3227
Validation level 1 Accuracy: 72.77%
Validation Final Accuracy (with Level 2 fallback): 85.98%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 5/100:
Train Loss:  2.6183
Level 1 Accuracy: 70.37
Final Accuracy (with Level 2 fallback):  82.79%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 2.2470
Validation level 1 Accuracy: 72.46%
Validation Final Accuracy (with Level 2 fallback): 86.15%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 6/100:
Train Loss:  2.5422
Level 1 Accuracy: 70.58
Final Accuracy (with Level 2 fallback):  83.15%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 2.1656
Validation level 1 Accuracy: 73.52%
Validation Final Accuracy (with Level 2 fallback): 87.17%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 7/100:
Train Loss:  2.4886
Level 1 Accuracy: 71.03
Final Accuracy (with Level 2 fallback):  83.80%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 2.0958
Validation level 1 Accuracy: 73.71%
Validation Final Accuracy (with Level 2 fallback): 87.89%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 8/100:
Train Loss:  2.4265
Level 1 Accuracy: 71.36
Final Accuracy (with Level 2 fallback):  84.25%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 2.0493
Validation level 1 Accuracy: 74.09%
Validation Final Accuracy (with Level 2 fallback): 88.70%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 9/100:
Train Loss:  2.3987
Level 1 Accuracy: 71.35
Final Accuracy (with Level 2 fallback):  84.31%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.9888
Validation level 1 Accuracy: 74.20%
Validation Final Accuracy (with Level 2 fallback): 88.61%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 10/100:
Train Loss:  2.3664
Level 1 Accuracy: 71.75
Final Accuracy (with Level 2 fallback):  84.68%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.9577
Validation level 1 Accuracy: 74.25%
Validation Final Accuracy (with Level 2 fallback): 88.72%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 11/100:
Train Loss:  2.3232
Level 1 Accuracy: 71.81
Final Accuracy (with Level 2 fallback):  84.96%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.9179
Validation level 1 Accuracy: 74.58%
Validation Final Accuracy (with Level 2 fallback): 88.86%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 12/100:
Train Loss:  2.2954
Level 1 Accuracy: 72.18
Final Accuracy (with Level 2 fallback):  85.35%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.8949
Validation level 1 Accuracy: 74.61%
Validation Final Accuracy (with Level 2 fallback): 89.00%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 13/100:
Train Loss:  2.2469
Level 1 Accuracy: 72.03
Final Accuracy (with Level 2 fallback):  85.51%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.8173
Validation level 1 Accuracy: 75.21%
Validation Final Accuracy (with Level 2 fallback): 89.99%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 14/100:
Train Loss:  2.2354
Level 1 Accuracy: 72.29
Final Accuracy (with Level 2 fallback):  85.68%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.8037
Validation level 1 Accuracy: 74.89%
Validation Final Accuracy (with Level 2 fallback): 89.83%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 15/100:
Train Loss:  2.2017
Level 1 Accuracy: 72.45
Final Accuracy (with Level 2 fallback):  86.02%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.7724
Validation level 1 Accuracy: 75.42%
Validation Final Accuracy (with Level 2 fallback): 90.36%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 16/100:
Train Loss:  2.1779
Level 1 Accuracy: 72.60
Final Accuracy (with Level 2 fallback):  86.19%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.7350
Validation level 1 Accuracy: 75.34%
Validation Final Accuracy (with Level 2 fallback): 90.66%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 17/100:
Train Loss:  2.1666
Level 1 Accuracy: 72.57
Final Accuracy (with Level 2 fallback):  86.42%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.7053
Validation level 1 Accuracy: 75.43%
Validation Final Accuracy (with Level 2 fallback): 90.89%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 18/100:
Train Loss:  2.1442
Level 1 Accuracy: 72.77
Final Accuracy (with Level 2 fallback):  86.47%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.6853
Validation level 1 Accuracy: 75.72%
Validation Final Accuracy (with Level 2 fallback): 90.86%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 19/100:
Train Loss:  2.1219
Level 1 Accuracy: 72.90
Final Accuracy (with Level 2 fallback):  86.82%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.6556
Validation level 1 Accuracy: 75.67%
Validation Final Accuracy (with Level 2 fallback): 90.93%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 20/100:
Train Loss:  2.1003
Level 1 Accuracy: 73.13
Final Accuracy (with Level 2 fallback):  87.07%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.6247
Validation level 1 Accuracy: 75.71%
Validation Final Accuracy (with Level 2 fallback): 91.40%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 21/100:
Train Loss:  2.0714
Level 1 Accuracy: 73.01
Final Accuracy (with Level 2 fallback):  87.18%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.6125
Validation level 1 Accuracy: 75.73%
Validation Final Accuracy (with Level 2 fallback): 91.45%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 22/100:
Train Loss:  2.0466
Level 1 Accuracy: 73.28
Final Accuracy (with Level 2 fallback):  87.34%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.5608
Validation level 1 Accuracy: 76.22%
Validation Final Accuracy (with Level 2 fallback): 91.83%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 23/100:
Train Loss:  2.0382
Level 1 Accuracy: 73.17
Final Accuracy (with Level 2 fallback):  87.38%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.5711
Validation level 1 Accuracy: 76.25%
Validation Final Accuracy (with Level 2 fallback): 92.20%



Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 24/100:
Train Loss:  2.0025
Level 1 Accuracy: 73.49
Final Accuracy (with Level 2 fallback):  87.65%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.5263
Validation level 1 Accuracy: 76.46%
Validation Final Accuracy (with Level 2 fallback): 91.98%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

Epoch 25/100:
Train Loss:  2.0159
Level 1 Accuracy: 73.47
Final Accuracy (with Level 2 fallback):  87.66%



Validate...:   0%|          | 0/334 [00:00<?, ?it/s]

Validation loss: 1.5142
Validation level 1 Accuracy: 76.66%
Validation Final Accuracy (with Level 2 fallback): 92.03%

Model saved.


Training...:   0%|          | 0/1669 [00:00<?, ?it/s]

In [35]:
new_model = MedicalSpecialistClassifer(num_specialists=len(specialist_encoder.classes_), user_feature_dim=2, load_pretrained=False, trust_remote_code=True)

In [36]:
import pickle

# Save encoders to pickle files
with open("specialist_encoder.pkl", "wb") as f:
    pickle.dump(specialist_encoder, f)
    
with open("age_encoder.pkl", "wb") as f:
    pickle.dump(age_encoder, f)
    
with open("gender_encoder.pkl", "wb") as f:
    pickle.dump(gender_encoder, f)


In [37]:
len(specialist_encoder.classes_), specialist_encoder.classes_

(19,
 array(['chuyên khoa mắt', 'cơ xương khớp', 'da liễu', 'hô hấp - phổi',
        'nam học', 'nha khoa', 'nhi khoa', 'nội khoa', 'răng hàm mặt',
        'sản phụ khoa', 'sức khỏe tâm thần', 'tai mũi họng', 'thần kinh',
        'thận - tiết niệu', 'tim mạch', 'tiêu hóa',
        'tiểu đường - nội tiết', 'ung bướu', 'vô sinh - hiếm muộn'],
       dtype=object))

In [38]:
new_model

MedicalSpecialistClassifer(
  (reason_encoder): NewModel(
    (embeddings): NewEmbeddings(
      (word_embeddings): Embedding(250048, 768, padding_idx=1)
      (rotary_emb): NTKScalingRotaryEmbedding()
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): NewEncoder(
      (layer): ModuleList(
        (0-11): 12 x NewLayer(
          (attention): NewSdpaAttention(
            (qkv_proj): Linear(in_features=768, out_features=2304, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp): NewGatedMLP(
            (up_gate_proj): Linear(in_features=768, out_features=6144, bias=False)
            (down_proj): Linear(in_features=3072, out_features=768, bias=True)
            (act_fn): GELUActivation()
            (hidden_dropout): Dropout(p=0.1, inp

In [39]:
new_model.load_state_dict(torch.load("../notebooks/best_model.pt"))

<All keys matched successfully>

In [58]:
new_model.eval()
reason_text = "đau bụng dữ dội"
data = tokenizer(
    reason_text,
    padding='max_length',
    truncation=True,
    max_length=128,
    return_tensors='pt'
)
age_category = "child"
gender = "male"

In [59]:
# Fix the tensor shape for user_info
user_info = torch.tensor([[gender_encoder.transform([gender])[0], age_encoder.transform([age_category])[0]]], dtype=torch.float32)

In [60]:
data, user_info

({'input_ids': tensor([[    0, 24859, 92070, 67885,   104, 65166,     2,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,   

In [61]:
res = specialist_encoder.inverse_transform([new_model.predict(data['input_ids'], data['attention_mask'], user_info=None)])
res.tolist()

  y = column_or_1d(y, warn=True)


['tiêu hóa']

In [62]:
res = specialist_encoder.inverse_transform([new_model.predict(data['input_ids'], data['attention_mask'], user_info=user_info)])

  y = column_or_1d(y, warn=True)


In [63]:
res.tolist()

['nhi khoa']