In [8]:
import json

In [1]:
from datasets import load_dataset
import pandas as pd
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, Dataset 
import torch
import torch.nn as nn
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data = load_dataset("antareepdey/Patient_doctor_chat")
datasets = [i['Text'] for i in data['train']]
patient_query = []
doctor_response = []

for i in range(1000):
    inputs, outputs = datasets[i].split("###Output:")
    patient_query.append(inputs.replace("###Input:",""))
    doctor_response.append(outputs)
data = {"Patient query": patient_query, "Doctor response": doctor_response}
df = pd.DataFrame(data=data)
def preprocess_data(text):
    # preprocess
    text = text.lower()
    text = text.replace('?','')
    text = text.replace("'","")
    text = text.replace(","," ")
    text = text.replace("1)"," ")
    text = text.replace("2)"," ")
    text = text.replace("3)"," ")
    text = text.replace("4)"," ")
    text = text.replace("."," ")
    text = text.strip()
    return text

df['Patient query'] = df['Patient query'].apply(preprocess_data)
df['Doctor response'] = df['Doctor response'].apply(preprocess_data)

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
embed_model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')



In [4]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.max_length = 8192

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        question = df.iloc[index]['Patient query']
        answer = df.iloc[index]['Doctor response']

        question_ids = tokenizer(
            question,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )['input_ids'][0] 

        answer_ids = tokenizer(
            answer,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )['input_ids'][0]

        return question_ids, answer_ids


In [5]:
def collate_fn(batch):
    questions, answers = zip(*batch) 

    padded_questions = pad_sequence(questions, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_answers = pad_sequence(answers, batch_first=True, padding_value=tokenizer.pad_token_id)

    return padded_questions, padded_answers

In [6]:
dataset = CustomDataset(df) 

In [7]:
train_loader = DataLoader( dataset, batch_size=8, collate_fn=collate_fn, shuffle=True)

In [None]:
class EncoderDecoderWithKBAttention(nn.Module):
    def __init__(self, vocab_size, hidden_dim, embedding_dim=1024):
        super().__init__()
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.decoder = nn.LSTMCell(embedding_dim, hidden_dim*2)

        # Attention over encoder outputs
        self.W_denc1 = nn.Linear(hidden_dim * 4, hidden_dim)
        self.W_denc2 = nn.Linear(hidden_dim, hidden_dim)
        self.w = nn.Linear(hidden_dim, 1)

        # Attention over KB keys
        self.W_kb1 = nn.Linear(hidden_dim + embedding_dim, hidden_dim)
        self.W_kb2 = nn.Linear(hidden_dim, hidden_dim)
        self.r = nn.Linear(hidden_dim, 1)

        # Final vocab projection
        self.U = nn.Linear(hidden_dim * 4, vocab_size)

        self.vocab_size = vocab_size
    

    def decoder_attention(self, decoder_hidden, encoder_hidden):
        #print(encoder_hidden.shape)
        B, T, H = encoder_hidden.shape
        decoder_exp = decoder_hidden.unsqueeze(1).expand(-1, T, -1)     # (B, T, H)
        #print(decoder_exp.shape)
        combined = torch.cat([encoder_hidden, decoder_exp], dim=2)     # (B, T, 2H)
        #print(combined.shape)
        x = torch.tanh(self.W_denc1(combined))                          # (B, T, H)
        #print(x.shape)
        x = torch.tanh(self.W_denc2(x))                                 # (B, T, H)
        #print(x.shape)
        u = self.w(x).squeeze(-1)                                       # (B, T)
        #print(u.shape)
        attn = torch.softmax(u, dim=1)                                  # (B, T)
        #print(attn.shape)

        context = torch.bmm(attn.unsqueeze(1), encoder_hidden).squeeze(1)  # (B, H)
        #print(context.shape)
        concat = torch.cat([decoder_hidden, context], dim=1)           # (B, 2H)
        #print(concat.shape)
        vocab_logits = self.U(concat)                                   # (B, vocab_size)
        return vocab_logits

    def kb_attention(self, decoder_hidden, kb_keys):
        B, H = decoder_hidden.shape
        num_kb = kb_keys.shape[0]

        decoder_exp = decoder_hidden.unsqueeze(1).expand(-1, num_kb, -1)        # (B, num_kb, H)
        kb_keys_exp = kb_keys.squeeze(1).unsqueeze(0).expand(B, -1, -1)         # (B, num_kb, E)

        combined = torch.cat([kb_keys_exp, decoder_exp], dim=-1)                # (B, num_kb, H+E)
        x = torch.tanh(self.W_kb1(combined))                                    # (B, num_kb, H)
        
        x = torch.tanh(self.W_kb2(x))                                           # (B, num_kb, H)
        scores = self.r(x).squeeze(-1)                                          # (B, num_kb)
        return scores

    def forward(self, inputs, targets, kb_keys=None, teacher_forcing=True, fine_tune=False):
        #print(targets.shape)
        batch, T_out , _ = targets.shape
        
        enc_out, (h_enc, c_enc) = self.encoder(inputs)             # enc_out: (B, T_in, H)
        # print("enc out", enc_out.shape)
        # print("h_enc",h_enc.shape)
        # print("c_enc", c_enc.shape)
        #print(h_enc[0,:,:].unsqueeze(0).shape)
        hidden, cell = torch.cat([h_enc[0],h_enc[-1]], dim=-1), torch.cat([c_enc[0], c_enc[1]], dim=-1)
        #print(hidden.shape)
       
        #print("start",start_token)
        logits = []
        
        #print(targets[0].shape)
        dec_input = targets[:,0,:]                   # (B, E)

        #print("dec_input",dec_input[-1].unsqueeze(0))
        #print("hidden_cell", hidden.shape)
        for t in range(1, T_out):
            #print("incode input shape", dec_input.shape)
            hidden_logits, cell = self.decoder(dec_input, (hidden, cell))      # (B, H)
            #print("hiddne_logits", hidden_logits.shape)
            #print(enc_out.shape)
            hidden_logits = self.decoder_attention(hidden_logits, enc_out)      # (B, V)
            #print(hidden_logits)
            if fine_tune:
                kb_logits = self.kb_attention(hidden, kb_keys)              # (B, n)
                
                hidden_logits = torch.cat([hidden_logits, kb_logits], dim=1)  # (B, V+n)
                
            logits.append(hidden_logits)  
            #print(hidden_logits.shape)
            

            if teacher_forcing:
                #target = targets[:, t].unsqueeze(0)
                #dec_input = self.embedding(target)['last_hidden_state'][0]
                #print(dec_input.shape)
                dec_input = targets[:,t,:]
            else:
                pred = torch.argmax(hidden_logits, dim=1)
                pred = torch.where(pred < self.vocab_size, pred, torch.tensor(0).to(pred.device))
                dec_input = self.embedding(pred)

        return torch.stack(logits, dim=1)  # (B, T_out-1, V+n)

In [9]:
with open('../appointments_data.json', 'r') as f:
    key_value_data = json.load(f)


In [10]:
key_value_data

{'Lauren_doctor': 'Dr. Wilson',
 'Lauren_appointment_date': '2026-03-27',
 'Lauren_checkup_department': 'Psychiatry',
 'Lauren_hospital_location': 'City Hospital',
 'Lauren_contact': '(557)935-7463x543',
 'Lauren_appointment_time': '16:25',
 'Amber_doctor': 'Dr. Wilson',
 'Amber_appointment_date': '2025-10-11',
 'Amber_checkup_department': 'Orthopedics',
 'Amber_hospital_location': 'City Hospital',
 'Amber_contact': '+1-502-840-1548x530',
 'Amber_appointment_time': '12:51',
 'Savannah_doctor': 'Dr. Anderson',
 'Savannah_appointment_date': '2026-05-28',
 'Savannah_checkup_department': 'Pediatrics',
 'Savannah_hospital_location': 'Mountain View Hospital',
 'Savannah_contact': '340.929.8569x91127',
 'Savannah_appointment_time': '01:55',
 'April_doctor': 'Dr. Smith',
 'April_appointment_date': '2026-04-11',
 'April_checkup_department': 'Dentist',
 'April_hospital_location': 'City Hospital',
 'April_contact': '001-328-824-4066x8146',
 'April_appointment_time': '18:01',
 'Laurie_doctor': 'Dr

In [11]:
with open('../qa_pairs.json' , 'r') as f:
    db_data = json.load(f)

In [31]:
patient , assistant = [ preprocess_data(i['question']) for i in db_data['Q&A'] ], [ preprocess_data(i['answer']) for i in db_data['Q&A'] ]


In [35]:
new_key_value_data = {}
for key, value in key_value_data.items():
    new_key = preprocess_data(key)
    new_value = preprocess_data(value)
    new_key_value_data[new_key] = new_value


In [36]:
new_key_value_data

{'jill_appointment_time': '21:07',
 'brooke_doctor': 'dr  roshan',
 'brooke_appointment_date': '2026-01-02',
 'brooke_checkup_department': 'dentist',
 'brooke_hospital_location': 'city hospital',
 'brooke_contact': '766-399-5098',
 'brooke_appointment_time': '19:19',
 'alexander_doctor': 'dr  taylor',
 'alexander_appointment_date': '2026-04-03',
 'alexander_checkup_department': 'neurology',
 'alexander_hospital_location': 'sunshine medical center',
 'alexander_contact': '555 944 7898',
 'alexander_appointment_time': '12:36',
 'samantha_doctor': 'dr  smith',
 'samantha_appointment_date': '2025-08-11',
 'samantha_checkup_department': 'neurology',
 'samantha_hospital_location': 'city hospital',
 'samantha_contact': '447-829-0210x8955',
 'samantha_appointment_time': '08:45',
 'tony_doctor': 'dr  roshan',
 'tony_appointment_date': '2026-02-27',
 'tony_checkup_department': 'dermatology',
 'tony_hospital_location': 'sunshine medical center',
 'tony_contact': '763 710 0415',
 'tony_appointment

In [43]:
def replace_values_with_keys(texts, dict):
    for i, text in enumerate(texts):
        for key, val in dict.items():
            key = key.split('_')[0]
            print(key)
            if key in text:
                texts[i] = text.replace(val, key)
    return texts

In [None]:
assistant = replace_values_with_keys(assistant[0:3], new_key_value_data)

In [42]:
assistant

['ryan is visiting the dentist department',
 'the appointment time of linda is Linda_appointment_time',
 'the appointment time of jason is Jason_appointment_time',
 'meghan is visiting the dermatology department',
 'jennifer is having their appointment at city hospital',
 'the doctor for timothy is dr  johnson',
 'ashley is having their appointment at sunshine medical center',
 'the appointment time of jill is Jill_appointment_time',
 'amber is visiting the orthopedics department',
 'lauren is visiting the psychiatry department',
 'andrew is visiting the neurology department',
 'the contact number for brenda is 744 670 0051x238',
 'jennifer is visiting the radiology department',
 'sherry is visiting the dermatology department',
 'the appointment time of sonia is Sonia_appointment_time',
 'the contact number for maria is Maria_contact',
 'heather is having their appointment at global care hospital',
 'the doctor for laurie is dr  smith',
 'the appointment time of blake is Blake_appointm