In [1]:
# TODO sampling logic (undersample)
# TODO binary qa prompt and co need to vary between a few samples
# TODO determine hyperparams
# TODO arg parsing
# TODO dataloader num workers set to default

# SOME NOTES
# PPO_TRAINER AUTOMATICALLY PADS THE INPUTS BY TOKENIZER.PADDING_SIDE AND TOKENIZER.PADDING_TOKEN_ID
# Uh-oh, because ppo termination token is set as the eos_seq_token, it'll stop when it sees a left padded sequence
# Skipping random exploration for now


# BATCH TIMING
# A batch of 8 samples take around 1-1.5-2-3min to process in a train step (so around 400 samples per hour is trainable, every 50th batch, we save a checkpoint, and do val)
# Lets save a checkpoint every half an hour or so
# Give validation around 15 mins => 100 samples or so
# Validation is around 8k so it'll be 1000 batches (1000*1.5 min = 25 hours)
# len(dataset_eval) = 8737

In [1]:
# %% Set script for interactive development and import modules
from RewardingVisualDoubt import infrastructure, training, vllm

infrastructure.make_ipython_reactive_to_changing_codebase()
infrastructure.supress_known_warnings()

import pathlib as path
import typing as t
import torch
import numpy as np

import os
from torch.utils.data import DataLoader
import accelerate
import dataclasses
import functools
import wandb
from tqdm import tqdm

# from LLAVA_Biovil.llava.mm_utils import KeywordsStoppingCriteria
from trl import PPOConfig, PPOTrainer
import trl

from RewardingVisualDoubt import dataset, prompter, shared, response, reward, vllm
from RewardingVisualDoubt import training as training

os.environ["WANDB_API_KEY"] = "da3cb086bbc110c16cbc5ba4c284a19b0b461710"

from LLAVA_Biovil.llava.mm_utils import KeywordsStoppingCriteria

STOP_STR = prompter.Seperator.END_OF_SEQUENCE_SEPERATOR.value

Fetching 69 files:   0%|          | 0/69 [00:00<?, ?it/s]

In [2]:
############################################ For prototyping only: Input hyperparameters ########################################
NUM_EPOCHS = 1
DEFAULT_BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 2
MINI_BATCH_SIZE = int(DEFAULT_BATCH_SIZE / 2)
LEARNING_RATE = 5e-5
DEFAULT_OUTPUT_DIR = path.Path("output")

batch_size = DEFAULT_BATCH_SIZE
num_epochs = NUM_EPOCHS
batch_size = DEFAULT_BATCH_SIZE
gradient_accumulation_steps = GRADIENT_ACCUMULATION_STEPS
mini_batch_size = int(DEFAULT_BATCH_SIZE / 2)
learning_rate = LEARNING_RATE
out_dir: path.Path = DEFAULT_OUTPUT_DIR

######################################## 0. Define the environment ########################################

device_str = (
    shared.torch_devices.cuda.value if torch.cuda.is_available() else shared.torch_devices.cpu.value
)
device = torch.device(device_str)

######################################## 1. Load the model and tokenizer ########################################

model = vllm.load_pretrained_llava_model_for_ppo_training_with_fresh_lora_adapters(
    device_str=device_str,
    llava_model_path=vllm.RadialogMergedLlavaModelPath.BINARY_QA_WITH_CONFIDENCE_SFT.value,
    precision="4bit",
)

tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer.padding_side = "left"


######################################## 2. Load the datasets and the dataloaders ########################################

print("Loading the datasets and the dataloaders...")
prompter_ = functools.partial(
    prompter.build_binary_qa_prompt_with_response_and_confidence_for_sft, is_for_inference=True
)
dataset_train = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.TRAIN,
    tokenizer=tokenizer,
    prompter=prompter_,
)
dataset_eval = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.VALIDATION,
    tokenizer=tokenizer,
    prompter=prompter_,
)

dataloader_train = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset=dataset_train,
    batch_size=batch_size,
    padding_tokenizer=padding_tokenizer,
    num_workers=8,
)

dataloader_eval = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset=dataset_eval,
    batch_size=2 * batch_size,
    padding_tokenizer=padding_tokenizer,
    num_workers=8,
)

eval_batch_iterator = iter(dataloader_eval)

import sys

sys.path.append("../..")  # Adds higher directory to python modules path.
from workflows import radialog_binary_qa_ppo_training

Adding fresh set of LoRA adapters and a fresh value head to the model for PPO training using Llava model loaded from: /home/guests/deniz_gueler/repos/RewardingVisualDoubt/models/radialog_binary_qa_with_confidence_sft_full_merged_model
Loading LLaVA model with the base LLM and with RaDialog finetuned vision modules...


Fetching 69 files:   0%|          | 0/69 [00:00<?, ?it/s]

Model will be loaded at precision: 4bit
Loading LLaVA from base /home/guests/deniz_gueler/repos/RewardingVisualDoubt/models/radialog_binary_qa_with_confidence_sft_full_merged_model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading additional LLaVA weights...
Merging model with vision tower weights...
Using downloaded and verified file: /tmp/biovil_t_image_model_proj_size_128.pt




Loaded additional vision tower weights...
Adding pretrained RaDialog LoRA adapters (or fresh LoRa adapters) and value head to the model...




Loading the datasets and the dataloaders...
Loading mimic_cxr_df from cache
Loading balanced_binary_qa_mimic_cxr_df from cache
Loading mimic_cxr_df from cache
Loading balanced_binary_qa_mimic_cxr_df from cache
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
######################################## 3. Define the PPO and generation configurations ########################################

# Example of batch size 16: 4 epochs over the batch. Each backward batch is of size 8, and each mini batch is of size 4
# Gradients get accumulated during 4 + 4 mini batches, and then the model gets updated (the "backward batch" is completed)

ppo_config = trl.PPOConfig(
    learning_rate=learning_rate,
    task_name="gpt",
    ppo_epochs=1,  # Default value from TRL library is 4 (i.e. will go over the batch 4 times), but since we have a lot of data, we can set it to 1
    batch_size=batch_size,
    # backward_batch_size=MINI_BATCH_SIZE,  # Default value from TRL library is 1, gets overwritten anyways at __init__ time
    mini_batch_size=mini_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    log_with="wandb",
    tracker_project_name="radialog_binary_qa_ppo_training",
    project_kwargs=dataclasses.asdict(
        accelerate.utils.ProjectConfiguration(
            project_dir="radialog_binary_qa_ppo_training", logging_dir="logs"
        )
    ),
    remove_unused_columns=False,
    # optimize_device_cache=True,
    kl_penalty="kl",  # 'kl': model_logp - ref_logp,  'abs': abs(kl),  'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution"
    init_kl_coef=0.05,
)

generation_kwargs_ppo = {
    "min_length": -1,  # don't ignore the EOS token (see above)
    "top_k": 0.0,  # no top-k sampling
    "top_p": 1.0,  # no nucleus sampling
    "temperature": 1.0,  # DONT BE CREATIVE
    "do_sample": True,  # yes, we want to sample
    "pad_token_id": tokenizer.pad_token_id,  # most decoder models don't have a padding token - use EOS token instead (for this tokenizer it was already set to eos_token_id)
    "max_new_tokens": 50,  # let's not be chatty, we need only a few tokens to generate confidence but also let us not limit the response too much
    "eos_token_id": tokenizer.eos_token_id,  # (instead of ppo_terminators list)
}

ppo_trainer = t.cast(
    training.MultimodalPPOTrainer,
    training.MultimodalPPOTrainer(
        model=model,
        config=ppo_config,
        tokenizer=tokenizer,
    ),
)

# not sure if needed but just to be safe for now
tokenizer.padding_side = "left"
model.config.padding_side = "left"
model.config.tokenizer_padding_side = "left"
# model.pad_token_id = tokenizer.eos_token_id

fatal: No names found, cannot describe anything.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33monurdenizguler[0m ([33monurdenizguler-technical-university-of-munich[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




In [6]:
rewards_epoch = []
accumulating_game_logs: training.GameLogs = {
    "queries": [],
    "responses": [],
    "is_answer_correct": [],
    "scores": [],
    "confidences": [],
}
iterator_train = iter(dataloader_train)
# model.pretrained_model.enable_input_require_grads()
# custom_game_log_table = wandb.Table(columns=["query", "response", "reward"])

In [None]:
for i in tqdm(range(360)):
    batch = next(iterator_train)
    radialog_binary_qa_ppo_training.radialog_binary_qa_ppo_training_step(
        model,
        device,
        tokenizer,
        generation_kwargs_ppo,
        ppo_trainer,
        batch,
        reward.reward_teachers_pet_behaviour,
        accumulating_game_logs,
    )

 43%|████▎     | 154/360 [58:21<1:17:47, 22.66s/it]

In [None]:
model.save_pretrained(
    "/home/guests/deniz_gueler/repos/RewardingVisualDoubt/models/teachers_pet_ppo_adapter_1"
)

# Debug the training pipeline

In [None]:
######### 5.1 Unpack the batch #########
batch: dataset.MimicCxrLlavaModelInputBatchDict = batch
batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
input_ids, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)
attention_mask = batch["batch_attention_mask"].to(device)
labels = t.cast(torch.Tensor, batch["batch_labels"]).to(device)
stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)
input_ids_list = training.remove_preciding_padding_from_batch_tensor(input_ids)

######### 5.2 Generate the binary q&a answer and remove trailing padding tokens #########
model.eval()
model.gradient_checkpointing_disable()
generated_ids = ppo_trainer.generate(
    query_tensor=input_ids_list,  # ppo_trainer.generate() method admits list of tensors, handles padding and batching itself
    images=images,
    return_prompt=False,
    batch_size=input_ids.shape[0],
    use_cache=True,  # => not compatible with gradient checkpointing!
    stopping_criteria=[stopping_criteria],
    **generation_kwargs_ppo,
)

In [None]:
######### 5.3 Parse the responses and compute the scores #########
generated_texts = tokenizer.batch_decode(generated_ids)
generated_answer_labels = response.parse_binary_labels(generated_texts)
generated_confidence_values = response.parse_confidences(generated_texts)

scores = [
    reward.generated_answer_and_confidence_to_reward(
        generated_answer_label, generated_confidence_value, ground_truth_label
    )
    for generated_answer_label, generated_confidence_value, ground_truth_label in zip(
        generated_answer_labels, generated_confidence_values, labels.bool().tolist()
    )
]

scores = t.cast(
    list[torch.FloatTensor],
    [torch.tensor(s).to(device) for s in scores],
)

In [None]:
######### 5.7 Take a PPO optimization step #########
model.train()
model.gradient_checkpointing_enable()
stats = ppo_trainer.multimodal_step(
    queries=t.cast(list[torch.LongTensor], input_ids_list),
    responses=t.cast(list[torch.LongTensor], generated_ids),
    scores=scores,
    images=images,
)

In [None]:
model_inputs = ppo_trainer.prepare_model_inputs(
    queries=t.cast(list[torch.LongTensor], input_ids_list),
    responses=t.cast(list[torch.LongTensor], generated_confidences_ids),
)

model_inputs["images"] = images  # N
model_inputs_names = list(model_inputs.keys())

queries = t.cast(list[torch.LongTensor], input_ids_list)
responses = t.cast(list[torch.LongTensor], generated_confidences_ids)
bs = len(queries)
fbs = ppo_trainer.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []


i = 2
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
# query_batch = queries[i * fbs : (i + 1) * fbs]
# response_batch = responses[i * fbs : (i + 1) * fbs]
with torch.no_grad():
    logits_mini, _, values_mini = model(**input_kwargs)

input_ids_mini = input_kwargs["input_ids"]
attention_mask_mini = input_kwargs["attention_mask"]
images_mini = input_kwargs["images"]

In [None]:
with torch.no_grad():
    all_logprobs, logits_or_none, values, masks = ppo_trainer.batched_forward_pass(
        model,
        queries,
        responses,
        model_inputs,
    )

In [None]:
mini_batch_inds = [0, 1, 2, 3]
mini_batch_dict = {
    "logprobs": all_logprobs[mini_batch_inds],
    "values": values[mini_batch_inds],
    "masks": masks[mini_batch_inds],
    # hacks: the queries and responses are ragged.
    "queries": [queries[i] for i in mini_batch_inds],
    "responses": [responses for i in mini_batch_inds],
}
model_inputs_ = {
    "input_ids": model_inputs["input_ids"][mini_batch_inds],
    "attention_mask": model_inputs["attention_mask"][mini_batch_inds],
    "images": model_inputs["images"][mini_batch_inds],
}

with torch.no_grad():
    logprobs_new, logits_new, vpreds_new, _ = ppo_trainer.batched_forward_pass(
        model,
        mini_batch_dict["queries"],
        mini_batch_dict["responses"],
        model_inputs=model_inputs_,
        return_logits=True,
    )

# Archived

## Archived attempt to account for image embeddings

In [None]:
def account_for_image_embeddings_for_single_image_inputs(input_ids, logits, values, attention_mask):
    """
    Args:
        input_ids (torch.Tensor): A tensor shaped (batch_size, sequence_length) including exactly 1 image token id (-200 for llava) for each sequence
        logits (torch.Tensor): A tensor shaped (batch_size, sequence_length, vocab_size)
        values: The values of the model
        attention_mask: The attention mask of the model
    """
    # Locate the image index
    indexes_of_image_token = (input_ids == training.LLAVA_IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[
        1
    ]

    # Shift the indexes taking in the account where each sequence begins (bos_token is taken as basis)
    indexes_of_bos_token = (input_ids == 1).int().argmax(dim=1)
    indexes_of_image_token += indexes_of_bos_token

In [None]:
# queries = training.replace_image_token_with_another_token_for_list_of_tensors(input_ids_list)
# queries = t.cast(list[torch.LongTensor], queries)
model.gradient_checkpointing_enable()
stats = ppo_trainer.multimodal_step(
    queries=t.cast(list[torch.LongTensor], input_ids_list),
    responses=t.cast(list[torch.LongTensor], generated_confidences_ids),
    scores=rewards,
    images=images,
)

In [None]:
image_features = model.pretrained_model.base_model.model.model.get_vision_tower()(
    images
).patch_embeddings
image_features = image_features.flatten(2).transpose(1, 2)
image_features = model.pretrained_model.base_model.model.model.mm_projector(image_features)
image_features.shape

In [None]:
input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels = (
    model.pretrained_model.base_model.model.prepare_inputs_labels_for_multimodal(
        input_ids=input_ids,
        position_ids=None,
        attention_mask=attention_mask,
        past_key_values=None,
        labels=None,
        images=images,
    )
)

## Prompting Experiments

In [None]:
POST_GENERATION_CONFIDENCE_REQUEST_1 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. "
    "A value close to 0 means you think there is a high probability that the answer is wrong. "
    "The closer the value is to 10, the higher you think is the probability that the answer is correct. "
    "The output should have the format 'Confidence: <confidence>' and nothing else. "
)


POST_GENERATION_CONFIDENCE_REQUEST_2 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. "
    "The output should have the format 'My Confidence: <confidence>'. "
    "ASSISTANT: Here is my self evaluation of my confidence. My Confidence (out of 10):"
)

POST_GENERATION_CONFIDENCE_REQUEST_3 = "USER: Irrelevant from my question, what is your name?"

POST_GENERATION_CONFIDENCE_REQUEST_4 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5 of how sure you are the answer is correct. "
    "The output should have the format 'My Confidence: <confidence>' My Reasoning for my Self Confidence Evaluation: <reasoning>. "
    "ASSISTANT: Here is my self evaluation of my confidence. My Confidence (out of 5):"
)

POST_GENERATION_CONFIDENCE_REQUEST_5 = (
    "USER: Now evaluate your own response. How confident are you in your answer? "
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  of how sure you are the answer is correct. "
    "The output should have the format 'My Confidence: <confidence>' My Very Short Reasoning for my Self Confidence Evaluation: <very_short_reasoning>. "
    "ASSISTANT: Here is my self evaluation of my confidence. My Confidence (out of 10):"
)

POST_GENERATION_CONFIDENCE_REQUEST_6 = (
    "USER: Now evaluate your own response. How confident are you in your answer and what is your explanation regarding your confidence?"
    "ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy."
    "My confidence is a value between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, where 0 means I am not confident at all and 10 means I am very confident."
    "The confidence JSON follows this structure: {'confidence': int, 'explanation': str}."
    "Here's my confidence JSON about my last response: "
)

POST_GENERATION_CONFIDENCE_REQUEST_7 = (
    "USER: Now evaluate your own response. How confident are you in your answer?"
    "ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy."
    "My confidence is a value between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, where 0 means I am not confident at all and 10 means I am very confident."
    "The confidence JSON follows this structure: {'confidence': int}."
    "Here's my confidence JSON about my last response: "
)

POST_GENERATION_CONFIDENCE_REQUEST_8 = (
    "</s> USER: Now evaluate your own response. How confident are you in your answer?"
    "Provide a confidence between 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, of how sure you are the answer is correct. "
    "A value close to 0 means you think there is a high probability that the answer is wrong. "
    "The closer the value is to 10, the higher you think is the probability that the answer is correct. "
    "ASSISTANT: When asked how confident I am about a response, I consistently provide it in a JSON object, adhering to my policy."
    "The confidence JSON follows this structure: {'confidence': int}."
    "Here's my confidence JSON about my last response: "
)

In [None]:
tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)

In [None]:
STOP_STR = prompter.Seperator.END_OF_SEQUENCE_SEPERATOR.value
from LLAVA_Biovil.llava.mm_utils import KeywordsStoppingCriteria
from RewardingVisualDoubt import inference

padding_tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
padding_tokenizer.padding_side = "left"
padding_tokenizer.pad_token_id = padding_tokenizer.bos_token_id
dataset_test = dataset.get_binary_qa_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.TEST,
    tokenizer=tokenizer,
    prompter=prompter.build_binary_qa_instruction_from_disease_under_study,
)
dataloader_test = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset=dataset_test, batch_size=1, padding_tokenizer=padding_tokenizer, num_workers=8
)

for idx, batch in enumerate(dataloader_test):
    batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
    batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
    batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
        batch_llava_model_input_dict, torch.device(shared.torch_devices.cuda.value)
    )
    input_ids, images = (
        batch_llava_model_input_dict["text_prompt_input_ids"],
        batch_llava_model_input_dict["images"],
    )
    stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)
    pred = inference.generate_radialog_answer_for_binary_qa_for_single_study(
        model, tokenizer, input_ids, images, stopping_criteria
    )
    confidence_request_prompt = (
        batch["batch_prompts"][0]
        + " "
        + pred
        + " "
        + prompter.build_post_generation_user_confidence_request()  # POST_GENERATION_CONFIDENCE_REQUEST_8
    )
    confidence_request_input_ids = torch.unsqueeze(
        torch.IntTensor(tokenizer(confidence_request_prompt)["input_ids"]), 0
    ).to(device)
    stopping_criteria = KeywordsStoppingCriteria(
        [STOP_STR], tokenizer, confidence_request_input_ids
    )
    pred_with_confidence = inference.generate_radialog_answer_for_binary_qa_for_single_study(
        model, tokenizer, confidence_request_input_ids, images, stopping_criteria
    )
    print(f"\n Metadata: {batch['batch_mimic_cxr_datapoint_metadata']}")
    print(f"Prompt: {batch['batch_prompts']}")
    print(f"Label:", batch["batch_labels"])
    print(f"File_idx {idx}, ASSISTANT: ", pred)
    print(f"File_idx {idx}, ASSISTANT (after confidence request): ", pred_with_confidence)
    if idx == 5:
        break

In [None]:
######################################## TEST TO SEE IF TEMPERATURE AND TOP_P PARAMS HELP WITH USER CONFIDENCE REQUEST WITHOUT ASSISTANT CONFIRMATION ########################################


from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token

iterator_train = iter(dataloader_train)
batch = next(iterator_train)

batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
_, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)

my_prompt = prompter.build_binary_qa_instruction_from_disease_under_study_with_confidence_request(
    "Cardiomegaly"
)
tokenized_prompt = tokenizer_image_token(my_prompt, tokenizer, return_tensors="pt").to(device)

stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, tokenized_prompt.unsqueeze(0))

prompt_and_generated_answers_ids = model.generate(
    input_ids=tokenized_prompt.unsqueeze(0),
    images=images[0].unsqueeze(0),
    # attention_mask=attention_mask,
    do_sample=True,
    use_cache=True,
    temperature=1.8,
    top_p=0.7,
    max_new_tokens=300,  # TODO maybe move to the kwargs
    stopping_criteria=[stopping_criteria],  # TODO understand better
    pad_token_id=tokenizer.pad_token_id,  # used in tokenizing after the generation, # TODO maybe move to the kwargs
    # **generation_kwargs_prediction,  # TODO check which args to pass.
)

tokenizer.decode(
    training.replace_image_token_with_another_token(prompt_and_generated_answers_ids)[0]
)

## Archived

In [None]:
batch = next(iterator_train)

batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
batch_llava_model_input_dict = batch["batch_llava_model_input_dict"]
batch_llava_model_input_dict = dataset.move_llava_model_input_dict_to_device(
    batch_llava_model_input_dict, device
)
input_ids, images = (
    batch_llava_model_input_dict["text_prompt_input_ids"],
    batch_llava_model_input_dict["images"],
)
attention_mask = batch["batch_attention_mask"].to(device)  # TODO handle elsewhere
labels = batch["batch_labels"].to(device)  # TODO handle elsewhere


model.eval()
stopping_criteria = KeywordsStoppingCriteria([STOP_STR], tokenizer, input_ids)


t3 = time.time()
prompt_and_generated_answers_ids = model.generate(
    input_ids=input_ids,
    images=images,
    attention_mask=attention_mask,
    do_sample=False,
    use_cache=True,
    max_new_tokens=32,  # Limiting, YES, but binary q&a answers are not very long!
    stopping_criteria=[stopping_criteria],
    pad_token_id=tokenizer.pad_token_id,
)
t4 = time.time()
prompt_and_generated_answers_ids = training.remove_trailing_padding_from_prediction(
    prompt_and_generated_answers_ids, tokenizer.pad_token_id
)

# Append confidence request to the generated answers
prompt_and_generated_answers_with_confidence_requests_ids = []
for item in prompt_and_generated_answers_ids:
    confidence_request_input_ids = (
        tokenizer(prompter.build_post_generation_user_confidence_request(), return_tensors="pt")
        .input_ids.to(device)
        .squeeze(0)
    )[
        1:
    ]  # drop start of sequence token
    prompt_and_generated_answers_with_confidence_requests_ids.append(
        torch.cat((item, confidence_request_input_ids), 0)
    )
model.train()

t5 = time.time()
generated_confidences_ids = ppo_trainer.generate(
    prompt_and_generated_answers_with_confidence_requests_ids,  # ppo_trainer.generate() method admits list of tensors, not a batch tensor unfortunately
    images=images,
    return_prompt=False,
    **generation_kwargs_ppo,
)
t6 = time.time()


complete_conversation_ids = [
    torch.cat((p, c), 0)
    for p, c in zip(
        prompt_and_generated_answers_with_confidence_requests_ids,
        generated_confidences_ids,
    )
]
generated_answer_only_ids = [
    prompt_and_generated_answers_ids[i][len(input_ids[i]) :] for i in range(len(input_ids))
]

# Remove the unindex image token from the prompt
prompt_and_generated_answers_with_confidence_requests_ids = (
    training.replace_image_token_with_another_token_for_list_of_tensors(
        prompt_and_generated_answers_with_confidence_requests_ids
    )
)
generated_answers_texts = tokenizer.batch_decode(
    generated_answer_only_ids,
    skip_special_tokens=True,
)
generated_confidences_texts = tokenizer.batch_decode(
    generated_confidences_ids,
    skip_special_tokens=True,
)
generated_answer_labels = response.parse_binary_labels(generated_answers_texts)
generated_confidence_values = response.parse_confidences(generated_confidences_texts)

rewards = [
    reward.generated_answer_and_confidence_to_reward(
        generated_answer_label, generated_confidence_value, ground_truth_label
    )
    for generated_answer_label, generated_confidence_value, ground_truth_label in zip(
        generated_answer_labels, generated_confidence_values, labels.bool().tolist()
    )
]

report = {}
report["generated_answer_labels"] = generated_answer_labels

rewards_epoch += rewards
rewards = [torch.tensor(r).to(device) for r in rewards]

t7 = time.time()
stats = ppo_trainer.step(
    prompt_and_generated_answers_with_confidence_requests_ids, generated_answer_only_ids, rewards
)
t8 = time.time()

# ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "answer"])

# print(f"Finished epoch {epoch}. Average reward: {avg_reward}")
# ppo_trainer.save_pretrained(os.path.join(out_dir, "model_finetuned"))

# TODO: For random exploration
# chance_to_change_confidence -= reduce_per_step
# chance_to_change_confidence = max(0, chance_to_change_confidence)

In [None]:
working_set = "mini_batch"
input_ids_idx = 1


if working_set == "mini_batch":
    print("Working with mini batch")
    input_ids_working = input_ids_mini.clone().detach()
    logits_working = logits_mini.clone().detach()
elif working_set == "full_batch":
    print("Working with full batch")
    input_ids_working = input_ids.clone().detach()
    logits_working = logits.clone().detach()
else:
    raise ValueError(
        f"working_set must be one of ['mini_batch', 'full_batch'], but got {working_set}"
    )


indexes_of_image_token = (input_ids_working == training.LLAVA_IMAGE_TOKEN_INDEX).nonzero(
    as_tuple=True
)[1]

print(
    tokenizer.batch_decode(
        torch.argmax(logits_working[input_ids_idx][: indexes_of_image_token[input_ids_idx]], dim=-1)
    )
)
print(
    tokenizer.batch_decode(
        torch.argmax(
            logits_working[input_ids_idx][
                indexes_of_image_token[input_ids_idx] : indexes_of_image_token[input_ids_idx] + 196
            ],
            dim=-1,
        )
    )
)
print(
    tokenizer.batch_decode(
        torch.argmax(
            logits_working[input_ids_idx][indexes_of_image_token[input_ids_idx] + 196 :], dim=-1
        )
    )
)
print("\n ")

print(input_ids_working[input_ids_idx])

print("\n \n ")