In [1]:
import json
import csv

In [2]:
from datasets import load_dataset
import pandas as pd
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, Dataset 
import torch
import torch.nn as nn
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
data = load_dataset("antareepdey/Patient_doctor_chat")
datasets = [i['Text'] for i in data['train']]
patient_query = []
doctor_response = []

for i in range(1000):
    inputs, outputs = datasets[i].split("###Output:")
    patient_query.append(inputs.replace("###Input:",""))
    doctor_response.append(outputs)
data = {"Patient query": patient_query, "Doctor response": doctor_response}
df = pd.DataFrame(data=data)
def preprocess_data(text):
    # preprocess
    text = text.lower()
    text = text.replace('?','')
    text = text.replace("'","")
    text = text.replace(","," ")
    text = text.replace("1)"," ")
    text = text.replace("2)"," ")
    text = text.replace("3)"," ")
    text = text.replace("4)"," ")
    text = text.replace("."," ")
    text = text.strip()
    return text

df['Patient query'] = df['Patient query'].apply(preprocess_data)
df['Doctor response'] = df['Doctor response'].apply(preprocess_data)

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
embed_model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')



In [6]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.max_length = 8192

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        question = self.data.iloc[index]['Patient query']
        answer = self.data.iloc[index]['Doctor response']

        question_ids = tokenizer(
            question,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )['input_ids'][0] 

        answer_ids = tokenizer(
            answer,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )['input_ids'][0]

        return question_ids, answer_ids


In [None]:
# def collate_fn(batch):
#     questions, answers = zip(*batch) 

#     padded_questions = pad_sequence(questions, batch_first=True, padding_value=tokenizer.pad_token_id)
#     padded_answers = pad_sequence(answers, batch_first=True, padding_value=tokenizer.pad_token_id)

#     return padded_questions, padded_answers

In [7]:
dataset = CustomDataset(df) 

In [8]:
train_loader = DataLoader( dataset, batch_size=1, shuffle=True)

In [9]:
with open('../appointments_data.json', 'r') as f:
    key_value_data = json.load(f)

In [10]:
key_value_data

{'Lauren_doctor': 'Dr. Wilson',
 'Lauren_appointment_date': '2026-03-27',
 'Lauren_checkup_department': 'Psychiatry',
 'Lauren_hospital_location': 'City Hospital',
 'Lauren_contact': '(557)935-7463x543',
 'Lauren_appointment_time': '16:25',
 'Amber_doctor': 'Dr. Wilson',
 'Amber_appointment_date': '2025-10-11',
 'Amber_checkup_department': 'Orthopedics',
 'Amber_hospital_location': 'City Hospital',
 'Amber_contact': '+1-502-840-1548x530',
 'Amber_appointment_time': '12:51',
 'Savannah_doctor': 'Dr. Anderson',
 'Savannah_appointment_date': '2026-05-28',
 'Savannah_checkup_department': 'Pediatrics',
 'Savannah_hospital_location': 'Mountain View Hospital',
 'Savannah_contact': '340.929.8569x91127',
 'Savannah_appointment_time': '01:55',
 'April_doctor': 'Dr. Smith',
 'April_appointment_date': '2026-04-11',
 'April_checkup_department': 'Dentist',
 'April_hospital_location': 'City Hospital',
 'April_contact': '001-328-824-4066x8146',
 'April_appointment_time': '18:01',
 'Laurie_doctor': 'Dr

In [11]:
with open('../qa_pairs.json' , 'r') as f:
    db_data = json.load(f)

In [12]:
patient , assistant = [ preprocess_data(i['question']) for i in db_data['Q&A'] ], [ preprocess_data(i['answer']) for i in db_data['Q&A'] ]


In [13]:
len(patient[-100])

31

In [15]:
new_key_value_data = {}
for key, value in key_value_data.items():
    new_key = preprocess_data(key)
    new_value = preprocess_data(value)
    new_key_value_data[new_key] = new_value


In [16]:
new_key_value_data

{'lauren_doctor': 'dr  wilson',
 'lauren_appointment_date': '2026-03-27',
 'lauren_checkup_department': 'psychiatry',
 'lauren_hospital_location': 'city hospital',
 'lauren_contact': '(557)935-7463x543',
 'lauren_appointment_time': '16:25',
 'amber_doctor': 'dr  wilson',
 'amber_appointment_date': '2025-10-11',
 'amber_checkup_department': 'orthopedics',
 'amber_hospital_location': 'city hospital',
 'amber_contact': '+1-502-840-1548x530',
 'amber_appointment_time': '12:51',
 'savannah_doctor': 'dr  anderson',
 'savannah_appointment_date': '2026-05-28',
 'savannah_checkup_department': 'pediatrics',
 'savannah_hospital_location': 'mountain view hospital',
 'savannah_contact': '340 929 8569x91127',
 'savannah_appointment_time': '01:55',
 'april_doctor': 'dr  smith',
 'april_appointment_date': '2026-04-11',
 'april_checkup_department': 'dentist',
 'april_hospital_location': 'city hospital',
 'april_contact': '001-328-824-4066x8146',
 'april_appointment_time': '18:01',
 'laurie_doctor': 'dr

In [17]:
assistant

['the appointment time of benjamin is 00:09',
 'the doctor for michael is dr  anderson',
 'the contact number for kenneth is 001-441-419-7256',
 'the appointment time of blake is 12:56',
 'brenda is having their appointment at mountain view hospital',
 'the contact number for savannah is 340 929 8569x91127',
 'the appointment time of sonia is 08:35',
 'the contact number for tyler is 938-908-4238',
 'the doctor for lisa is dr  johnson',
 'the doctor for robert is dr  johnson',
 'the contact number for yvette is 001-250-268-7144x3131',
 'anna is visiting the orthopedics department',
 'lisa is visiting the cardiology department',
 'the appointment time of kenneth is 02:52',
 'the appointment time of sue is 03:04',
 'sue is having their appointment at mountain view hospital',
 'the appointment time of lisa is 02:01',
 'jessica is visiting the radiology department',
 'james is visiting the orthopedics department',
 'the doctor for anna is dr  johnson',
 'the contact number for ricky is (95

In [18]:
def replace_values_with_keys(texts, mapping):
    for i, text in enumerate(texts):
        for key, val in mapping.items():
            name = key.split('_')[0]
            if name in text and val in text:
                texts[i] = text.replace(val, key)
                break
    return texts


In [19]:
assistant = replace_values_with_keys(assistant, new_key_value_data)

In [20]:
len(assistant)

4100

In [21]:
all_keys = [ i for i in new_key_value_data.keys()]

In [22]:
tokenizer.add_tokens(all_keys)

492

In [23]:
embed_model.resize_token_embeddings(len(tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(152161, 1024)

In [24]:
used_token_ids = set()
for target_text in all_keys:
    tokens = tokenizer.encode(target_text, add_special_tokens=False)
    used_token_ids.update(tokens)
    
token_id_list = sorted(list(used_token_ids))
old2new = {old: new for new, old in enumerate(token_id_list)}
new2old = {v: k for k, v in old2new.items()}

In [25]:
db_data_dict = {'Patient query': patient, 'Doctor response': assistant}
print(len(patient))
db_df = pd.DataFrame(db_data_dict)
db_df.head()

4100


Unnamed: 0,Patient query,Doctor response
0,what time is benjamins appointment,the appointment time of benjamin is benjamin_a...
1,who is michaels doctor,the doctor for michael is michael_doctor
2,what is the contact number for kenneth,the contact number for kenneth is kenneth_contact
3,what time is blakes appointment,the appointment time of blake is blake_appoint...
4,where is brenda having their appointment,brenda is having their appointment at brenda_h...


In [26]:
db_df['Patient query'] = db_df['Patient query'].apply(preprocess_data)
db_df['Doctor response'] = db_df['Doctor response'].apply(preprocess_data)

In [27]:
db_df.head()

Unnamed: 0,Patient query,Doctor response
0,what time is benjamins appointment,the appointment time of benjamin is benjamin_a...
1,who is michaels doctor,the doctor for michael is michael_doctor
2,what is the contact number for kenneth,the contact number for kenneth is kenneth_contact
3,what time is blakes appointment,the appointment time of blake is blake_appoint...
4,where is brenda having their appointment,brenda is having their appointment at brenda_h...


In [28]:
db_df.shape

(4100, 2)

In [29]:
dataset_db = CustomDataset(db_df)

In [30]:
dataset_db.__len__()

4100

In [31]:
len(dataset_db)

4100

In [32]:
train_set, val_set = torch.utils.data.random_split(dataset_db, [3500, 600])

In [33]:
train_loader_db = DataLoader(train_set , batch_size=1, shuffle=True)
val_loader_db = DataLoader(val_set , batch_size=1, shuffle=True)

In [34]:
for x, y in train_loader_db:
    print(x.shape)
    break

torch.Size([1, 9])


In [35]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [36]:
embed_model = embed_model.to(device)

In [37]:
keys_embed = []
for key in all_keys:
    key = tokenizer(key, return_tensors='pt')['input_ids'][:,:-1]
    key_embed = embed_model(key.to(device), )['last_hidden_state'][0]
    keys_embed.append(key_embed)

keys_embed = torch.stack(keys_embed, dim=0).squeeze(1).detach()

In [None]:
class EncoderDecoderWithKBAttention(nn.Module):
    def __init__(self, embed_model, vocab_size, hidden_dim, embedding_dim=1024):
        super().__init__()
        self.embed_model = embed_model
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, dropout=0.2, batch_first=True, bidirectional=True)
        self.decoder = nn.LSTMCell(embedding_dim, hidden_dim * 2)

        # Attention over encoder outputs
        self.W_denc1 = nn.Linear(hidden_dim * 4, hidden_dim)
        self.W_denc2 = nn.Linear(hidden_dim, hidden_dim)
        self.w = nn.Linear(hidden_dim, 1)

        # Attention over KB keys (commented out, unused)
        self.W_kb1 = nn.Linear(hidden_dim * 2, embedding_dim)

        # Final vocab projection
        self.U = nn.Linear(hidden_dim * 4, vocab_size)

        self.vocab_size = vocab_size

    def decoder_attention(self, decoder_hidden, encoder_hidden):
        B, T, H = encoder_hidden.shape                               # B=batch, T=encoder seq len, H=hidden size
        decoder_exp = decoder_hidden.unsqueeze(1).expand(-1, T, -1)  # (B, T, H)
        combined = torch.cat([encoder_hidden, decoder_exp], dim=2)   # (B, T, 2H)
        x = torch.tanh(self.W_denc1(combined))                       # (B, T, H)
        x = torch.tanh(self.W_denc2(x))                              # (B, T, H)
        u = self.w(x).squeeze(-1)                                    # (B, T)
        attn = torch.softmax(u, dim=1)                               # (B, T)

        context = torch.bmm(attn.unsqueeze(1), encoder_hidden).squeeze(1)  # (B, H)
        concat = torch.cat([decoder_hidden, context], dim=1)          # (B, 2H)
        vocab_logits = self.U(concat)                                 # (B, vocab_size)
        return vocab_logits
    

    
    def kb_attention(self, decoder_hidden, kb_keys):
        B, _ = decoder_hidden.shape # (B,H)
        x = torch.tanh(self.W_kb1(decoder_hidden))  #(B, E)
       
        x_norm = F.normalize(x, p=2, dim=1)              #(B, E)
        kb_key_norm = F.normalize(kb_keys, p=2, dim=1)    #(N, E)
       

        # similarity
        cosine_sim = torch.matmul(x_norm, kb_key_norm.T)  #(B, N)
        cosine_sim = cosine_sim.squeeze(0)
        
        kb_atten_score_dict = {new2old[idx]:sim for idx, sim in enumerate(cosine_sim)} #(token_id, similarity)
   
        kb_attention = torch.zeros(B, self.vocab_size).to('cuda')


        for idx, val in kb_atten_score_dict.items():
            kb_attention[:, idx] += val

        return kb_attention
    
    


    def forward(self, inputs, targets, kb_keys, teacher_forcing_ratio=1.0):
    

        n_t = len(targets)

        enc_out, (h_enc, c_enc) = self.encoder(inputs)  # enc_out: (B, T_in, H)
        hidden = torch.cat([h_enc[0], h_enc[-1]], dim=-1)  # (B, hidden_dim*2)
        cell = torch.cat([c_enc[0], c_enc[-1]], dim=-1)    # (B, hidden_dim*2)

        logits = []

        dec_input = targets[0][0]  

        for t in range(1, n_t):

            
            hidden_state, cell_state = self.decoder(dec_input, (hidden, cell))
            hidden = hidden_state
            cell = cell_state
        
            
            hidden_logits = self.decoder_attention(hidden_state, enc_out)
            
            kb_attention_logits = self.kb_attention(hidden_state, kb_keys).to('cuda')
            #print(kb_attention_logits.shape)
            hidden_logits = kb_attention_logits + hidden_logits
            #print(kb_attention_logits.shape)
       
            logits.append(hidden_logits)

            # Decide whether to do teacher forcing this step
            use_teacher_forcing = (torch.rand(1).item() < teacher_forcing_ratio)

            if use_teacher_forcing:
                # Use ground-truth target embedding for next input
                dec_input = targets[t][0]
            else:
                # Use model prediction:
                pred_tokens = torch.argmax(hidden_logits, dim=1)  # (B,)
                with torch.no_grad():
                    # Get embeddings for predicted tokens
                    # Assuming embed_model accepts token IDs and returns embeddings
                    dec_input = self.embed_model(pred_tokens.unsqueeze(0))['last_hidden_state'][0].detach()

        return torch.stack(logits, dim=1)  # (B, T_out-1, vocab_size)


In [44]:
use_teacher_forcing = (torch.rand(1).item() < 0.5)
use_teacher_forcing 

True

In [39]:
start_token = tokenizer('<|im_start|>', return_tensors='pt')['input_ids'][:,0].unsqueeze(0).to(device)

In [40]:
vocab_size = len(tokenizer.vocab)
model = EncoderDecoderWithKBAttention(embed_model, vocab_size, 320, embedding_dim=1024).to(device)




In [81]:

tokens = []

for input_texts, target_tokens in train_loader_db:
    input_texts = input_texts.to(device)
    target_tokens = target_tokens.to(device)
      
    # Add start token to targets
    start_token_expanded = start_token.expand(target_tokens.size(0), 1).to(device)
    target_with_sos = torch.cat([start_token_expanded, target_tokens], dim=1)


    with torch.no_grad():
        input_embeds = embed_model(input_texts)['last_hidden_state'].detach()

        target_embeds = [embed_model(i.unsqueeze(0).unsqueeze(0))['last_hidden_state'] for i in target_with_sos[0]]
    
    print(target_embeds)

    print("target_tokens:", target_tokens)
    #print("target", target_embeds)
    # Forward pass
    logits = model(input_embeds, target_embeds, kb_keys=keys_embed, teacher_forcing_ratio=0)
       
    preds = torch.argmax(logits, dim=-1)
    
    print("pred:", preds)
    
    
    break

[tensor([[[  0.8472,  -4.0913,   0.2632,  ...,  -4.2332, -10.3624,   0.6719]]],
       device='cuda:0'), tensor([[[  1.9247, -17.8928,  -0.0284,  ...,  -7.3079, -12.2634,   1.2074]]],
       device='cuda:0'), tensor([[[  3.1244, -13.9529,  -0.0792,  ...,  -7.4514, -13.7968,   0.6089]]],
       device='cuda:0'), tensor([[[ 1.9287e+00, -1.0204e+01,  1.3403e-02,  ..., -7.2744e+00,
          -1.3729e+01,  6.2118e-01]]], device='cuda:0'), tensor([[[  1.6676,  -8.5780,   0.1047,  ...,  -7.5663, -12.6896,   0.8222]]],
       device='cuda:0'), tensor([[[  2.3460,   7.2528,  -0.1433,  ...,  -7.3429, -13.5800,   0.1104]]],
       device='cuda:0'), tensor([[[  2.9029,  -4.5368,  -0.0450,  ...,  -7.9930, -13.2434,   0.4468]]],
       device='cuda:0'), tensor([[[  1.0000, -10.6392,   0.0482,  ...,  -7.5135, -12.0549,   1.1845]]],
       device='cuda:0'), tensor([[[  4.0107, -15.3276,  -0.0967,  ...,  -8.3391, -13.3001,   0.3539]]],
       device='cuda:0'), tensor([[[-1.4002, 39.6511,  0.5486,  ...,

In [57]:
for input , target in train_loader_db:
    input = input.to(device)
    target = target.to(device)
    input = input[:, :-1]
    target_with_sos = torch.cat([start_token, target], dim=1)
    with torch.no_grad():
        input_embeds = embed_model(input)['last_hidden_state']
        target_embeds = [embed_model(i.unsqueeze(0).unsqueeze(0))['last_hidden_state'] for i in target_with_sos[0]]
    
    logits = model(input_embeds, target_embeds,  kb_keys=keys_embed, teacher_forcing_ratio=0.1)


    break

torch.Size([1, 1024])
torch.Size([1, 1024])
torch.Size([1, 1024])
torch.Size([1, 1024])
torch.Size([1, 1024])
torch.Size([1, 1024])
torch.Size([1, 1024])
torch.Size([1, 1024])
torch.Size([1, 1024])


In [67]:
learning_rate = 0.001
epochs = 25
criterion = nn.CrossEntropyLoss()
#weight_decay = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [69]:
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

monitor_csv = 'kb_monitor.csv'
best_val_loss = float('inf')  # For saving best model

# CSV header
with open(monitor_csv, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['epoch', 'train_loss', 'train_accuracy', 'val_loss', 'val_accuracy'])

for epoch in range(epochs):
    model.train()
    total_train_loss = 0.0
    total_train_correct = 0
    total_train_tokens = 0

    for input_texts, target_tokens in tqdm(train_loader_db, desc=f"Training Epoch {epoch + 1}"):
        input_texts = input_texts.to(device)
        target_tokens = target_tokens.to(device)

        input_texts = input_texts[:, :-1]
        start_token_expanded = start_token.expand(target_tokens.size(0), 1).to(device)
        target_with_sos = torch.cat([start_token_expanded, target_tokens], dim=1)

        optimizer.zero_grad()

        with torch.no_grad():
            input_embeds = embed_model(input_texts)['last_hidden_state'].detach()
            target_embeds = [embed_model(i.unsqueeze(0).unsqueeze(0))['last_hidden_state'].detach() for i in target_with_sos[0]]
        if epoch > 5:
            logits = model(input_embeds, target_embeds, kb_keys=keys_embed, teacher_forcing_ratio=0.5)
        else:
            logits = model(input_embeds, target_embeds, kb_keys=keys_embed, teacher_forcing_ratio=0.9)
        loss = criterion(logits.view(-1, logits.size(-1)), target_tokens.view(-1))
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        preds = torch.argmax(logits, dim=-1)
        total_train_correct += (preds == target_tokens).float().sum().item()
        total_train_tokens += target_tokens.numel()

    avg_train_loss = total_train_loss / len(train_loader_db)
    train_accuracy = total_train_correct / total_train_tokens

    train_losses.append(avg_train_loss)
    train_accuracies.append(train_accuracy)

    model.eval()
    total_val_loss = 0.0
    total_val_correct = 0
    total_val_tokens = 0

    with torch.no_grad():
        for val_input_texts, val_target_tokens in tqdm(val_loader_db, desc="Validation: "):
            val_input_texts = val_input_texts.to(device)
            val_target_tokens = val_target_tokens.to(device)

            val_input_texts = val_input_texts[:, :-1]
            val_start_token_expanded = start_token.expand(val_target_tokens.size(0), 1).to(device)
            val_target_with_sos = torch.cat([val_start_token_expanded, val_target_tokens], dim=1)

            val_input_embeds = embed_model(val_input_texts)['last_hidden_state'].detach()
            val_target_embeds = [embed_model(i.unsqueeze(0).unsqueeze(0))['last_hidden_state'].detach() for i in val_target_with_sos[0]]
            if epoch > 5:
                val_logits = model(val_input_embeds, val_target_embeds, kb_keys=keys_embed, teacher_forcing_ratio=0.5)
            else:
                val_logits = model(val_input_embeds, val_target_embeds, kb_keys=keys_embed, teacher_forcing_ratio=0.9)

            val_loss = criterion(val_logits.view(-1, val_logits.size(-1)), val_target_tokens.view(-1))

            total_val_loss += val_loss.item()
            val_preds = torch.argmax(val_logits, dim=-1)
            total_val_correct += (val_preds == val_target_tokens).float().sum().item()
            total_val_tokens += val_target_tokens.numel()

    avg_val_loss = total_val_loss / len(val_loader_db)
    val_accuracy = total_val_correct / total_val_tokens

    val_losses.append(avg_val_loss)
    val_accuracies.append(val_accuracy)

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), f"../models/best_model_weight_kb.pth")

    # Save current model checkpoint
    torch.save(model.state_dict(), f"../models/model_weight_kb_epoch_{epoch+1}.pth")

    # Log all metrics
    with open(monitor_csv, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch + 1, avg_train_loss, train_accuracy, avg_val_loss, val_accuracy])

    print(f"Epoch {epoch + 1}: "
          f"Train Loss = {avg_train_loss:.4f}, Train Acc = {train_accuracy:.4f} | "
          f"Val Loss = {avg_val_loss:.4f}, Val Acc = {val_accuracy:.4f}")


Training Epoch 1: 100%|██████████| 3500/3500 [52:10<00:00,  1.12it/s]
Validation: 100%|██████████| 600/600 [03:16<00:00,  3.06it/s]


Epoch 1: Train Loss = 0.3742, Train Acc = 0.8966 | Val Loss = 0.3359, Val Acc = 0.9136


Training Epoch 2: 100%|██████████| 3500/3500 [50:55<00:00,  1.15it/s]
Validation: 100%|██████████| 600/600 [03:18<00:00,  3.02it/s]


Epoch 2: Train Loss = 0.2244, Train Acc = 0.9356 | Val Loss = 0.1521, Val Acc = 0.9591


Training Epoch 3: 100%|██████████| 3500/3500 [51:38<00:00,  1.13it/s]
Validation: 100%|██████████| 600/600 [03:17<00:00,  3.04it/s]


Epoch 3: Train Loss = 0.2379, Train Acc = 0.9343 | Val Loss = 0.1719, Val Acc = 0.9541


Training Epoch 4: 100%|██████████| 3500/3500 [52:58<00:00,  1.10it/s]
Validation: 100%|██████████| 600/600 [03:21<00:00,  2.97it/s]


Epoch 4: Train Loss = 0.1303, Train Acc = 0.9644 | Val Loss = 0.1102, Val Acc = 0.9710


Training Epoch 5: 100%|██████████| 3500/3500 [52:36<00:00,  1.11it/s]
Validation: 100%|██████████| 600/600 [03:20<00:00,  3.00it/s]


Epoch 5: Train Loss = 0.1077, Train Acc = 0.9718 | Val Loss = 0.0694, Val Acc = 0.9817


Training Epoch 6: 100%|██████████| 3500/3500 [51:55<00:00,  1.12it/s]
Validation: 100%|██████████| 600/600 [03:18<00:00,  3.02it/s]


Epoch 6: Train Loss = 0.0883, Train Acc = 0.9762 | Val Loss = 0.1211, Val Acc = 0.9682


Training Epoch 7: 100%|██████████| 3500/3500 [56:14<00:00,  1.04it/s] 
Validation: 100%|██████████| 600/600 [04:07<00:00,  2.42it/s]


Epoch 7: Train Loss = 0.1198, Train Acc = 0.9706 | Val Loss = 0.1848, Val Acc = 0.9496


Training Epoch 8: 100%|██████████| 3500/3500 [55:59<00:00,  1.04it/s] 
Validation: 100%|██████████| 600/600 [04:02<00:00,  2.48it/s]


Epoch 8: Train Loss = 0.1166, Train Acc = 0.9698 | Val Loss = 0.1634, Val Acc = 0.9603


Training Epoch 9: 100%|██████████| 3500/3500 [55:05<00:00,  1.06it/s] 
Validation: 100%|██████████| 600/600 [03:59<00:00,  2.51it/s]


Epoch 9: Train Loss = 0.1061, Train Acc = 0.9740 | Val Loss = 0.1167, Val Acc = 0.9708


Training Epoch 10: 100%|██████████| 3500/3500 [55:15<00:00,  1.06it/s]
Validation: 100%|██████████| 600/600 [04:02<00:00,  2.48it/s]


Epoch 10: Train Loss = 0.0857, Train Acc = 0.9784 | Val Loss = 0.1084, Val Acc = 0.9757


Training Epoch 11: 100%|██████████| 3500/3500 [57:36<00:00,  1.01it/s] 
Validation: 100%|██████████| 600/600 [04:09<00:00,  2.41it/s]


Epoch 11: Train Loss = 0.1047, Train Acc = 0.9732 | Val Loss = 0.1375, Val Acc = 0.9654


Training Epoch 12: 100%|██████████| 3500/3500 [55:37<00:00,  1.05it/s] 
Validation: 100%|██████████| 600/600 [04:01<00:00,  2.48it/s]


Epoch 12: Train Loss = 0.0934, Train Acc = 0.9756 | Val Loss = 0.0315, Val Acc = 0.9943


Training Epoch 13: 100%|██████████| 3500/3500 [54:59<00:00,  1.06it/s]
Validation: 100%|██████████| 600/600 [03:56<00:00,  2.54it/s]


Epoch 13: Train Loss = 0.0753, Train Acc = 0.9812 | Val Loss = 0.0846, Val Acc = 0.9803


Training Epoch 14: 100%|██████████| 3500/3500 [56:40<00:00,  1.03it/s] 
Validation: 100%|██████████| 600/600 [04:09<00:00,  2.41it/s]


Epoch 14: Train Loss = 0.0536, Train Acc = 0.9874 | Val Loss = 0.0778, Val Acc = 0.9814


Training Epoch 15: 100%|██████████| 3500/3500 [57:04<00:00,  1.02it/s] 
Validation: 100%|██████████| 600/600 [04:05<00:00,  2.45it/s]


Epoch 15: Train Loss = 0.0622, Train Acc = 0.9844 | Val Loss = 0.0588, Val Acc = 0.9873


Training Epoch 16: 100%|██████████| 3500/3500 [59:22<00:00,  1.02s/it]  
Validation: 100%|██████████| 600/600 [04:10<00:00,  2.39it/s]


Epoch 16: Train Loss = 0.0578, Train Acc = 0.9855 | Val Loss = 0.0473, Val Acc = 0.9917


Training Epoch 17: 100%|██████████| 3500/3500 [56:40<00:00,  1.03it/s] 
Validation: 100%|██████████| 600/600 [04:04<00:00,  2.45it/s]


Epoch 17: Train Loss = 0.0575, Train Acc = 0.9860 | Val Loss = 0.0143, Val Acc = 0.9965


Training Epoch 18: 100%|██████████| 3500/3500 [56:45<00:00,  1.03it/s] 
Validation: 100%|██████████| 600/600 [04:13<00:00,  2.36it/s]


Epoch 18: Train Loss = 0.0596, Train Acc = 0.9862 | Val Loss = 0.0312, Val Acc = 0.9911


Training Epoch 19: 100%|██████████| 3500/3500 [57:20<00:00,  1.02it/s] 
Validation: 100%|██████████| 600/600 [04:09<00:00,  2.41it/s]


Epoch 19: Train Loss = 0.0570, Train Acc = 0.9864 | Val Loss = 0.1386, Val Acc = 0.9677


Training Epoch 20:  62%|██████▏   | 2172/3500 [35:48<21:53,  1.01it/s] 


KeyboardInterrupt: 

In [82]:
input = "What time is Edwin's appointment?"
input = preprocess_data(input)
start_token = tokenizer(
        '<|im_start|>',
        max_length= 8192,
        return_tensors="pt",
    )['input_ids'][:,0].unsqueeze(0).to(device)
input_tokens = tokenizer(
        input,
        max_length= 8192,
        return_tensors="pt",
    )['input_ids'][:, :-1].to(device)

embed_model = embed_model.to(device)

embed_query = embed_model(input_tokens)['last_hidden_state'].to(device)
target = embed_model(start_token)['last_hidden_state'].to(device)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [83]:
target

tensor([[[  0.8472,  -4.0913,   0.2632,  ...,  -4.2332, -10.3624,   0.6719]]],
       device='cuda:0', grad_fn=<MulBackward0>)

In [84]:
enc_out, (h_enc, c_enc) = model.encoder(embed_query.to(device))
hidden = torch.cat([h_enc[0], h_enc[1]], dim=-1)  
cell   = torch.cat([c_enc[0], c_enc[1]], dim=-1) 

In [85]:
input_emb = embed_model(start_token.to(device))['last_hidden_state'][:, 0, :]

In [93]:
tokenizer.eos_token_id

151645

In [79]:
state_dict = torch.load("../models/model_weight_kb_epoch_19.pth")
load_result = model.load_state_dict(state_dict)

In [94]:
def inference(model=model, tokenizer=tokenizer, embed_model=embed_model,preprocess_data=preprocess_data, keys_embed=keys_embed, device=device, ):
    query = input("Enter what is you query? ")
    start_token = tokenizer(
                '<|im_start|>',
                max_length= 8192,
                return_tensors="pt",
            )['input_ids'][:,0].unsqueeze(0).to(device)
    embed_model = embed_model.to(device)
    start_embed = embed_model(start_token)['last_hidden_state'][:, 0, :].to(device)
    input_tokens = tokenizer(
        query,
        max_length= 8192,
        return_tensors="pt",
        )['input_ids'][:, :-1].to(device)
    embed_query = embed_model(input_tokens)['last_hidden_state']
    enc_out, (h_enc, c_enc) = model.encoder(embed_query)
    hidden = torch.cat([h_enc[0], h_enc[1]], dim=-1)  
    cell   = torch.cat([c_enc[0], c_enc[1]], dim=-1) 
    tokens = []
    stop_token = -1
    input_emb = start_embed
   
   
    for i in range(20):
        
        hidden, cell = model.decoder(input_emb, (hidden, cell))

        vocab_logits = model.decoder_attention(hidden, enc_out)  

        kb_attention_logits = model.kb_attention(hidden, keys_embed).to('cuda')
                
        hidden_logits = kb_attention_logits + vocab_logits
        pred_tokens = torch.argmax(hidden_logits, dim=1) 
        stop_token = pred_tokens.item()
        if pred_tokens.item() == 151643:
            break
        tokens.append(pred_tokens)

        input_emb = embed_model(pred_tokens.unsqueeze(0))['last_hidden_state'][0]


    pred_tokens = tokenizer.decode(torch.stack(tokens,  dim=1)[0])

    text = ""
    for id , token in enumerate(pred_tokens.split()):
        if token in all_keys:
            text += new_key_value_data[token]
        else:
            if id != 0:
                if pred_tokens.split()[id -1 ] != token:
                    text += token
            else:       
                text += token
        text += " "

    

    return text

    



In [95]:
inference()

'ariana is visiting is pediatrics department '