In [1]:
from datasets import load_dataset
import pandas as pd
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, Dataset 
import torch
import torch.nn as nn
from tqdm import tqdm
import csv
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = load_dataset("antareepdey/Patient_doctor_chat")
datasets = [i['Text'] for i in data['train']]
patient_query = []
doctor_response = []
patient_query_val = []
doctor_response_val = []

for i in range(1000):
    inputs, outputs = datasets[i].split("###Output:")
    patient_query.append(inputs.replace("###Input:",""))
    doctor_response.append(outputs)
for i in range(1000,1200):
    inputs, outputs = datasets[i].split("###Output:")
    patient_query_val.append(inputs.replace("###Input:",""))
    doctor_response_val.append(outputs)

data = {"Patient query": patient_query, "Doctor response": doctor_response}
val_data = {"Patient query": patient_query_val, "Doctor response": doctor_response_val}
df = pd.DataFrame(data=data)
val_df = pd.DataFrame(data=val_data)
def preprocess_data(text):
    # preprocess
    text = text.lower()
    # text = text.replace('?','')
    # text = text.replace("'","")
    # text = text.replace("1)"," ")
    # text = text.replace("2)"," ")
    # text = text.replace("3)"," ")
    # text = text.replace("4)"," ")
    # text = text.replace("."," ")
    text = text.strip()
    return text

df['Patient query'] = df['Patient query'].apply(preprocess_data)
df['Doctor response'] = df['Doctor response'].apply(preprocess_data)

val_df['Patient query'] = val_df['Patient query'].apply(preprocess_data)
val_df['Doctor response'] = val_df['Doctor response'].apply(preprocess_data)

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
embed_model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')


# make own tokens 

In [3]:
all_text_data = patient_query + doctor_response + patient_query_val + doctor_response_val

In [4]:
len(all_text_data)

2400

In [5]:
used_token_ids = set()
for target_text in all_text_data:
    tokens = tokenizer.encode(target_text, add_special_tokens=False)
    used_token_ids.update(tokens)

In [6]:
tokenizer.add_special_tokens({
    'bos_token': '<sos>'
})

1

In [7]:
special_ids = [tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.bos_token_id]
used_token_ids.update([tid for tid in special_ids if tid is not None])

In [8]:
token_id_list = sorted(list(used_token_ids))
old2new = {old: new for new, old in enumerate(token_id_list)}
new2old = {v: k for k, v in old2new.items()}

#custom dataset

In [9]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.max_length = 8192

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        question = df.iloc[index]['Patient query']
        answer = df.iloc[index]['Doctor response']

        question_ids = tokenizer(
            question,
            truncation=True,
            max_length=self.max_length,
            add_special_tokens=False,
            return_tensors="pt"
        )['input_ids'][0] 

        answer_ids = tokenizer(
            answer,
            truncation=True,
            max_length=self.max_length,
             add_special_tokens=False,
            return_tensors="pt"
        )['input_ids'][0]

        return question_ids, answer_ids


In [4]:
def collate_fn(batch):
    questions, answers = zip(*batch) 

    padded_questions = pad_sequence(questions, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_answers = pad_sequence(answers, batch_first=True, padding_value=tokenizer.pad_token_id)

    return padded_questions, padded_answers


In [10]:
dataset = CustomDataset(df) 
val_dataset = CustomDataset(val_df)

In [11]:
train_loader = DataLoader( dataset, batch_size=1, shuffle=True)
val_loader = DataLoader( val_dataset, batch_size=1,  shuffle=True)

In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class EncoderDecoderWithKBAttention(nn.Module):
    def __init__(self, embed_model, vocab_size, hidden_dim, embedding_dim=1024):
        super().__init__()
        self.embed_model = embed_model
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.decoder = nn.LSTMCell(embedding_dim, hidden_dim * 2)

        # Attention over encoder outputs
        self.W_denc1 = nn.Linear(hidden_dim * 4, hidden_dim)
        self.W_denc2 = nn.Linear(hidden_dim, hidden_dim)
        self.w = nn.Linear(hidden_dim, 1)

        # Attention over KB keys (commented out, unused)
        # self.W_kb1 = nn.Linear(hidden_dim + embedding_dim, hidden_dim)
        # self.W_kb2 = nn.Linear(hidden_dim, hidden_dim)
        # self.r = nn.Linear(hidden_dim, 1)

        # Final vocab projection
        self.U = nn.Linear(hidden_dim * 4, vocab_size)

        self.vocab_size = vocab_size

    def decoder_attention(self, decoder_hidden, encoder_hidden):
        B, T, H = encoder_hidden.shape  # B=batch, T=encoder seq len, H=hidden size
        decoder_exp = decoder_hidden.unsqueeze(1).expand(-1, T, -1)  # (B, T, H)
        combined = torch.cat([encoder_hidden, decoder_exp], dim=2)  # (B, T, 2H)
        x = torch.tanh(self.W_denc1(combined))                      # (B, T, H)
        x = torch.tanh(self.W_denc2(x))                             # (B, T, H)
        u = self.w(x).squeeze(-1)                                   # (B, T)
        attn = torch.softmax(u, dim=1)                              # (B, T)

        context = torch.bmm(attn.unsqueeze(1), encoder_hidden).squeeze(1)  # (B, H)
        concat = torch.cat([decoder_hidden, context], dim=1)          # (B, 2H)
        vocab_logits = self.U(concat)                                 # (B, vocab_size)
        return vocab_logits

    def forward(self, inputs, targets, inference=False, teacher_forcing_ratio=1.0, kb_keys=None, fine_tune=False):
    

        batch_size, T_out, _ = targets.shape

        enc_out, (h_enc, c_enc) = self.encoder(inputs)  # enc_out: (B, T_in, H)
        hidden = torch.cat([h_enc[0], h_enc[-1]], dim=-1)  # (B, hidden_dim*2)
        cell = torch.cat([c_enc[0], c_enc[-1]], dim=-1)    # (B, hidden_dim*2)

        logits = []

        dec_input = targets[:, 0, :]  # Initial decoder input (usually <sos> embedding)

        if inference:
            # For inference, decode only one step here
            hidden_state, cell_state = self.decoder(dec_input, (hidden, cell))
            hidden_logits = self.decoder_attention(hidden_state, enc_out)
            logits.append(hidden_logits)
            return torch.stack(logits, dim=1)

        else:
            for t in range(1, T_out):
                hidden_state, cell_state = self.decoder(dec_input, (hidden, cell))
                hidden = hidden_state
                cell = cell_state

                hidden_logits = self.decoder_attention(hidden_state, enc_out)
                logits.append(hidden_logits)

                # Decide whether to do teacher forcing this step
                use_teacher_forcing = (torch.rand(1).item() < teacher_forcing_ratio)

                if use_teacher_forcing:
                    # Use ground-truth target embedding for next input
                    dec_input = targets[:, t, :]
                else:
                    # Use model prediction:
                    pred_tokens = torch.argmax(hidden_logits, dim=1)  # (B,)
                    with torch.no_grad():
                        # Get embeddings for predicted tokens
                        # Assuming embed_model accepts token IDs and returns embeddings
                        dec_input = self.embed_model(pred_tokens)['last_hidden_state'][:, 0, :]

        return torch.stack(logits, dim=1)  # (B, T_out-1, vocab_size)


In [14]:
vocab_size = len(token_id_list)
model = EncoderDecoderWithKBAttention(embed_model=embed_model, vocab_size=vocab_size, hidden_dim=320, embedding_dim=1024).to(device)

In [15]:
maxlength = 8192
start_token = tokenizer(
        '<|im_start|>',
        max_length= maxlength,
        return_tensors="pt",
    )['input_ids']

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [9]:
start_token

tensor([[151644, 151643]])

In [15]:
embed_model

Qwen3Model(
  (embed_tokens): Embedding(151669, 1024)
  (layers): ModuleList(
    (0-27): 28 x Qwen3DecoderLayer(
      (self_attn): Qwen3Attention(
        (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
        (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
        (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
      )
      (mlp): Qwen3MLP(
        (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
        (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
        (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
      (post_attention_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
    )
  )
  (norm): Qwen3RMSNorm((102

In [16]:
new_embed_model =  AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')

In [18]:
for input, target in train_loader:
    input = new_embed_model(input)['last_hidden_state']
    #print(input.shape)
    #print(target.shape)
    batch_size = target.shape[0]
    target = torch.cat([start_token[:,0].repeat(batch_size, 1),target] , dim=1)
    #print(target.shape)
    outputs = new_embed_model(target)['last_hidden_state']
    #print(outputs)
    #print(outputs.shape)
    #print(input.shape)
    #print(input)
    input = input.to(device)
    outputs = outputs.to(device)
    target = target.to(device)
    logits =  model(input, outputs)
    # #print(target[0][1:].unsqueeze(0).shape)
    print(logits.shape)
    print(target.shape)
    break
    

torch.Size([8, 173, 151669])
torch.Size([8, 174])


In [17]:
learning_rate = 0.001
epochs = 25
criterion = nn.CrossEntropyLoss()
#weight_decay = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

monitor_csv = 'monitor.csv'

# Write CSV header
with open(monitor_csv, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['epoch', 'train_loss', 'val_loss', 'train_accuracy', 'val_accuracy'])

for epoch in range(0, epochs):
    model.train()
    total_train_loss = 0.0
    total_train_correct = 0
    total_train_tokens = 0

    for input_texts, target_tokens in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        batch_size = target_tokens.size(0)

        # Prepare inputs
        input_texts = input_texts[:, :-1]
        target_with_bos = torch.cat([start_token[:, 0].repeat(batch_size, 1), target_tokens], dim=1)

        with torch.no_grad():
            input_embeds = new_embed_model(input_texts)['last_hidden_state']
            target_embeds = new_embed_model(target_with_bos)['last_hidden_state']

        input_embeds = input_embeds.to(device)
        target_embeds = target_embeds.to(device)
        target_tokens = target_tokens.to(device)

        # Map target_tokens to reduced vocab IDs
        target_tokens_np = target_tokens.cpu().numpy()
        target_reduced_np = np.vectorize(lambda x: old2new.get(x, 0))(target_tokens_np)
        target_reduced = torch.tensor(target_reduced_np, dtype=torch.long, device=target_tokens.device)

        optimizer.zero_grad()
        logits = model(input_embeds, target_embeds)

        loss = criterion(
            logits.view(-1, logits.shape[-1]),
            target_reduced.reshape(-1)
        )
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

        preds = torch.argmax(logits, dim=-1)
        correct = (preds == target_reduced).float()
        total_train_correct += correct.sum().item()
        total_train_tokens += target_reduced.numel()

    avg_train_loss = total_train_loss / len(train_loader)
    train_accuracy = total_train_correct / total_train_tokens

    model.eval()
    total_val_loss = 0.0
    total_val_correct = 0
    total_val_tokens = 0

    with torch.no_grad():
        for input_texts, target_tokens in tqdm(val_loader, desc="Validating"):
            batch_size = target_tokens.size(0)
            target_with_bos = torch.cat([start_token[:, 0].repeat(batch_size, 1), target_tokens], dim=1)

            input_embeds = new_embed_model(input_texts.to(device))['last_hidden_state']
            target_embeds = new_embed_model(target_with_bos.to(device))['last_hidden_state']

            input_embeds = input_embeds.to(device)
            target_embeds = target_embeds.to(device)
            target_tokens = target_tokens.to(device)

            # Map target tokens to reduced vocab IDs
            target_tokens_np = target_tokens.cpu().numpy()
            target_reduced_np = np.vectorize(lambda x: old2new.get(x, 0))(target_tokens_np)
            target_reduced = torch.tensor(target_reduced_np, dtype=torch.long, device=target_tokens.device)

            logits = model(input_embeds, target_embeds)

            loss = criterion(
                logits.view(-1, logits.shape[-1]),
                target_reduced.reshape(-1)
            )
            total_val_loss += loss.item()

            preds = torch.argmax(logits, dim=-1)
            correct = (preds == target_reduced).float()
            total_val_correct += correct.sum().item()
            total_val_tokens += target_reduced.numel()

    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = total_val_correct / total_val_tokens

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)

    # Save model checkpoint
    torch.save(model.state_dict(), f"../new_models/model_weights_{epoch:02d}.pth")

    # Save metrics to CSV
    with open(monitor_csv, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch + 1, avg_train_loss, avg_val_loss, train_accuracy, val_accuracy])

    print(f"Epoch {epoch + 1}, "
          f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
          f"Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")


Training Epoch 1:  19%|█▉        | 191/1000 [01:36<06:20,  2.12it/s]

In [14]:
maxlength = 8192

In [20]:
input = "Hello doctor,I have infection of the scrotal hair. At the hair root, there is some white bump. Please suggest medicine for the problem."
input = preprocess_data(input)
start_token = tokenizer(
        '<|im_start|>',
        max_length= maxlength,
        return_tensors="pt",
    )['input_ids'][:,0].unsqueeze(0).to(device)
input_tokens = tokenizer(
        input,
        max_length= maxlength,
        return_tensors="pt",
    )['input_ids'][:, :-1].to(device)

print(input_tokens)
embed_model = embed_model.to(device)

embed_query = embed_model(input_tokens)['last_hidden_state'].to(device)
target = embed_model(start_token)['last_hidden_state'].to(device)


tensor([[14990, 10668,   600,   614, 18873,   315,   279,  1136,  4640,   278,
          6869,   220,   518,   279,  6869,  3704,   220,  1052,   374,  1045,
          4158, 27575,   220,  4486,  4190, 15712,   369,   279,  3491]],
       device='cuda:0')


In [21]:
state_dict = torch.load("../models/model_weights_024.pth")
load_result = model.load_state_dict(state_dict)

In [12]:
load_result

<All keys matched successfully>

In [17]:
target.shape

torch.Size([1, 1, 1024])

In [31]:
start_token.item()

151644

In [23]:
model.eval()
tokens = []
token = start_token  # integer token ID for <sos>

with torch.no_grad():
    # 1. Encode input sequence once
    enc_out, (h_enc, c_enc) = model.encoder(embed_query)
    hidden = torch.cat([h_enc[0], h_enc[1]], dim=-1)  # (B, 2H)
    cell   = torch.cat([c_enc[0], c_enc[1]], dim=-1)  # (B, 2H)

    # 2. Start with <sos> token embedding
    input_emb = embed_model(torch.tensor([[token]]).to(device))['last_hidden_state'][:, 0, :]  # (1, E)

    for i in range(20):  # max decoding steps
        # One LSTMCell step
        hidden, cell = model.decoder(input_emb, (hidden, cell))

        # Attention over encoder outputs
        vocab_logits = model.decoder_attention(hidden, enc_out)

        # Predict next token
        pred_token = torch.argmax(vocab_logits, dim=1).item()
        tokens.append(pred_token)

        # Stop if EOS token predicted
        if pred_token == 151644:
            break

        # Embed predicted token for next step
        input_emb = embed_model(torch.tensor([[pred_token]]).to(device))['last_hidden_state'][:, 0, :]


NameError: name 'embed_query' is not defined

In [23]:
tokens

[6023,
 1850,
 2518,
 1850,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518,
 2518]

In [24]:
tokens = torch.tensor(tokens)

In [25]:
tokenizer.convert_ids_to_tokens(tokens)

['hi',
 'Ġbest',
 'Ġred',
 'Ġbest',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred',
 'Ġred']

In [22]:
token = torch.argmax(logits[0], dim=1)

In [33]:
new_list = []

In [36]:
new_list.append(token.item())

In [37]:
new_list

[tensor([9330], device='cuda:0'), 9330]

In [25]:
embed_model(token.unsqueeze(0))

BaseModelOutputWithPast(last_hidden_state=tensor([[[  2.3688, -15.1576,  -0.0910,  ...,  -7.9240, -11.6597,   0.9951]]],
       device='cuda:0', grad_fn=<MulBackward0>), past_key_values=<transformers.cache_utils.DynamicCache object at 0x00000201FF880190>, hidden_states=None, attentions=None)

In [21]:
tokenizer.convert_ids_to_tokens(9330)

'you'

In [73]:
load_result

<All keys matched successfully>

In [93]:
start_token

tensor([[151644, 151643]])

In [94]:
target

tensor([[[  0.8472,  -4.0913,   0.2632,  ...,  -4.2332, -10.3624,   0.6719]]],
       device='cuda:0', grad_fn=<MulBackward0>)

Training Epoch 1:  12%|█▏        | 124/1000 [01:02<07:25,  1.97it/s]

In [1]:
val_acc = [0.8739, 0.8879, 0.9021, 0.9004, 0.9212]
train_acc = [0.8456, 0.8663, 0.8772, 0.8773, 0.9036]
train_loss = [0.5003, 0.4344, 0.4007, 0.3974, 0.3111]
val_loss = [0.4197, 0.3763, 0.3406, 0.3411, 0.2604]

In [66]:
tokenizer.convert_ids_to_tokens(16)

'1'

In [32]:
embed_query

tensor([[[ 2.3917e+00,  1.2132e+00, -1.8796e-01,  ..., -8.5305e+00,
          -1.1822e+01, -4.7525e-03],
         [ 2.0638e-01, -4.2831e+00, -1.1676e+00,  ..., -2.1979e+00,
          -1.1188e+00, -3.0773e+00],
         [-2.8330e+00, -8.4577e-01, -1.0637e+00,  ..., -2.3205e+00,
           2.6245e+00, -8.2585e-01],
         ...,
         [-1.7195e+00,  1.7569e+00, -7.6966e-01,  ..., -1.3342e+00,
          -2.3326e-01, -1.9689e+00],
         [-6.4266e+00, -5.7917e+00, -7.3048e-01,  ...,  2.4106e+00,
          -2.4566e+00, -7.1773e+00],
         [-1.5489e+00,  2.2582e+00, -8.7037e-01,  ...,  2.0526e+00,
           1.2427e+00, -2.1785e+00]]], device='cuda:0',
       grad_fn=<ToCopyBackward0>)

In [33]:
target

tensor([[[  0.8472,  -4.0913,   0.2632,  ...,  -4.2332, -10.3623,   0.6719]]],
       device='cuda:0', grad_fn=<ToCopyBackward0>)