## Set up

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys

sys.path.insert(0, '..')

In [4]:
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig

In [5]:
df = pd.read_csv('../data/data_version 3 - Sheet1(2).csv')
df.head()

Unnamed: 0,partner_id,specialist_id,status,gender,province_id,age,reason_combind,specialist_name
0,23.0,67.0,2.0,1,38.0,25,muốn làm ivf,vô sinh - hiếm muộn
1,23.0,67.0,2.0,1,27.0,26,mong con,vô sinh - hiếm muộn
2,23.0,67.0,2.0,0,27.0,28,hiếm muộn,vô sinh - hiếm muộn
3,23.0,67.0,2.0,1,31.0,26,"khám vô sinh, hiếm muộn",vô sinh - hiếm muộn
4,23.0,67.0,2.0,1,1.0,34,thả hơn 1 năm mà không có thai nên muốn siêu â...,vô sinh - hiếm muộn


In [6]:
df['gender'] = df['gender'].fillna('unknown')
df['gender'] = df['gender'].replace({1.0: 'female', 0.0: 'male'})

In [7]:
df['age'] = df['age'].fillna(0)
df['age_category'] = df['age'].apply(lambda x: 
                                                     'unknown' if x == 0 else 
                                                     'child' if 0 < x <= 15 else 
                                                     'adult')
# Display the results
print(df['age_category'].value_counts())

age_category
adult    46919
child     6481
Name: count, dtype: int64


In [10]:
df.specialist_name.unique()

array(['vô sinh - hiếm muộn', 'ung bướu', 'tim mạch', 'tiêu hóa',
       'tiêu hoá', 'tiểu đường - nội tiết', 'thần kinh', 'thần kin',
       'thận - tiết niệu', 'tai mũi họng', 'sức khỏe tâm thần',
       'sản phụ khoa', 'nội khoa', 'nam học', 'cơ xương khớp',
       'hô hấp - phổi', 'răng hàm mặt', 'nhi khoa', 'nha khoa', 'da liễu',
       'chuyên khoa mắt'], dtype=object)

In [11]:
class_counts = df['specialist_id'].value_counts()
print(f'number of samples per class: {class_counts}')

number of samples per class: specialist_id
1.0     8972
18.0    7332
11.0    5441
22.0    5092
17.0    4551
4.0     4106
27.0    3320
3.0     3167
26.0    2305
19.0    1598
29.0    1589
5.0     1466
15.0    1362
32.0    1096
21.0     761
43.0     650
33.0     313
67.0     253
Name: count, dtype: int64


In [13]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight

y = df['specialist_name']
weight_class = compute_class_weight(class_weight="balanced", classes=np.unique(y), y=y)

In [14]:
weight_class, len(weight_class)

(array([1.60028769e+00, 2.82759607e-01, 4.67522917e-01, 3.84117393e+00,
        1.11823093e+00, 8.12414423e+00, 1.73455467e+00, 5.60965617e-01,
        6.35714286e+02, 1.60028769e+00, 7.64999140e-01, 6.19604567e-01,
        2.54285714e+03, 3.47052974e-01, 2.25031606e+00, 8.03430377e-01,
        4.98501694e-01, 3.17857143e+02, 3.33708286e+00, 1.86563253e+00,
        1.00508187e+01]),
 21)

In [15]:
type(weight_class[0])

numpy.float64

In [16]:
df.specialist_name.unique(), len(df.specialist_name.unique()), df.shape

(array(['vô sinh - hiếm muộn', 'ung bướu', 'tim mạch', 'tiêu hóa',
        'tiêu hoá', 'tiểu đường - nội tiết', 'thần kinh', 'thần kin',
        'thận - tiết niệu', 'tai mũi họng', 'sức khỏe tâm thần',
        'sản phụ khoa', 'nội khoa', 'nam học', 'cơ xương khớp',
        'hô hấp - phổi', 'răng hàm mặt', 'nhi khoa', 'nha khoa', 'da liễu',
        'chuyên khoa mắt'], dtype=object),
 21,
 (53400, 9))

In [17]:
len(df.specialist_name.unique())

21

In [18]:
df.isnull().sum()

partner_id         26
specialist_id      26
status             26
gender              0
province_id         1
age                 0
reason_combind      1
specialist_name     0
age_category        0
dtype: int64

In [19]:
class MedicalSpecialistClassifer(nn.Module):
    def __init__(self, num_specialists: int, model_name: str = "BookingCare/gte-multilingual-base-v2.1", user_feature_dim=2, dropout=0.2, load_pretrained=True, trust_remote_code=False):
        super(MedicalSpecialistClassifer, self).__init__()

        config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
        if load_pretrained:
            self.reason_encoder = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
        else:
            self.reason_encoder = AutoModel.from_config(config, trust_remote_code=True)
        
        for param in self.reason_encoder.parameters():
            param.requires_grad = False
        self.reason_encoder_hidden_dim = self.reason_encoder.config.hidden_size
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
        
        self.user_encoder = nn.Sequential(
            nn.Linear(user_feature_dim, 128),
            nn.BatchNorm1d(128),
            self.relu,
            self.dropout,
            nn.Linear(128, 256),
            self.relu
        )

        self.hidden_layer1 = nn.Sequential(
            nn.Linear(self.reason_encoder_hidden_dim, 512),
            nn.BatchNorm1d(512),
            self.relu,
            self.dropout,
        )

        self.level1_output = nn.Linear(512, num_specialists)

        self.hidden_layer2 = nn.Sequential(
            nn.Linear(self.reason_encoder_hidden_dim + 256 + 512, 768),
            nn.BatchNorm1d(768),
            self.relu,
            self.dropout
        )
        self.level2_output = nn.Linear(768, num_specialists)

    def forward(self, reason_text_ids, reason_text_mask, user_info=None):
        reason_outputs = self.reason_encoder(reason_text_ids, attention_mask=reason_text_mask)
        reason_embedding = reason_outputs.last_hidden_state[:, 0, :]

        hidden1 = self.hidden_layer1(reason_embedding)

        level1_logits = self.level1_output(hidden1)

        if user_info is None:
            return level1_logits, None

        user_features = self.user_encoder(user_info)

        combined_features = torch.cat((reason_embedding, hidden1, user_features), dim=1)

        hidden2 = self.hidden_layer2(combined_features)
        level2_logits = self.level2_output(hidden2)

        return level1_logits, level2_logits

    def predict(self, reason_text_ids, reason_text_mask, user_info=None, threshold=0.7):
        self.eval()
        with torch.no_grad():
            level1_logits, level2_logits = self.forward(reason_text_ids, reason_text_mask, user_info)
            level1_probs = F.softmax(level1_logits, dim=1)
            max_probs, preds = torch.max(level1_probs, dim=1)
            final_preds = preds.clone()

            if user_info is not None and level2_logits is not None:
                for i, prob in enumerate(max_probs):
                    final_preds[i] = torch.argmax(level2_logits[i])

            return final_preds

In [35]:
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder

class MedicalDataFrameDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = 128
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        reason_text = str(row['reason_combind'])  # Ensure reason_text is a string
        user_info = torch.tensor([row['gender'], row['age_category']], dtype=torch.float32)
        label = row['specialist_name']
        encoded = self.tokenizer(
            [reason_text],  # Wrap reason_text in a list to ensure correct input type
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            "reason_text_ids": encoded['input_ids'].squeeze(0),
            "reason_text_mask": encoded['attention_mask'].squeeze(0),
            "user_info": user_info,
            "labels": torch.tensor(label, dtype=torch.long)
        }


In [37]:
from tqdm.notebook import tqdm

def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=5, patience=2, save_path="best_model.pt"):
    model.to(device)
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0
        correct_level1 = 0
        correct_combined = 0
        total = 0

        for batch in tqdm(train_loader, desc="Training..."):
            reason_text_ids = batch['reason_text_ids'].to(device)
            reason_text_mask = batch['reason_text_mask'].to(device)
            user_info = batch.get('user_info', None)

            if user_info is not None:
                user_info = user_info.to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            level1_logits, level2_logits = model(
                reason_text_ids,
                reason_text_mask,
                user_info
            )

            _, level1_preds = torch.max(level1_logits, dim=1)
            correct_mask = (level1_preds == labels)
            loss = criterion(level1_logits, labels)

            if level2_logits is not None and (~correct_mask).any():
                incorrect_indices = (~correct_mask).nonzero(as_tuple=True)[0]
                level2_loss = criterion(level2_logits[incorrect_indices], labels[incorrect_indices])
                loss += level2_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            correct_level1 += correct_mask.sum().item()

            if level2_logits is not None:
                _, level2_preds = torch.max(level2_logits, dim=1)
                final_preds = torch.where(correct_mask, level1_preds, level2_preds)
            else:
                final_preds = level1_preds

            correct_combined += (final_preds == labels).sum().item()
            total += labels.size(0)

        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"Train Loss: {running_loss / len(train_loader): .4f}")
        print(f"Level 1 Accuracy: {100 * correct_level1 / total:.2f}")
        print(f"Final Accuracy (with Level 2 fallback): {100 * correct_combined / total: .2f}%\n")

        # validation
        model.eval()
        val_loss = 0.0
        correct_level1 = 0
        correct_combined = 0
        total = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validate...'):
                reason_text_ids = batch['reason_text_ids'].to(device)
                reason_text_mask = batch['reason_text_mask'].to(device)
                user_info = batch.get('user_info', None)
                if user_info is not None:
                    user_info = user_info.to(device)

                    for p in model.hidden_layer2.parameters():
                        p.requires_grad = True
                    for p in model.level2_output.parameters():
                        p.requires_grad = True
                else:
                    for p in model.hidden_layer2.parameters():
                        p.requires_grad = False
                    for p in model.level2_output.parameters():
                        p.requires_grad = False
                labels = batch['labels'].to(device)

                level1_logits, level2_logits = model(
                    reason_text_ids,
                    reason_text_mask,
                    user_info
                )

                _, level1_preds = torch.max(level1_logits, dim=1)
                correct_mask = (level1_preds == labels)
                loss = criterion(level1_logits, labels)
                if level2_logits is not None and (~correct_mask).any():
                    incorrect_indices = (~correct_mask).nonzero(as_tuple=True)[0]
                    level2_loss = criterion(level2_logits[incorrect_indices], labels[incorrect_indices])
                    loss += level2_loss

                val_loss += loss.item()
                correct_level1 += correct_mask.sum().item()

                if level2_logits is not None:
                    _, level2_preds = torch.max(level2_logits, dim=1)
                    final_preds = torch.where(correct_mask, level1_preds, level2_preds)
                else:
                    final_preds = level1_preds

                correct_combined += (final_preds == labels).sum().item()
                total += labels.size(0)
        
        avg_val_loss = val_loss / len(val_loader)
        print(f"Validation loss: {avg_val_loss:.4f}")
        print(f"Validation level 1 Accuracy: {100 * correct_level1 / total:.2f}%")
        print(f"Validation Final Accuracy (with Level 2 fallback): {100 * correct_combined / total:.2f}%\n")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
            print("Model saved.")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break                

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [39]:
device

device(type='cuda')

In [40]:
tokenizer = AutoTokenizer.from_pretrained('../models/BookingCare/gte-multilingual-base-v2.1/')

In [41]:
specialist_encoder = LabelEncoder()
age_encoder = LabelEncoder()
gender_encoder = LabelEncoder()
df['age_category'] = age_encoder.fit_transform(df['age_category'])
df['gender'] = gender_encoder.fit_transform(df['gender'])
df['specialist_name'] = specialist_encoder.fit_transform(df['specialist_name'])

In [42]:
dataset = MedicalDataFrameDataset(df, tokenizer)

In [43]:
dataset.__getitem__(0)

{'reason_text_ids': tensor([   0, 6542, 1839,   17,  334,  420,    2,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1]),
 'reason_text_mask': tensor([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0,

In [44]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [45]:
model = MedicalSpecialistClassifer(model_name = "../models/BookingCare/gte-multilingual-base-v2.1/",num_specialists=len(specialist_encoder.classes_), user_feature_dim=2, load_pretrained=True, trust_remote_code=True)

In [46]:
model

MedicalSpecialistClassifer(
  (reason_encoder): NewModel(
    (embeddings): NewEmbeddings(
      (word_embeddings): Embedding(250048, 768, padding_idx=1)
      (rotary_emb): NTKScalingRotaryEmbedding()
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): NewEncoder(
      (layer): ModuleList(
        (0-11): 12 x NewLayer(
          (attention): NewSdpaAttention(
            (qkv_proj): Linear(in_features=768, out_features=2304, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp): NewGatedMLP(
            (up_gate_proj): Linear(in_features=768, out_features=6144, bias=False)
            (down_proj): Linear(in_features=3072, out_features=768, bias=True)
            (act_fn): GELUActivation()
            (hidden_dropout): Dropout(p=0.1, inp

In [47]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

In [48]:
weight_class = torch.tensor(weight_class, dtype=torch.float).to(device)

  weight_class = torch.tensor(weight_class, dtype=torch.float).to(device)


In [49]:
criterion = nn.CrossEntropyLoss(weight=weight_class)

In [50]:
train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=30, patience=3)

  0%|          | 0/30 [00:00<?, ?it/s]

Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 1/30:
Train Loss:  3.7751
Level 1 Accuracy: 59.21
Final Accuracy (with Level 2 fallback):  73.04%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.9702
Validation level 1 Accuracy: 68.07%
Validation Final Accuracy (with Level 2 fallback): 80.94%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 2/30:
Train Loss:  3.1013
Level 1 Accuracy: 66.98
Final Accuracy (with Level 2 fallback):  79.34%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.7376
Validation level 1 Accuracy: 69.95%
Validation Final Accuracy (with Level 2 fallback): 83.35%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 3/30:
Train Loss:  2.9503
Level 1 Accuracy: 67.96
Final Accuracy (with Level 2 fallback):  80.36%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.6078
Validation level 1 Accuracy: 70.30%
Validation Final Accuracy (with Level 2 fallback): 84.03%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 4/30:
Train Loss:  2.8594
Level 1 Accuracy: 68.61
Final Accuracy (with Level 2 fallback):  81.00%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.5092
Validation level 1 Accuracy: 70.66%
Validation Final Accuracy (with Level 2 fallback): 84.55%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 5/30:
Train Loss:  2.7822
Level 1 Accuracy: 69.18
Final Accuracy (with Level 2 fallback):  81.73%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.4241
Validation level 1 Accuracy: 71.24%
Validation Final Accuracy (with Level 2 fallback): 85.69%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 6/30:
Train Loss:  2.7178
Level 1 Accuracy: 69.41
Final Accuracy (with Level 2 fallback):  82.11%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.3562
Validation level 1 Accuracy: 71.39%
Validation Final Accuracy (with Level 2 fallback): 85.97%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 7/30:
Train Loss:  2.6699
Level 1 Accuracy: 69.72
Final Accuracy (with Level 2 fallback):  82.92%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.3120
Validation level 1 Accuracy: 71.60%
Validation Final Accuracy (with Level 2 fallback): 85.93%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 8/30:
Train Loss:  2.6222
Level 1 Accuracy: 70.07
Final Accuracy (with Level 2 fallback):  83.14%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.2488
Validation level 1 Accuracy: 72.25%
Validation Final Accuracy (with Level 2 fallback): 87.27%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 9/30:
Train Loss:  2.5815
Level 1 Accuracy: 70.38
Final Accuracy (with Level 2 fallback):  83.62%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.2146
Validation level 1 Accuracy: 72.49%
Validation Final Accuracy (with Level 2 fallback): 87.57%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 10/30:
Train Loss:  2.5730
Level 1 Accuracy: 70.34
Final Accuracy (with Level 2 fallback):  83.69%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.1528
Validation level 1 Accuracy: 72.84%
Validation Final Accuracy (with Level 2 fallback): 87.98%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 11/30:
Train Loss:  2.5203
Level 1 Accuracy: 70.90
Final Accuracy (with Level 2 fallback):  84.15%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.1159
Validation level 1 Accuracy: 72.94%
Validation Final Accuracy (with Level 2 fallback): 87.97%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 12/30:
Train Loss:  2.4852
Level 1 Accuracy: 70.92
Final Accuracy (with Level 2 fallback):  84.33%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.0893
Validation level 1 Accuracy: 73.36%
Validation Final Accuracy (with Level 2 fallback): 88.49%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 13/30:
Train Loss:  2.4626
Level 1 Accuracy: 71.33
Final Accuracy (with Level 2 fallback):  84.52%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.0442
Validation level 1 Accuracy: 73.56%
Validation Final Accuracy (with Level 2 fallback): 88.72%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 14/30:
Train Loss:  2.4365
Level 1 Accuracy: 71.10
Final Accuracy (with Level 2 fallback):  84.80%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 2.0176
Validation level 1 Accuracy: 73.56%
Validation Final Accuracy (with Level 2 fallback): 89.00%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 15/30:
Train Loss:  2.4085
Level 1 Accuracy: 71.46
Final Accuracy (with Level 2 fallback):  85.01%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.9858
Validation level 1 Accuracy: 74.12%
Validation Final Accuracy (with Level 2 fallback): 89.19%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 16/30:
Train Loss:  2.3903
Level 1 Accuracy: 71.49
Final Accuracy (with Level 2 fallback):  84.93%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.9598
Validation level 1 Accuracy: 73.93%
Validation Final Accuracy (with Level 2 fallback): 89.49%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 17/30:
Train Loss:  2.3609
Level 1 Accuracy: 71.70
Final Accuracy (with Level 2 fallback):  85.36%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.9391
Validation level 1 Accuracy: 74.28%
Validation Final Accuracy (with Level 2 fallback): 89.32%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 18/30:
Train Loss:  2.3423
Level 1 Accuracy: 71.62
Final Accuracy (with Level 2 fallback):  85.43%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.9043
Validation level 1 Accuracy: 74.33%
Validation Final Accuracy (with Level 2 fallback): 89.90%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 19/30:
Train Loss:  2.3279
Level 1 Accuracy: 71.88
Final Accuracy (with Level 2 fallback):  85.77%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.8882
Validation level 1 Accuracy: 74.08%
Validation Final Accuracy (with Level 2 fallback): 89.62%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 20/30:
Train Loss:  2.3105
Level 1 Accuracy: 71.84
Final Accuracy (with Level 2 fallback):  85.74%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.8729
Validation level 1 Accuracy: 74.14%
Validation Final Accuracy (with Level 2 fallback): 90.02%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 21/30:
Train Loss:  2.2930
Level 1 Accuracy: 72.09
Final Accuracy (with Level 2 fallback):  86.04%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.8149
Validation level 1 Accuracy: 74.79%
Validation Final Accuracy (with Level 2 fallback): 90.80%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 22/30:
Train Loss:  2.2754
Level 1 Accuracy: 71.97
Final Accuracy (with Level 2 fallback):  86.01%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.7923
Validation level 1 Accuracy: 74.81%
Validation Final Accuracy (with Level 2 fallback): 90.81%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 23/30:
Train Loss:  2.2657
Level 1 Accuracy: 72.10
Final Accuracy (with Level 2 fallback):  86.20%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.7769
Validation level 1 Accuracy: 75.08%
Validation Final Accuracy (with Level 2 fallback): 91.01%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 24/30:
Train Loss:  2.2285
Level 1 Accuracy: 72.28
Final Accuracy (with Level 2 fallback):  86.40%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.7664
Validation level 1 Accuracy: 75.13%
Validation Final Accuracy (with Level 2 fallback): 90.89%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 25/30:
Train Loss:  2.2131
Level 1 Accuracy: 72.33
Final Accuracy (with Level 2 fallback):  86.69%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.7468
Validation level 1 Accuracy: 75.33%
Validation Final Accuracy (with Level 2 fallback): 90.96%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 26/30:
Train Loss:  2.2101
Level 1 Accuracy: 72.38
Final Accuracy (with Level 2 fallback):  86.46%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.7274
Validation level 1 Accuracy: 75.47%
Validation Final Accuracy (with Level 2 fallback): 91.22%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 27/30:
Train Loss:  2.2019
Level 1 Accuracy: 72.61
Final Accuracy (with Level 2 fallback):  86.74%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.7228
Validation level 1 Accuracy: 75.21%
Validation Final Accuracy (with Level 2 fallback): 90.95%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 28/30:
Train Loss:  2.1725
Level 1 Accuracy: 72.62
Final Accuracy (with Level 2 fallback):  86.98%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.6835
Validation level 1 Accuracy: 75.67%
Validation Final Accuracy (with Level 2 fallback): 91.62%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 29/30:
Train Loss:  2.1681
Level 1 Accuracy: 72.70
Final Accuracy (with Level 2 fallback):  87.03%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.6803
Validation level 1 Accuracy: 75.62%
Validation Final Accuracy (with Level 2 fallback): 91.78%

Model saved.


Training...:   0%|          | 0/3338 [00:00<?, ?it/s]

Epoch 30/30:
Train Loss:  2.1432
Level 1 Accuracy: 72.60
Final Accuracy (with Level 2 fallback):  87.00%



Validate...:   0%|          | 0/668 [00:00<?, ?it/s]

Validation loss: 1.6454
Validation level 1 Accuracy: 75.78%
Validation Final Accuracy (with Level 2 fallback): 91.91%

Model saved.


In [51]:
new_model = MedicalSpecialistClassifer(num_specialists=len(specialist_encoder.classes_), user_feature_dim=2, load_pretrained=False, trust_remote_code=True)

In [26]:
import pickle

# Save encoders to pickle files
with open("specialist_encoder.pkl", "wb") as f:
    pickle.dump(specialist_encoder, f)
    
with open("age_encoder.pkl", "wb") as f:
    pickle.dump(age_encoder, f)
    
with open("gender_encoder.pkl", "wb") as f:
    pickle.dump(gender_encoder, f)


In [27]:
len(specialist_encoder.classes_)

18

In [70]:
new_model

MedicalSpecialistClassifer(
  (reason_encoder): NewModel(
    (embeddings): NewEmbeddings(
      (word_embeddings): Embedding(250048, 768, padding_idx=1)
      (rotary_emb): NTKScalingRotaryEmbedding()
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): NewEncoder(
      (layer): ModuleList(
        (0-11): 12 x NewLayer(
          (attention): NewSdpaAttention(
            (qkv_proj): Linear(in_features=768, out_features=2304, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp): NewGatedMLP(
            (up_gate_proj): Linear(in_features=768, out_features=6144, bias=False)
            (down_proj): Linear(in_features=3072, out_features=768, bias=True)
            (act_fn): GELUActivation()
            (hidden_dropout): Dropout(p=0.1, inp

In [71]:
new_model.load_state_dict(torch.load("../notebooks/best_model.pt"))

<All keys matched successfully>

In [72]:
new_model.eval()
reason_text = "lười ăn, không tăng cân, constipation"
data = tokenizer(
    reason_text,
    padding='max_length',
    truncation=True,
    max_length=128,
    return_tensors='pt'
)
age_category = "child"
gender = "male"

# Fix the tensor shape for user_info
user_info = torch.tensor([[gender_encoder.transform([gender])[0], age_encoder.transform([age_category])[0]]], dtype=torch.float32)

In [73]:
data, user_info

({'input_ids': tensor([[     0,     96, 150365,   6687,      4,    687,  11122,  24376,      4,
             158,      7,  30019,   2320,      2,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
  

In [74]:
new_model.predict(data['input_ids'], data['attention_mask'], user_info=user_info)

tensor([5])

In [78]:
res = specialist_encoder.inverse_transform([new_model.predict(data['input_ids'], data['attention_mask'], user_info)])

  y = column_or_1d(y, warn=True)


In [80]:
res.tolist()

['nhi khoa']