In [1]:
# TODO sampling logic (undersample)
# TODO binary qa prompt and co need to vary between a few samples
# TODO determine hyperparams
# TODO arg parsing
# TODO dataloader num workers set to default

# SOME NOTES
# PPO_TRAINER AUTOMATICALLY PADS THE INPUTS BY TOKENIZER.PADDING_SIDE AND TOKENIZER.PADDING_TOKEN_ID
# Uh-oh, because ppo termination token is set as the eos_seq_token, it'll stop when it sees a left padded sequence
# Skipping random exploration for now


# BATCH TIMING
# A batch of 8 samples take around 1-1.5-2-3min to process in a train step (so around 400 samples per hour is trainable, every 50th batch, we save a checkpoint, and do val)
# Lets save a checkpoint every half an hour or so
# Give validation around 15 mins => 100 samples or so
# Validation is around 8k so it'll be 1000 batches (1000*1.5 min = 25 hours)
# len(dataset_eval) = 8737

In [1]:
# %% Set script for interactive development and import modules
from RewardingVisualDoubt import infrastructure, training

infrastructure.make_ipython_reactive_to_changing_codebase()
infrastructure.supress_known_warnings()

import pathlib as path
import typing as t
import torch
import numpy as np
import time
import os
from torch.utils.data import DataLoader
import accelerate
import dataclasses

# from LLAVA_Biovil.llava.mm_utils import KeywordsStoppingCriteria
from trl import PPOConfig, PPOTrainer

from RewardingVisualDoubt import dataset, prompter, shared, vllm, response, reward
from RewardingVisualDoubt import training as training

os.environ["WANDB_API_KEY"] = "da3cb086bbc110c16cbc5ba4c284a19b0b461710"

from LLAVA_Biovil.llava.mm_utils import KeywordsStoppingCriteria

STOP_STR = prompter.Seperator.END_OF_SEQUENCE_SEPERATOR.value

Fetching 69 files:   0%|          | 0/69 [00:00<?, ?it/s]

In [2]:
######################################## 0. Define the environment ########################################

DEFAULT_BATCH_SIZE = 8
DEFAULT_OUTPUT_DIR = path.Path("output")

device_str = (
    shared.torch_devices.cuda.value if torch.cuda.is_available() else shared.torch_devices.cpu.value
)
device = torch.device(device_str)

######################################## 1. Load the model and tokenizer ########################################

model = vllm.load_pretrained_llava_model_for_ppo_training(device_str=device_str, precision="4bit")
# model_ref = vllm.load_pretrained_llava_model_for_ppo_training(device_str=device_str)

tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer.padding_side = "left"  # Why? Because: A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


######################################## 2. Load the datasets and the dataloaders ########################################

print("Loading the datasets and the dataloaders...")
dataset_train = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.TRAIN,
    tokenizer=tokenizer,
    prompter=prompter.build_binary_qa_instruction_from_disease_under_study,
)
dataset_eval = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.VALIDATION,
    tokenizer=tokenizer,
    prompter=prompter.build_binary_qa_instruction_from_disease_under_study,
)

padding_tokenizer.pad_token = padding_tokenizer.bos_token  # TODO how about this?

dataloader_train = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset=dataset_train,
    batch_size=DEFAULT_BATCH_SIZE,
    padding_tokenizer=padding_tokenizer,
    num_workers=8,  # Let Torch decide.
)

dataloader_eval = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset=dataset_eval,
    batch_size=2 * DEFAULT_BATCH_SIZE,
    padding_tokenizer=padding_tokenizer,
    num_workers=8,  # Let Torch decide.
)

import sys

sys.path.append("..")  # Adds higher directory to python modules path.
from workflows import radialog_binary_qa_ppo_training

Loading model in non-trainable mode...
Precision: 4bit quantized
Model base:  liuhaotian/llava-v1.5-7b
Loading LLaVA from base model...


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading additional LLaVA weights...
Downloading https://cdn-lfs.hf.co/repos/63/2f/632fbb459426d5c3e8e64aa8be737ccf0c8ba541980f23a79ecf1ab6e87df8b4/b2399d73dc2a68b9f3a1950e864ae0ecd24093fb07aa459d7e65807ebdc0fb77?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27biovil_t_image_model_proj_size_128.pt%3B+filename%3D%22biovil_t_image_model_proj_size_128.pt%22%3B&Expires=1741272658&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0MTI3MjY1OH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy82My8yZi82MzJmYmI0NTk0MjZkNWMzZThlNjRhYThiZTczN2NjZjBjOGJhNTQxOTgwZjIzYTc5ZWNmMWFiNmU4N2RmOGI0L2IyMzk5ZDczZGMyYTY4YjlmM2ExOTUwZTg2NGFlMGVjZDI0MDkzZmIwN2FhNDU5ZDdlNjU4MDdlYmRjMGZiNzc%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=T6dSMsVr3C7-JJdUgqHmxYJ0rSAe%7ESyk30q2W8Wrvjmv1Kl7U2RF7vUfZorPhZFiL7IxGHkSVlzMY5U5CL-1P8ohATGgmB%7E1Ku4FRqNt9PuBnlCUYEm19UjBWUoWp7AEPh82EYx-rircYIaMoD8lrelPR%7EIx1xjQkgeUuFs1IXRjw910qLcNBNCuqGVmzLhWY6tsJynC4q

100%|██████████| 109745561/109745561 [00:00<00:00, 220940139.85it/s]


Loaded additional vision tower weights...
Adding pretrained RaDialog LoRA adapters and value head to the model...




Loading the datasets and the dataloaders...
Loading mimic_cxr_df from cache
Loading balanced_binary_qa_mimic_cxr_df from cache
Loading mimic_cxr_df from cache
Loading balanced_binary_qa_mimic_cxr_df from cache
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
######################################## 3. Define the PPO and generation configurations ########################################
epochs = 1
lr = 5e-6
log_with = "foo"
out_dir = "output"

ppo_config = PPOConfig(
    learning_rate=lr,
    task_name="gpt",
    batch_size=DEFAULT_BATCH_SIZE,
    mini_batch_size=int(DEFAULT_BATCH_SIZE / 4),
    log_with="wandb",
    project_kwargs=dataclasses.asdict(
        accelerate.utils.ProjectConfiguration(
            project_dir="radialog_binary_qa_ppo_training", logging_dir="logs"
        )
    ),
    remove_unused_columns=False,
    # optimize_device_cache=True,
    init_kl_coef=0.05,
)

generation_kwargs_ppo = {
    "min_length": -1,  # don't ignore the EOS token (see above)
    "top_k": 0.0,  # no top-k sampling
    "top_p": 1.0,  # no nucleus sampling
    "temperature": 0.5,  # DONT BE CREATIVE
    "do_sample": True,  # yes, we want to sample
    "pad_token_id": tokenizer.pad_token_id,  # most decoder models don't have a padding token - use EOS token instead (for this tokenizer it was already set to eos_token_id)
    "max_new_tokens": 50,  # let's not be chatty, we need a few tokens to generate confidence but also not limit the response too much
    "eos_token_id": tokenizer.eos_token_id,  # (instead of ppo_terminators list)
}

ppo_trainer = t.cast(
    PPOTrainer,
    PPOTrainer(
        model=model,
        config=ppo_config,
        tokenizer=tokenizer,
    ),
)

# not sure if needed but just to be safe for now
tokenizer.padding_side = "left"
model.config.padding_side = "left"
model.config.tokenizer_padding_side = "left"
# model.pad_token_id = tokenizer.eos_token_id

fatal: No names found, cannot describe anything.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33monurdenizguler[0m ([33monurdenizguler-technical-university-of-munich[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




In [20]:
rewards_epoch = []
iterator_train = iter(dataloader_train)
batch = next(iterator_train)

In [9]:
for i in range(1):
    print(i)
    batch = next(iterator_train)
    rewards, batch_report = radialog_binary_qa_ppo_training.radialog_binary_qa_ppo_training_step(
        model,
        device,
        tokenizer,
        generation_kwargs_ppo,
        ppo_trainer,
        batch,
    )

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
0




In [22]:
######### 5.1 Unpack the batch #########
batch: dataset.MimicCxrLlavaModelInputBatchDict = batch
batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
input_ids, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)
attention_mask = batch["batch_attention_mask"].to(device)
labels = t.cast(torch.Tensor, batch["batch_labels"]).to(device)
stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)

######### 5.2 Generate the binary q&a answer and remove trailing padding tokens #########
model.eval()
prompt_and_generated_answers_ids = model.generate(
    input_ids=input_ids,
    images=images,
    attention_mask=attention_mask,
    do_sample=False,
    use_cache=True,
    max_new_tokens=300,
    stopping_criteria=[stopping_criteria],
    pad_token_id=tokenizer.pad_token_id,
)

prompt_and_generated_answers_ids = training.remove_trailing_padding_from_prediction(
    prompt_and_generated_answers_ids, tokenizer.pad_token_id
)
######### 5.3 Append confidence request to the generated answers #########
prompt_and_generated_answers_with_confidence_requests_ids = []
for item in prompt_and_generated_answers_ids:
    confidence_request_input_ids = (
        tokenizer(
            prompter.build_post_generation_user_confidence_request(),
            return_tensors="pt",
        )
        .input_ids.to(device)
        .squeeze(0)
    )[
        1:
    ]  # drop start of sequence token
    prompt_and_generated_answers_with_confidence_requests_ids.append(
        torch.cat((item, confidence_request_input_ids), 0)
    )

model.train()
generated_confidences_ids = ppo_trainer.generate(
    prompt_and_generated_answers_with_confidence_requests_ids,  # ppo_trainer.generate() method admits list of tensors, handles padding and batching itself
    images=images,
    use_cache=True,
    return_prompt=False,
    **generation_kwargs_ppo,
)

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


In [28]:
######### 5.1 Unpack the batch #########
batch: dataset.MimicCxrLlavaModelInputBatchDict = batch
batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
input_ids, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)
attention_mask = batch["batch_attention_mask"].to(device)
labels = t.cast(torch.Tensor, batch["batch_labels"]).to(device)
stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)

######### 5.2 Generate the binary q&a answer and remove trailing padding tokens #########
model.eval()
prompt_and_generated_answers_ids = model.generate(
    input_ids=input_ids,
    images=images,
    attention_mask=attention_mask,
    do_sample=False,
    use_cache=True,
    max_new_tokens=300,
    stopping_criteria=[stopping_criteria],
    pad_token_id=tokenizer.pad_token_id,
)

prompt_and_generated_answers_ids = training.remove_trailing_padding_from_prediction(
    prompt_and_generated_answers_ids, tokenizer.pad_token_id
)
######### 5.3 Append confidence request to the generated answers #########
prompt_and_generated_answers_with_confidence_requests_ids = []
for item in prompt_and_generated_answers_ids:
    confidence_request_input_ids = (
        tokenizer(
            prompter.build_post_generation_user_confidence_request(),
            return_tensors="pt",
        )
        .input_ids.to(device)
        .squeeze(0)
    )[
        1:
    ]  # drop start of sequence token
    prompt_and_generated_answers_with_confidence_requests_ids.append(
        torch.cat((item, confidence_request_input_ids), 0)
    )

model.train()
model.gradient_checkpointing_disable
generated_confidences_ids = ppo_trainer.generate(
    prompt_and_generated_answers_with_confidence_requests_ids,  # ppo_trainer.generate() method admits list of tensors, handles padding and batching itself
    images=images,
    use_cache=True,
    return_prompt=False,
    **generation_kwargs_ppo,
)



In [31]:
######### 5.5 Arrange all generations into useful variables #########
complete_conversation_ids = [
    torch.cat((p, c), 0)
    for p, c in zip(
        prompt_and_generated_answers_with_confidence_requests_ids,
        generated_confidences_ids,
    )
]
generated_answer_only_ids = [
    prompt_and_generated_answers_ids[i][len(input_ids[i]) :] for i in range(len(input_ids))
]
prompt_and_generated_answers_with_confidence_requests_ids = (
    training.replace_image_token_with_another_token_for_list_of_tensors(
        prompt_and_generated_answers_with_confidence_requests_ids
    )
)
generated_answers_texts = tokenizer.batch_decode(
    generated_answer_only_ids,
    skip_special_tokens=True,
)
generated_confidences_texts = tokenizer.batch_decode(
    generated_confidences_ids,
    skip_special_tokens=True,
)

######### 5.6 Parse the responses and compute the rewards #########
generated_answer_labels = response.parse_binary_labels(generated_answers_texts)
generated_confidence_values = response.parse_confidences(generated_confidences_texts)

rewards = [
    reward.generated_answer_and_confidence_to_reward(
        generated_answer_label, generated_confidence_value, ground_truth_label
    )
    for generated_answer_label, generated_confidence_value, ground_truth_label in zip(
        generated_answer_labels, generated_confidence_values, labels.bool().tolist()
    )
]

rewards = t.cast(
    list[torch.FloatTensor],
    [torch.tensor(r).to(device) for r in rewards],
)

In [32]:
stats = ppo_trainer.step(
    queries=t.cast(
        list[torch.LongTensor],
        prompt_and_generated_answers_with_confidence_requests_ids,
    ),
    responses=t.cast(list[torch.LongTensor], generated_answer_only_ids),
    scores=rewards,
)



In [9]:
batch = prompt_and_generated_answers_with_confidence_requests_ids[0:8]
batch_mask = [torch.ones_like(element) for element in batch]
inputs = {"input_ids": batch, "attention_mask": batch_mask}
padded_inputs = tokenizer.pad(
    inputs,
    padding=True,
    max_length=None,
    pad_to_multiple_of=None,
    return_tensors="pt",
).to("cuda")
padded_inputs["input_ids"][padded_inputs["input_ids"] == 2] = 1

In [12]:
generation = model.generate(
    **padded_inputs,
    images=images,
    do_sample=False,
    use_cache=True,
    max_new_tokens=50,
    stopping_criteria=[stopping_criteria],
    pad_token_id=tokenizer.pad_token_id,
)

In [None]:
generation[generation == -200] = 0

In [18]:
tokenizer.batch_decode(generation)

['<s><s> A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. The assistant gives professional, detailed, and polite answers to the user\'s questions. USER: <unk> . You are to act as a radiologist and answer the following question: Is the following disease visible in the given X-ray image: Atelectasis?  ASSISTANT: No, the patient does not have atelectasis. USER: Now evaluate your own response. How confident are you in your answer? Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. A value close to 0 means you think there is a high probability that the answer is wrong. The closer the value is to 10, the higher you think is the probability that the answer is correct.  ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy.The confidence JSON follows this structure: {\'confidence\': int}.Here\'s my confidence J

In [10]:
padded_inputs

{'input_ids': tensor([[    1,     1,   319,  ...,  2933, 29901,   259],
        [    1,     1,     1,  ...,  2933, 29901,   259],
        [    1,     1,     1,  ...,  2933, 29901,   259],
        ...,
        [    1,     1,     1,  ...,  2933, 29901,   259],
        [    1,     1,   319,  ...,  2933, 29901,   259],
        [    1,     1,     1,  ...,  2933, 29901,   259]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [0, 0, 1,  ..., 1, 1, 1],
        [0, 0, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [0, 0, 1,  ..., 1, 1, 1]], device='cuda:0')}

In [73]:
unwrapped_model = ppo_trainer.accelerator.unwrap_model(model)

In [77]:
stopping_criteria

<LLAVA_Biovil.llava.mm_utils.KeywordsStoppingCriteria at 0x7f8940841b40>

In [76]:
stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, padded_inputs["input_ids"])

In [90]:
padded_inputs["input_ids"]

tensor([[    1,     1,   319,  ...,  2933, 29901,   259],
        [    1,     1,     1,  ...,  2933, 29901,   259],
        [    1,     1,     1,  ...,  2933, 29901,   259],
        ...,
        [    1,     1,     1,  ...,  2933, 29901,   259],
        [    1,     1,   319,  ...,  2933, 29901,   259],
        [    1,     1,     1,  ...,  2933, 29901,   259]], device='cuda:0')

In [96]:
padded_inputs["input_ids"]

torch.Size([8, 261])

In [15]:
infrastructure.empty_torch_cuda_cache()

In [13]:
padded_inputs["input_ids"].shape

torch.Size([8, 261])

In [16]:
model.generate(
    **padded_inputs,
    images=images,
    do_sample=False,
    use_cache=False,
    max_new_tokens=300,
    stopping_criteria=[stopping_criteria],
    pad_token_id=tokenizer.pad_token_id,
)

OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 44.35 GiB total capacity; 11.01 GiB already allocated; 2.81 MiB free; 12.46 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [102]:
# input_ids=input_ids,
# images=images,
# attention_mask=attention_mask,
# do_sample=False,
# use_cache=True,
# max_new_tokens=300,
# stopping_criteria=[stopping_criteria],
# pad_token_id=tokenizer.pad_token_id,


generation_kwargs_ppo = {
    # "min_length": -1,  # don't ignore the EOS token (see above)
    # "top_k": 0.0,  # no top-k sampling
    # "top_p": 1.0,  # no nucleus sampling
    # "temperature": 0.5,  # DONT BE CREATIVE
    "do_sample": False,  # yes, we want to sample
    "use_cache": True,
    "pad_token_id": tokenizer.pad_token_id,  # most decoder models don't have a padding token - use EOS token instead (for this tokenizer it was already set to eos_token_id)
    "max_new_tokens": 50,  # let's not be chatty, we need a few tokens to generate confidence but also not limit the response too much
    # "eos_token_id": tokenizer.eos_token_id,  # (instead of ppo_terminators list)
}

unwrapped_model.generate(
    **padded_inputs, images=images, **generation_kwargs_ppo, stopping_criteria=[stopping_criteria]
)

OutOfMemoryError: CUDA out of memory. Tried to allocate 204.00 MiB (GPU 0; 44.35 GiB total capacity; 10.88 GiB already allocated; 58.81 MiB free; 12.41 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
model.train()
generated_confidences_ids = ppo_trainer.generate(
    prompt_and_generated_answers_with_confidence_requests_ids,  # ppo_trainer.generate() method admits list of tensors, not a batch tensor unfortunately
    images=images,
    return_prompt=False,
    **generation_kwargs_ppo,
)

In [49]:
# pad list of 1-d tensors to same length
# prompt_and_generated_answers_with_confidence_requests_ids = training.pad_list_of_tensors(
#     prompt_and_generated_answers_with_confidence_requests_ids, tokenizer.pad_token_id
# )

import torch.nn.functional as F

pad_token_id = int(tokenizer.bos_token_id)  # Define your pad token ID
max_length = max(
    len(seq) for seq in prompt_and_generated_answers_with_confidence_requests_ids
)  # Find max sequence length

# Pad from the left
padded_input_ids = [
    F.pad(seq, (max_length - len(seq), 0), value=pad_token_id)
    for seq in prompt_and_generated_answers_with_confidence_requests_ids
]
padded_input_ids = torch.stack(padded_input_ids)
attention_mask = (padded_input_ids != pad_token_id).int()

In [58]:
padded_input_ids.unsqueeze(0).shape

torch.Size([1, 8, 261])

In [57]:
attention_mask.shape
padded_input_ids.shape

torch.Size([8, 261])

In [39]:
batch_idx

0

In [44]:
image_token_indices

[-1, 42, 261]

In [48]:
labels

tensor([0., 1., 0., 0., 0., 1., 1., 1.], device='cuda:0')

In [43]:
IMAGE_TOKEN_INDEX = -200

new_input_embeds = []
new_labels = []
cur_image_idx = 0
batch_idx = 0
cur_input_ids = padded_input_ids[batch_idx]

num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()

image_token_indices = (
    [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
)

cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
    cur_input_ids_noim.append(
        cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]
    )
    cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []

# for i in range(num_images + 1):
#     cur_new_input_embeds.append(cur_input_embeds_no_im[i])
#     cur_new_labels.append(cur_labels_noim[i])
#     if i < num_images:
#         cur_image_features = image_features[cur_image_idx]
#         cur_image_idx += 1
#         cur_new_input_embeds.append(cur_image_features)
#         cur_new_labels.append(
#             torch.full(
#                 (cur_image_features.shape[0],),
#                 IGNORE_INDEX,
#                 device=cur_labels.device,
#                 dtype=cur_labels.dtype,
#             )
#         )
# cur_new_input_embeds = torch.cat(cur_new_input_embeds)
# cur_new_labels = torch.cat(cur_new_labels)

# new_input_embeds.append(cur_new_input_embeds)
# new_labels.append(cur_new_labels)

In [42]:
rewards

NameError: name 'rewards' is not defined

In [11]:
batch_report

{'query': ["A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. The assistant gives professional, detailed, and polite answers to the user's questions. USER:  image . You are to act as a radiologist and answer the following question: Is the following disease visible in the given X-ray image: Pneumothorax?  ASSISTANT: No, the patient does not have pneumothorax. USER: Now evaluate your own response. How confident are you in your answer? Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. A value close to 0 means you think there is a high probability that the answer is wrong. The closer the value is to 10, the higher you think is the probability that the answer is correct.  ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy.The confidence JSON follows this structure: {'confidence': int}.Here's my confidence

# Prompting Experiments

In [17]:
POST_GENERATION_CONFIDENCE_REQUEST_1 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. "
    "A value close to 0 means you think there is a high probability that the answer is wrong. "
    "The closer the value is to 10, the higher you think is the probability that the answer is correct. "
    "The output should have the format 'Confidence: <confidence>' and nothing else. "
)


POST_GENERATION_CONFIDENCE_REQUEST_2 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. "
    "The output should have the format 'My Confidence: <confidence>'. "
    "ASSISTANT: Here is my self evaluation of my confidence. My Confidence (out of 10):"
)

POST_GENERATION_CONFIDENCE_REQUEST_3 = "USER: Irrelevant from my question, what is your name?"

POST_GENERATION_CONFIDENCE_REQUEST_4 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5 of how sure you are the answer is correct. "
    "The output should have the format 'My Confidence: <confidence>' My Reasoning for my Self Confidence Evaluation: <reasoning>. "
    "ASSISTANT: Here is my self evaluation of my confidence. My Confidence (out of 5):"
)

POST_GENERATION_CONFIDENCE_REQUEST_5 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  of how sure you are the answer is correct. "
    "The output should have the format 'My Confidence: <confidence>' My Very Short Reasoning for my Self Confidence Evaluation: <very_short_reasoning>. "
    "ASSISTANT: Here is my self evaluation of my confidence. My Confidence (out of 10):"
)

POST_GENERATION_CONFIDENCE_REQUEST_6 = (
    "USER: Now evaluate your own response. How confident are you in your answer and what is your explanation regarding your confidence?"
    "ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy."
    "My confidence is a value between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, where 0 means I am not confident at all and 10 means I am very confident."
    "The confidence JSON follows this structure: {'confidence': int, 'explanation': str}."
    "Here's my confidence JSON about my last response: "
)

POST_GENERATION_CONFIDENCE_REQUEST_7 = (
    "USER: Now evaluate your own response. How confident are you in your answer?"
    "ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy."
    "My confidence is a value between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, where 0 means I am not confident at all and 10 means I am very confident."
    "The confidence JSON follows this structure: {'confidence': int}."
    "Here's my confidence JSON about my last response: "
)

POST_GENERATION_CONFIDENCE_REQUEST_8 = (
    "</s> USER: Now evaluate your own response. How confident are you in your answer?"
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. "
    "A value close to 0 means you think there is a high probability that the answer is wrong. "
    "The closer the value is to 10, the higher you think is the probability that the answer is correct. "
    "ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy."
    "The confidence JSON follows this structure: {'confidence': int}."
    "Here's my confidence JSON about my last response: "
)

In [5]:
tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)

In [3]:
STOP_STR = prompter.Seperator.END_OF_SEQUENCE_SEPERATOR.value
from LLAVA_Biovil.llava.mm_utils import KeywordsStoppingCriteria
from RewardingVisualDoubt import inference

padding_tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer.padding_side = "left"
padding_tokenizer.pad_token_id = padding_tokenizer.bos_token_id
dataset_test = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.TEST,
    tokenizer=tokenizer,
    prompter=prompter.build_binary_qa_instruction_from_disease_under_study,
)
dataloader_test = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset=dataset_test, batch_size=1, padding_tokenizer=padding_tokenizer, num_workers=8
)

for idx, batch in enumerate(dataloader_test):
    batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
    batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
    batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
        batch_llava_model_input_dict, torch.device(shared.torch_devices.cuda.value)
    )
    input_ids, images = (
        batch_llava_model_input_dict["text_prompt_input_ids"],
        batch_llava_model_input_dict["images"],
    )
    stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)
    pred = inference.generate_radialog_answer_for_binary_qa_for_single_study(
        model, tokenizer, input_ids, images, stopping_criteria
    )
    confidence_request_prompt = (
        batch["batch_prompts"][0]
        + " "
        + pred
        + " "
        + prompter.build_post_generation_user_confidence_request()  # POST_GENERATION_CONFIDENCE_REQUEST_8
    )
    confidence_request_input_ids = torch.unsqueeze(
        torch.IntTensor(tokenizer(confidence_request_prompt)["input_ids"]), 0
    ).to(device)
    stopping_criteria = KeywordsStoppingCriteria(
        [STOP_STR], tokenizer, confidence_request_input_ids
    )
    pred_with_confidence = inference.generate_radialog_answer_for_binary_qa_for_single_study(
        model, tokenizer, confidence_request_input_ids, images, stopping_criteria
    )
    print(f"\n Metadata: {batch['batch_mimic_cxr_datapoint_metadata']}")
    print(f"Prompt: {batch['batch_prompts']}")
    print(f"Label:", batch["batch_labels"])
    print(f"File_idx {idx}, ASSISTANT: ", pred)
    print(f"File_idx {idx}, ASSISTANT (after confidence request): ", pred_with_confidence)
    if idx == 5:
        break


 Metadata: [MimicCxrBinaryQADatapoint(subject_id=18460230, study_id=53631792, img_path='/home/data/DIVA/mimic/mimic-cxr-jpg/2.0.0/files/p18/p18460230/s53631792/369dc5bd-70bd89d0-2d90fa80-f319ec1d-fb2802aa.jpg', disease=<ChexpertFinding.PLEURAL_EFFUSION: 'Pleural Effusion'>, label=<ChexpertLabel.POSITIVE: 1.0>)]
Prompt: ["A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. The assistant gives professional, detailed, and polite answers to the user's questions. USER: <image>. You are to act as a radiologist and answer the following question: Is the following disease visible in the given X-ray image: Pleural Effusion?  ASSISTANT:"]
Label: tensor([1.])
File_idx 0, ASSISTANT:  Yes, the image shows pleural effusion.
File_idx 0, ASSISTANT (after confidence request):  {"confidence": 9}

 Metadata: [MimicCxrBinaryQADatapoint(subject_id=13263843, study_id=52138943, img_path='/home/data/DIVA/mimic/mimic-cxr-jpg/2.0.0/files/p13/p13263843/s52

In [62]:
pred

'Based on the provided X-ray image, it is not possible to definitively determine the presence of pleural effusion without additional information or a more detailed analysis of the image. However, the image does show a chest X-ray of a person with a right-sided pleural effusion, which is a build-up of fluid in the pleural space surrounding the lungs. This can be a sign of various underlying conditions such as pneumonia, lung cancer, or heart failure. It is important to consult with a medical professional for a proper diagnosis and treatment plan.'

In [16]:
######################################## TEST TO SEE IF TEMPERATURE AND TOP_P PARAMS HELP WITH USER CONFIDENCE REQUEST WITHOUT ASSISTANT CONFIRMATION ########################################


from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token

iterator_train = iter(dataloader_train)
batch = next(iterator_train)

batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
_, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)

my_prompt = prompter.build_binary_qa_instruction_from_disease_under_study_with_confidence_request(
    "Cardiomegaly"
)
tokenized_prompt = tokenizer_image_token(my_prompt, tokenizer, return_tensors="pt").to(device)

stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, tokenized_prompt.unsqueeze(0))

prompt_and_generated_answers_ids = model.generate(
    input_ids=tokenized_prompt.unsqueeze(0),
    images=images[0].unsqueeze(0),
    # attention_mask=attention_mask,
    do_sample=True,
    use_cache=True,
    temperature=1.8,
    top_p=0.7,
    max_new_tokens=300,  # TODO maybe move to the kwargs
    stopping_criteria=[stopping_criteria],  # TODO understand better
    pad_token_id=tokenizer.pad_token_id,  # used in tokenizing after the generation, # TODO maybe move to the kwargs
    # **generation_kwargs_prediction,  # TODO check which args to pass.
)

tokenizer.decode(
    training.replace_image_token_with_another_token(prompt_and_generated_answers_ids)[0]
)

"<s> A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. The assistant gives professional, detailed, and polite answers to the user's questions. USER:  image . You are to act as a radiologist and answer a single question. After you respond, please provide your self evaluation of your confidence. Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. A value close to 0 means you think there is a high probability that the answer is wrong. Your confidence is to be reported in a JSON dictionary of the following format: {'confidence': int}. Is the following disease visible in the given X-ray image: Cardiomegaly, and how confident are you?  ASSISTANT: No.</s>"

# Archived

In [None]:
batch = next(iterator_train)

batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
input_ids, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)
attention_mask = batch["batch_attention_mask"].to(device)  # TODO handle elsewhere
labels = batch["batch_labels"].to(device)  # TODO handle elsewhere


model.eval()
stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)


t3 = time.time()
prompt_and_generated_answers_ids = model.generate(
    input_ids=input_ids,
    images=images,
    attention_mask=attention_mask,
    do_sample=False,
    use_cache=True,
    max_new_tokens=32,  # Limiting, YES, but binary q&a answers are not very long!
    stopping_criteria=[stopping_criteria],
    pad_token_id=tokenizer.pad_token_id,
)
t4 = time.time()
prompt_and_generated_answers_ids = training.remove_trailing_padding_from_prediction(
    prompt_and_generated_answers_ids, tokenizer.pad_token_id
)

# Append confidence request to the generated answers
prompt_and_generated_answers_with_confidence_requests_ids = []
for item in prompt_and_generated_answers_ids:
    confidence_request_input_ids = (
        tokenizer(prompter.build_post_generation_user_confidence_request(), return_tensors="pt")
        .input_ids.to(device)
        .squeeze(0)
    )[
        1:
    ]  # drop start of sequence token
    prompt_and_generated_answers_with_confidence_requests_ids.append(
        torch.cat((item, confidence_request_input_ids), 0)
    )
model.train()

t5 = time.time()
generated_confidences_ids = ppo_trainer.generate(
    prompt_and_generated_answers_with_confidence_requests_ids,  # ppo_trainer.generate() method admits list of tensors, not a batch tensor unfortunately
    images=images,
    return_prompt=False,
    **generation_kwargs_ppo,
)
t6 = time.time()


complete_conversation_ids = [
    torch.cat((p, c), 0)
    for p, c in zip(
        prompt_and_generated_answers_with_confidence_requests_ids,
        generated_confidences_ids,
    )
]
generated_answer_only_ids = [
    prompt_and_generated_answers_ids[i][len(input_ids[i]) :] for i in range(len(input_ids))
]

# Remove the unindex image token from the prompt
prompt_and_generated_answers_with_confidence_requests_ids = (
    training.replace_image_token_with_another_token_for_list_of_tensors(
        prompt_and_generated_answers_with_confidence_requests_ids
    )
)
generated_answers_texts = tokenizer.batch_decode(
    generated_answer_only_ids,
    skip_special_tokens=True,
)
generated_confidences_texts = tokenizer.batch_decode(
    generated_confidences_ids,
    skip_special_tokens=True,
)
generated_answer_labels = response.parse_binary_labels(generated_answers_texts)
generated_confidence_values = response.parse_confidences(generated_confidences_texts)

rewards = [
    reward.generated_answer_and_confidence_to_reward(
        generated_answer_label, generated_confidence_value, ground_truth_label
    )
    for generated_answer_label, generated_confidence_value, ground_truth_label in zip(
        generated_answer_labels, generated_confidence_values, labels.bool().tolist()
    )
]

report = {}
report["generated_answer_labels"] = generated_answer_labels

rewards_epoch += rewards
rewards = [torch.tensor(r).to(device) for r in rewards]

t7 = time.time()
stats = ppo_trainer.step(
    prompt_and_generated_answers_with_confidence_requests_ids, generated_answer_only_ids, rewards
)
t8 = time.time()

# ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "answer"])

# print(f"Finished epoch {epoch}. Average reward: {avg_reward}")
# ppo_trainer.save_pretrained(os.path.join(out_dir, "model_finetuned"))

# TODO: For random exploration
# chance_to_change_confidence -= reduce_per_step
# chance_to_change_confidence = max(0, chance_to_change_confidence)

In [None]:
for i in range(len(rewards_epoch) // batch_size):
    print(sum(rewards_epoch[i * batch_size : (i + 1) * batch_size]))

36.95582970558741


In [None]:
generated_confidence_values

[9, 4, 6, 7, 5, 7, 5, 6]

In [None]:
generated_answer_labels

[True, False, False, True, True, True, True, False]

In [None]:
generated_answers_texts

['Yes, the patient has consolidation.',
 'No, there is no evidence of that in the image.',
 'No, there is no evidence of that in the image.',
 'Yes, the patient has enlarged cardiomediastinum.',
 'Yes, there is evidence of that in the image.',
 'Yes, the patient has cardiomegaly.',
 'Yes, the patient has support devices.',
 'No, there is no evidence of that in the image.']

In [None]:
generated_confidences_texts

['\n{"confidence": 8}',
 '9',
 '\n{"confidence": 4}',
 '\n{"confidence": 5}',
 '\n{"confidence": 9}',
 '\n{"confidence": 6}',
 '\n{"confidence": 8}',
 '\n{"confidence": 6}']

In [None]:
answers_decoded

['Yes, the image shows atelectasis.',
 'Yes, the patient has pleural effusion.',
 'Yes, there is evidence of that in the image.',
 'Yes, the image shows pleural effusion.',
 'No, there is no evidence of that in the image.',
 'No, there is no evidence of that in the image.',
 'No, there is no evidence of that in the image.',
 'Yes, the image shows pleural effusion.']

In [None]:
confidences_decoded

['\n{"confidence": 8}',
 '\n{"confidence": 7}',
 '\n{"confidence": 9}',
 '\n{"confidence": 9}',
 '\n{"confidence": 3}',
 '\n{"confidence": 3}',
 "\n\n{'confidence': 3}",
 '\n{"confidence": 4}']

In [None]:
print((t8 - t7) * 1000, "time it took to ppo step")
print((t6 - t5) * 1000, "time it took to generate confidences")
print((t4 - t3) * 1000, "time it took to generate answers")
print((t2 - t1) * 1000, "time it took to get batch")

print("total time it took", int((t8 - t1)), "seconds")

9359.999895095825 time it took to ppo step
64616.09601974487 time it took to generate confidences
1864.7639751434326 time it took to generate answers
1835.4952335357666 time it took to get batch
total time it took 77 seconds
