In [14]:
%reload_ext autoreload
%autoreload 2

# Use aligner_v7 kernel

import sys
add_paths = [
    "/fsx_0/user/tranx/moe", # ALIGNER_PARENT_DIR
    "/fsx_0/user/tranx/moe/llm_mm_aligner/replicated", # ALIGNER_PARENT_DIR/llm_mm_aligner/replicated
    # "/data/home/tranx/conda/envs/aligner_20240822_v2/python-packages", #CONDA_PREFIX/python-packages
    # "/data/home/kapilk/.conda/envs/aligner_20240822_v2/python-packages"
    "/fsx_0/shared/conda/aligner_20241030/python-packages"
]

for p in add_paths:
    if p not in sys.path:
        sys.path.append(p)
        
import json
from pprint import pprint
import pickle 
from itertools import chain

import torch
import webdataset as wds
from torch.utils.data import DataLoader, IterableDataset

from transformers import HfArgumentParser
from llm_mm_aligner.lib.configs import (
    DataTrainingArguments,
    ModelArguments,
    TrainingArguments,
)
from llm_mm_aligner.lib.datasets.web_dataset import get_wb_dataset
from llm_mm_aligner.lib.data_collators import get_collator

In [7]:
def print_green(text):
    green_color = "\033[92m"  # bright green
    reset_color = "\033[0m"  # Reset the color to default terminal color

    print(f"{green_color}{text}{reset_color}")
    
    
def get_args_list(args):
    """
    Copied from https://fburl.com/code/3pq3dn99
    Convert a dict of args to a list of strings for passing to the binary
    """
    return list(
        chain.from_iterable(
            [f"--{k}", str(v)] if v is not None else [f"--{k}"] for k, v in args.items()
        )
    )


def get_local_test_artifacts():
    """
    Getting test artifacts from manifold or AWS fsx
    """

    tokenizer_file = (
        "/fsx_0/shared/qa/models/HFMetaFormerTokenizer/HFMetaFormerTokenizer.pkl"
    )
    preprocessor_file = (
        "/fsx_0/shared/qa/models/LlavaNextImageProcessor/LlavaNextImageProcessor.pkl"
    )
    # params_file = os.path.join(_THIS_DIRECTORY, "pretrain_llama3_8B.json")
    params_file = "/fsx_0/user/tranx/experiments/aligner/tests/pretrain_llama3_8B.json"

    with open(params_file, "r") as f:
        params = json.load(f)

    with open(tokenizer_file, "rb") as f:
        tokenizer = pickle.load(f)

    with open(preprocessor_file, "rb") as f:
        preprocessor = pickle.load(f)

    trainer_args = params.get("trainer_args", None)
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments)
    )

    model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        args=get_args_list(trainer_args)
    )

    print_green(f"data_path: {data_args.wd_data_path}")

    return model_args, data_args, training_args, tokenizer, preprocessor

In [29]:
model_args, data_args, training_args, tokenizer, preprocessor = (
    get_local_test_artifacts()
)

train_dataset = get_wb_dataset(
    preprocessor=preprocessor,
    data_args=data_args,
    model_args=model_args,
    training_args=training_args,
)

data_collator = get_collator(data_args, model_args, tokenizer)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=2,  # training_args.per_device_train_batch_size
    collate_fn=data_collator,
    num_workers=0,
    pin_memory=training_args.dataloader_pin_memory,  # False
)

[92mdata_path: /fsx_0/shared/qa/datasets/sg_mmllm_stage1_m2c2v3_sstk_10x_arxiv_pdf_mix_v6/shards[0m


In [24]:
print(f"number of batches = len(train_dataloader) = {len(train_dataloader)}")

number of batches = len(train_dataloader) = 10000


In [33]:
preprocessor

LlavaNextImageProcessor {
  "crop_size": {
    "height": 504,
    "width": 504
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_pad": true,
  "do_rescale": true,
  "do_resize": true,
  "image_grid_pinpoints": [
    [
      504,
      504
    ],
    [
      504,
      1008
    ],
    [
      504,
      1512
    ],
    [
      1008,
      504
    ],
    [
      1008,
      1008
    ],
    [
      1008,
      1512
    ],
    [
      1512,
      504
    ],
    [
      1512,
      1008
    ],
    [
      1512,
      1512
    ]
  ],
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "LlavaNextImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 504
  }
}

In [32]:
preprocessor.image_grid_pinpoints

[(504, 504),
 (504, 1008),
 (504, 1512),
 (1008, 504),
 (1008, 1008),
 (1008, 1512),
 (1512, 504),
 (1512, 1008),
 (1512, 1512)]

In [30]:
batch = next(iter(train_dataloader))
batch

Skipped 1 samples due to max batch size constraints


AssertionError: 