In [None]:
from RewardingVisualDoubt import infrastructure

infrastructure.make_ipython_reactive_to_changing_codebase()


from RewardingVisualDoubt import (
    dataset,
    green,
    evaluation,
    training,
    vllm,
    prompter,
    inference,
    response,
    reward,
    shared,
)

import accelerate
import dataclasses
import torch
import functools
import pathlib as path
import math
import os
import itertools
import time

from tqdm import tqdm

import typing as t

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
BATCH_SIZE = 24

device, device_str = shared.get_device_and_device_str()
model = vllm.shortcut_load_the_original_radialog_model()
tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)

model.config.padding_side = "left"

prompter_ = prompter.build_report_generation_instruction_from_findings
dataset_ = dataset.get_in_distribution_report_generation_test_set(tokenizer, prompter_)
dataloader = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset_,
    batch_size=BATCH_SIZE,
    padding_tokenizer=vllm.load_pretrained_llava_tokenizer_with_image_support(
        for_use_in_padding=True
    ),
    num_workers=8,
)

BASE_JSON_DIR = "/home/guests/deniz_gueler/repos/RewardingVisualDoubt/workflows/report_generation/evaluations/results/testset_in_dist/vanilla_verbalize.json"
records = dataset.read_records_from_json_file(BASE_JSON_DIR)

Adding LoRA adapters to the model for SFT training or inference from Radialog Lora Weights path: /home/guests/deniz_gueler/repos/RewardingVisualDoubt/data/RaDialog_adapter_model.bin
Loading LLaVA model with the base LLM and with RaDialog finetuned vision modules...


Fetching 69 files:   0%|          | 0/69 [00:00<?, ?it/s]

Model will be loaded at precision: 4bit
Loading LLaVA from base liuhaotian/llava-v1.5-7b


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading additional LLaVA weights...
Merging model with vision tower weights...
Using downloaded and verified file: /tmp/biovil_t_image_model_proj_size_128.pt
Adding LoRA adapters to the model...
Loading mimic_cxr_df from cache




# SEQUENCE PROBABILITY

In [None]:
for i, batch in enumerate(tqdm(dataloader)):
    batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
    input_ids, images, attention_mask, batch_metadata_list = (
        dataset.typical_unpacking_for_report_generation(device, batch)
    )
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(records))
    batch_records = records[start_idx:end_idx]
    confidence_stripped_reports = training.remove_confidence_part_from_generated_responses(
        [record["generated_report"] for record in batch_records]
    )
    generated_reports_input_ids = tokenizer(confidence_stripped_reports).input_ids

    concatenated_sequences = []
    generation_lenghts = []
    for original_input_ids, generated_input_ids in zip(input_ids, generated_reports_input_ids):
        concatenated = torch.cat(
            [original_input_ids, torch.tensor(generated_input_ids, dtype=torch.long).to(device)[1:]]
        )
        sequence_start = (concatenated == 1).nonzero()[0]
        concatenated = concatenated[sequence_start:]
        concatenated_sequences.append(concatenated)
        generation_lenghts.append(len(generated_input_ids) - 1)

    final_input_ids, final_attention_masks = dataset.pad_batch_text_sequences(
        concatenated_sequences,
        padding_tokenizer=vllm.load_pretrained_llava_tokenizer_with_image_support(
            for_use_in_padding=True
        ),
    )
    with torch.no_grad():
        outputs = model(
            input_ids=final_input_ids,
            images=images,
            attention_mask=final_attention_masks,
        )
    logits = outputs.logits
    collected_mean_probs = []
    for logit, gen_len, gen_input_ids in zip(
        logits, generation_lenghts, generated_reports_input_ids
    ):
        logit = logit[-(gen_len + 1) : -1]
        token_probs = torch.softmax(logit, dim=-1)
        all_rows = torch.arange(token_probs.size(0))
        mean_prob = token_probs[all_rows, gen_input_ids[1:]].mean()
        collected_mean_probs.append(mean_prob)
    for idx, mean_prob in enumerate(collected_mean_probs):
        batch_records[idx]["confidence"] = round(mean_prob.item() * 10)

    json_path = "sequence_probability.json"
    dataset.append_records_to_json_file(batch_records, json_path)

100%|██████████| 46/46 [09:27<00:00, 12.34s/it]


# P(TRUE)

In [218]:
FOLLOWUP_PROMPT = (
    " USER: Was your previous answer correctt? Answer with a single word: Yes or No. ASSISTANT:"
)

for i, batch in enumerate(tqdm(dataloader)):
    batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
    input_ids, images, attention_mask, batch_metadata_list = (
        dataset.typical_unpacking_for_report_generation(device, batch)
    )
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(records))
    batch_records = records[start_idx:end_idx]
    confidence_stripped_reports = training.remove_confidence_part_from_generated_responses(
        [record["generated_report"] for record in batch_records]
    )
    generated_reports_input_ids = tokenizer(confidence_stripped_reports).input_ids

    concatenated_sequences = []
    for original_input_ids, generated_input_ids in zip(input_ids, generated_reports_input_ids):
        concatenated = torch.cat(
            [original_input_ids, torch.tensor(generated_input_ids, dtype=torch.long).to(device)[1:]]
        )
        sequence_start = (concatenated == 1).nonzero()[0]
        concatenated = concatenated[sequence_start:]
        concatenated = torch.cat(
            [
                concatenated,
                torch.tensor(tokenizer(FOLLOWUP_PROMPT).input_ids, dtype=torch.long).to(device),
            ]
        )
        concatenated_sequences.append(concatenated)

    final_input_ids, final_attention_masks = dataset.pad_batch_text_sequences(
        concatenated_sequences,
        padding_tokenizer=vllm.load_pretrained_llava_tokenizer_with_image_support(
            for_use_in_padding=True
        ),
    )
    with torch.no_grad():
        outputs = model(
            input_ids=final_input_ids,
            images=images,
            attention_mask=final_attention_masks,
        )
    logits = outputs.logits
    normalized_yes, normalized_no = evaluation.extract_yes_no_probs_from_logits(logits)
    for idx, norm_yes in enumerate(normalized_yes):
        batch_records[idx]["confidence"] = round(norm_yes * 10)
    json_path = "ptrue.json"
    dataset.append_records_to_json_file(batch_records, json_path)

100%|██████████| 46/46 [09:56<00:00, 12.97s/it]
