In [1]:
# TODO sampling logic (undersample)
# TODO binary qa prompt and co need to vary between a few samples
# TODO determine hyperparams
# TODO arg parsing
# TODO dataloader num workers set to default

# SOME NOTES
# PPO_TRAINER AUTOMATICALLY PADS THE INPUTS BY TOKENIZER.PADDING_SIDE AND TOKENIZER.PADDING_TOKEN_ID
# Uh-oh, because ppo termination token is set as the eos_seq_token, it'll stop when it sees a left padded sequence
# Skipping random exploration for now


# BATCH TIMING
# A batch of 8 samples take around 1-1.5-2-3min to process in a train step (so around 400 samples per hour is trainable, every 50th batch, we save a checkpoint, and do val)
# Lets save a checkpoint every half an hour or so
# Give validation around 15 mins => 100 samples or so
# Validation is around 8k so it'll be 1000 batches (1000*1.5 min = 25 hours)
# len(dataset_eval) = 8737

In [1]:
# %% Set script for interactive development and import modules
from RewardingVisualDoubt import infrastructure, training

infrastructure.make_ipython_reactive_to_changing_codebase()
infrastructure.supress_known_warnings()

import pathlib as path
import typing as t
import torch
import numpy as np
import time
import os
from torch.utils.data import DataLoader
import accelerate
import dataclasses
import functools

# from LLAVA_Biovil.llava.mm_utils import KeywordsStoppingCriteria
from trl import PPOConfig, PPOTrainer

from RewardingVisualDoubt import dataset, prompter, shared, vllm, response, reward
from RewardingVisualDoubt import training as training

os.environ["WANDB_API_KEY"] = "da3cb086bbc110c16cbc5ba4c284a19b0b461710"

from LLAVA_Biovil.llava.mm_utils import KeywordsStoppingCriteria

STOP_STR = prompter.Seperator.END_OF_SEQUENCE_SEPERATOR.value

Fetching 69 files:   0%|          | 0/69 [00:00<?, ?it/s]

In [2]:
######################################## 0. Define the environment ########################################

DEFAULT_BATCH_SIZE = 16
DEFAULT_OUTPUT_DIR = path.Path("output")


batch_size = DEFAULT_BATCH_SIZE

device_str = (
    shared.torch_devices.cuda.value if torch.cuda.is_available() else shared.torch_devices.cpu.value
)
device = torch.device(device_str)

######################################## 1. Load the model and tokenizer ########################################

model = vllm.load_pretrained_llava_model_for_ppo_training(
    device_str=device_str,
    precision="4bit",
    radialog_lora_weights_path=vllm.RadialogLoraWeightsPath.BINARY_QA_WITH_CONFIDENCE_SFT.value,
)

tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer.padding_side = "left"  # Why? Because: A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


######################################## 2. Load the datasets and the dataloaders ########################################

print("Loading the datasets and the dataloaders...")
prompter_ = functools.partial(
    prompter.build_binary_qa_prompt_with_response_and_confidence_for_sft, is_for_inference=True
)
dataset_train = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.TRAIN,
    tokenizer=tokenizer,
    prompter=prompter_,
)
dataset_eval = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.VALIDATION,
    tokenizer=tokenizer,
    prompter=prompter_,
)

dataloader_train = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset=dataset_train,
    batch_size=batch_size,
    padding_tokenizer=padding_tokenizer,
    num_workers=8,
)

dataloader_eval = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset=dataset_eval,
    batch_size=2 * batch_size,
    padding_tokenizer=padding_tokenizer,
    num_workers=8,
)

eval_batch_iterator = iter(dataloader_eval)

import sys

sys.path.append("../..")  # Adds higher directory to python modules path.
from workflows import radialog_binary_qa_ppo_training

Adding LoRA adapters and value head to the model for PPO training using Radialog Lora Weights path: /home/guests/deniz_gueler/repos/RewardingVisualDoubt/workflows/training_checkpoints/best_model_epoch0_step179.pth/adapter_model.bin
Loading the model in non-trainable mode...
Precision: 4bit quantized
Model base:  liuhaotian/llava-v1.5-7b
Loading LLaVA from base model...


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading additional LLaVA weights...
Downloading https://cdn-lfs.hf.co/repos/63/2f/632fbb459426d5c3e8e64aa8be737ccf0c8ba541980f23a79ecf1ab6e87df8b4/b2399d73dc2a68b9f3a1950e864ae0ecd24093fb07aa459d7e65807ebdc0fb77?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27biovil_t_image_model_proj_size_128.pt%3B+filename%3D%22biovil_t_image_model_proj_size_128.pt%22%3B&Expires=1744374786&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NDM3NDc4Nn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy82My8yZi82MzJmYmI0NTk0MjZkNWMzZThlNjRhYThiZTczN2NjZjBjOGJhNTQxOTgwZjIzYTc5ZWNmMWFiNmU4N2RmOGI0L2IyMzk5ZDczZGMyYTY4YjlmM2ExOTUwZTg2NGFlMGVjZDI0MDkzZmIwN2FhNDU5ZDdlNjU4MDdlYmRjMGZiNzc%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=GLWo1SvfxTaTIglO1ojgVZV%7EP5QN1hGo6mzIekryAgIPI%7EcwvdFKNgSeuMSyNDKhUsbalrAh1Jck%7ESQDbD1b1pqj4-THEZc-5lExlTlafbDI2jcWRJwPLLXxuwxNKatKB8EtAoutBJkQ-3uxNePRX5m-wy1CgvEIVUsUOZL03sXN5YU1OPuPCYX8d048lvbAX2Ra7WKCrg

100%|██████████| 109745561/109745561 [00:00<00:00, 218901060.96it/s]


Loaded additional vision tower weights...
Adding pretrained RaDialog LoRA adapters and value head to the model...




Loading the datasets and the dataloaders...
Loading mimic_cxr_df from cache
Loading balanced_binary_qa_mimic_cxr_df from cache
Loading mimic_cxr_df from cache
Loading balanced_binary_qa_mimic_cxr_df from cache
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
type(input_kwargs["input_ids"])

torch.Tensor

In [None]:
logits.shape

torch.Size([4, 396, 32000])

In [23]:
indexes_of_image_token = (input_ids == training.LLAVA_IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[1]

In [24]:
indexes_of_image_token

tensor([46, 47, 51, 47, 49, 52, 46, 48, 47, 51, 47, 41, 46, 46, 52, 45],
       device='cuda:0')

In [28]:
input_kwargs["input_ids"]

tensor([[    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
         21082, 20255, 16684,   408,   385, 18860, 17937, 19915, 29889,   450,
         20255,  4076, 10257, 29892, 13173, 29892,   322,  1248,   568,  6089,
           304,   278,  1404, 29915, 29879,  5155, 29889,  3148,  1001, 29901,
         29871,  -200,   869,   887,   526,   304,  1044,   408,   263, 17937,
         19915,   322,  1234,   263,  2323,  1139, 29889,  2860,   366, 10049,
         29892,  3113,  3867,   596,  1583, 17983,   310,   596, 16420, 29889,
          9133,   680,   263, 16420,  1546, 29871, 29900, 29892, 29871, 29896,
         29892, 29871, 29906, 29892, 29871, 29941, 29892, 29871, 29946, 29892,
         29871, 29945, 29892, 29871, 29953, 29892, 29871, 29955, 29892, 29871,
         29947, 29892, 29871, 29929, 29892, 29871, 29896, 29900, 29892,   310,
           920,  1854,   366,   526,   278,  1234,  

In [41]:
input_ids[-1][:5]

tensor([2, 2, 2, 2, 1], device='cuda:0')

In [42]:
indexes_of_image_token

tensor([46, 47, 51, 47, 49, 52, 46, 48, 47, 51, 47, 41, 46, 46, 52, 45],
       device='cuda:0')

In [35]:
(input_ids == 1).int().argmax(dim=1)

tensor([ 5,  6, 10,  6,  8, 11,  5,  7,  6, 10,  6,  0,  5,  5, 11,  4],
       device='cuda:0')

In [None]:
def account_for_image_embeddings_for_single_image_inputs(input_ids, logits, values, attention_mask):
    """
    Args:
        input_ids (torch.Tensor): A tensor shaped (batch_size, sequence_length) including exactly 1 image token id (-200 for llava) for each sequence
        logits (torch.Tensor): A tensor shaped (batch_size, sequence_length, vocab_size)
        values: The values of the model
        attention_mask: The attention mask of the model
    """
    # Locate the image index
    indexes_of_image_token = (input_ids == training.LLAVA_IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[
        1
    ]

    # Shift the indexes taking in the account where each sequence begins (bos_token is taken as basis)
    indexes_of_bos_token = (input_ids == 1).int().argmax(dim=1)

In [3]:
######################################## 3. Define the PPO and generation configurations ########################################
epochs = 1
lr = 5e-5
log_with = "foo"
out_dir = "output"

# Example of batch size 16: 4 epochs over the batch. Each backward batch is of size 8, and each mini batch is of size 4
# Gradients get accumulated during 4 + 4 mini batches, and then the model gets updated (the "backward batch" is completed)

ppo_config = PPOConfig(
    learning_rate=lr,
    task_name="gpt",
    ppo_epochs=4,  # Default value from TRL library is 4 (i.e. will go over the batch 4 times)
    batch_size=DEFAULT_BATCH_SIZE,
    backward_batch_size=int(DEFAULT_BATCH_SIZE / 2),  # Default value from TRL library is 1
    mini_batch_size=int(DEFAULT_BATCH_SIZE / 4),
    gradient_accumulation_steps=4,
    log_with="wandb",
    project_kwargs=dataclasses.asdict(
        accelerate.utils.ProjectConfiguration(
            project_dir="radialog_binary_qa_ppo_training", logging_dir="logs"
        )
    ),
    remove_unused_columns=False,
    # optimize_device_cache=True,
    init_kl_coef=0.05,
)

generation_kwargs_ppo = {
    "min_length": -1,  # don't ignore the EOS token (see above)
    "top_k": 0.0,  # no top-k sampling
    "top_p": 1.0,  # no nucleus sampling
    "temperature": 1.0,  # DONT BE CREATIVE
    "do_sample": True,  # yes, we want to sample
    "pad_token_id": tokenizer.pad_token_id,  # most decoder models don't have a padding token - use EOS token instead (for this tokenizer it was already set to eos_token_id)
    "max_new_tokens": 50,  # let's not be chatty, we need a few tokens to generate confidence but also not limit the response too much
    "eos_token_id": tokenizer.eos_token_id,  # (instead of ppo_terminators list)
}

ppo_trainer = t.cast(
    training.MultimodalPPOTrainer,
    training.MultimodalPPOTrainer(
        model=model,
        config=ppo_config,
        tokenizer=tokenizer,
    ),
)

# not sure if needed but just to be safe for now
tokenizer.padding_side = "left"
model.config.padding_side = "left"
model.config.tokenizer_padding_side = "left"
# model.pad_token_id = tokenizer.eos_token_id

fatal: No names found, cannot describe anything.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33monurdenizguler[0m ([33monurdenizguler-technical-university-of-munich[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




In [4]:
rewards_epoch = []
iterator_train = iter(dataloader_train)
batch = next(iterator_train)

In [None]:
for i in range(1):
    print(i)
    batch = next(iterator_train)
    rewards, batch_report = radialog_binary_qa_ppo_training.radialog_binary_qa_ppo_training_step(
        model,
        device,
        tokenizer,
        generation_kwargs_ppo,
        ppo_trainer,
        batch,
    )

In [5]:
######### 5.1 Unpack the batch #########
batch: dataset.MimicCxrLlavaModelInputBatchDict = batch
batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
input_ids, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)
attention_mask = batch["batch_attention_mask"].to(device)
labels = t.cast(torch.Tensor, batch["batch_labels"]).to(device)
stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)
input_ids_list = training.remove_preciding_padding_from_batch_tensor(input_ids)

# model.train()
model.gradient_checkpointing_disable()
generated_confidences_ids = ppo_trainer.generate(
    query_tensor=input_ids_list,  # ppo_trainer.generate() method admits list of tensors, handles padding and batching itself
    images=images,
    return_prompt=False,
    batch_size=DEFAULT_BATCH_SIZE,
    use_cache=True,  # => not compatible with gradient checkpointing!
    stopping_criteria=[stopping_criteria],
    **generation_kwargs_ppo,
)

In [6]:
generated_confidences_texts = tokenizer.batch_decode(generated_confidences_ids)
generated_answer_labels = response.parse_binary_labels(generated_confidences_texts)
generated_confidence_values = response.parse_confidences(generated_confidences_texts)

rewards = [
    reward.generated_answer_and_confidence_to_reward(
        generated_answer_label, generated_confidence_value, ground_truth_label
    )
    for generated_answer_label, generated_confidence_value, ground_truth_label in zip(
        generated_answer_labels, generated_confidence_values, labels.bool().tolist()
    )
]

rewards = t.cast(
    list[torch.FloatTensor],
    [torch.tensor(r).to(device) for r in rewards],
)

In [7]:
model_inputs = ppo_trainer.prepare_model_inputs(
    queries=t.cast(list[torch.LongTensor], input_ids_list),
    responses=t.cast(list[torch.LongTensor], generated_confidences_ids),
)

model_inputs["images"] = images  # N
model_inputs_names = list(model_inputs.keys())

queries = (t.cast(list[torch.LongTensor], input_ids_list),)
responses = (t.cast(list[torch.LongTensor], generated_confidences_ids),)
bs = len(queries)
fbs = ppo_trainer.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []

In [8]:
with torch.no_grad():
    base_model_output = model.pretrained_model(**input_kwargs, output_hidden_states=True)

NameError: name 'input_kwargs' is not defined

In [31]:
base_model_output.logits.shape

torch.Size([4, 395, 32000])

In [15]:
last_hidden_state = base_model_output.hidden_states[-1]

In [16]:
last_hidden_state.shape

torch.Size([4, 395, 4096])

In [85]:
val = model.v_head(last_hidden_state).squeeze(-1)

In [86]:
val.shape

torch.Size([4, 209])

In [37]:
logits.shape

torch.Size([4, 390, 32000])

In [None]:
# find the indexes where the tensor "my_tensor" equals -200
sequence = input_kwargs["input_ids"][0]
indexes_of_image_token = (sequence == -200).nonzero(as_tuple=True)[0]

In [None]:
logitsss = model(input_ids=sequence[:20].unsqueeze(0))[1]["logits"]

In [68]:
sequence

tensor([    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     1,   319, 13563,  1546,   263, 12758,  1404,
          322,   385, 23116, 21082, 20255, 16684,   408,   385, 18860, 17937,
        19915, 29889,   450, 20255,  4076, 10257, 29892, 13173, 29892,   322,
         1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155, 29889,
         3148,  1001, 29901, 29871,  -200,   869,   887,   526,   304,  1044,
          408,   263, 17937, 19915,   322,  1234,   263,  2323,  1139, 29889,
         2860,   366, 10049, 29892,  3113,  3867,   596,  1583, 17983,   310,
          596, 16420, 29889,  9133,   680,   263, 16420,  1546, 29871, 29900,
        29892, 29871, 29896, 29892, 29871, 29906, 29892, 29871, 29941, 29892,
        29871, 29946, 29892, 29871, 29945, 29892, 29871, 29953, 29892, 29871,
        29955, 29892, 29871, 29947, 29892, 29871, 29929, 29892, 29871, 29896,
        29900, 29892,   310,   920,  1854,   366,   526,   278, 

In [None]:
logits[0][indexes_of_image_token - 12 : indexes_of_image_token]

tensor([[ -4.4936, -10.1075,   2.7072,  ...,   2.1521,  -3.0400,   1.7643],
        [ -2.5677,  -0.9008,   5.9041,  ...,   3.0179,  -2.7296,   1.1002],
        [ -2.2757,   0.0464,   2.4819,  ...,   2.1444,  -3.4312,  -1.7103],
        ...,
        [ -3.0267,   1.4197,   3.3870,  ...,   2.2771,  -0.7604,  -0.3429],
        [ -3.3076,   2.2019,   2.7390,  ...,   1.8549,  -1.8728,  -0.1617],
        [ -2.6943,   2.8725,   2.5231,  ...,   2.4108,  -1.4761,  -1.0808]],
       device='cuda:0')

torch.Size([20, 32000])

In [85]:
tokenizer.batch_decode(torch.argmax(logitsss.squeeze(0), dim=-1))

['<s>',
 '<s>',
 '<s>',
 '<s>',
 '<s>',
 '<s>',
 '<s>',
 '<s>',
 '<s>',
 '<s>',
 '<s>',
 '<s>',
 '<s>',
 'A',
 '.',
 ',',
 'MS',
 'PA',
 'PA',
 'PA']

In [None]:
tokenizer.batch_decode(torch.argmax(logits[0][: indexes_of_image_token - 12], 1))

['<s>',
 'A',
 'chat',
 'between',
 'a',
 'curious',
 'user',
 'and',
 'an',
 'artificial',
 'intelligence',
 'assistant',
 'acting',
 'as',
 'an',
 'experienced',
 'radi',
 'ologist',
 '.',
 'The',
 'assistant',
 'gives',
 'professional',
 ',',
 'detailed',
 ',',
 'and',
 'pol',
 'ite',
 'answers',
 'to',
 'the',
 'user',
 "'",
 's',
 'questions',
 '.',
 'US',
 'ER',
 ':',
 '',
 '1']

In [9]:
import math

i = 3

input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
with torch.no_grad():
    logits, _, values = model(**input_kwargs)

In [32]:
input_kwargs["attention_mask"][0]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')

In [29]:
values.shape

torch.Size([4, 390])

In [33]:
logits.shape

torch.Size([4, 390, 32000])

In [9]:
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]

In [14]:
logits[:, :-1, :].shape

torch.Size([4, 395, 32000])

In [10]:
import torch.nn.functional as F

logp = F.log_softmax(logits[:, :-1, :], dim=2)

In [None]:
input_ids_without_image_token = training.replace_image_token_with_another_token(
    input_ids.clone(), replacement_token_id=0
)

In [None]:
input_ids

tensor([[    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
         21082, 20255, 16684,   408,   385, 18860, 17937, 19915, 29889,   450,
         20255,  4076, 10257, 29892, 13173, 29892,   322,  1248,   568,  6089,
           304,   278,  1404, 29915, 29879,  5155, 29889,  3148,  1001, 29901,
         29871,  -200,   869,   887,   526,   304,  1044,   408,   263, 17937,
         19915,   322,  1234,   263,  2323,  1139, 29889,  2860,   366, 10049,
         29892,  3113,  3867,   596,  1583, 17983,   310,   596, 16420, 29889,
          9133,   680,   263, 16420,  1546, 29871, 29900, 29892, 29871, 29896,
         29892, 29871, 29906, 29892, 29871, 29941, 29892, 29871, 29946, 29892,
         29871, 29945, 29892, 29871, 29953, 29892, 29871, 29955, 29892, 29871,
         29947, 29892, 29871, 29929, 29892, 29871, 29896, 29900, 29892,   310,
           920,  1854,   366,   526,   278,  1234,  

In [13]:
logpy = torch.gather(logp, 2, input_ids_without_image_token[:, 1:].unsqueeze(2)).squeeze(-1)

In [34]:
from trl.core import logprobs_from_logits

logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]

RuntimeError: Size does not match at dimension 0 expected index [16, 181, 1] to be smaller than self [4, 389, 32000] apart from dimension 2

In [35]:
# queries = training.replace_image_token_with_another_token_for_list_of_tensors(input_ids_list)
# queries = t.cast(list[torch.LongTensor], queries)
model.gradient_checkpointing_enable()
stats = ppo_trainer.multimodal_step(
    queries=t.cast(list[torch.LongTensor], input_ids_list),
    responses=t.cast(list[torch.LongTensor], generated_confidences_ids),
    scores=rewards,
    images=images,
)

First pass for logprobs
torch.Size([16, 207])
torch.Size([4, 395])
torch.Size([4, 206])
torch.Size([16, 207])
torch.Size([4, 393])
torch.Size([4, 206])
torch.Size([16, 207])
torch.Size([4, 402])
torch.Size([4, 206])
torch.Size([16, 207])
torch.Size([4, 390])
torch.Size([4, 206])


RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 395 but got size 393 for tensor number 1 in the list.

# Prompting Experiments

In [17]:
POST_GENERATION_CONFIDENCE_REQUEST_1 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. "
    "A value close to 0 means you think there is a high probability that the answer is wrong. "
    "The closer the value is to 10, the higher you think is the probability that the answer is correct. "
    "The output should have the format 'Confidence: <confidence>' and nothing else. "
)


POST_GENERATION_CONFIDENCE_REQUEST_2 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. "
    "The output should have the format 'My Confidence: <confidence>'. "
    "ASSISTANT: Here is my self evaluation of my confidence. My Confidence (out of 10):"
)

POST_GENERATION_CONFIDENCE_REQUEST_3 = "USER: Irrelevant from my question, what is your name?"

POST_GENERATION_CONFIDENCE_REQUEST_4 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5 of how sure you are the answer is correct. "
    "The output should have the format 'My Confidence: <confidence>' My Reasoning for my Self Confidence Evaluation: <reasoning>. "
    "ASSISTANT: Here is my self evaluation of my confidence. My Confidence (out of 5):"
)

POST_GENERATION_CONFIDENCE_REQUEST_5 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  of how sure you are the answer is correct. "
    "The output should have the format 'My Confidence: <confidence>' My Very Short Reasoning for my Self Confidence Evaluation: <very_short_reasoning>. "
    "ASSISTANT: Here is my self evaluation of my confidence. My Confidence (out of 10):"
)

POST_GENERATION_CONFIDENCE_REQUEST_6 = (
    "USER: Now evaluate your own response. How confident are you in your answer and what is your explanation regarding your confidence?"
    "ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy."
    "My confidence is a value between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, where 0 means I am not confident at all and 10 means I am very confident."
    "The confidence JSON follows this structure: {'confidence': int, 'explanation': str}."
    "Here's my confidence JSON about my last response: "
)

POST_GENERATION_CONFIDENCE_REQUEST_7 = (
    "USER: Now evaluate your own response. How confident are you in your answer?"
    "ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy."
    "My confidence is a value between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, where 0 means I am not confident at all and 10 means I am very confident."
    "The confidence JSON follows this structure: {'confidence': int}."
    "Here's my confidence JSON about my last response: "
)

POST_GENERATION_CONFIDENCE_REQUEST_8 = (
    "</s> USER: Now evaluate your own response. How confident are you in your answer?"
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. "
    "A value close to 0 means you think there is a high probability that the answer is wrong. "
    "The closer the value is to 10, the higher you think is the probability that the answer is correct. "
    "ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy."
    "The confidence JSON follows this structure: {'confidence': int}."
    "Here's my confidence JSON about my last response: "
)

In [5]:
tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)

In [None]:
STOP_STR = prompter.Seperator.END_OF_SEQUENCE_SEPERATOR.value
from LLAVA_Biovil.llava.mm_utils import KeywordsStoppingCriteria
from RewardingVisualDoubt import inference

padding_tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer.padding_side = "left"
padding_tokenizer.pad_token_id = padding_tokenizer.bos_token_id
dataset_test = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.TEST,
    tokenizer=tokenizer,
    prompter=prompter.build_binary_qa_instruction_from_disease_under_study,
)
dataloader_test = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset=dataset_test, batch_size=1, padding_tokenizer=padding_tokenizer, num_workers=8
)

for idx, batch in enumerate(dataloader_test):
    batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
    batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
    batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
        batch_llava_model_input_dict, torch.device(shared.torch_devices.cuda.value)
    )
    input_ids, images = (
        batch_llava_model_input_dict["text_prompt_input_ids"],
        batch_llava_model_input_dict["images"],
    )
    stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)
    pred = inference.generate_radialog_answer_for_binary_qa_for_single_study(
        model, tokenizer, input_ids, images, stopping_criteria
    )
    confidence_request_prompt = (
        batch["batch_prompts"][0]
        + " "
        + pred
        + " "
        + prompter.build_post_generation_user_confidence_request()  # POST_GENERATION_CONFIDENCE_REQUEST_8
    )
    confidence_request_input_ids = torch.unsqueeze(
        torch.IntTensor(tokenizer(confidence_request_prompt)["input_ids"]), 0
    ).to(device)
    stopping_criteria = KeywordsStoppingCriteria(
        [STOP_STR], tokenizer, confidence_request_input_ids
    )
    pred_with_confidence = inference.generate_radialog_answer_for_binary_qa_for_single_study(
        model, tokenizer, confidence_request_input_ids, images, stopping_criteria
    )
    print(f"\n Metadata: {batch['batch_mimic_cxr_datapoint_metadata']}")
    print(f"Prompt: {batch['batch_prompts']}")
    print(f"Label:", batch["batch_labels"])
    print(f"File_idx {idx}, ASSISTANT: ", pred)
    print(f"File_idx {idx}, ASSISTANT (after confidence request): ", pred_with_confidence)
    if idx == 5:
        break

In [None]:
pred

In [None]:
######################################## TEST TO SEE IF TEMPERATURE AND TOP_P PARAMS HELP WITH USER CONFIDENCE REQUEST WITHOUT ASSISTANT CONFIRMATION ########################################


from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token

iterator_train = iter(dataloader_train)
batch = next(iterator_train)

batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
_, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)

my_prompt = prompter.build_binary_qa_instruction_from_disease_under_study_with_confidence_request(
    "Cardiomegaly"
)
tokenized_prompt = tokenizer_image_token(my_prompt, tokenizer, return_tensors="pt").to(device)

stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, tokenized_prompt.unsqueeze(0))

prompt_and_generated_answers_ids = model.generate(
    input_ids=tokenized_prompt.unsqueeze(0),
    images=images[0].unsqueeze(0),
    # attention_mask=attention_mask,
    do_sample=True,
    use_cache=True,
    temperature=1.8,
    top_p=0.7,
    max_new_tokens=300,  # TODO maybe move to the kwargs
    stopping_criteria=[stopping_criteria],  # TODO understand better
    pad_token_id=tokenizer.pad_token_id,  # used in tokenizing after the generation, # TODO maybe move to the kwargs
    # **generation_kwargs_prediction,  # TODO check which args to pass.
)

tokenizer.decode(
    training.replace_image_token_with_another_token(prompt_and_generated_answers_ids)[0]
)

# Archived

In [None]:
batch = next(iterator_train)

batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
input_ids, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)
attention_mask = batch["batch_attention_mask"].to(device)  # TODO handle elsewhere
labels = batch["batch_labels"].to(device)  # TODO handle elsewhere


model.eval()
stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)


t3 = time.time()
prompt_and_generated_answers_ids = model.generate(
    input_ids=input_ids,
    images=images,
    attention_mask=attention_mask,
    do_sample=False,
    use_cache=True,
    max_new_tokens=32,  # Limiting, YES, but binary q&a answers are not very long!
    stopping_criteria=[stopping_criteria],
    pad_token_id=tokenizer.pad_token_id,
)
t4 = time.time()
prompt_and_generated_answers_ids = training.remove_trailing_padding_from_prediction(
    prompt_and_generated_answers_ids, tokenizer.pad_token_id
)

# Append confidence request to the generated answers
prompt_and_generated_answers_with_confidence_requests_ids = []
for item in prompt_and_generated_answers_ids:
    confidence_request_input_ids = (
        tokenizer(prompter.build_post_generation_user_confidence_request(), return_tensors="pt")
        .input_ids.to(device)
        .squeeze(0)
    )[
        1:
    ]  # drop start of sequence token
    prompt_and_generated_answers_with_confidence_requests_ids.append(
        torch.cat((item, confidence_request_input_ids), 0)
    )
model.train()

t5 = time.time()
generated_confidences_ids = ppo_trainer.generate(
    prompt_and_generated_answers_with_confidence_requests_ids,  # ppo_trainer.generate() method admits list of tensors, not a batch tensor unfortunately
    images=images,
    return_prompt=False,
    **generation_kwargs_ppo,
)
t6 = time.time()


complete_conversation_ids = [
    torch.cat((p, c), 0)
    for p, c in zip(
        prompt_and_generated_answers_with_confidence_requests_ids,
        generated_confidences_ids,
    )
]
generated_answer_only_ids = [
    prompt_and_generated_answers_ids[i][len(input_ids[i]) :] for i in range(len(input_ids))
]

# Remove the unindex image token from the prompt
prompt_and_generated_answers_with_confidence_requests_ids = (
    training.replace_image_token_with_another_token_for_list_of_tensors(
        prompt_and_generated_answers_with_confidence_requests_ids
    )
)
generated_answers_texts = tokenizer.batch_decode(
    generated_answer_only_ids,
    skip_special_tokens=True,
)
generated_confidences_texts = tokenizer.batch_decode(
    generated_confidences_ids,
    skip_special_tokens=True,
)
generated_answer_labels = response.parse_binary_labels(generated_answers_texts)
generated_confidence_values = response.parse_confidences(generated_confidences_texts)

rewards = [
    reward.generated_answer_and_confidence_to_reward(
        generated_answer_label, generated_confidence_value, ground_truth_label
    )
    for generated_answer_label, generated_confidence_value, ground_truth_label in zip(
        generated_answer_labels, generated_confidence_values, labels.bool().tolist()
    )
]

report = {}
report["generated_answer_labels"] = generated_answer_labels

rewards_epoch += rewards
rewards = [torch.tensor(r).to(device) for r in rewards]

t7 = time.time()
stats = ppo_trainer.step(
    prompt_and_generated_answers_with_confidence_requests_ids, generated_answer_only_ids, rewards
)
t8 = time.time()

# ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "answer"])

# print(f"Finished epoch {epoch}. Average reward: {avg_reward}")
# ppo_trainer.save_pretrained(os.path.join(out_dir, "model_finetuned"))

# TODO: For random exploration
# chance_to_change_confidence -= reduce_per_step
# chance_to_change_confidence = max(0, chance_to_change_confidence)

In [None]:
for i in range(len(rewards_epoch) // batch_size):
    print(sum(rewards_epoch[i * batch_size : (i + 1) * batch_size]))

In [None]:
generated_confidence_values

In [None]:
generated_answer_labels

In [None]:
generated_answers_texts

In [None]:
generated_confidences_texts

In [None]:
answers_decoded

In [None]:
confidences_decoded

In [None]:
print((t8 - t7) * 1000, "time it took to ppo step")
print((t6 - t5) * 1000, "time it took to generate confidences")
print((t4 - t3) * 1000, "time it took to generate answers")
print((t2 - t1) * 1000, "time it took to get batch")

print("total time it took", int((t8 - t1)), "seconds")