In [1]:
from datasets import load_dataset
import pandas as pd
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, Dataset 
import torch
import torch.nn as nn
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = load_dataset("antareepdey/Patient_doctor_chat")
datasets = [i['Text'] for i in data['train']]
patient_query = []
doctor_response = []

for i in range(1000):
    inputs, outputs = datasets[i].split("###Output:")
    patient_query.append(inputs.replace("###Input:",""))
    doctor_response.append(outputs)
data = {"Patient query": patient_query, "Doctor response": doctor_response}
df = pd.DataFrame(data=data)
def preprocess_data(text):
    # preprocess
    text = text.lower()
    text = text.replace('?','')
    text = text.replace("'","")
    text = text.replace(","," ")
    text = text.replace("1)"," ")
    text = text.replace("2)"," ")
    text = text.replace("3)"," ")
    text = text.replace("4)"," ")
    text = text.replace("."," ")
    text = text.strip()
    return text

df['Patient query'] = df['Patient query'].apply(preprocess_data)
df['Doctor response'] = df['Doctor response'].apply(preprocess_data)

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
embed_model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')



In [3]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.max_length = 8192

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        question = df.iloc[index]['Patient query']
        answer = df.iloc[index]['Doctor response']

        question_ids = tokenizer(
            question,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )['input_ids'][0] 

        answer_ids = tokenizer(
            answer,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )['input_ids'][0]

        return question_ids, answer_ids


In [4]:
def collate_fn(batch):
    questions, answers = zip(*batch) 

    padded_questions = pad_sequence(questions, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_answers = pad_sequence(answers, batch_first=True, padding_value=tokenizer.pad_token_id)

    return padded_questions, padded_answers


In [5]:
dataset = CustomDataset(df) 

In [6]:
train_loader = DataLoader( dataset, batch_size=8, collate_fn=collate_fn, shuffle=True)

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class EncoderDecoderWithKBAttention(nn.Module):
    def __init__(self, vocab_size, hidden_dim, embedding_dim=1024):
        super().__init__()
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.decoder = nn.LSTMCell(embedding_dim, hidden_dim*2)

        # Attention over encoder outputs
        self.W_denc1 = nn.Linear(hidden_dim * 4, hidden_dim)
        self.W_denc2 = nn.Linear(hidden_dim, hidden_dim)
        self.w = nn.Linear(hidden_dim, 1)

        # Attention over KB keys
        self.W_kb1 = nn.Linear(hidden_dim + embedding_dim, hidden_dim)
        self.W_kb2 = nn.Linear(hidden_dim, hidden_dim)
        self.r = nn.Linear(hidden_dim, 1)

        # Final vocab projection
        self.U = nn.Linear(hidden_dim * 4, vocab_size)

        self.vocab_size = vocab_size
    

    def decoder_attention(self, decoder_hidden, encoder_hidden):
        #print(encoder_hidden.shape)
        B, T, H = encoder_hidden.shape
        decoder_exp = decoder_hidden.unsqueeze(1).expand(-1, T, -1)     # (B, T, H)
        #print(decoder_exp.shape)
        combined = torch.cat([encoder_hidden, decoder_exp], dim=2)     # (B, T, 2H)
        #print(combined.shape)
        x = torch.tanh(self.W_denc1(combined))                          # (B, T, H)
        #print(x.shape)
        x = torch.tanh(self.W_denc2(x))                                 # (B, T, H)
        #print(x.shape)
        u = self.w(x).squeeze(-1)                                       # (B, T)
        #print(u.shape)
        attn = torch.softmax(u, dim=1)                                  # (B, T)
        #print(attn.shape)

        context = torch.bmm(attn.unsqueeze(1), encoder_hidden).squeeze(1)  # (B, H)
        #print(context.shape)
        concat = torch.cat([decoder_hidden, context], dim=1)           # (B, 2H)
        #print(concat.shape)
        vocab_logits = self.U(concat)                                   # (B, vocab_size)
        return vocab_logits

    def kb_attention(self, decoder_hidden, kb_keys):
        B, H = decoder_hidden.shape
        num_kb = kb_keys.shape[0]

        decoder_exp = decoder_hidden.unsqueeze(1).expand(-1, num_kb, -1)        # (B, num_kb, H)
        kb_keys_exp = kb_keys.squeeze(1).unsqueeze(0).expand(B, -1, -1)         # (B, num_kb, E)

        combined = torch.cat([kb_keys_exp, decoder_exp], dim=-1)                # (B, num_kb, H+E)
        x = torch.tanh(self.W_kb1(combined))                                    # (B, num_kb, H)
        
        x = torch.tanh(self.W_kb2(x))                                           # (B, num_kb, H)
        scores = self.r(x).squeeze(-1)                                          # (B, num_kb)
        return scores

    def forward(self, inputs, targets, kb_keys=None, teacher_forcing=True, fine_tune=False):
        #print(targets.shape)
        batch, T_out , _ = targets.shape
        
        enc_out, (h_enc, c_enc) = self.encoder(inputs)             # enc_out: (B, T_in, H)
        # print("enc out", enc_out.shape)
        # print("h_enc",h_enc.shape)
        # print("c_enc", c_enc.shape)
        #print(h_enc[0,:,:].unsqueeze(0).shape)
        hidden, cell = torch.cat([h_enc[0],h_enc[-1]], dim=-1), torch.cat([c_enc[0], c_enc[1]], dim=-1)
        #print(hidden.shape)
       
        #print("start",start_token)
        logits = []
        
        #print(targets[0].shape)
        dec_input = targets[:,0,:]                   # (B, E)

        #print("dec_input",dec_input[-1].unsqueeze(0))
        #print("hidden_cell", hidden.shape)
        for t in range(1, T_out):
            #print("incode input shape", dec_input.shape)
            hidden_logits, cell = self.decoder(dec_input, (hidden, cell))      # (B, H)
            #print("hiddne_logits", hidden_logits.shape)
            #print(enc_out.shape)
            hidden_logits = self.decoder_attention(hidden_logits, enc_out)      # (B, V)
            #print(hidden_logits)
            if fine_tune:
                kb_logits = self.kb_attention(hidden, kb_keys)              # (B, n)
                hidden_logits = torch.cat([hidden_logits, kb_logits], dim=1)  # (B, V+n)
                
            logits.append(hidden_logits)  
            #print(hidden_logits.shape)

            if teacher_forcing:
                #target = targets[:, t].unsqueeze(0)
                #dec_input = self.embedding(target)['last_hidden_state'][0]
                #print(dec_input.shape)
                dec_input = targets[:,t,:]
            else:
                pred = torch.argmax(hidden_logits, dim=1)
                pred = torch.where(pred < self.vocab_size, pred, torch.tensor(0).to(pred.device))

        return torch.stack(logits, dim=1)  # (B, T_out-1, V+n)

In [9]:
vocab_size = len(tokenizer.vocab)
model = EncoderDecoderWithKBAttention(vocab_size=vocab_size,hidden_dim=320, embedding_dim=1024).to(device)

In [10]:
maxlength = 8192
start_token = tokenizer(
        '<|im_start|>',
        max_length= maxlength,
        return_tensors="pt",
    )['input_ids']


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [11]:
embed_model

Qwen3Model(
  (embed_tokens): Embedding(151669, 1024)
  (layers): ModuleList(
    (0-27): 28 x Qwen3DecoderLayer(
      (self_attn): Qwen3Attention(
        (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
        (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
        (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
      )
      (mlp): Qwen3MLP(
        (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
        (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
        (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
      (post_attention_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
    )
  )
  (norm): Qwen3RMSNorm((102

In [12]:
for input, target in train_loader:
    input = embed_model(input)['last_hidden_state']
    #print(input.shape)
    #print(target.shape)
    batch_size = target.shape[0]
    target = torch.cat([start_token[:,0].repeat(batch_size, 1),target] , dim=1)
    #print(target.shape)
    outputs = embed_model(target)['last_hidden_state']
    #print(outputs)
    #print(outputs.shape)
    #print(input.shape)
    #print(input)
    input = input.to(device)
    outputs = outputs.to(device)
    target = target.to(device)
    logits =  model(input, outputs,device)
    # #print(target[0][1:].unsqueeze(0).shape)
    print(logits.shape)
    print(target.shape)
    break
    

torch.Size([8, 130, 151669])
torch.Size([8, 131])


In [13]:
learning_rate = 0.0001
epochs = 10
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
train_loss = []
for epoch in range(epochs):
    model.train()
    total_loss = 0

    for input, target in tqdm(train_loader, desc=f"Training Epochs - {epoch+1}"):
        
        #print(input.shape)
        #print(target.shape)
        batch_size = target.shape[0]
        target = torch.cat([start_token[:,0].repeat(batch_size, 1),target] , dim=1)
        #print(target.shape)
        with torch.no_grad():
            input = embed_model(input)['last_hidden_state']
            outputs = embed_model(target)['last_hidden_state']
        #print(outputs)
        #print(outputs.shape)
        #print(input.shape)
        #print(input)
        input = input.to(device)
        outputs = outputs.to(device)
        target = target.to(device)
        logits =  model(input, outputs,device)
        
        # input = input.to(device)
        # target = target.to(device)
        # target = torch.cat([start_token,target, end_token] , dim=1)
        # outputs = embed_model(target)['last_hidden_state'][0]
        # # print("input shape",input.shape)
        # # print("len ", len(input))
        # if input.shape[1]==0:
        #     continue
        optimizer.zero_grad()
        # logits = model(input,outputs,device)  # (B, T-1, V+K)
        target_trimmed = target[:, 1:]  
        loss = criterion(
            logits.reshape(-1, logits.shape[-1]),
            target_trimmed.reshape(-1)
        )
        loss.backward()
        optimizer.step()
        #print(loss)

        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    train_loss.append(avg_loss)
    print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

Training Epochs - 1: 100%|██████████| 125/125 [09:24<00:00,  4.52s/it]


Epoch 1, Loss: 6.0091


Training Epochs - 2: 100%|██████████| 125/125 [09:23<00:00,  4.51s/it]


Epoch 2, Loss: 4.1590


Training Epochs - 3: 100%|██████████| 125/125 [09:09<00:00,  4.40s/it]


Epoch 3, Loss: 3.6237


Training Epochs - 4: 100%|██████████| 125/125 [09:13<00:00,  4.43s/it]


Epoch 4, Loss: 3.2383


Training Epochs - 5: 100%|██████████| 125/125 [09:00<00:00,  4.33s/it]


Epoch 5, Loss: 2.9634


Training Epochs - 6: 100%|██████████| 125/125 [09:28<00:00,  4.55s/it]


Epoch 6, Loss: 2.7455


Training Epochs - 7: 100%|██████████| 125/125 [09:35<00:00,  4.60s/it]


Epoch 7, Loss: 2.5247


Training Epochs - 8: 100%|██████████| 125/125 [09:07<00:00,  4.38s/it]


Epoch 8, Loss: 2.3755


Training Epochs - 9: 100%|██████████| 125/125 [09:09<00:00,  4.39s/it]


Epoch 9, Loss: 2.2165


Training Epochs - 10:  42%|████▏     | 53/125 [03:57<05:03,  4.21s/it]

In [None]:
input = "Hi doctor,I have a small white scar on the lower lip for seven months. It was not white at first. But now it just stays there and does not show any change."
input = preprocess_data(input)
start_token = tokenizer(
        '<|im_start|>',
        max_length= maxlength,
        return_tensors="pt",
    )['input_ids']
input_tokens = tokenizer(
        input,
        max_length= maxlength,
        return_tensors="pt",
    )['input_ids']

embed_query = embed_model(input_tokens)


In [None]:
pred_text_token = []
