In [2]:
from datasets import load_dataset
import pandas as pd
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, Dataset 
import torch
import torch.nn as nn
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data = load_dataset("antareepdey/Patient_doctor_chat")
datasets = [i['Text'] for i in data['train']]
patient_query = []
doctor_response = []
patient_query_val = []
doctor_response_val = []

for i in range(1000):
    inputs, outputs = datasets[i].split("###Output:")
    patient_query.append(inputs.replace("###Input:",""))
    doctor_response.append(outputs)
for i in range(1000,1200):
    inputs, outputs = datasets[i].split("###Output:")
    patient_query_val.append(inputs.replace("###Input:",""))
    doctor_response_val.append(outputs)

data = {"Patient query": patient_query, "Doctor response": doctor_response}
val_data = {"Patient query": patient_query_val, "Doctor response": doctor_response_val}
df = pd.DataFrame(data=data)
val_df = pd.DataFrame(data=val_data)
def preprocess_data(text):
    # preprocess
    text = text.lower()
    text = text.replace('?','')
    text = text.replace("'","")
    text = text.replace(","," ")
    text = text.replace("1)"," ")
    text = text.replace("2)"," ")
    text = text.replace("3)"," ")
    text = text.replace("4)"," ")
    text = text.replace("."," ")
    text = text.strip()
    return text

df['Patient query'] = df['Patient query'].apply(preprocess_data)
df['Doctor response'] = df['Doctor response'].apply(preprocess_data)

val_df['Patient query'] = val_df['Patient query'].apply(preprocess_data)
val_df['Doctor response'] = val_df['Doctor response'].apply(preprocess_data)

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
embed_model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')


In [4]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.max_length = 8192

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        question = df.iloc[index]['Patient query']
        answer = df.iloc[index]['Doctor response']

        question_ids = tokenizer(
            question,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )['input_ids'][0] 

        answer_ids = tokenizer(
            answer,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )['input_ids'][0]

        return question_ids, answer_ids


In [4]:
def collate_fn(batch):
    questions, answers = zip(*batch) 

    padded_questions = pad_sequence(questions, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_answers = pad_sequence(answers, batch_first=True, padding_value=tokenizer.pad_token_id)

    return padded_questions, padded_answers


In [5]:
dataset = CustomDataset(df) 
val_dataset = CustomDataset(val_df)

In [6]:
train_loader = DataLoader( dataset, batch_size=1, shuffle=True)
val_loader = DataLoader( val_dataset, batch_size=1,  shuffle=True)

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
class EncoderDecoderWithKBAttention(nn.Module):
    def __init__(self, embed_model , vocab_size, hidden_dim, embedding_dim=1024):
        super().__init__()
        self.embed_model = embed_model
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.decoder = nn.LSTMCell(embedding_dim, hidden_dim*2)

        # Attention over encoder outputs
        self.W_denc1 = nn.Linear(hidden_dim * 4, hidden_dim)
        self.W_denc2 = nn.Linear(hidden_dim, hidden_dim)
        self.w = nn.Linear(hidden_dim, 1)

        # Attention over KB keys
        self.W_kb1 = nn.Linear(hidden_dim + embedding_dim, hidden_dim)
        self.W_kb2 = nn.Linear(hidden_dim, hidden_dim)
        self.r = nn.Linear(hidden_dim, 1)

        # Final vocab projection
        self.U = nn.Linear(hidden_dim * 4, vocab_size)

        self.vocab_size = vocab_size
    

    def decoder_attention(self, decoder_hidden, encoder_hidden):
        #print(encoder_hidden.shape)
        B, T, H = encoder_hidden.shape
        decoder_exp = decoder_hidden.unsqueeze(1).expand(-1, T, -1)     # (B, T, H)
        #print(decoder_exp.shape)
        combined = torch.cat([encoder_hidden, decoder_exp], dim=2)     # (B, T, 2H)
        #print(combined.shape)
        x = torch.tanh(self.W_denc1(combined))                          # (B, T, H)
        #print(x.shape)
        x = torch.tanh(self.W_denc2(x))                                 # (B, T, H)
        #print(x.shape)
        u = self.w(x).squeeze(-1)                                       # (B, T)
        #print(u.shape)
        attn = torch.softmax(u, dim=1)                                  # (B, T)
        #print(attn.shape)

        context = torch.bmm(attn.unsqueeze(1), encoder_hidden).squeeze(1)  # (B, H)
        #print(context.shape)
        concat = torch.cat([decoder_hidden, context], dim=1)           # (B, 2H)
        #print(concat.shape)
        vocab_logits = self.U(concat)                                   # (B, vocab_size)
        return vocab_logits

    def kb_attention(self, decoder_hidden, kb_keys):
        B, H = decoder_hidden.shape
        num_kb = kb_keys.shape[0]

        decoder_exp = decoder_hidden.unsqueeze(1).expand(-1, num_kb, -1)        # (B, num_kb, H)
        kb_keys_exp = kb_keys.squeeze(1).unsqueeze(0).expand(B, -1, -1)         # (B, num_kb, E)

        combined = torch.cat([kb_keys_exp, decoder_exp], dim=-1)                # (B, num_kb, H+E)
        x = torch.tanh(self.W_kb1(combined))                                    # (B, num_kb, H)
        
        x = torch.tanh(self.W_kb2(x))                                           # (B, num_kb, H)
        scores = self.r(x).squeeze(-1)                                          # (B, num_kb)
        return scores

    def forward(self, inputs, targets, inference = False, teacher_forcing=True,  kb_keys=None,  fine_tune=False):
        #print(targets.shape)
        batch, T_out , _ = targets.shape
        
        enc_out, (h_enc, c_enc) = self.encoder(inputs)             # enc_out: (B, T_in, H)
        # print("enc out", enc_out.shape)
        # print("h_enc",h_enc.shape)
        # print("c_enc", c_enc.shape)
        #print(h_enc[0,:,:].unsqueeze(0).shape)
        hidden, cell = torch.cat([h_enc[0],h_enc[-1]], dim=-1), torch.cat([c_enc[0], c_enc[1]], dim=-1)
        #print(hidden.shape)
       
        #print("start",start_token)
        logits = []
        
        #print(targets[0].shape)
        dec_input = targets[:,0,:]                   # (B, E)
        #print(dec_input)

        #print("dec_input",dec_input[-1].unsqueeze(0))
        #print("hidden_cell", hidden.shape)
        if inference:
            #print("incode input shape", dec_input.shape)
            hidden_logits, cell = self.decoder(dec_input, (hidden, cell))      # (B, H)
            #print("hiddne_logits", hidden_logits.shape)
            #print(enc_out.shape)
            hidden_logits = self.decoder_attention(hidden_logits, enc_out)      # (B, V)
            #print(hidden_logits)
            if fine_tune:
                kb_logits = self.kb_attention(hidden, kb_keys)              # (B, n)
                hidden_logits = torch.cat([hidden_logits, kb_logits], dim=1)  # (B, V+n)
                
            logits.append(hidden_logits)  
            #print(hidden_logits.shape)

            if teacher_forcing:
                #target = targets[:, t].unsqueeze(0)
                #dec_input = self.embedding(target)['last_hidden_state'][0]
                #print(dec_input.shape)
                dec_input = targets[:,t,:]
            else:
                pred = torch.argmax(hidden_logits, dim=1)
                pred = torch.where(pred < self.vocab_size, pred, torch.tensor(0).to(pred.device))
        else:
            for t in range(1, T_out):
                #print("incode input shape", dec_input.shape)
                hidden_logits, cell = self.decoder(dec_input, (hidden, cell))      # (B, H)
                #print("hiddne_logits", hidden_logits.shape)
                #print(enc_out.shape)
                hidden_logits = self.decoder_attention(hidden_logits, enc_out)      # (B, V)
                #print(hidden_logits)
                if fine_tune:
                    kb_logits = self.kb_attention(hidden, kb_keys)              # (B, n)
                    hidden_logits = torch.cat([hidden_logits, kb_logits], dim=1)  # (B, V+n)
                    
                logits.append(hidden_logits)  
                #print(hidden_logits.shape)

                if teacher_forcing:
                    #target = targets[:, t].unsqueeze(0)
                    #dec_input = self.embedding(target)['last_hidden_state'][0]
                    #print(dec_input.shape)
                    dec_input = targets[:,t,:]
                else:
                    pred = torch.argmax(hidden_logits, dim=1)
                    pred = torch.where(pred < self.vocab_size, pred, torch.tensor(0).to(pred.device))

        return torch.stack(logits, dim=1)  # (B, T_out-1, V+n)

In [9]:
vocab_size = len(tokenizer.vocab)
model = EncoderDecoderWithKBAttention(embed_model=embed_model, vocab_size=vocab_size, hidden_dim=320, embedding_dim=1024).to(device)

In [9]:
maxlength = 8192
start_token = tokenizer(
        '<|im_start|>',
        max_length= maxlength,
        return_tensors="pt",
    )['input_ids']

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [9]:
start_token

tensor([[151644, 151643]])

In [15]:
embed_model

Qwen3Model(
  (embed_tokens): Embedding(151669, 1024)
  (layers): ModuleList(
    (0-27): 28 x Qwen3DecoderLayer(
      (self_attn): Qwen3Attention(
        (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
        (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
        (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
      )
      (mlp): Qwen3MLP(
        (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
        (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
        (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
      (post_attention_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
    )
  )
  (norm): Qwen3RMSNorm((102

In [10]:
new_embed_model =  AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')

In [18]:
for input, target in train_loader:
    input = new_embed_model(input)['last_hidden_state']
    #print(input.shape)
    #print(target.shape)
    batch_size = target.shape[0]
    target = torch.cat([start_token[:,0].repeat(batch_size, 1),target] , dim=1)
    #print(target.shape)
    outputs = new_embed_model(target)['last_hidden_state']
    #print(outputs)
    #print(outputs.shape)
    #print(input.shape)
    #print(input)
    input = input.to(device)
    outputs = outputs.to(device)
    target = target.to(device)
    logits =  model(input, outputs)
    # #print(target[0][1:].unsqueeze(0).shape)
    print(logits.shape)
    print(target.shape)
    break
    

torch.Size([8, 173, 151669])
torch.Size([8, 174])


In [11]:
learning_rate = 0.001
epochs = 25
criterion = nn.CrossEntropyLoss()
#weight_decay = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Training Epochs - 1: 100%|██████████| 125/125 [09:57<00:00,  4.78s/it]
Validation: 100%|██████████| 25/25 [01:34<00:00,  3.77s/it]


Epoch 1, Train Loss: 4.0543 , Val Loss: 2.7719


Training Epochs - 2: 100%|██████████| 125/125 [10:02<00:00,  4.82s/it]
Validation: 100%|██████████| 25/25 [01:38<00:00,  3.94s/it]


Epoch 2, Train Loss: 2.4440 , Val Loss: 1.9502


Training Epochs - 3: 100%|██████████| 125/125 [10:28<00:00,  5.03s/it]
Validation: 100%|██████████| 25/25 [01:40<00:00,  4.02s/it]


Epoch 3, Train Loss: 1.9312 , Val Loss: 1.5661


Training Epochs - 4: 100%|██████████| 125/125 [10:39<00:00,  5.12s/it]
Validation: 100%|██████████| 25/25 [01:41<00:00,  4.06s/it]


Epoch 4, Train Loss: 1.5467 , Val Loss: 1.2751


Training Epochs - 5: 100%|██████████| 125/125 [10:40<00:00,  5.13s/it]
Validation: 100%|██████████| 25/25 [01:31<00:00,  3.67s/it]


Epoch 5, Train Loss: 1.2866 , Val Loss: 1.1280


Training Epochs - 6: 100%|██████████| 125/125 [10:29<00:00,  5.03s/it]
Validation: 100%|██████████| 25/25 [01:38<00:00,  3.96s/it]


Epoch 6, Train Loss: 1.1229 , Val Loss: 0.9211


Training Epochs - 7: 100%|██████████| 125/125 [10:34<00:00,  5.07s/it]
Validation: 100%|██████████| 25/25 [01:43<00:00,  4.16s/it]


Epoch 7, Train Loss: 0.9661 , Val Loss: 0.7992


Training Epochs - 8: 100%|██████████| 125/125 [10:38<00:00,  5.11s/it]
Validation: 100%|██████████| 25/25 [01:41<00:00,  4.05s/it]


Epoch 8, Train Loss: 0.8323 , Val Loss: 0.7006


Training Epochs - 9: 100%|██████████| 125/125 [10:23<00:00,  4.99s/it]
Validation: 100%|██████████| 25/25 [01:39<00:00,  3.98s/it]


Epoch 9, Train Loss: 0.7298 , Val Loss: 0.5948


Training Epochs - 10: 100%|██████████| 125/125 [10:38<00:00,  5.11s/it]
Validation: 100%|██████████| 25/25 [01:38<00:00,  3.93s/it]


Epoch 10, Train Loss: 0.6198 , Val Loss: 0.5138


Training Epochs - 11: 100%|██████████| 125/125 [10:34<00:00,  5.08s/it]
Validation: 100%|██████████| 25/25 [01:40<00:00,  4.04s/it]


Epoch 11, Train Loss: 0.5292 , Val Loss: 0.4396


Training Epochs - 12: 100%|██████████| 125/125 [09:59<00:00,  4.80s/it]
Validation: 100%|██████████| 25/25 [01:36<00:00,  3.84s/it]


Epoch 12, Train Loss: 0.4502 , Val Loss: 0.3726


Training Epochs - 13: 100%|██████████| 125/125 [10:14<00:00,  4.92s/it]
Validation: 100%|██████████| 25/25 [01:39<00:00,  3.97s/it]


Epoch 13, Train Loss: 0.3901 , Val Loss: 0.3036


Training Epochs - 14:  13%|█▎        | 16/125 [01:23<09:17,  5.12s/it]

In [13]:
maxlength = 8192

In [52]:
input = "Hello doctor"
input = preprocess_data(input)
start_token = tokenizer(
        '<|im_start|>',
        max_length= maxlength,
        return_tensors="pt",
    )['input_ids'][:,0].unsqueeze(0).to(device)
input_tokens = tokenizer(
        input,
        max_length= maxlength,
        return_tensors="pt",
    )['input_ids'][:, :-1].to(device)

print(input_tokens)
embed_model = embed_model.to(device)

embed_query = embed_model(input_tokens)['last_hidden_state'].to(device)
target = embed_model(start_token)['last_hidden_state'].to(device)


tensor([[14990, 10668]], device='cuda:0')


In [15]:
state_dict = torch.load("../models/model_weights_024.pth")
load_result = model.load_state_dict(state_dict)

In [16]:
load_result

<All keys matched successfully>

In [17]:
target.shape

torch.Size([1, 1, 1024])

In [31]:
start_token.item()

151644

In [53]:

# Make sure the model is in eval mode for inference
model.eval()
tokens = []
token = start_token.item()
# Now call the model, not `load_result`:
with torch.no_grad():
    for i in range(20):
        logits = model(embed_query, target, True, False)
        token = torch.argmax(logits[0], dim=1)
        target = embed_model(token.unsqueeze(0))['last_hidden_state'].to(device)
        token = token.item()
        tokens.append(token)
        print(token)


        


2152
1850
1850
1850
1850
1850
1850
1850
1850
1850
1850
1850
1850
1850
1850
1850
1850
1850
1850
1850


In [54]:
tokens = torch.tensor(tokens)

In [55]:
tokenizer.convert_ids_to_tokens(tokens)

['no',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest',
 'Ġbest']

In [22]:
token = torch.argmax(logits[0], dim=1)

In [33]:
new_list = []

In [36]:
new_list.append(token.item())

In [37]:
new_list

[tensor([9330], device='cuda:0'), 9330]

In [25]:
embed_model(token.unsqueeze(0))

BaseModelOutputWithPast(last_hidden_state=tensor([[[  2.3688, -15.1576,  -0.0910,  ...,  -7.9240, -11.6597,   0.9951]]],
       device='cuda:0', grad_fn=<MulBackward0>), past_key_values=<transformers.cache_utils.DynamicCache object at 0x00000201FF880190>, hidden_states=None, attentions=None)

In [21]:
tokenizer.convert_ids_to_tokens(9330)

'you'

In [73]:
load_result

<All keys matched successfully>

In [93]:
start_token

tensor([[151644, 151643]])

In [94]:
target

tensor([[[  0.8472,  -4.0913,   0.2632,  ...,  -4.2332, -10.3624,   0.6719]]],
       device='cuda:0', grad_fn=<MulBackward0>)

In [13]:
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(23,epochs):
    model.train()
    total_train_loss = 0.0
    total_val_loss = 0.0
    total_train_correct = 0
    total_train_tokens = 0

    for input_texts, target_tokens in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        batch_size = target_tokens.size(0)
        input_texts = input_texts[:, :-1]

        target_with_bos = torch.cat([start_token[:, 0].repeat(batch_size, 1), target_tokens], dim=1)

        with torch.no_grad():
            input_embeds = new_embed_model(input_texts)['last_hidden_state']
            target_embeds = new_embed_model(target_with_bos)['last_hidden_state']

        input_embeds = input_embeds.to(device)
        target_embeds = target_embeds.to(device)
        target_tokens = target_tokens.to(device)

        optimizer.zero_grad()
        logits = model(input_embeds, target_embeds)

        loss = criterion(
            logits.view(-1, logits.shape[-1]),
            target_tokens.reshape(-1)
        )

        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

        # Calculate accuracy
        preds = torch.argmax(logits, dim=-1)
        correct = (preds == target_tokens).float()
        total_train_correct += correct.sum().item()
        total_train_tokens += target_tokens.numel()

    model.eval()
    total_val_correct = 0
    total_val_tokens = 0

    with torch.no_grad():
        for input_texts, target_tokens in tqdm(val_loader, desc="Validating"):
            batch_size = target_tokens.size(0)
            target_with_bos = torch.cat([start_token[:, 0].repeat(batch_size, 1), target_tokens], dim=1)

            input_embeds = new_embed_model(input_texts)['last_hidden_state']
            target_embeds = new_embed_model(target_with_bos)['last_hidden_state']

            input_embeds = input_embeds.to(device)
            target_embeds = target_embeds.to(device)
            target_tokens = target_tokens.to(device)

            logits = model(input_embeds, target_embeds)

            loss = criterion(
                logits.view(-1, logits.shape[-1]),
                target_tokens.reshape(-1)
            )

            total_val_loss += loss.item()

            # Calculate accuracy
            preds = torch.argmax(logits, dim=-1)
            correct = (preds == target_tokens).float()
            total_val_correct += correct.sum().item()
            total_val_tokens += target_tokens.numel()

    avg_train_loss = total_train_loss / len(train_loader)
    avg_val_loss = total_val_loss / len(val_loader)
    train_accuracy = total_train_correct / total_train_tokens
    val_accuracy = total_val_correct / total_val_tokens

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)

    torch.save(model.state_dict(), f"../models/model_weights_0{epoch}.pth")

    print(f"Epoch {epoch + 1}, "
          f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
          f"Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")


Training Epoch 24: 100%|██████████| 1000/1000 [17:54<00:00,  1.07s/it]
Validating: 100%|██████████| 200/200 [01:53<00:00,  1.76it/s]


Epoch 24, Train Loss: 0.3553, Val Loss: 0.3077, Train Acc: 0.8887, Val Acc: 0.9055


Training Epoch 25: 100%|██████████| 1000/1000 [18:07<00:00,  1.09s/it]
Validating: 100%|██████████| 200/200 [01:39<00:00,  2.02it/s]


Epoch 25, Train Loss: 0.3111, Val Loss: 0.2604, Train Acc: 0.9036, Val Acc: 0.9212


In [1]:
val_acc = [0.8739, 0.8879, 0.9021, 0.9004, 0.9212]
train_acc = [0.8456, 0.8663, 0.8772, 0.8773, 0.9036]
train_loss = [0.5003, 0.4344, 0.4007, 0.3974, 0.3111]
val_loss = [0.4197, 0.3763, 0.3406, 0.3411, 0.2604]

In [66]:
tokenizer.convert_ids_to_tokens(16)

'1'

In [32]:
embed_query

tensor([[[ 2.3917e+00,  1.2132e+00, -1.8796e-01,  ..., -8.5305e+00,
          -1.1822e+01, -4.7525e-03],
         [ 2.0638e-01, -4.2831e+00, -1.1676e+00,  ..., -2.1979e+00,
          -1.1188e+00, -3.0773e+00],
         [-2.8330e+00, -8.4577e-01, -1.0637e+00,  ..., -2.3205e+00,
           2.6245e+00, -8.2585e-01],
         ...,
         [-1.7195e+00,  1.7569e+00, -7.6966e-01,  ..., -1.3342e+00,
          -2.3326e-01, -1.9689e+00],
         [-6.4266e+00, -5.7917e+00, -7.3048e-01,  ...,  2.4106e+00,
          -2.4566e+00, -7.1773e+00],
         [-1.5489e+00,  2.2582e+00, -8.7037e-01,  ...,  2.0526e+00,
           1.2427e+00, -2.1785e+00]]], device='cuda:0',
       grad_fn=<ToCopyBackward0>)

In [33]:
target

tensor([[[  0.8472,  -4.0913,   0.2632,  ...,  -4.2332, -10.3623,   0.6719]]],
       device='cuda:0', grad_fn=<ToCopyBackward0>)