In [1]:
# !pip install transformers torch pandas matplotlib peft

In [2]:
import json
import math
import random
import torch



import os

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from tqdm.auto import tqdm


import matplotlib.pyplot as plt
import pandas as pd


In [3]:
from string import Template
prompt_q_without_contex_train= Template('''Instruct: Youre a Medical Question Answering Expert, answer the following question. Please generate only answer choice (1, 2, 3 or 4)\n: 
$question
$options
''')


prompt_without_contex_train= Template('''Instruct: Youre a Medical Question Answering Expert, answer the following question. Please generate only answer choice (1, 2, 3 or 4)\n
$question                                           
$options
Output: option ''')
# prompt_without_context = f'Hello {planet}'

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda"
torch.set_default_device(device)
# BASE_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
BASE_MODEL_ID = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
class MyLLMDataloader:
    def __init__(self, batch_size, tokenizer, data, shuffle = False, val= False):
        ## initializations
        self.batch_size  = batch_size
        self.tokenizer  = tokenizer
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.data = []
        self.val = val
        with open(data, "r") as f:
            test_data = f.readlines()
        if self.val:
            for line in test_data:
                self.data.append(json.loads(line))
        else:
             for line in test_data[:30000]:
                self.data.append(json.loads(line))           
        # self.all_examples = list(self.data.keys())
        self.shuffle = shuffle
        # self.val = val
        
        self.n_data_points = math.ceil(len(self.data)/self.batch_size)
        print("datapoints",self.n_data_points , len(self.data))
        self.indices = [i for i in range(self.n_data_points)]
        
    def __getitem__(self, idx):
        ## this gets a batch 
        option_header = ["option A ", "option B ", "option C ", "option D ", "option E "]
        batch_start_id = idx * self.batch_size
        batch_end_id  = min(len(self.data), batch_start_id + self.batch_size) 
        batch = {"question_context":[], "answer":[]}
        mapper_ans = {"a":1, "b":2, "c":3, "d":4, "e":5}
        mapper_ans_inverted = {1:"a", 2:"b", 3:"c", 4:"d", 5:"e"}
        for i in range(batch_start_id, batch_end_id):
            example = self.data[i]
            options = []
            for key in example.keys():
                if key.startswith("op"):
                    options.append(example[key])

            correct_option_id = example["cop"]
            correct_option_txt =example["op"+mapper_ans_inverted[correct_option_id]]
            explanation = ''
            if example["exp"] != None:
                explanation  = example["exp"]
            # correct_option_txt += '\n' + 'explanation: ' + explanation
            # correct_option_id = int(example["answer"].split("option ")[1][0])  - 1
           

            if self.shuffle:
                random.shuffle(options)
                correct_option_id = options.index(correct_option_txt) +1
            
                options_with_header = [option_header[i] +options[i] for i in range(len(options)) ]
                options_with_header = "\n".join(options_with_header)
                # correct_option_txt = example["answer"].split(": ")[1]
                # correct_option_txt = example["answer"][10:]

                correct_option_txt_header = mapper_ans_inverted[correct_option_id].upper() +" " + correct_option_txt

                prompt = prompt_without_contex_train.substitute(question = example["question"],options =options_with_header)

            else:
                correct_option_id = correct_option_id
                correct_option_txt_header = mapper_ans_inverted[correct_option_id].upper() +" " + correct_option_txt
                options_with_header = [option_header[i] +options[i] for i in range(len(options)) ]
                options_with_header = "\n".join(options_with_header)
                prompt = prompt_without_contex_train.substitute(question = example["question"] , options =options_with_header)
            ## context TBD
            # answer =  f"{correct_option_txt_header}  \nExplanation: {example['explanation']}"
            if not self.val:
                answer =  f"{correct_option_txt_header}  \nExplanation: {explanation}"
            else:
                answer =  f"{correct_option_txt_header}"



            # print(prompt, answer)
            batch["question_context"] += [prompt]

            batch["answer"] += [answer]

        self.tokenizer.padding_side = "left"
        q_tokens = self.tokenizer(batch["question_context"], padding="longest", return_tensors="pt")  
        self.tokenizer.padding_side = "right"
        a_tokens = self.tokenizer(batch["answer"], padding="longest", return_tensors="pt")
        tokens = torch.cat([q_tokens["input_ids"], a_tokens["input_ids"]], dim=1)
        attn_masks = torch.cat([q_tokens["attention_mask"], a_tokens["attention_mask"]], dim=1)
        loss_mask = torch.cat([torch.zeros_like(q_tokens["attention_mask"]), a_tokens["attention_mask"]], dim=1)[:,1:]
   
        result = {
        "inp_ids":tokens[:,:-1],
        "inp_mask":attn_masks[:,:-1],## Causal Training
        "out_ids":tokens[:,1:], ## Causal Labels
        "out_mask":attn_masks[:,1:],
        "q_tokens": q_tokens,
        "a_tokens": a_tokens,}
        result["loss_mask"] = loss_mask * result["out_mask"]
        # result["out_ids"][:,:q_tokens["input_ids"].size(1)-10] = self.tokenizer.eos_token_id

        return result       


            

    def __iter__(self):
        self.idx = 0
        return self

    def __next__(self):
        if self.idx >= self.n_data_points:
            self.idx = 0
            raise StopIteration
        temp_idx = self.indices[self.idx]
        self.idx += 1
        return self[temp_idx]
             








            




    
    def __len__(self):
        return self.n_data_points
    



        

In [6]:
train_data = MyLLMDataloader(4, tokenizer, "train.json", val=False)
for item in train_data:
    # print(item)
    break
# valLoader = MyLLMDataloader(1, tokenizer, "questions_365_val.json", val=True, shuffle=False)
#     # if int(tokenizer.decode(item["a_tokens"].input_ids[:,0], skip_special_tokens=True)) ==0:
#     print(tokenizer.decode(item['inp_ids'][0])),
#     print(tokenizer.decode(item["a_tokens"].input_ids[:,0], skip_special_tokens=True))
#     break

datapoints 7500 30000


In [7]:
train_data = MyLLMDataloader(4, tokenizer, "dev.json", shuffle=False)
for item in train_data:
    # print(item)
    break

datapoints 1046 4183


In [8]:
def forward_pass(model, batch):
    inp_ids = batch["inp_ids"].to(model.device)
    attn_mask = batch["inp_mask"].to(model.device)
    result = model(input_ids=inp_ids, attention_mask=attn_mask)
    logits = result.logits
    return logits


def calc_loss(loss_fn, logits, batch):
    B, L, C = logits.shape
    target = batch["out_ids"].to(logits.device)
    mask = batch["loss_mask"].to(logits.device)
    loss = loss_fn(logits.reshape(-1, C), target.reshape(-1)) * mask.reshape(-1)
    loss = loss.sum()/mask.sum()
    return loss

def update(model, optimizer, loss_fn, batch, accumulate_grad=True):
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        logits = forward_pass(model, batch)
        loss = calc_loss(loss_fn, logits, batch)
    
    scaler.scale(loss).backward()
    if not accumulate_grad:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
    return loss.item(), logits




def train_on_batches(model, train_data_loader, optimizer, loss_fn, grad_accum_bs=16):
    train_loss = 0
    score = 0
    num_correct = 0
    total_num = 0
    mapper_ans = {"A":1, "B":2, "C":3, "D":4, "E":5}

    pbar = tqdm(range(len(train_data_loader)))
    option_ids = [tokenizer(o).input_ids[0] for o in ["A", "B", "C", "D", "E"]]
    model.train()
    for i,batch in enumerate(train_data_loader):
        grad_step_bs = grad_accum_bs/train_data_loader.batch_size
        if (i>=grad_step_bs and i%grad_step_bs == 0) or i == len(train_data_loader)-1:
            accumulate_grad = False
        else:
            accumulate_grad = True
        loss, logits = update(model, optimizer, loss_fn, batch, accumulate_grad=accumulate_grad)
        train_loss += (1/(i+1))*(loss-train_loss)
        # if i% 50 == 0:
        #     print(logits.shape, logits[:,batch["q_tokens"].input_ids.size(1)-1,option_ids].shape, logits[ :,batch["q_tokens"].input_ids.size(1)-1:  ].shape, train_data_loader.tokenizer.batch_decode(logits[ :,batch["q_tokens"].input_ids.size(1)-1:  ].argmax(dim=2)))
        pred = (logits[:,batch["q_tokens"].input_ids.size(1)-1,option_ids].argmax(dim=1) + 1).tolist()
        target = train_data_loader.tokenizer.batch_decode(batch["a_tokens"].input_ids[:,0], skip_special_tokens=True)
        for p,t in zip(pred, target):
            total_num += 1
            # print('p', p, 't',t)
            t = mapper_ans[t]

            num_correct += (int(p) == int(t))
            score = num_correct/total_num
        

        pbar.set_description(f"Train Loss: {train_loss:.4f} Score: {score:.4f}")
        pbar.update(1)
    pbar.close()

    return train_loss, score



In [9]:
def train(model, tdataloader,vdataloader, optimizer, loss_fn, scheduler, epochs=3, log_dir="logs/medmcq_phi"):
    os.makedirs(log_dir, exist_ok=True)
    with open(f"{log_dir}/train.txt", "w") as tf, open(f"{log_dir}/val.txt", "w") as vf:
        tf.write(f"Epoch,Loss,Score\n")
        vf.write("Epoch,Loss,Score\n")
    best_val_score = 0
    
    for epoch in range(1,1+epochs):
        train_loss, tscore = train_on_batches(model, tdataloader, optimizer, loss_fn, grad_accum_bs=16)
        scheduler.step()
        val_loss, score = validation(model, vdataloader, loss_fn)
        with open(f"{log_dir}/train.txt", "a+") as tf, open(f"{log_dir}/val.txt", "a+") as vf:
            tf.write(f"{epoch},{train_loss},{tscore}\n")
            vf.write(f"{epoch},{val_loss},{score}\n")
        lr = optimizer.param_groups[0]["lr"]
        tqdm.write(f"Epoch: {epoch} | LR: {lr:.7f} | Train Loss: {train_loss:.4f} | Train Score: {tscore:.4f} | Val Loss: {val_loss:.4f} | Val Score: {score:.4f}")
        
        if score > best_val_score:
            model.save_pretrained(log_dir+"/model")
            best_val_score = score

        # update the learning rate
        # for param_group in optimizer.param_groups:
        #     param_group['lr'] = 1.0*float(param_group['lr'])

        train_df = pd.read_csv(f"{log_dir}/train.txt")
        val_df = pd.read_csv(f"{log_dir}/val.txt")

        fig, ax = plt.subplots(1, 2, figsize=(20,8))
        ax[0].plot(range(1,len(train_df)+1), train_df["Loss"], label="Train")
        ax[0].plot(range(1,len(val_df)+1), val_df["Loss"], label="Val")

        ax[1].plot(range(1,len(train_df)+1), train_df["Score"], label="Train")
        ax[1].plot(range(1,len(val_df)+1), val_df["Score"], label="Val")

        ax[0].set_xlabel("Epochs")
        ax[1].set_xlabel("Epochs")

        ax[0].set_ylabel("Loss")
        ax[1].set_ylabel("Score")

        ax[0].legend()
        ax[1].legend()

        fig.suptitle('', fontsize=16)
        fig.savefig("plt.png")


def validation(model, val_data_loader, loss_fn):
    val_loss = 0
    mapper_ans = {"A":1, "B":2, "C":3, "D":4, "E":5}

    score = 0
    num_correct = 0
    pbar = tqdm(range(len(val_data_loader)))
    option_ids = [tokenizer(o).input_ids[0] for o in ["A", "B", "C", "D", "E"]]
    total_num = 0
    model.eval()
    with torch.inference_mode():
        for i, batch in enumerate(val_data_loader):
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                logits = forward_pass(model, batch)
                loss = calc_loss(loss_fn, logits, batch).item()
            val_loss += (1/(i+1))*(loss - val_loss)
            pred = (logits[:,batch["q_tokens"].input_ids.size(1)-1,option_ids].argmax(dim=1) + 1).tolist()
            target = val_data_loader.tokenizer.batch_decode(batch["a_tokens"].input_ids[:,0], skip_special_tokens=True)
            for p,t in zip(pred, target):
             
                total_num += 1
                t = mapper_ans[t]
                num_correct += (int(p) == int(t))
                score = num_correct/total_num
                
            pbar.set_description(f"Val Loss: {val_loss:.4f} Score: {score:.4f}")
            pbar.update(1)
            
        gen_tokens = model.generate(**batch["q_tokens"].to(model.device), max_new_tokens=10)[:,batch["q_tokens"].input_ids.size(1):]
        gen_txt = val_data_loader.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
        target_txt = val_data_loader.tokenizer.batch_decode(batch["a_tokens"].input_ids, skip_special_tokens=True)
        tqdm.write("Target: " + "\n" + "\n".join(target_txt) + "\nGenerated: \n " + "\n".join(gen_txt) +"\n")
        pbar.close()

    return val_loss, score


In [10]:
# device = "cuda"
# torch.set_default_device(device)
# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
# model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, trust_remote_code=True, device_map ="auto")
epochs = 10
lr = 1e-4
peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=64,  # reduce if running into out-of-memory issues
    lora_alpha=16,
    target_modules=['q_proj', 'k_proj', 'v_proj', 'dense'],
    # modules_to_save=["lm_head"],
    lora_dropout=0.05,
)
model = get_peft_model(model, peft_config)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.eos_token_id, reduction='none').cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay= 0.001)
optimizer.zero_grad()


# trainLoader = TrainDataLoader(2, tokenizer, topk=topk)
trainLoader = MyLLMDataloader(2, tokenizer, "train.json", val=False, shuffle=True)

valLoader = MyLLMDataloader(1, tokenizer, "dev.json", val=True)
scaler =  torch.cuda.amp.GradScaler(enabled=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
            T_max = epochs*2, eta_min=1E-8)
train(model, trainLoader, valLoader,optimizer, loss_fn, scheduler, epochs=10)



datapoints 15000 30000
datapoints 4183 4183


  scaler =  torch.cuda.amp.GradScaler(enabled=True)


  0%|          | 0/15000 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (2055 > 2048). Running this sequence through the model will result in indexing errors


OutOfMemoryError: CUDA out of memory. Tried to allocate 168.00 MiB. GPU 0 has a total capacity of 47.71 GiB of which 118.56 MiB is free. Including non-PyTorch memory, this process has 42.59 GiB memory in use. Of the allocated memory 40.24 GiB is allocated by PyTorch, and 1.84 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
torch.cuda.empty_cache()
import gc
gc.collect()

In [13]:
# train_loader = torch.utils.data.DataLoader(train_data, num_workers=2, batch_size =1,)
