In [16]:
import pandas as pd
import numpy as np
import os
from PIL import Image

In [3]:
# Load the training data
data = pd.read_csv('spatial_hm3d_responses.csv')

In [4]:
# Select image history, question answer column
train_data = data[['question', 'answer', 'episode_history']]

In [5]:
train_data

Unnamed: 0,question,answer,episode_history
0,What is in between the two picture frames on t...,The TV,hm3d-v0/000-hm3d-BFRyYbPCCPE
1,Is there room on the dining table to eat?,Yes,hm3d-v0/000-hm3d-BFRyYbPCCPE
2,What is to the left of the mirror?,A plant in a tall vase,hm3d-v0/001-hm3d-TPhiubUHKcP
3,What is to the left of the staircase?,A storage closet,hm3d-v0/001-hm3d-TPhiubUHKcP
4,What is on the top shelf to the right side of ...,An ice cooler,hm3d-v0/002-hm3d-wcojb4TFT35
...,...,...,...
64,Is there space for another pillow on the back ...,No,hm3d-v0/094-hm3d-Qpor2mEya8F
65,"Where is the ""Be Thankful"" poster?",Above the stairs,hm3d-v0/096-hm3d-uLz9jNga3kC
66,"Where is ""Bless this home"" written?",Above the dinner room entrance,hm3d-v0/096-hm3d-uLz9jNga3kC
67,What is on the opposite wall of the bed?,TV,hm3d-v0/099-hm3d-q5QZSEeHe5g


## Prepare training data

In [17]:
def load_images_to_array(directory):
    # List all PNG files in the directory
    files = [f for f in sorted(os.listdir(directory)) if f.endswith('.png')]
    images = []
    
    # Loop through files and load each image
    for file in files:
        # Construct full file path
        file_path = os.path.join(directory, file)
        # Open the image file
        with Image.open(file_path) as img:
            # Convert the image to RGB (if not already in that format)
            img = img.convert('RGB')
            # Convert the image to a NumPy array and append to list
            images.append(np.array(img))
    
    video = np.stack(images)
    total_frames = video.shape[0]

    indices = np.arange(0, total_frames, total_frames / 8).astype(int)  # np.linspace(0, total_frames - 1, num=8, dtype=int)
    print("Number of frames: ", total_frames)
    print("Selected indices: ", indices)
    images = [video[i] for i in indices]
    # Stack all image arrays into a single numpy array
    return np.stack(images)

In [9]:
def json_item(image, question, answer, index):
    new_item = {
        "id": f"{index:03}",  # Formats index leading zeros
        "image": image,
        "conversations": [
            {
                "from": "human",
                "value": f"<image>\n{question}"
            },
            {
                "from": "gpt",
                "value": answer
            }
        ]
    }
    return new_item

In [18]:
train_json = []

for index, row in train_data.iterrows():
    question = row['question']
    answer = row['answer']
    ep_history = row['episode_history']
    episode_history = f"data/frames/{ep_history}"
    images = load_images_to_array(episode_history)
    train_json.append(json_item(images, question, answer, index))

Number of frames:  99
Selected indices:  [ 0 12 24 37 49 61 74 86]
Number of frames:  99
Selected indices:  [ 0 12 24 37 49 61 74 86]
Number of frames:  120
Selected indices:  [  0  15  30  45  60  75  90 105]
Number of frames:  120
Selected indices:  [  0  15  30  45  60  75  90 105]
Number of frames:  82
Selected indices:  [ 0 10 20 30 41 51 61 71]
Number of frames:  120
Selected indices:  [  0  15  30  45  60  75  90 105]
Number of frames:  120
Selected indices:  [  0  15  30  45  60  75  90 105]
Number of frames:  100
Selected indices:  [ 0 12 25 37 50 62 75 87]
Number of frames:  120
Selected indices:  [  0  15  30  45  60  75  90 105]
Number of frames:  120
Selected indices:  [  0  15  30  45  60  75  90 105]
Number of frames:  120
Selected indices:  [  0  15  30  45  60  75  90 105]
Number of frames:  120
Selected indices:  [  0  15  30  45  60  75  90 105]
Number of frames:  82
Selected indices:  [ 0 10 20 30 41 51 61 71]
Number of frames:  82
Selected indices:  [ 0 10 20 30 41

In [21]:
# Save json
import json
with open('train.json', 'w') as f:
    json.dump(train_json, f)

TypeError: Object of type ndarray is not JSON serializable

## Set up metrics
We'll use Exact Match and F1-score since this is a question answering task

In [1]:
from sklearn.metrics import f1_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    em = sum([1 if p == l else 0 for p, l in zip(preds, labels)]) / len(labels)
    f1 = f1_score(labels, preds, average='macro')
    return {
        'exact_match': em,
        'f1': f1,
    }

## Fine-tune the model

In [None]:
from transformers import Trainer, TrainingArguments, VideoLlavaForConditionalGeneration, VideoLlavaProcessor

# load the model
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", device_map="auto")

# define training arguments
training_args = TrainingArguments(
    output_dir="./finetune_results",
    evaluation_strategy="epoch",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    save_total_limit=2,
    save_steps=500,
    logging_dir="./logs",
    load_best_model_at_end=True,
    push_to_hub=False,
)

# create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset