In [None]:
%load_ext autoreload
%autoreload 2

## Imports

In [None]:
import sys
import os

import torch
from llava.mm_utils import process_images
from tqdm import tqdm
from transformers import AutoTokenizer

# Add the src directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..", "src")))

from utils.train_utils import build_dataloader
from dataset.processor import Processor
from model.model import VisionLanguageModel
from utils.config import DatasetConfig, ExperimentConfig


import os
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf
from hydra.core.config_store import ConfigStore

OmegaConf.register_new_resolver(
    "ifel", lambda flag, val_true, val_false: val_true if flag else val_false
)

## Load config

In [None]:
# load hydra configs
cs = ConfigStore.instance()
cs.store(name="ExperimentConfig", node=ExperimentConfig)
cs.store(name="DatasetConfig", group="dataset", node=DatasetConfig)
# OmegaConf.register_new_resolver("models_dir", lambda: MODELS_DIR)


with initialize(version_base=None, config_path="../conf"):
    config = compose(config_name="train", overrides=["+experiment=train_local_test", "main_dir='..'"])
    print(OmegaConf.to_yaml(config))

## Load processor, tokenizer, val_dataloader, batch

In [None]:
MODEL_NAME = "last_model_silver-field-126.pt" #"checkpoint_1_vital-sound-133_1741647312.pt" #"last_model_legendary-cloud-125.pt"
MODEL_NAME = "checkpoint_3_rare-fire-135_1741767317.pt" #"checkpoint_3_balmy-snow-134_1741766686.pt"
config.num_coordinate_bins = 100
config.add_special_tokens = True # False

processor = Processor.from_config(config, add_special_tokens=config.add_special_tokens)
tokenizer = processor.tokenizer

In [None]:
val_dataloader = build_dataloader(
    processor=processor,
    dataset_config=config.train_dataset,
    batch_size=2,#config.batch_size,
    is_train=False, # val_dataset # CURRENTLY TRUE
    num_workers=config.num_workers,
    subset_size=10,
    # use_random_subset=True,
)

val_batch = next(iter(val_dataloader))

# test labels for train dataset
labels = val_batch["labels"][0][val_batch["labels"][0] != -100]
print(labels.shape)
print(tokenizer.decode(labels))

#check if labels is just -100
if torch.all(val_batch["labels"] == -100):
    print("All labels are -100")

val_batch["labels"], val_batch["bbox_str"]

In [None]:
from utils.data_utils import show_img_with_bbox
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as T
from PIL import Image

dataset = val_dataloader.dataset
# Get original dataset from dataset if Subset is used
if hasattr(dataset, "dataset"):
    dataset = dataset.dataset

class_id_to_name = dataset.index_to_cat_name

#print(example_batch)

#fig, ax = plt.subplots()   
axes = show_img_with_bbox(val_batch, dataset, figsize=(10,10))


## Load pretrained model and check batch

In [None]:
from model.model import VisionLanguageModel
model = VisionLanguageModel(
    config=config,
    image_token_index=processor.image_token_index,
    num_new_tokens=len(processor.special_tokens),
    do_init=config.add_special_tokens,
    initializers=processor.special_tokens_initializer
    )

In [None]:
device=torch.device("cpu")
state_dict = torch.load("../../checkpoints-trained/" + MODEL_NAME, map_location=device)
model.load_state_dict(state_dict.get("model_state_dict"))

### check batch 
 - decode val prompt
 - check tokenized special tokens
 - check format input bbox to xml
 - print decoded input_ids

In [None]:
# decoded validation prompt
processor.tokenizer.batch_decode(val_batch["input_ids"], skip_special_tokens=True)

In [None]:
# check if special tokens is correclty encoded, if add_special_tokens is True
str_special_tokens = "<object><x34/>"
tokenizer.encode(str_special_tokens, add_special_tokens=False), config.add_special_tokens

In [None]:
# check how model formats the input to xml
classes_str = ["class1", "class2"]
normalized_bbox = [[0.08, 0.1, 0.9, 0.9], [0.11, 0.4, 0.6, 0.7]]
processor.format_bbox_to_xml(classes_str, normalized_bbox)

In [None]:
with np.printoptions(threshold=np.inf):
    print(tokenizer.decode(val_batch["input_ids"][0].numpy()))
    #print(batch["attention_mask"][0].numpy())

tokenizer.batch_decode(val_batch["input_ids"], skip_special_tokens=True)

In [None]:
prompt1 = "Detect all objects in the image and output ONLY a valid JSON array of objects. Each object must have a 'class' (string name) and 'bbox' (normalized coordinates [x_min, y_min, x_max, y_max] between 0 and 1). Format: [{'class': 'person', 'bbox': [0.2, 0.3, 0.5, 0.8]}, {'class': 'car', 'bbox': [0.6, 0.7, 0.9, 0.95]}]. Include all visible objects, even if partially visible. Output nothing but the JSON array."
prompt2 = "Detect all objects in the image and output ONLY a valid JSON array of objects. Each object must have a 'class' (string name) and 'bbox' (list of 4 special coordinate tokens describing [x_min, y_min, x_max, y_max]). Format: [{'class': 'person', 'bbox': ['<coord_2>', '<coord_3>', '<coord_5>', '<coord_8>']}, {'class': 'car', 'bbox': ['<coord_6>', '<coord_7>', '<coord_9>', '<coord_9>']}]. Each <coord_X> token represents a quantized position. Include all visible objects, even if partially visible. Output nothing but the JSON array."

example_xml = "<annotation><object><class>car</class><bbox x0='0.14673' y0='0.36377' x1='0.18527' y1='0.44438'/></object><object><class>surfboard</class><bbox x0='0.0' y0='0.41329' x1='0.86317' y1='0.67906'/></object></annotation>"
prompt3 = f"Detect all objects in the image and output ONLY a valid XML of list of object. Each <object> must have a <class> (string name) and <bbox> (list of 4 special coordinate tokens <x0><y0><x1><y1>). Format: {example_xml}. Include all visible objects, even if partially visible. Output nothing but the XML."

# len of new prompt 2 with special tokens shorter than prompt 1, even tough string is longer
len(tokenizer.encode(prompt1)), len(tokenizer.encode(prompt2)), len(tokenizer.encode(prompt3))

## Generate output with pretrained model

In [None]:
# Load Model with pretrained projection layer
from utils.train_utils import JSONStoppingCriteria
model.eval()
# TODO: use val set, so info bout bbox is not in input_ids, check if image tokens are filled with image info

outputs = model.generate(
    input_ids=val_batch["input_ids"].to(device),
    attention_mask=val_batch["attention_mask"].to(device),
    image=val_batch["images"].to(device),
    stopping_criteria=[JSONStoppingCriteria(processor.tokenizer)],
    do_sample=True,
    temperature=.8,
    top_p = 0.9,
    top_k = 50,
)

# Decode predictions
generated_text, predicted_boxes = processor.postprocess_xml_batch(outputs, dataset, device)
print(len(outputs[0]), len(outputs[1]))

predicted_boxes, generated_text, val_batch["bbox_str"]

In [None]:
# Plot predicted boxes, target boxes and labels on images
id_to_cat_name = dataset.index_to_cat_name
print(predicted_boxes)

#predicted_boxes = [{"class": [1, 32], "bbox": [[0.4879453125, 0.6142578125, 0.6474609375, 0.814453125], [0.0, 0.0, 0.99951171875, 0.9990234375]]}]

for i in range(len(val_batch["images"])):
    fig, ax = plt.subplots()

    img, bboxes, categories = val_batch["images"][i], predicted_boxes[i]["boxes"], predicted_boxes[i]["labels"]

    img = img.permute(1, 2, 0).numpy()
    img = img - img.min()
    img = img / img.max()
    ax.imshow(img)

    for cat, bbox in zip(categories, bboxes):
        # print(bbox)
        x1, y1, x2, y2 = bbox # x_min, y_min, x_max, y_max -> YOLO format
        # x1, y1, x2, y2 = x1*img.shape[1], y1*img.shape[0], x2*img.shape[1], y2*img.shape[0] 
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor="r", facecolor="none")
        ax.add_patch(rect)
        
        # add label text to rect
        if cat.item() in id_to_cat_name:
            class_name = id_to_cat_name[cat.item()] #no .item()
        else:
            class_name = "Unknown"
        ax.text(x1, y1-5, class_name, fontsize=12, color="red")

    corr_boxes, corr_labels = val_batch["instance_bboxes"][i], val_batch["instance_classes_id"][i]

    for cat, bbox in zip(corr_labels, corr_boxes):
        x1, y1, x2, y2 = bbox
        x1, y1, x2, y2 = (
                x1 * img.shape[1],
                y1 * img.shape[0],
                x2 * img.shape[1],
                y2 * img.shape[0],
            )
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor="g", facecolor="none")
        ax.add_patch(rect)
        
        # add label text to rect
        if cat.item() in id_to_cat_name:
            class_name = id_to_cat_name[cat.item()]
        ax.text(x1, y1-5, class_name, fontsize=12, color="green")
    
    plt.show()

## Test Train Metrics

In [None]:
# remove second element from first element in predicted boxes in lists of boxes labels and scores
# remove second element from first element in predicted boxes in lists of boxes labels and scores

predicted_boxes2 = [{'boxes': torch.tensor([[120.1728, 106.3066, 304.2547, 354.1632],    
          [202.0569,   0.0000, 384.0000, 378.8736]]),
  'labels': torch.tensor([17,  70]),
  'scores': torch.tensor([1., 1.])},
 {'boxes': torch.tensor([[138.8928,  34.3304, 289.1942, 246.1824],
          [142.8058, 246.2093, 223.6800, 284.7475],
          [ 33.1200, 100.0166,  42.8160, 113.4874],
          [351.2563, 166.5485, 384.0000, 278.6074]]),
  'labels': torch.tensor([ 1, 41,2, 15]),
  'scores': torch.tensor([1., 1., 1., 1.])}]


In [None]:
# Calculate test metrics on generated sample
from utils.train_metrics import TrainMetrics
device = "cpu"
metrics = TrainMetrics(device=device)


target_boxes = processor.postprocess_target_batch(val_batch, device)

metrics.update(
    predicted_boxes=predicted_boxes,
    target_boxes=target_boxes,
    target_texts=val_batch["bbox_str"],
    generated_text=generated_text,
)
metrics.compute()

In [None]:
# Caluclate test metrics on fixed sample
pred_boxes_test = [{
        'boxes': torch.tensor([[183.7680,  92.0112, 332.6478, 226.5556], [169.6680,   5.6480, 324.7800, 377.7072]]), 
        'labels': torch.tensor([17, 70]), 
        'scores': torch.tensor([1., 1.])
    }, 
    {
        'boxes': torch.tensor([[305.1480, 223.1204, 320.6880, 235.3108], [ 73.9140,  37.1321, 219.9680, 262.9113]]),
        'labels': torch.tensor([15,  1]), 
        'scores': torch.tensor([1., 1.])
    }]
target_boxes_test = [{
        'boxes': torch.tensor([[169.6680,   5.6480, 324.7800, 377.7072], [225.7680,  92.0112, 332.6478, 226.5556]]), 
        'labels': torch.tensor([70, 17]), 
        'scores': torch.tensor([1., 1.])
    }, 
    {
        'boxes': torch.tensor([[305.1480, 223.1204, 320.6880, 235.3108], [ 73.9140,  37.1321, 219.9680, 262.9113]]),
        'labels': torch.tensor([15,  1]), 
        'scores': torch.tensor([1., 1.])
    }]

test_metrics = TrainMetrics(device=device)
test_metrics.update(predicted_boxes=pred_boxes_test, target_boxes=target_boxes_test, target_texts=val_batch["bbox_str"], generated_text=generated_text)
test_metrics.compute()

## Else

In [None]:
from utils.train_utils import build_train_dataloader

train_dataloader = build_train_dataloader(config, processor)
batch_train = next(iter(train_dataloader))

### Check input/labels of train dataset

In [None]:
processor.tokenizer.batch_decode(batch_train["input_ids"], skip_special_tokens=True)

In [None]:
batch_train["labels"].shape

In [None]:
if torch.all(batch_train["labels"] == -100):
    print("All labels are -100")
batch_train["labels"]

In [None]:
tokenizer.encode("<annotation>", add_special_tokens=False)

In [None]:
token = tokenizer.encode("<annotation>", add_special_tokens=False)

# find token in labels
for i, label in enumerate(batch_train["labels"]):
    for l in label:
        if l != -100:
            print(l, tokenizer.decode(l))

In [None]:
masked_labels = batch_train["labels"] != -100
print(masked_labels.shape)
processor.tokenizer.decode(batch_train["labels"][masked_labels])

In [None]:
forward_output = model.forward(
    input_ids=batch_train["input_ids"].to(device), 
    attention_mask=batch_train["attention_mask"].to(device), 
    images=batch_train["images"].to(device))

In [None]:
# mask logits to only get logits for labels that are not -100
logits_masked = forward_output.logits[masked_labels]
print(logits_masked.shape)

# check logits for first item in batch_train["labels"] that is not -100
#softmax_first_word = logits_masked.softmax(-1) #.softmax(-1) 
# plot this
#plt.plot(softmax_first_word.detach().numpy())

# get prob of annoation token as first token
# get index of annotation token
annotation_tag = tokenizer.encode("<annotation>", add_special_tokens=False)
print(annotation_tag)
print("annotation tag logit:", logits_masked[0][annotation_tag], "; max logit:", logits_masked[0].max(), " ; max logit index:", logits_masked[0].argmax(), " ; max decoded:", tokenizer.decode(logits_masked[0].argmax()))
print(logits_masked[1][annotation_tag])
print(logits_masked[2][annotation_tag])

plt.plot(logits_masked[0].softmax(-1).detach().numpy())



print(logits_masked.argmax(-1))
tokenizer.decode(logits_masked.argmax(-1))

In [None]:
batch_train["input_ids"].shape, forward_output.logits.shape

In [None]:
processor.tokenizer.batch_decode(forward_output.logits.argmax(dim=-1))

In [None]:
# check if attention mask is correctly set
lab = batch_train["labels"]
in_id = batch_train["input_ids"]
processor.tokenizer.decode(in_id[lab != -100])

### Check postprocessing

In [None]:
tt = "<annotation><object><class>dining table</class><bbox x_min='0.0015' y_min='0.0023944' x_max='1.0' y_max='0.97319'/></object><object><class>spoon</class><bbox x_min='0.70611' y_min='0.0033803' x_max='0.92198' y_max='0.85472'/></object><object><class>cake</class><bbox x_min='0.13873' y_min='0.17587' x_max='0.72175' y_max='1.0'/></object></annotation><|im_end|>"
tt = "<annotation><object><class>person</class><bbox><x27/><y67/><x29/><y71/></bbox></object><object><class>surfboard</class><bbox><x28/><y70/><x32/><y74/></bbox></object><object><class>kite</class><bbox><x75/><y09/><x90/><y36/></bbox></object><object><class>person</class><bbox><x97/><y39/><x98/><y40/></bbox></object></annotation>"
processor._postprocess_xml(tt, val_dataloader.dataset.dataset.cat_name_to_id, "cpu")

In [None]:
#tt = "<annotation><object><class>dining table</class><bbox x_min='0.0015' y_min='0.0023944' x_max='1.0' y_max='0.97319'/></object><object><class>spoon</class><bbox x_min='0.70611' y_min='0.0033803' x_max='0.92198' y_max='0.85472'/></object><object><class>cake</class><bbox x_min='0.13873' y_min='0.17587' x_max='0.72175' y_max='1.0'/></object></annotation><|im_end|>"
tt_token = [tokenizer.encode(tt, add_special_tokens=False)]
processor.postprocess_xml_batch(tt_token, val_dataloader.dataset, "cpu")

### Try generate with new prompt

In [None]:
# generate with new prompt

prompt = "Detect all objects in the image and output ONLY a with specified XML tags."

example_xml = "<annotation><object><class>car</class><bbox x0='0.14673' y0='0.36377' x1='0.18527' y1='0.44438'/></object><object><class>surfboard</class><bbox x_min='0.0' y_min='0.41329' x_max='0.86317' y_max='0.67906'/></object></annotation>"
prompt = f"Detect all objects in the image and output valid XML of root <annotation> and child <object>. Each <object> must have a <class> (string name) and <bbox (4 attributes with normalized coordinates x_min, y_min, x_max, y_max). Example: {example_xml}. Include all visible objects, even if partially visible. Output nothing but the XML."

prompt1, gen_text = processor.prepare_text_input(config.num_image_tokens, [], [], [], prompt = prompt, train=False)
print(prompt1)
tokenized = tokenizer(
            [prompt1, prompt1],
            padding=True,
            truncation=True,
            max_length=config.max_tokens,
            pad_to_multiple_of=config.pad_to_multiple_of,
            return_tensors="pt",
)


outputs = model.generate(
    input_ids=tokenized.input_ids.to(device),
    attention_mask=tokenized.attention_mask.to(device),
    image=val_batch["images"].to(device),
    stopping_criteria=[JSONStoppingCriteria(processor.tokenizer)],
    do_sample=True,
    temperature=.6,
    top_p = 0.9,
    top_k = 50,
)

print("Output:", processor.tokenizer.batch_decode(outputs, skip_special_tokens=False))

In [None]:
tokenizer.batch_decode(outputs, skip_special_tokens=True)

### Token length plots

In [None]:
model_name = "lmms-lab/llava-onevision-qwen2-0.5b-si"
config.batch_size = 1
config.max_tokens = None
config.pad_to_multiple_of = None

processor = processor.from_config(config, add_special_tokens=None)
model = VisionLanguageModel(config=config, image_token_index=processor.image_token_index, num_new_tokens=0, do_init=False, initializers=None)
dataloader = build_train_dataloader(config, processor)

token_sizes = []
for batch in tqdm(dataloader, desc="Processing batches"):
    token_size = batch["input_ids"].shape[1]
    token_sizes.append(token_size)


In [None]:
# Calculate statistics
max_size = max(token_sizes)
min_size = min(token_sizes)
avg_size = sum(token_sizes) / len(token_sizes)

print(f"Token size statistics:")
print(f"Max: {max_size}")
print(f"Min: {min_size}")
print(f"Average: {avg_size:.2f}")
print(f"Number of samples: {len(token_sizes)}")

In [None]:
# Calculate number of samples with token size > 3200
count = 0
for size in token_sizes:
    if size > 3200:
        count += 1
count, count/len(token_sizes)

In [None]:
# create new plot with log scale
plt.figure(figsize=(12, 8))
plt.hist(token_sizes, range=(900, 6200), bins=30, orientation="horizontal", log=False, color='tab:blue', edgecolor='black')
plt.ylabel("Token size", fontsize=14)
plt.xlabel("Frequency", fontsize=14)
plt.title("Token size distribution", fontsize=16)
plt.grid(True, which="both", linestyle='--', linewidth=0.5)

# Get the current axes
ax = plt.gca()

# Iterate over the patches (bars) and set the color
for patch in ax.patches:
	if patch.get_y() > 3200:
		patch.set_facecolor('tab:red')

# Add a legend
import matplotlib.patches as mpatches
blue_patch = mpatches.Patch(color='tab:blue', label='Token size <= 3200')
red_patch = mpatches.Patch(color='tab:red', label='Token size > 3200')
plt.legend(handles=[blue_patch, red_patch], fontsize=12)

# Adjust layout
plt.tight_layout()

plt.show()

### End