## Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.insert(0, '..')

In [45]:
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig

In [4]:
df = pd.read_csv('../data/data_version1/combined_data.csv')
df.head()

Unnamed: 0,gender,age,reason_combind,specialist_name,source_file,age_category
0,unknown,57,mất ngủ,thần kinh,reason_specialist - thần kinh.csv,adult
1,unknown,35,rối loạn thần kinh thực vật,thần kinh,reason_specialist - thần kinh.csv,adult
2,unknown,36,đau đầu,thần kinh,reason_specialist - thần kinh.csv,adult
3,unknown,40,"đau đầu,đau sau ngực gần phổi",thần kinh,reason_specialist - thần kinh.csv,adult
4,unknown,12,co giật 3 lần,thần kinh,reason_specialist - thần kinh.csv,child


In [5]:
df.columns, df.gender.unique(), df.age_category.unique(), df.specialist_name.unique()

(Index(['gender', 'age', 'reason_combind', 'specialist_name', 'source_file',
        'age_category'],
       dtype='object'),
 array(['unknown', 'male', 'female'], dtype=object),
 array(['adult', 'child', 'unknown'], dtype=object),
 array(['thần kinh', 'vô sinh - hiếm muộn', 'nhi khoa', 'thận - tiết niệu',
        'ung bướu', 'hô hấp - phổi', 'chuyên khoa mắt', 'cơ xương khớp',
        'tim mạch', 'tiêu hoá', 'sức khỏe tâm thần', 'nội khoa',
        'tiểu đường - nội tiết', 'tai mũi họng', 'nam học', 'da liễu',
        'sản phụ khoa'], dtype=object))

In [6]:
# Extract data for model training
# Take 1000 examples per specialist_name
specialist_samples = {}

# Get list of unique specialists
specialists = df['specialist_name'].unique()

# Sample 1000 examples per specialist (or all if fewer than 1000)
balanced_df = pd.DataFrame()
for specialist in specialists:
    specialist_df = df[df['specialist_name'] == specialist]
    if len(specialist_df) > 1000:
        sampled_df = specialist_df.sample(300, random_state=42)
    else:
        sampled_df = specialist_df
    balanced_df = pd.concat([balanced_df, sampled_df])

# Reset index of the final balanced dataset
balanced_df = balanced_df.reset_index(drop=True)
print(f"Total samples in balanced dataset: {len(balanced_df)}")
print(balanced_df['specialist_name'].value_counts())

Total samples in balanced dataset: 5864
specialist_name
tiểu đường - nội tiết    761
hô hấp - phổi            650
thần kinh                300
ung bướu                 300
nhi khoa                 300
cơ xương khớp            300
chuyên khoa mắt          300
tim mạch                 300
thận - tiết niệu         300
tai mũi họng             300
tiêu hoá                 300
sức khỏe tâm thần        300
nội khoa                 300
da liễu                  300
nam học                  300
sản phụ khoa             300
vô sinh - hiếm muộn      253
Name: count, dtype: int64


In [7]:
balanced_df

Unnamed: 0,gender,age,reason_combind,specialist_name,source_file,age_category
0,female,73,"tay chân run lâu ngày, đau đầu",thần kinh,reason_specialist - thần kinh.csv,adult
1,female,57,"đau đầu, rối loạn tiền đình",thần kinh,reason_specialist - thần kinh.csv,adult
2,female,0,bị tai nạn mờ mắt,thần kinh,reason_specialist - thần kinh.csv,unknown
3,female,0,"u tuyến yên 9,2*9,7mm(bv tim và bv phụ sản tru...",thần kinh,reason_specialist - thần kinh.csv,unknown
4,female,44,"nghi ngờ tuyến yên, đau đầu",thần kinh,reason_specialist - thần kinh.csv,adult
...,...,...,...,...,...,...
5859,female,29,khám thai,sản phụ khoa,reason_specialist - sản phụ khoa.csv,adult
5860,female,49,kiem tra dinh ky vu va phu khoa,sản phụ khoa,reason_specialist - sản phụ khoa.csv,adult
5861,female,32,"ra nhiều khí hư màu vàng, ngứa, có mùi. ngoài ...",sản phụ khoa,reason_specialist - sản phụ khoa.csv,adult
5862,female,29,"kinh nguyệt không đều, chảy máu ngoài chu kỳ k...",sản phụ khoa,reason_specialist - sản phụ khoa.csv,adult


In [68]:
class MedicalSpecialistClassifer(nn.Module):
    def __init__(self, num_specialists: int, model_name: str = "BookingCare/gte-multilingual-base-v2.1", user_feature_dim=2, dropout=0.2, load_pretrained=True, trust_remote_code=False):
        super(MedicalSpecialistClassifer, self).__init__()

        config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
        if load_pretrained:
            self.reason_encoder = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
        else:
            self.reason_encoder = AutoModel.from_config(config, trust_remote_code=True)
        
        for param in self.reason_encoder.parameters():
            param.requires_grad = False
        self.reason_encoder_hidden_dim = self.reason_encoder.config.hidden_size
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
        
        self.user_encoder = nn.Sequential(
            nn.Linear(user_feature_dim, 128),
            nn.BatchNorm1d(128),
            self.relu,
            self.dropout,
            nn.Linear(128, 256),
            self.relu
        )

        self.hidden_layer1 = nn.Sequential(
            nn.Linear(self.reason_encoder_hidden_dim, 512),
            nn.BatchNorm1d(512),
            self.relu,
            self.dropout,
        )

        self.level1_output = nn.Linear(512, num_specialists)

        self.hidden_layer2 = nn.Sequential(
            nn.Linear(self.reason_encoder_hidden_dim + 256 + 512, 768),
            nn.BatchNorm1d(768),
            self.relu,
            self.dropout
        )
        self.level2_output = nn.Linear(768, num_specialists)

    def forward(self, reason_text_ids, reason_text_mask, user_info=None):
        reason_outputs = self.reason_encoder(reason_text_ids, attention_mask=reason_text_mask)
        reason_embedding = reason_outputs.last_hidden_state[:, 0, :]

        hidden1 = self.hidden_layer1(reason_embedding)

        level1_logits = self.level1_output(hidden1)

        if user_info is None:
            return level1_logits, None

        user_features = self.user_encoder(user_info)

        combined_features = torch.cat((reason_embedding, hidden1, user_features), dim=1)

        hidden2 = self.hidden_layer2(combined_features)
        level2_logits = self.level2_output(hidden2)

        return level1_logits, level2_logits

    def predict(self, reason_text_ids, reason_text_mask, user_info=None, threshold=0.7):
        self.eval()
        with torch.no_grad():
            level1_logits, level2_logits = self.forward(reason_text_ids, reason_text_mask, user_info)
            level1_probs = F.softmax(level1_logits, dim=1)
            max_probs, preds = torch.max(level1_probs, dim=1)
            final_preds = preds.clone()

            if user_info is not None and level2_logits is not None:
                for i, prob in enumerate(max_probs):
                    final_preds[i] = torch.argmax(level2_logits[i])

            return final_preds

In [9]:
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder

class MedicalDataFrameDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = 128
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        reason_text = row['reason_combind']
        user_info = torch.tensor([row['gender'], row['age_category']], dtype=torch.float32)
        label = row['specialist_name']
        encoded = self.tokenizer(
            reason_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            "reason_text_ids": encoded['input_ids'].squeeze(0),
            "reason_text_mask": encoded['attention_mask'].squeeze(0),
            "user_info": user_info,
            "labels": torch.tensor(label, dtype=torch.long)
        }

In [10]:
from tqdm.notebook import tqdm

def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=5, patience=2, save_path="best_model.pt"):
    model.to(device)
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0
        correct_level1 = 0
        correct_combined = 0
        total = 0

        for batch in tqdm(train_loader, desc="Training..."):
            reason_text_ids = batch['reason_text_ids'].to(device)
            reason_text_mask = batch['reason_text_mask'].to(device)
            user_info = batch.get('user_info', None)

            if user_info is not None:
                user_info = user_info.to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            level1_logits, level2_logits = model(
                reason_text_ids,
                reason_text_mask,
                user_info
            )

            _, level1_preds = torch.max(level1_logits, dim=1)
            correct_mask = (level1_preds == labels)
            loss = criterion(level1_logits, labels)

            if level2_logits is not None and (~correct_mask).any():
                incorrect_indices = (~correct_mask).nonzero(as_tuple=True)[0]
                level2_loss = criterion(level2_logits[incorrect_indices], labels[incorrect_indices])
                loss += level2_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            correct_level1 += correct_mask.sum().item()

            if level2_logits is not None:
                _, level2_preds = torch.max(level2_logits, dim=1)
                final_preds = torch.where(correct_mask, level1_preds, level2_preds)
            else:
                final_preds = level1_preds

            correct_combined += (final_preds == labels).sum().item()
            total += labels.size(0)

        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"Train Loss: {running_loss / len(train_loader): .4f}")
        print(f"Level 1 Accuracy: {100 * correct_level1 / total:.2f}")
        print(f"Final Accuracy (with Level 2 fallback): {100 * correct_combined / total: .2f}%\n")

        # validation
        model.eval()
        val_loss = 0.0
        correct_level1 = 0
        correct_combined = 0
        total = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validate...'):
                reason_text_ids = batch['reason_text_ids'].to(device)
                reason_text_mask = batch['reason_text_mask'].to(device)
                user_info = batch.get('user_info', None)
                if user_info is not None:
                    user_info = user_info.to(device)

                    for p in model.hidden_layer2.parameters():
                        p.requires_grad = True
                    for p in model.level2_output.parameters():
                        p.requires_grad = True
                else:
                    for p in model.hidden_layer2.parameters():
                        p.requires_grad = False
                    for p in model.level2_output.parameters():
                        p.requires_grad = False
                labels = batch['labels'].to(device)

                level1_logits, level2_logits = model(
                    reason_text_ids,
                    reason_text_mask,
                    user_info
                )

                _, level1_preds = torch.max(level1_logits, dim=1)
                correct_mask = (level1_preds == labels)
                loss = criterion(level1_logits, labels)
                if level2_logits is not None and (~correct_mask).any():
                    incorrect_indices = (~correct_mask).nonzero(as_tuple=True)[0]
                    level2_loss = criterion(level2_logits[incorrect_indices], labels[incorrect_indices])
                    loss += level2_loss

                val_loss += loss.item()
                correct_level1 += correct_mask.sum().item()

                if level2_logits is not None:
                    _, level2_preds = torch.max(level2_logits, dim=1)
                    final_preds = torch.where(correct_mask, level1_preds, level2_preds)
                else:
                    final_preds = level1_preds

                correct_combined += (final_preds == labels).sum().item()
                total += labels.size(0)
        
        avg_val_loss = val_loss / len(val_loader)
        print(f"Validation loss: {avg_val_loss:.4f}")
        print(f"Validation level 1 Accuracy: {100 * correct_level1 / total:.2f}%")
        print(f"Validation Final Accuracy (with Level 2 fallback): {100 * correct_combined / total:.2f}%\n")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
            print("Model saved.")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break                

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
device

device(type='cuda')

In [13]:
tokenizer = AutoTokenizer.from_pretrained('../models/gte/')

In [14]:
specialist_encoder = LabelEncoder()
age_encoder = LabelEncoder()
gender_encoder = LabelEncoder()
balanced_df['age_category'] = age_encoder.fit_transform(balanced_df['age_category'])
balanced_df['gender'] = gender_encoder.fit_transform(balanced_df['gender'])
balanced_df['specialist_name'] = specialist_encoder.fit_transform(balanced_df['specialist_name'])

In [15]:
dataset = MedicalDataFrameDataset(balanced_df, tokenizer)

In [16]:
dataset.__getitem__(0)

{'reason_text_ids': tensor([    0,  6329, 18973, 11675, 25825,  3063,     4, 24859,  2494,     2,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,    

In [17]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [18]:
model = MedicalSpecialistClassifer(model_name = "../models/gte/",num_specialists=len(specialist_encoder.classes_), user_feature_dim=2)

In [19]:
model

MedicalSpecialistClassifer(
  (reason_encoder): NewModel(
    (embeddings): NewEmbeddings(
      (word_embeddings): Embedding(250048, 768, padding_idx=1)
      (rotary_emb): NTKScalingRotaryEmbedding()
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): NewEncoder(
      (layer): ModuleList(
        (0-11): 12 x NewLayer(
          (attention): NewSdpaAttention(
            (qkv_proj): Linear(in_features=768, out_features=2304, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp): NewGatedMLP(
            (up_gate_proj): Linear(in_features=768, out_features=6144, bias=False)
            (down_proj): Linear(in_features=3072, out_features=768, bias=True)
            (act_fn): GELUActivation()
            (hidden_dropout): Dropout(p=0.1, inp

In [20]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

In [21]:
criterion = nn.CrossEntropyLoss()

In [22]:
train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=30, patience=3)

  0%|          | 0/30 [00:00<?, ?it/s]

Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 1/30:
Train Loss:  4.7374
Level 1 Accuracy: 33.34
Final Accuracy (with Level 2 fallback):  54.84%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 4.0079
Validation level 1 Accuracy: 56.86%
Validation Final Accuracy (with Level 2 fallback): 72.55%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 2/30:
Train Loss:  3.8877
Level 1 Accuracy: 56.11
Final Accuracy (with Level 2 fallback):  72.12%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 3.4837
Validation level 1 Accuracy: 62.83%
Validation Final Accuracy (with Level 2 fallback): 78.26%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 3/30:
Train Loss:  3.5952
Level 1 Accuracy: 61.41
Final Accuracy (with Level 2 fallback):  75.27%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 3.2384
Validation level 1 Accuracy: 65.81%
Validation Final Accuracy (with Level 2 fallback): 79.88%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 4/30:
Train Loss:  3.4004
Level 1 Accuracy: 63.78
Final Accuracy (with Level 2 fallback):  77.20%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 3.0541
Validation level 1 Accuracy: 66.58%
Validation Final Accuracy (with Level 2 fallback): 80.90%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 5/30:
Train Loss:  3.2556
Level 1 Accuracy: 64.43
Final Accuracy (with Level 2 fallback):  77.76%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.9233
Validation level 1 Accuracy: 67.01%
Validation Final Accuracy (with Level 2 fallback): 82.78%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 6/30:
Train Loss:  3.1808
Level 1 Accuracy: 64.63
Final Accuracy (with Level 2 fallback):  78.27%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.8110
Validation level 1 Accuracy: 68.12%
Validation Final Accuracy (with Level 2 fallback): 82.61%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 7/30:
Train Loss:  3.0775
Level 1 Accuracy: 66.05
Final Accuracy (with Level 2 fallback):  79.67%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.7269
Validation level 1 Accuracy: 69.65%
Validation Final Accuracy (with Level 2 fallback): 84.57%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 8/30:
Train Loss:  3.0175
Level 1 Accuracy: 66.56
Final Accuracy (with Level 2 fallback):  80.71%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.6276
Validation level 1 Accuracy: 69.74%
Validation Final Accuracy (with Level 2 fallback): 86.02%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 9/30:
Train Loss:  2.9254
Level 1 Accuracy: 67.19
Final Accuracy (with Level 2 fallback):  81.21%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.5484
Validation level 1 Accuracy: 69.65%
Validation Final Accuracy (with Level 2 fallback): 85.93%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 10/30:
Train Loss:  2.9076
Level 1 Accuracy: 66.73
Final Accuracy (with Level 2 fallback):  81.36%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.4932
Validation level 1 Accuracy: 70.08%
Validation Final Accuracy (with Level 2 fallback): 87.04%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 11/30:
Train Loss:  2.8206
Level 1 Accuracy: 67.85
Final Accuracy (with Level 2 fallback):  82.44%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.4397
Validation level 1 Accuracy: 70.76%
Validation Final Accuracy (with Level 2 fallback): 87.72%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 12/30:
Train Loss:  2.7705
Level 1 Accuracy: 67.70
Final Accuracy (with Level 2 fallback):  83.27%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.3724
Validation level 1 Accuracy: 71.10%
Validation Final Accuracy (with Level 2 fallback): 88.24%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 13/30:
Train Loss:  2.7072
Level 1 Accuracy: 67.77
Final Accuracy (with Level 2 fallback):  83.25%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.3138
Validation level 1 Accuracy: 71.01%
Validation Final Accuracy (with Level 2 fallback): 89.26%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 14/30:
Train Loss:  2.6743
Level 1 Accuracy: 68.37
Final Accuracy (with Level 2 fallback):  84.06%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.2554
Validation level 1 Accuracy: 71.44%
Validation Final Accuracy (with Level 2 fallback): 89.77%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 15/30:
Train Loss:  2.6853
Level 1 Accuracy: 68.40
Final Accuracy (with Level 2 fallback):  83.99%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.2108
Validation level 1 Accuracy: 71.44%
Validation Final Accuracy (with Level 2 fallback): 90.79%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 16/30:
Train Loss:  2.6309
Level 1 Accuracy: 68.83
Final Accuracy (with Level 2 fallback):  84.07%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.1716
Validation level 1 Accuracy: 71.95%
Validation Final Accuracy (with Level 2 fallback): 91.39%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 17/30:
Train Loss:  2.5705
Level 1 Accuracy: 69.25
Final Accuracy (with Level 2 fallback):  85.20%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.1281
Validation level 1 Accuracy: 72.29%
Validation Final Accuracy (with Level 2 fallback): 91.05%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 18/30:
Train Loss:  2.5535
Level 1 Accuracy: 69.29
Final Accuracy (with Level 2 fallback):  85.59%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.0819
Validation level 1 Accuracy: 71.87%
Validation Final Accuracy (with Level 2 fallback): 91.30%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 19/30:
Train Loss:  2.5480
Level 1 Accuracy: 69.15
Final Accuracy (with Level 2 fallback):  85.52%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 2.0201
Validation level 1 Accuracy: 72.55%
Validation Final Accuracy (with Level 2 fallback): 92.92%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 20/30:
Train Loss:  2.4663
Level 1 Accuracy: 69.80
Final Accuracy (with Level 2 fallback):  86.15%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.9872
Validation level 1 Accuracy: 72.89%
Validation Final Accuracy (with Level 2 fallback): 93.01%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 21/30:
Train Loss:  2.4576
Level 1 Accuracy: 69.97
Final Accuracy (with Level 2 fallback):  86.53%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.9511
Validation level 1 Accuracy: 72.89%
Validation Final Accuracy (with Level 2 fallback): 93.27%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 22/30:
Train Loss:  2.4266
Level 1 Accuracy: 69.49
Final Accuracy (with Level 2 fallback):  86.26%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.9194
Validation level 1 Accuracy: 73.06%
Validation Final Accuracy (with Level 2 fallback): 92.75%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 23/30:
Train Loss:  2.3813
Level 1 Accuracy: 70.53
Final Accuracy (with Level 2 fallback):  87.62%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.8942
Validation level 1 Accuracy: 73.06%
Validation Final Accuracy (with Level 2 fallback): 93.52%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 24/30:
Train Loss:  2.3636
Level 1 Accuracy: 69.97
Final Accuracy (with Level 2 fallback):  87.43%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.8475
Validation level 1 Accuracy: 73.23%
Validation Final Accuracy (with Level 2 fallback): 94.03%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 25/30:
Train Loss:  2.3461
Level 1 Accuracy: 70.41
Final Accuracy (with Level 2 fallback):  87.76%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.7969
Validation level 1 Accuracy: 73.57%
Validation Final Accuracy (with Level 2 fallback): 94.29%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 26/30:
Train Loss:  2.3041
Level 1 Accuracy: 71.06
Final Accuracy (with Level 2 fallback):  88.32%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.7560
Validation level 1 Accuracy: 73.91%
Validation Final Accuracy (with Level 2 fallback): 94.54%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 27/30:
Train Loss:  2.2685
Level 1 Accuracy: 71.04
Final Accuracy (with Level 2 fallback):  88.47%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.7222
Validation level 1 Accuracy: 74.60%
Validation Final Accuracy (with Level 2 fallback): 95.06%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 28/30:
Train Loss:  2.2585
Level 1 Accuracy: 71.09
Final Accuracy (with Level 2 fallback):  88.28%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.6924
Validation level 1 Accuracy: 74.17%
Validation Final Accuracy (with Level 2 fallback): 94.80%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 29/30:
Train Loss:  2.2488
Level 1 Accuracy: 71.66
Final Accuracy (with Level 2 fallback):  88.85%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.6410
Validation level 1 Accuracy: 74.42%
Validation Final Accuracy (with Level 2 fallback): 95.74%

Model saved.


Training...:   0%|          | 0/367 [00:00<?, ?it/s]

Epoch 30/30:
Train Loss:  2.1954
Level 1 Accuracy: 71.73
Final Accuracy (with Level 2 fallback):  89.34%



Validate...:   0%|          | 0/74 [00:00<?, ?it/s]

Validation loss: 1.6189
Validation level 1 Accuracy: 75.45%
Validation Final Accuracy (with Level 2 fallback): 96.42%

Model saved.


In [69]:
new_model = MedicalSpecialistClassifer(num_specialists=len(specialist_encoder.classes_), user_feature_dim=2, load_pretrained=False, trust_remote_code=True)

In [77]:
import pickle

# Save encoders to pickle files
with open("specialist_encoder.pkl", "wb") as f:
    pickle.dump(specialist_encoder, f)
    
with open("age_encoder.pkl", "wb") as f:
    pickle.dump(age_encoder, f)
    
with open("gender_encoder.pkl", "wb") as f:
    pickle.dump(gender_encoder, f)


In [76]:
len(specialist_encoder.classes_)

17

In [70]:
new_model

MedicalSpecialistClassifer(
  (reason_encoder): NewModel(
    (embeddings): NewEmbeddings(
      (word_embeddings): Embedding(250048, 768, padding_idx=1)
      (rotary_emb): NTKScalingRotaryEmbedding()
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): NewEncoder(
      (layer): ModuleList(
        (0-11): 12 x NewLayer(
          (attention): NewSdpaAttention(
            (qkv_proj): Linear(in_features=768, out_features=2304, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp): NewGatedMLP(
            (up_gate_proj): Linear(in_features=768, out_features=6144, bias=False)
            (down_proj): Linear(in_features=3072, out_features=768, bias=True)
            (act_fn): GELUActivation()
            (hidden_dropout): Dropout(p=0.1, inp

In [71]:
new_model.load_state_dict(torch.load("../notebooks/best_model.pt"))

<All keys matched successfully>

In [72]:
new_model.eval()
reason_text = "lười ăn, không tăng cân, constipation"
data = tokenizer(
    reason_text,
    padding='max_length',
    truncation=True,
    max_length=128,
    return_tensors='pt'
)
age_category = "child"
gender = "male"

# Fix the tensor shape for user_info
user_info = torch.tensor([[gender_encoder.transform([gender])[0], age_encoder.transform([age_category])[0]]], dtype=torch.float32)

In [73]:
data, user_info

({'input_ids': tensor([[     0,     96, 150365,   6687,      4,    687,  11122,  24376,      4,
             158,      7,  30019,   2320,      2,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
               1,      1,      1,      1,      1,      1,      1,      1,      1,
  

In [74]:
new_model.predict(data['input_ids'], data['attention_mask'], user_info=user_info)

tensor([5])

In [78]:
res = specialist_encoder.inverse_transform([new_model.predict(data['input_ids'], data['attention_mask'], user_info)])

  y = column_or_1d(y, warn=True)


In [80]:
res.tolist()

['nhi khoa']