In [1]:
# %% Set script for interactive development and import modules
from RewardingVisualDoubt import infrastructure, training

infrastructure.make_ipython_reactive_to_changing_codebase()
infrastructure.supress_known_warnings()

import pathlib as path
import typing as t
import torch
import numpy as np
import time
import os
from torch.utils.data import DataLoader
import accelerate
import dataclasses
import transformers
import wandb

from trl import PPOConfig, PPOTrainer

from RewardingVisualDoubt import dataset, prompter, shared, vllm, response, reward
from RewardingVisualDoubt import training as training

os.environ["WANDB_API_KEY"] = "da3cb086bbc110c16cbc5ba4c284a19b0b461710"

from LLAVA_Biovil.llava.mm_utils import KeywordsStoppingCriteria

STOP_STR = prompter.Seperator.END_OF_SEQUENCE_SEPERATOR.value

Fetching 69 files:   0%|          | 0/69 [00:00<?, ?it/s]

In [2]:
device_str = (
    shared.torch_devices.cuda.value if torch.cuda.is_available() else shared.torch_devices.cpu.value
)
device = torch.device(device_str)

model = vllm.load_pretrained_llava_model_for_sft_training(device_str=device_str, precision="4bit")

tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer.padding_side = "left"
padding_tokenizer.pad_token = padding_tokenizer.bos_token

Loading model in non-trainable mode...
Precision: 4bit quantized
Model base:  liuhaotian/llava-v1.5-7b
Loading LLaVA from base model...


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading additional LLaVA weights...
Using downloaded and verified file: /tmp/biovil_t_image_model_proj_size_128.pt
Loaded additional vision tower weights...
Adding pretrained RaDialog LoRA adapters to the model...


In [3]:
######################################## 2. Load the datasets and the dataloaders ########################################

DEFAULT_BATCH_SIZE = 8
DEFAULT_OUTPUT_DIR = path.Path("output")


print("Loading the datasets and the dataloaders...")
dataset_train = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset_for_sft(
    split=dataset.DatasetSplit.TRAIN,
    tokenizer=tokenizer,
    prompter=prompter.build_binary_qa_prompt_with_response_and_confidence_for_sft,
)
dataset_eval = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset_for_sft(
    split=dataset.DatasetSplit.VALIDATION,
    tokenizer=tokenizer,
    prompter=prompter.build_binary_qa_prompt_with_response_and_confidence_for_sft,
)

padding_tokenizer.pad_token = padding_tokenizer.bos_token  # TODO how about this?

dataloader_train = dataset.get_mimic_cxr_llava_model_input_dataloader_for_sft(
    dataset=dataset_train,
    batch_size=DEFAULT_BATCH_SIZE,
    padding_tokenizer=padding_tokenizer,
    num_workers=8,  # Let Torch decide.
    simplified_batch=False,
)

dataloader_eval = dataset.get_mimic_cxr_llava_model_input_dataloader_for_sft(
    dataset=dataset_eval,
    batch_size=DEFAULT_BATCH_SIZE * 2,
    padding_tokenizer=padding_tokenizer,
    num_workers=8,  # Let Torch decide.
    simplified_batch=False,
)

eval_batch_iterator = iter(dataloader_eval)
# batch = next(eval_batch_iterator)
# for k, v in batch.items():
#     batch[k] = v.to(device)

Loading the datasets and the dataloaders...
Loading mimic_cxr_df from cache
Loading balanced_binary_qa_mimic_cxr_df from cache
Loading mimic_cxr_df from cache
Loading balanced_binary_qa_mimic_cxr_df from cache


In [8]:
batch: dataset.MimicCxrLlavaModelInputBatchDictForSFT = next(eval_batch_iterator)
batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
input_ids, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)
labels = batch["batch_expected_output_ids"].to(device)
attention_mask = batch["batch_attention_mask"].to(device)

In [9]:
resp = model(input_ids=input_ids, images=images, attention_mask=attention_mask, labels=labels)

In [6]:
# ---- Training & Checkpointing ----
# A BATCH OF 8 SAMPLES TAKES 10sec to take a training step
NUM_EPOCHS = 1
CHECKPOINT_DIR = "training_checkpoints"
GRAD_ACCUM_STEPS = 4  # TODO select value
LOGGING_STEPS = 1
STEPS_UNTIL_CHECKPOINT = 1000
NUM_BATCHES_TO_EVALUATE = 15
LR = 5e-5

# ---- Optimizer & Scheduler ----
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
lr_scheduler = transformers.get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=len(dataloader_train) * NUM_EPOCHS,
)

from tqdm import tqdm

model.train()
best_val_loss = float("inf")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
wandb.init(project="radialog-confidence-score-sft")


for epoch in range(NUM_EPOCHS):
    loop = tqdm(dataloader_train, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    for step, batch in tqdm(enumerate(loop)):
        batch: dataset.MimicCxrLlavaModelInputBatchDictForSFT = batch
        batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
        batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
            batch_llava_model_input_dict, device
        )
        input_ids, images = (
            batch_llava_model_input_dict["text_prompt_input_ids"],
            batch_llava_model_input_dict["images"],
        )
        labels = batch["batch_expected_output_ids"].to(device)
        attention_mask = batch["batch_attention_mask"].to(device)

        torch.all(
            (input_ids == labels) | (labels == -100)
        ).item()  # Verify that labels only differ when -100 token is present

        # ---- Forward & Backward Pass ----
        outputs = model(
            input_ids=input_ids, images=images, attention_mask=attention_mask, labels=labels
        )
        loss = outputs.loss / GRAD_ACCUM_STEPS
        loss.backward()

        if (step + 1) % GRAD_ACCUM_STEPS == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        # ---- Logging ----
        if step % LOGGING_STEPS == 0:
            wandb.log({"train_loss": loss.item() * GRAD_ACCUM_STEPS})

        # ---- Validation & Checkpointing ----
        if (step + 1) % STEPS_UNTIL_CHECKPOINT == 0:
            print(f"Arrived at checkpoint {step + 1}.")
            model.eval()
            val_losses = []
            for _ in range(NUM_BATCHES_TO_EVALUATE):
                try:
                    eval_batch = next(eval_batch_iterator)
                except StopIteration:
                    eval_batch_iterator = iter(dataloader_eval)
                    eval_batch = next(eval_batch_iterator)

                with torch.no_grad():
                    val_outputs = model(**eval_batch)
                    val_loss = val_outputs.loss.item()
                    val_losses.append(val_loss)

            avg_val_loss = sum(val_losses) / len(val_losses)
            wandb.log({"val_loss": avg_val_loss})

            # ---- Checkpoint Saving ----
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                checkpoint_path = os.path.join(
                    CHECKPOINT_DIR, f"best_model_epoch{epoch}_step{step}.pth"
                )
                torch.save(model.state_dict(), checkpoint_path)
                print(f"ðŸ”¥ Saved best model at {checkpoint_path}")

            model.train()  # Resume training

wandb.finish()
# torch.save(model.state_dict(), "llava_lora_final.pth")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33monurdenizguler[0m ([33monurdenizguler-technical-university-of-munich[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


27it [04:09,  9.23s/it]   | 27/135572 [04:00<333:25:27,  8.86s/it]
Epoch 1/1:   0%|          | 27/135572 [04:09<347:25:00,  9.23s/it]


KeyboardInterrupt: 

# Archived

In [3]:
from RewardingVisualDoubt import inference

tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer.padding_side = "left"
padding_tokenizer.pad_token_id = padding_tokenizer.bos_token_id
dataset_test = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.TEST,
    tokenizer=tokenizer,
    prompter=prompter.build_binary_qa_instruction_from_disease_under_study,
)
dataloader_test = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset=dataset_test, batch_size=1, padding_tokenizer=padding_tokenizer, num_workers=8
)

for idx, batch in enumerate(dataloader_test):
    batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
    batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
    batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
        batch_llava_model_input_dict, torch.device(shared.torch_devices.cuda.value)
    )
    input_ids, images = (
        batch_llava_model_input_dict["text_prompt_input_ids"],
        batch_llava_model_input_dict["images"],
    )
    stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)
    pred = inference.generate_radialog_answer_for_binary_qa_for_single_study(
        model, tokenizer, input_ids, images, stopping_criteria
    )
    confidence_request_prompt = (
        batch["batch_prompts"][0]
        + " "
        + pred
        + " "
        + prompter.build_post_generation_user_confidence_request()  # POST_GENERATION_CONFIDENCE_REQUEST_8
    )
    confidence_request_input_ids = torch.unsqueeze(
        torch.IntTensor(tokenizer(confidence_request_prompt)["input_ids"]), 0
    ).to(device)
    stopping_criteria = KeywordsStoppingCriteria(
        [STOP_STR], tokenizer, confidence_request_input_ids
    )
    pred_with_confidence = inference.generate_radialog_answer_for_binary_qa_for_single_study(
        model, tokenizer, confidence_request_input_ids, images, stopping_criteria
    )
    print(f"\n Metadata: {batch['batch_mimic_cxr_datapoint_metadata']}")
    print(f"Prompt: {batch['batch_prompts']}")
    print(f"Label:", batch["batch_labels"])
    print(f"File_idx {idx}, ASSISTANT: ", pred)
    print(f"File_idx {idx}, ASSISTANT (after confidence request): ", pred_with_confidence)
    if idx == 5:
        break

Loading mimic_cxr_df from cache
Loading balanced_binary_qa_mimic_cxr_df from cache

 Metadata: [MimicCxrBinaryQADatapoint(subject_id=18460230, study_id=53631792, img_path='/home/data/DIVA/mimic/mimic-cxr-jpg/2.0.0/files/p18/p18460230/s53631792/369dc5bd-70bd89d0-2d90fa80-f319ec1d-fb2802aa.jpg', disease=<ChexpertFinding.PLEURAL_EFFUSION: 'Pleural Effusion'>, label=<ChexpertLabel.POSITIVE: 1.0>)]
Prompt: ["A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. The assistant gives professional, detailed, and polite answers to the user's questions. USER: <image>. You are to act as a radiologist and answer the following question: Is the following disease visible in the given X-ray image: Pleural Effusion?  ASSISTANT:"]
Label: tensor([1.])
File_idx 0, ASSISTANT:  Yes, the patient has pleural effusion.
File_idx 0, ASSISTANT (after confidence request):  {"confidence": 9}

 Metadata: [MimicCxrBinaryQADatapoint(subject_id=13263843, study_id=52

# Archived Experiments

In [None]:
my_prompt = "Hello, how are you?"
labels = "Say, lost in London?"
input_ids = tokenizer(my_prompt, return_tensors="pt").input_ids.to(device)
labels = tokenizer(labels, return_tensors="pt").input_ids.to(device)
resp = model(input_ids=input_ids, labels=labels)

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
from functools import partial


# ---- Config ----
DEFAULT_OUTPUT_DIR = path.Path("output")
BATCH_SIZE = 8
GRAD_ACCUM_STEPS = 4
LR = 5e-5
NUM_EPOCHS = 3
LOGGING_STEPS = 1

# ---- Init wandb ----
wandb.init(project="radialog-confidence-score-sft")

model.print_trainable_parameters()

# ---- Training args ----
training_args = transformers.TrainingArguments(
    output_dir=DEFAULT_OUTPUT_DIR,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=LOGGING_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LR,
    weight_decay=0.01,
    # fp16=True,
    save_total_limit=2,
    report_to="wandb",
)

collate_with_tokenizer = partial(
    dataset.prompted_mimic_cxr_llava_model_input_collate_fn_for_sft_simplified,
    padding_tokenizer=padding_tokenizer,
)

from transformers import Trainer

from transformers import Trainer


class MyTrainer(Trainer):
    def __init__(self, *args, my_train_dataloader=None, my_eval_dataloader=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.my_train_dataloader = my_train_dataloader
        self.my_eval_dataloader = my_eval_dataloader

    def get_train_dataloader(self):
        if self.my_train_dataloader is not None:
            return self.my_train_dataloader
        return super().get_train_dataloader()

    def get_eval_dataloader(self, eval_dataset=None):
        if self.my_eval_dataloader is not None:
            return self.my_eval_dataloader
        return super().get_eval_dataloader(eval_dataset)


# ---- Trainer ----
trainer = MyTrainer(
    model=model,
    args=training_args,
    my_train_dataloader=dataloader_train,
    my_eval_dataloader=dataloader_eval,
)

# ---- Train ----
trainer.train()

# ---- Save final model ----
trainer.save_model(DEFAULT_OUTPUT_DIR)
wandb.finish()

trainable params: 319,815,680 || all params: 3,866,461,096 || trainable%: 8.271534927142067




Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
resp.loss

tensor(9.5176, device='cuda:0', grad_fn=<NllLossBackward0>)