In [1]:
from RewardingVisualDoubt import infrastructure

infrastructure.make_ipython_reactive_to_changing_codebase()


from RewardingVisualDoubt import (
    dataset,
    green,
    evaluation,
    training,
    vllm,
    prompter,
    inference,
    response,
    reward,
    shared,
)

import accelerate
import dataclasses
import torch
import functools
import pathlib as path
import math
import os
import itertools
import time

from tqdm import tqdm

import typing as t

from torch.utils.data import DataLoader, Subset

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Fetching 69 files:   0%|          | 0/69 [00:00<?, ?it/s]

In [2]:
BATCH_SIZE = 12

device, device_str = shared.get_device_and_device_str()
model = vllm.shortcut_load_the_original_radialog_model()
tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)

model.config.padding_side = "left"

prompter_ = prompter.build_report_generation_instruction_from_findings
dataset_ = dataset.get_in_distribution_report_generation_test_set(tokenizer, prompter_)
dataloader = dataset.get_mimic_cxr_llava_model_input_dataloader(
    dataset_,
    batch_size=BATCH_SIZE,
    padding_tokenizer=vllm.load_pretrained_llava_tokenizer_with_image_support(
        for_use_in_padding=True
    ),
    num_workers=8,
)


# dataset_ = dataset.get_report_generation_prompted_mimic_cxr_llava_model_input_dataset(
#     split=dataset.DatasetSplit.TEST, tokenizer=tokenizer, prompter=prompter_
# )
# datapoint_indexes = list(range(988))
# dataset_ = t.cast(
#     dataset.ReportGenerationPromptedMimicCxrLlavaModelInputDataset,
#     Subset(dataset_, datapoint_indexes),
# )
# dataloader = DataLoader(
#     dataset=dataset_,
#     batch_size=BATCH_SIZE,
#     collate_fn=lambda x: dataset.prompted_mimic_cxr_llava_model_input_collate_fn(
#         x, vllm.load_pretrained_llava_tokenizer_with_image_support(for_use_in_padding=True)
#     ),
#     shuffle=False,
#     num_workers=8,
#     pin_memory=True,
#     drop_last=False,
#     persistent_workers=True,
# )


BASE_JSON_DIR = "/home/guests/deniz_gueler/repos/StudyingVisualDoubt/data/inference/report_generation/testset_difficulty_balanced_from_validation_set_idx_505-2129/base_sft_model/generated_reports_with_confidence_and_green_score.json"
records = dataset.read_records_from_json_file(BASE_JSON_DIR)

Adding LoRA adapters to the model for SFT training or inference from Radialog Lora Weights path: /home/guests/deniz_gueler/repos/RewardingVisualDoubt/data/RaDialog_adapter_model.bin
Loading LLaVA model with the base LLM and with RaDialog finetuned vision modules...


Fetching 69 files:   0%|          | 0/69 [00:00<?, ?it/s]

Model will be loaded at precision: 4bit
Loading LLaVA from base liuhaotian/llava-v1.5-7b




Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading additional LLaVA weights...
Merging model with vision tower weights...
Using downloaded and verified file: /tmp/biovil_t_image_model_proj_size_128.pt
Adding LoRA adapters to the model...
Loading mimic_cxr_df from cache




In [3]:
for i, batch in enumerate(tqdm(dataloader)):
    batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
    input_ids, images, stopping_criteria, attention_mask, batch_metadata_list = (
        dataset.unpack_report_generation_batch_with_attention_mask_and_metadata(
            device, tokenizer, batch
        )
    )
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(records))
    batch_records = records[start_idx:end_idx]
    confidence_stripped_reports = training.remove_confidence_part_from_generated_responses(
        [record["generated_report"] for record in batch_records]
    )
    confidence_stripped_reports_with_post_conf_req = [
        report + " " + prompter.build_post_generation_user_confidence_request()
        for report in confidence_stripped_reports
    ]
    generated_reports_input_ids = tokenizer(
        confidence_stripped_reports_with_post_conf_req
    ).input_ids

    concatenated_sequences = []
    generation_lenghts = []
    for original_input_ids, generated_input_ids in zip(input_ids, generated_reports_input_ids):
        concatenated = torch.cat(
            [original_input_ids, torch.tensor(generated_input_ids, dtype=torch.long).to(device)[1:]]
        )
        sequence_start = (concatenated == 1).nonzero()[0]
        concatenated = concatenated[sequence_start:]
        concatenated_sequences.append(concatenated)
        generation_lenghts.append(len(generated_input_ids) - 1)

    final_input_ids, final_attention_masks = dataset.pad_batch_text_sequences(
        concatenated_sequences,
        padding_tokenizer=vllm.load_pretrained_llava_tokenizer_with_image_support(
            for_use_in_padding=True
        ),
    )
    final_input_ids = final_input_ids.to(device)
    final_attention_masks = final_attention_masks.to(device)
    output_ids = model.generate(
        input_ids=final_input_ids[:, :-1],
        images=images,
        attention_mask=final_attention_masks[:, :-1],
        do_sample=False,
        use_cache=True,
        max_new_tokens=10,
        stopping_criteria=[stopping_criteria],
        pad_token_id=tokenizer.pad_token_id,
    )

    confidence_texts = []
    for input_idx, output in enumerate(output_ids):
        confidence_texts.append(
            tokenizer.decode(
                output[final_input_ids[input_idx].shape[0] :], skip_special_tokens=True
            )
        )
    confidence_values = response.parse_confidences(confidence_texts, granular_confidence=False)

    for idx, conf_val in enumerate(confidence_values):
        batch_records[idx]["confidence"] = conf_val

    json_path = "original_radialog_vanilla_verbalize.json"
    dataset.append_records_to_json_file(batch_records, json_path)

100%|██████████| 92/92 [14:14<00:00,  9.29s/it]
