In [1]:
from RewardingVisualDoubt import infrastructure

infrastructure.make_ipython_reactive_to_changing_codebase()


from RewardingVisualDoubt import (
    dataset,
    green,
    evaluation,
    training,
    vllm,
    prompter,
    inference,
    response,
    reward,
    shared,
)

import accelerate
import dataclasses
import torch
import functools
import pathlib as path
import math
import os
import itertools
import time

from tqdm import tqdm
import json

import typing as t
from torch.utils.data import Subset

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Fetching 69 files:   0%|          | 0/69 [00:00<?, ?it/s]

In [2]:
# records = dataset.read_records_from_json_file(
#     "/home/guests/deniz_gueler/repos/StudyingVisualDoubt/data/inference/report_generation/difficulty_balanced_training_set/original_radialog_model/original_radialog_model_green_scores.json"
# )

records = dataset.read_records_from_json_file(
    "/home/guests/deniz_gueler/repos/StudyingVisualDoubt/data/inference/report_generation/testset_2040_randomly_sampled/base_sft_model/generated_reports_with_confidence_and_green_score.json"
)

In [3]:
BATCH_SIZE = 24
device, device_str = shared.get_device_and_device_str()
model = vllm.shortcut_load_the_original_radialog_model()
tokenizer = vllm.load_pretrained_llava_tokenizer_with_image_support(
    model_base=vllm.LLAVA_BASE_MODEL_NAME
)
selected_prompter_fn = prompter.build_report_generation_instruction_from_findings
# dataset_ = dataset.get_report_generation_prompted_mimic_cxr_llava_model_input_dataset(
#     split=dataset.DatasetSplit.TRAIN, tokenizer=tokenizer, prompter=selected_prompter_fn
# )
# dataset_ = dataset.get_report_generation_prompted_mimic_cxr_llava_model_input_dataset(
#     split=dataset.DatasetSplit.VALIDATION, tokenizer=tokenizer, prompter=selected_prompter_fn
# )
dataset_ = dataset.get_report_generation_prompted_mimic_cxr_llava_model_input_dataset(
    split=dataset.DatasetSplit.TEST, tokenizer=tokenizer, prompter=selected_prompter_fn
)
# selected_datapoints_json = "/home/guests/deniz_gueler/repos/RewardingVisualDoubt/workflows/report_generation/selected_datapoints/1476_training_datapoints_balanced_difficulty.json"
# selected_datapoints_json = "/home/guests/deniz_gueler/repos/RewardingVisualDoubt/workflows/report_generation/selected_datapoints/988_test_datapoints_balanced_difficulty_sampled_from_validation_set_idx505-2120.json"
# with open(selected_datapoints_json, "r") as f:
#     datapoint_indexes: list[int] = json.load(f)

datapoint_indexes = list(range(988))
dataset_ = t.cast(
    dataset.ReportGenerationPromptedMimicCxrLlavaModelInputDataset,
    Subset(dataset_, datapoint_indexes),
)
model.config.padding_side = "left"

# dataloader = dataset.get_mimic_cxr_llava_model_input_dataloader(
#     dataset_,
#     batch_size=BATCH_SIZE,
#     padding_tokenizer=vllm.load_pretrained_llava_tokenizer_with_image_support(
#         for_use_in_padding=True
#     ),
#     num_workers=8,
# )

from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset=dataset_,
    batch_size=BATCH_SIZE,
    collate_fn=lambda x: dataset.prompted_mimic_cxr_llava_model_input_collate_fn(
        x, vllm.load_pretrained_llava_tokenizer_with_image_support(for_use_in_padding=True)
    ),
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True,
)

Adding LoRA adapters to the model for SFT training or inference from Radialog Lora Weights path: /home/guests/deniz_gueler/repos/RewardingVisualDoubt/data/RaDialog_adapter_model.bin
Loading LLaVA model with the base LLM and with RaDialog finetuned vision modules...


Fetching 69 files:   0%|          | 0/69 [00:00<?, ?it/s]

Model will be loaded at precision: 4bit
Loading LLaVA from base liuhaotian/llava-v1.5-7b




Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading additional LLaVA weights...
Merging model with vision tower weights...
Using downloaded and verified file: /tmp/biovil_t_image_model_proj_size_128.pt
Adding LoRA adapters to the model...
Loading mimic_cxr_df from cache


# Hidden state collection

In [4]:
SELECTED_HIDDEN_LAYER_IDX = 16
final_stats = []
for i, batch in enumerate(tqdm(dataloader)):
    batch = t.cast(dataset.MimicCxrLlavaModelInputBatchDict, batch)
    input_ids, images, attention_mask, batch_metadata_list = (
        dataset.typical_unpacking_for_report_generation(device, batch)
    )
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(records))
    batch_records = records[start_idx:end_idx]
    confidence_stripped_reports = training.remove_confidence_part_from_generated_responses(
        [record["generated_report"] for record in batch_records]
    )
    generated_reports_input_ids = tokenizer(confidence_stripped_reports).input_ids

    concatenated_sequences = []
    generation_lenghts = []
    for original_input_ids, generated_input_ids in zip(input_ids, generated_reports_input_ids):
        concatenated = torch.cat(
            [original_input_ids, torch.tensor(generated_input_ids, dtype=torch.long).to(device)[1:]]
        )
        sequence_start = (concatenated == 1).nonzero()[0]
        concatenated = concatenated[sequence_start:]
        concatenated_sequences.append(concatenated)
        generation_lenghts.append(len(generated_input_ids) - 1)

    final_input_ids, final_attention_masks = dataset.pad_batch_text_sequences(
        concatenated_sequences,
        padding_tokenizer=vllm.load_pretrained_llava_tokenizer_with_image_support(
            for_use_in_padding=True
        ),
    )
    with torch.no_grad():
        outputs = model(
            input_ids=final_input_ids,
            images=images,
            attention_mask=final_attention_masks,
            return_dict=True,
            output_hidden_states=True,
        )

    hidden_states = outputs["hidden_states"][SELECTED_HIDDEN_LAYER_IDX]
    collected_hidden_states = []
    for state in hidden_states:
        collected_hidden_states.append(state[-1, :].view(-1).cpu().float().detach().numpy())

    stats = [
        {"hidden": hidden_state, "target": float(batch_record["green_score"])}
        for hidden_state, batch_record in zip(collected_hidden_states, batch_records)
    ]
    final_stats.extend(stats)

100%|██████████| 42/42 [10:17<00:00, 14.70s/it]


In [5]:
torch.save(final_stats, "hidden_states_and_green_scores_layer_16_test_ood.pt")

# Training the probe

In [6]:
from pathlib import Path

import torch
from torch import nn
from datasets import Dataset
import pandas as pd
from torch.utils.data import DataLoader


class MLP2(nn.Module):
    """
    Multilayer Perceptron.
    """

    def __init__(self, hidden_size):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        """Forward pass"""
        return self.layers(x)

    def score(self, o, hidden_layer_idx):
        return torch.sigmoid(
            self.layers(
                o.hidden_states[hidden_layer_idx][:, -1, :]
                .to(self.layers[0].weight.dtype)
                .to(self.layers[0].weight.device)
            ).squeeze(1)
        )

In [28]:
HIDDEN_PATH = "hidden_states_and_green_scores_layer_16_training.pt"

# torch.manual_seed(seed)
# Path(out_name).parent.mkdir(exist_ok=True, parents=True)


data = torch.load(HIDDEN_PATH, map_location="cpu")

In [None]:
HIDDEN_PATH = "hidden_states_and_green_scores_layer_16_training.pt"

# torch.manual_seed(seed)
# Path(out_name).parent.mkdir(exist_ok=True, parents=True)


data = torch.load(HIDDEN_PATH, map_location="cpu")
dataset = Dataset.from_pandas(pd.DataFrame(data=data))
trainloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)

device = "cuda"

# Initialize the MLP
mlp = MLP2(hidden_size=len(dataset[0]["hidden"]))

mlp = mlp.to(device)

# Define the loss function and optimizer
loss_function = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)

In [71]:
out_name = "trained_probes/trained_probe_layer16_balanced_difficulty.pt"
# Run the training loop
for epoch in range(0, 5):  # 5 epochs at maximum
    # Print epoch
    print(f"Starting epoch {epoch}")

    # Set current loss value
    current_loss = 0.0
    collected_outputs = []
    collected_targets = []

    # Iterate over the DataLoader for training data
    for i, data in tqdm(
        enumerate(trainloader),
        desc=f"Taking training steps... Epoch {epoch+1}/{5}",
    ):
        # Get inputs
        inputs, targets = (
            torch.stack(data["hidden"], dim=1).type(torch.FloatTensor),
            data["target"],
        )
        # Prepare targets
        if not isinstance(targets, torch.Tensor):
            targets = torch.tensor(targets)
        targets = targets.type(torch.FloatTensor).reshape((-1, 1))
        collected_targets.extend(targets.cpu().detach().numpy().tolist())

        inputs = inputs.to(device)
        targets = targets.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = mlp(inputs)
        collected_outputs.extend(outputs.cpu().detach())

        # Compute loss
        loss = loss_function(outputs, targets)

        # Perform backward pass
        loss.backward()

        # Perform optimization
        optimizer.step()

        # Print statistics
        current_loss += loss.item()

    concat_collected_outputs = torch.concat(collected_outputs).view(-1)
    concat_collected_targets = torch.tensor(collected_targets).view(-1)
    collected_outputs = []
    collected_targets = []

    print("Loss after epoch %5d: %.3f" % (epoch, current_loss / len(trainloader)))
    print(
        "accuracy: ",
        (
            torch.abs(
                torch.nn.functional.sigmoid(concat_collected_outputs) - concat_collected_targets
            )
            .sum()
            .item()
        )
        / len(concat_collected_outputs),
    )
    # print("mlp.layers[0].weight.sum(): ", mlp.layers[0].weight.sum())
    current_loss = 0.0

    # save model
    out_epoch_name = f"{out_name}_epoch{epoch}"
    torch.save(mlp, out_epoch_name)

# Process is complete.
print("Training process has finished.")

Starting epoch 0


Taking training steps... Epoch 1/5: 23it [00:27,  1.19s/it]

Loss after epoch     0: 0.000
accuracy:  0.21988248564506488
Starting epoch 1



Taking training steps... Epoch 2/5: 23it [00:27,  1.21s/it]

Loss after epoch     1: 0.000
accuracy:  0.17122304113836237
Starting epoch 2



Taking training steps... Epoch 3/5: 23it [00:26,  1.16s/it]

Loss after epoch     2: 0.000
accuracy:  0.15479655865111638
Starting epoch 3



Taking training steps... Epoch 4/5: 23it [00:27,  1.18s/it]

Loss after epoch     3: 0.000
accuracy:  0.14891273206700392
Starting epoch 4



Taking training steps... Epoch 5/5: 23it [00:26,  1.17s/it]

Loss after epoch     4: 0.000
accuracy:  0.14530304872273095
Training process has finished.





# Use the trained probe

In [None]:
HIDDEN_PATH_TEST = "hidden_states_and_green_scores_layer_16_test_ood.pt"

RESULTS_FILE_NAME = "trained_probe_results_ood.json"
from pathlib import Path

import torch
from torch import nn
from datasets import Dataset
import pandas as pd
from torch.utils.data import DataLoader


class MLP2(nn.Module):
    """
    Multilayer Perceptron.
    """

    def __init__(self, hidden_size):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        """Forward pass"""
        return self.layers(x)

    def score(self, o, hidden_layer_idx):
        return torch.sigmoid(
            self.layers(
                o.hidden_states[hidden_layer_idx][:, -1, :]
                .to(self.layers[0].weight.dtype)
                .to(self.layers[0].weight.device)
            ).squeeze(1)
        )


mlp = torch.load(
    "/home/guests/deniz_gueler/repos/RewardingVisualDoubt/workflows/report_generation/evaluations/trained_probes/trained_probe_layer16_balanced_difficulty.pt_epoch3"
)
data_test = torch.load(HIDDEN_PATH_TEST, map_location="cpu")
dataset_test = Dataset.from_pandas(pd.DataFrame(data=data_test))
testloader = DataLoader(dataset_test, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

device = "cuda"
collected_outputs = []
collected_targets = []

for i, data_ in tqdm(
    enumerate(testloader),
    desc=f"Testing... (total steps: {len(testloader)})",
):

    # Get inputs
    inputs, targets = (
        torch.stack(data_["hidden"], dim=1).type(torch.FloatTensor),
        data_["target"],
    )
    # Prepare targets
    if not isinstance(targets, torch.Tensor):
        targets = torch.tensor(targets)
    targets = targets.type(torch.FloatTensor).reshape((-1, 1))
    collected_targets.extend(targets.cpu().detach().numpy().tolist())

    inputs = inputs.to(device)
    targets = targets.to(device)

    outputs = mlp(inputs)
    collected_outputs.extend(outputs.cpu().detach())

concat_collected_outputs = torch.concat(collected_outputs).view(-1)
concat_collected_targets = torch.tensor(collected_targets).view(-1)


for record, current_output, current_target in zip(
    records, concat_collected_outputs.tolist(), concat_collected_targets.tolist()
):
    assert round(current_target, 2) == round(record["green_score"], 2)
    record["confidence"] = round(
        torch.nn.functional.sigmoid(torch.tensor(current_output)).item() * 10
    )
    dataset.append_records_to_json_file([record], RESULTS_FILE_NAME)

Testing... (total steps: 16): 16it [00:18,  1.18s/it]
