In [1]:
import json
import csv

In [2]:
from datasets import load_dataset
import pandas as pd
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, Dataset 
import torch
import torch.nn as nn
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data = load_dataset("antareepdey/Patient_doctor_chat")
datasets = [i['Text'] for i in data['train']]
patient_query = []
doctor_response = []

for i in range(1000):
    inputs, outputs = datasets[i].split("###Output:")
    patient_query.append(inputs.replace("###Input:",""))
    doctor_response.append(outputs)
data = {"Patient query": patient_query, "Doctor response": doctor_response}
df = pd.DataFrame(data=data)
def preprocess_data(text):
    # preprocess
    text = text.lower()
    text = text.replace('?','')
    text = text.replace("'","")
    text = text.replace(","," ")
    text = text.replace("1)"," ")
    text = text.replace("2)"," ")
    text = text.replace("3)"," ")
    text = text.replace("4)"," ")
    text = text.replace("."," ")
    text = text.strip()
    return text

df['Patient query'] = df['Patient query'].apply(preprocess_data)
df['Doctor response'] = df['Doctor response'].apply(preprocess_data)

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
embed_model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')



In [4]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.max_length = 8192

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        question = self.data.iloc[index]['Patient query']
        answer = self.data.iloc[index]['Doctor response']

        question_ids = tokenizer(
            question,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )['input_ids'][0] 

        answer_ids = tokenizer(
            answer,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )['input_ids'][0]

        return question_ids, answer_ids


In [None]:
# def collate_fn(batch):
#     questions, answers = zip(*batch) 

#     padded_questions = pad_sequence(questions, batch_first=True, padding_value=tokenizer.pad_token_id)
#     padded_answers = pad_sequence(answers, batch_first=True, padding_value=tokenizer.pad_token_id)

#     return padded_questions, padded_answers

In [5]:
dataset = CustomDataset(df) 

In [6]:
train_loader = DataLoader( dataset, batch_size=1, shuffle=True)

In [7]:
with open('../appointments_data.json', 'r') as f:
    key_value_data = json.load(f)

In [8]:
key_value_data

{'Lauren_doctor': 'Dr. Wilson',
 'Lauren_appointment_date': '2026-03-27',
 'Lauren_checkup_department': 'Psychiatry',
 'Lauren_hospital_location': 'City Hospital',
 'Lauren_contact': '(557)935-7463x543',
 'Lauren_appointment_time': '16:25',
 'Amber_doctor': 'Dr. Wilson',
 'Amber_appointment_date': '2025-10-11',
 'Amber_checkup_department': 'Orthopedics',
 'Amber_hospital_location': 'City Hospital',
 'Amber_contact': '+1-502-840-1548x530',
 'Amber_appointment_time': '12:51',
 'Savannah_doctor': 'Dr. Anderson',
 'Savannah_appointment_date': '2026-05-28',
 'Savannah_checkup_department': 'Pediatrics',
 'Savannah_hospital_location': 'Mountain View Hospital',
 'Savannah_contact': '340.929.8569x91127',
 'Savannah_appointment_time': '01:55',
 'April_doctor': 'Dr. Smith',
 'April_appointment_date': '2026-04-11',
 'April_checkup_department': 'Dentist',
 'April_hospital_location': 'City Hospital',
 'April_contact': '001-328-824-4066x8146',
 'April_appointment_time': '18:01',
 'Laurie_doctor': 'Dr

In [9]:
with open('../qa_pairs.json' , 'r') as f:
    db_data = json.load(f)

In [10]:
patient , assistant = [ preprocess_data(i['question']) for i in db_data['Q&A'] ], [ preprocess_data(i['answer']) for i in db_data['Q&A'] ]


In [11]:
len(patient[-100])

20

In [12]:
new_key_value_data = {}
for key, value in key_value_data.items():
    new_key = preprocess_data(key)
    new_value = preprocess_data(value)
    new_key_value_data[new_key] = new_value


In [13]:
new_key_value_data

{'lauren_doctor': 'dr  wilson',
 'lauren_appointment_date': '2026-03-27',
 'lauren_checkup_department': 'psychiatry',
 'lauren_hospital_location': 'city hospital',
 'lauren_contact': '(557)935-7463x543',
 'lauren_appointment_time': '16:25',
 'amber_doctor': 'dr  wilson',
 'amber_appointment_date': '2025-10-11',
 'amber_checkup_department': 'orthopedics',
 'amber_hospital_location': 'city hospital',
 'amber_contact': '+1-502-840-1548x530',
 'amber_appointment_time': '12:51',
 'savannah_doctor': 'dr  anderson',
 'savannah_appointment_date': '2026-05-28',
 'savannah_checkup_department': 'pediatrics',
 'savannah_hospital_location': 'mountain view hospital',
 'savannah_contact': '340 929 8569x91127',
 'savannah_appointment_time': '01:55',
 'april_doctor': 'dr  smith',
 'april_appointment_date': '2026-04-11',
 'april_checkup_department': 'dentist',
 'april_hospital_location': 'city hospital',
 'april_contact': '001-328-824-4066x8146',
 'april_appointment_time': '18:01',
 'laurie_doctor': 'dr

In [14]:
assistant

['ryan is visiting the dentist department',
 'the appointment time of linda is 15:01',
 'the appointment time of jason is 19:49',
 'meghan is visiting the dermatology department',
 'jennifer is having their appointment at city hospital',
 'the doctor for timothy is dr  johnson',
 'ashley is having their appointment at sunshine medical center',
 'the appointment time of jill is 21:07',
 'amber is visiting the orthopedics department',
 'lauren is visiting the psychiatry department',
 'andrew is visiting the neurology department',
 'the contact number for brenda is 744 670 0051x238',
 'jennifer is visiting the radiology department',
 'sherry is visiting the dermatology department',
 'the appointment time of sonia is 08:35',
 'the contact number for maria is +1-459-688-7610x1277',
 'heather is having their appointment at global care hospital',
 'the doctor for laurie is dr  smith',
 'the appointment time of blake is 12:56',
 'stephanie is having their appointment at city hospital',
 'tyler

In [15]:
def replace_values_with_keys(texts, mapping):
    for i, text in enumerate(texts):
        for key, val in mapping.items():
            name = key.split('_')[0]
            if name in text and val in text:
                texts[i] = text.replace(val, key)
                break
    return texts


In [16]:
assistant = replace_values_with_keys(assistant, new_key_value_data)

In [17]:
len(assistant)

164000

In [18]:
all_keys = [ i for i in new_key_value_data.keys()]

In [19]:
tokenizer.add_tokens(all_keys)

492

In [20]:
embed_model.resize_token_embeddings(len(tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(152161, 1024)

In [21]:
used_token_ids = set()
for target_text in all_keys:
    tokens = tokenizer.encode(target_text, add_special_tokens=False)
    used_token_ids.update(tokens)
    
token_id_list = sorted(list(used_token_ids))
old2new = {old: new for new, old in enumerate(token_id_list)}
new2old = {v: k for k, v in old2new.items()}

In [22]:
db_data_dict = {'Patient query': patient, 'Doctor response': assistant}
print(len(patient))
db_df = pd.DataFrame(db_data_dict)
db_df.head()

164000


Unnamed: 0,Patient query,Doctor response
0,what department is ryan visiting,ryan is visiting the ryan_checkup_department d...
1,what time is lindas appointment,the appointment time of linda is linda_appoint...
2,what time is jasons appointment,the appointment time of jason is jason_appoint...
3,what department is meghan visiting,meghan is visiting the meghan_checkup_departme...
4,where is jennifer having their appointment,jennifer is having their appointment at jennif...


In [23]:
db_df['Patient query'] = db_df['Patient query'].apply(preprocess_data)
db_df['Doctor response'] = db_df['Doctor response'].apply(preprocess_data)

In [24]:
db_df.head()

Unnamed: 0,Patient query,Doctor response
0,what department is ryan visiting,ryan is visiting the ryan_checkup_department d...
1,what time is lindas appointment,the appointment time of linda is linda_appoint...
2,what time is jasons appointment,the appointment time of jason is jason_appoint...
3,what department is meghan visiting,meghan is visiting the meghan_checkup_departme...
4,where is jennifer having their appointment,jennifer is having their appointment at jennif...


In [25]:
db_df.shape

(164000, 2)

In [26]:
dataset_db = CustomDataset(db_df)

In [23]:
dataset_db.__len__()

164000

In [27]:
len(dataset_db)

164000

In [28]:
train_loader_db = DataLoader(dataset_db , batch_size=1, shuffle=True)

In [29]:
for x, y in train_loader_db:
    print(x.shape)
    break

torch.Size([1, 9])


In [30]:
class EncoderDecoderWithKBAttention(nn.Module):
    def __init__(self, embed_model, vocab_size, hidden_dim, embedding_dim=1024):
        super().__init__()
        self.embed_model = embed_model
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, dropout=0.2, batch_first=True, bidirectional=True)
        self.decoder = nn.LSTMCell(embedding_dim, hidden_dim * 2)

        # Attention over encoder outputs
        self.W_denc1 = nn.Linear(hidden_dim * 4, hidden_dim)
        self.W_denc2 = nn.Linear(hidden_dim, hidden_dim)
        self.w = nn.Linear(hidden_dim, 1)

        # Attention over KB keys (commented out, unused)
        self.W_kb1 = nn.Linear(hidden_dim * 2, embedding_dim)

        # Final vocab projection
        self.U = nn.Linear(hidden_dim * 4, vocab_size)

        self.vocab_size = vocab_size

    def decoder_attention(self, decoder_hidden, encoder_hidden):
        B, T, H = encoder_hidden.shape                               # B=batch, T=encoder seq len, H=hidden size
        decoder_exp = decoder_hidden.unsqueeze(1).expand(-1, T, -1)  # (B, T, H)
        combined = torch.cat([encoder_hidden, decoder_exp], dim=2)   # (B, T, 2H)
        x = torch.tanh(self.W_denc1(combined))                       # (B, T, H)
        x = torch.tanh(self.W_denc2(x))                              # (B, T, H)
        u = self.w(x).squeeze(-1)                                    # (B, T)
        attn = torch.softmax(u, dim=1)                               # (B, T)

        context = torch.bmm(attn.unsqueeze(1), encoder_hidden).squeeze(1)  # (B, H)
        concat = torch.cat([decoder_hidden, context], dim=1)          # (B, 2H)
        vocab_logits = self.U(concat)                                 # (B, vocab_size)
        return vocab_logits
    

    
    def kb_attention(self, decoder_hidden, kb_keys):
        B, _ = decoder_hidden.shape # (B,H)
        x = torch.tanh(self.W_kb1(decoder_hidden))  #(B, E)
       
        x_norm = F.normalize(x, p=2, dim=1)              #(B, E)
        kb_key_norm = F.normalize(kb_keys, p=2, dim=1)    #(N, E)
       

        # similarity
        cosine_sim = torch.matmul(x_norm, kb_key_norm.T)  #(B, N)
        cosine_sim = cosine_sim.squeeze(0)
        
        kb_atten_score_dict = {new2old[idx]:sim for idx, sim in enumerate(cosine_sim)} #(token_id, similarity)
   
        kb_attention = torch.zeros(B, self.vocab_size).to('cuda')


        for idx, val in kb_atten_score_dict.items():
            kb_attention[:, idx] += val

        return kb_attention
    
    


    def forward(self, inputs, targets, kb_keys, teacher_forcing_ratio=1.0):
    

        batch_size, T_out, _ = targets.shape

        enc_out, (h_enc, c_enc) = self.encoder(inputs)  # enc_out: (B, T_in, H)
        hidden = torch.cat([h_enc[0], h_enc[-1]], dim=-1)  # (B, hidden_dim*2)
        cell = torch.cat([c_enc[0], c_enc[-1]], dim=-1)    # (B, hidden_dim*2)

        logits = []

        dec_input = targets[:, 0, :]  

        for t in range(1, T_out):
            hidden_state, cell_state = self.decoder(dec_input, (hidden, cell))
            hidden = hidden_state
            cell = cell_state
        
         

            hidden_logits = self.decoder_attention(hidden_state, enc_out)
            
            kb_attention_logits = self.kb_attention(hidden_state, kb_keys).to('cuda')
            #print(kb_attention_logits.shape)
            hidden_logits = kb_attention_logits + hidden_logits
            #print(kb_attention_logits.shape)
       
            logits.append(hidden_logits)

            # Decide whether to do teacher forcing this step
            use_teacher_forcing = (torch.rand(1).item() < teacher_forcing_ratio)

            if use_teacher_forcing:
                # Use ground-truth target embedding for next input
                dec_input = targets[:, t, :]
            else:
                # Use model prediction:
                pred_tokens = torch.argmax(hidden_logits, dim=1)  # (B,)
                with torch.no_grad():
                    # Get embeddings for predicted tokens
                    # Assuming embed_model accepts token IDs and returns embeddings
                    dec_input = self.embed_model(pred_tokens)['last_hidden_state'][:, 0, :]

        return torch.stack(logits, dim=1)  # (B, T_out-1, vocab_size)


In [31]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [32]:
start_token = tokenizer('<|im_start|>', return_tensors='pt')['input_ids'][:,0].unsqueeze(0).to(device)

In [33]:
start_token

tensor([[151644]], device='cuda:0')

In [34]:
vocab_size = len(tokenizer.vocab)
model = EncoderDecoderWithKBAttention(embed_model, vocab_size, 320, embedding_dim=1024).to(device)




In [35]:
model

EncoderDecoderWithKBAttention(
  (embed_model): Qwen3Model(
    (embed_tokens): Embedding(152161, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attent

In [36]:
keys_embed = []
for key in all_keys:
    key = tokenizer(key, return_tensors='pt')['input_ids'][:,:-1]
    key_embed = embed_model(key.to(device))['last_hidden_state'][0]

    keys_embed.append(key_embed)

keys_embed = torch.stack(keys_embed, dim=0).squeeze(1).detach()


In [37]:
for input , target in train_loader_db:
    input = input.to(device)
    target = target.to(device)
    input = input[:, :-1]
    target_with_sos = torch.cat([start_token, target], dim=1)
    with torch.no_grad():
        input_embeds = embed_model(input)['last_hidden_state']
        target_embeds = embed_model(target_with_sos)['last_hidden_state']
    logits = model(input_embeds, target_embeds,  kb_keys=keys_embed)


    break

In [38]:
learning_rate = 0.001
epochs = 25
criterion = nn.CrossEntropyLoss()
#weight_decay = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
train_losses = []
train_accuracies = []

monitor_csv = 'kb_monitor.csv'

with open(monitor_csv, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['epoch', 'train_loss', 'train_accuracy'])

for epoch in range(epochs):
    model.train()
    total_train_loss = 0.0
    total_train_correct = 0
    total_train_tokens = 0

    for input_texts, target_tokens in tqdm(train_loader_db, desc=f"Training Epoch {epoch + 1}"):
        input_texts = input_texts.to(device)
        target_tokens = target_tokens.to(device)

        # Remove last token from input 
        input_texts = input_texts[:, :-1]

        # Add start token to targets
        start_token_expanded = start_token.expand(target_tokens.size(0), 1).to(device)
        target_with_sos = torch.cat([start_token_expanded, target_tokens], dim=1)

        optimizer.zero_grad()

        # Get embeddings without tracking gradients
        with torch.no_grad():
            input_embeds = embed_model(input_texts)['last_hidden_state'].detach()
            target_embeds = embed_model(target_with_sos)['last_hidden_state'].detach()

        # Forward pass
        logits = model(input_embeds, target_embeds, kb_keys=keys_embed)

        # Compute loss
        loss = criterion(logits.view(-1, logits.size(-1)), target_tokens.view(-1))
        loss.backward()
        optimizer.step()

        # Accumulate loss and accuracy
        total_train_loss += loss.item()
        preds = torch.argmax(logits, dim=-1)
        total_train_correct += (preds == target_tokens).float().sum().item()
        total_train_tokens += target_tokens.numel()

    # Epoch metrics
    avg_train_loss = total_train_loss / len(train_loader_db)
    train_accuracy = total_train_correct / total_train_tokens

    train_losses.append(avg_train_loss)
    train_accuracies.append(train_accuracy)

    # Save model checkpoint
    torch.save(model.state_dict(), f"../models/model_weight_kb.pth")

    # Log to CSV
    with open(monitor_csv, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch + 1, avg_train_loss, train_accuracy])

    print(f"Epoch {epoch + 1}: "
          f"Train Loss = {avg_train_loss:.4f}, "
          f"Train Acc = {train_accuracy:.4f}")

Training Epoch 1:   3%|▎         | 5162/164000 [1:05:50<33:20:33,  1.32it/s]