In [None]:
%load_ext autoreload
%autoreload 2

## Imports

In [None]:
import sys
import os

import torch
from llava.mm_utils import process_images
from tqdm import tqdm
from transformers import AutoTokenizer

# Add the src directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..", "src")))

from utils.train_utils import build_dataloader, build_train_dataloader, build_val_dataloader
from dataset.processor_fasterrcnn import FastRCNNProcessor
from model.model import VisionLanguageModel
from model.fastrcnn_adapter import FastRCNNAdapter
from utils.config import DatasetConfig, ExperimentConfig
from utils.train_metrics import TrainMetrics


In [None]:
#hydra imports
import os
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf
from hydra.core.config_store import ConfigStore

OmegaConf.register_new_resolver(
    "ifel", lambda flag, val_true, val_false: val_true if flag else val_false
)

## Load config

In [None]:
# load hydra configs
cs = ConfigStore.instance()
cs.store(name="ExperimentConfig", node=ExperimentConfig)
cs.store(name="DatasetConfig", group="dataset", node=DatasetConfig)
# OmegaConf.register_new_resolver("models_dir", lambda: MODELS_DIR)


with initialize(version_base=None, config_path="../conf"):
    config = compose(config_name="train", overrides=["+experiment=train_local_test", "main_dir='..'"])
    print(OmegaConf.to_yaml(config))

## Load processor, tokenizer, val_dataloader, batch

In [None]:
MODEL_NAME = "last_model_silver-field-126.pt" #"checkpoint_1_vital-sound-133_1741647312.pt" #"last_model_legendary-cloud-125.pt"
MODEL_NAME = "checkpoint_3_rare-fire-135_1741767317.pt" #"checkpoint_3_balmy-snow-134_1741766686.pt"
config.num_coordinate_bins = 100
config.model_name = "fasterrcnn-resnet50-fpn"

processor = FastRCNNProcessor.from_config(config)
tokenizer = processor.tokenizer

In [None]:
val_dataloader = build_val_dataloader(config=config, processor=processor, subset_size=10, use_random_subset=False)
val_batch = list(val_dataloader)[1]   # get 2nd batch from val_dataloader

val_batch

In [None]:
train_dataloader = build_train_dataloader(config=config, processor=processor, subset_size=10)
train_batch = next(iter(train_dataloader))

In [None]:
model = FastRCNNAdapter(config)

In [None]:
# test model generate
batch = val_batch
images = batch["images"]

print(images.shape)

output = model.generate(image=images)
print(output)

target_boxes = processor.postprocess_target_batch(batch=batch, device=config.device)
print(target_boxes)

metric = TrainMetrics(config.device, download_nltk=False)
metric.update(output, target_boxes, None, None)
print(metric.compute())

In [None]:
from llava.model.language_model.llava_qwen import LlavaQwenForCausalLM

image_encoder = LlavaQwenForCausalLM.from_pretrained(
    "lmms-lab/llava-onevision-qwen2-0.5b-si"
).get_vision_tower()

image_encoder(images)

In [None]:
# test model forward
output = model(input_ids=input_ids, images=images, labels=target_boxes)
print(output)