# Embedding Pipelines (1. VLM captioning; 2. CLIP visual embedding)

## 1. VLM captioning (example: Waymo dataset)

In [None]:
# Load the name of all images to be captioned
# only captioning fron-viewing images

import os

def list_files_in_folder(folder_path):
    try:
        # list all files and directories in the specified directory
        files_and_dirs = os.listdir(folder_path)
        
        # filter out directories, keep only files
        files = [f for f in files_and_dirs if os.path.isfile(os.path.join(folder_path, f))]
        
        return files
    except Exception as e:
        return f"An error occurred: {str(e)}"

folder_path = './image_0/'
files = list_files_in_folder(folder_path)
files.sort()
print(len(files))

In [None]:
trfiles = files[:31617]    # training images
vlfiles = files[31617:]    # validation images

# only captioning training images

In [None]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle

from PIL import Image
import requests
import copy
import torch

import argparse
import torch
import os
import yaml
import pickle as pkl
import json

from tqdm import tqdm

pretrained = "lmms-lab/llama3-llava-next-8b"
model_name = "llava_llama3"
device = "cuda"
device_map = "auto"
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args

model.eval()
model.tie_weights()

trresults = []
import pickle
counter = 1

for i in tqdm(trfiles):
    image = Image.open('./image_0/'+i).convert('RGB')


    image_tensor = process_images([image], image_processor, model.config)
    image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
    
    conv_template = "llava_llama_3" # Make sure you use correct chat template for different models
    
    with open('./questions/av_driving_condition_description.yaml') as f:
        questions = yaml.load(f, Loader=yaml.FullLoader)
    # Prompt: 
    # 
    # Generic AV Prompt from [Shen, M., Chang, N., Liu, S., & Alvarez, J. M. (2024). SSE: Multimodal Semantic 
    # Data Selection and Enrichment for Industrial-scale Data Assimilation. arXiv preprint arXiv:2409.13860.]
    #
    # The image is taken from inside the ego vehicle looking out through the windshield onto a road and you are
    # the driver of the ego vehicle. Describe the driving condition shown in the image in 150 words
    
    question = DEFAULT_IMAGE_TOKEN + "\n" + questions["common"][::-1].pop()
    conv = copy.deepcopy(conv_templates[conv_template])
    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt_question = conv.get_prompt()
    
    input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
    
    
    cont=model.generate(
        input_ids,
        images=image_tensor,
        image_sizes=[image.size],
        do_sample=True,
        temperature=0.2,
        max_new_tokens=512,
    )
    
    text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
    answer = text_outputs[0].replace("<s>", "").replace("</s>", "").strip()
    answer = answer.replace("<|startoftext|>", "").replace("<|im_end|>", "").strip()
    print(answer)

    trresults.append(answer)
    if counter%1000==0:
        with open("./trllava16b8l3"+str(counter)+".pkl", 'wb') as file:  # 'wb' mode opens the file in binary write mode
            pickle.dump(trresults, file, protocol=4)
    counter += 1
    
with open("./trllava16b8l3-final.pkl", 'wb') as file:  # 'wb' mode opens the file in binary write mode
    pickle.dump(trresults, file, protocol=4)

### Embed the captions with SentenceTransformer

In [None]:
from sentence_transformers import SentenceTransformer

# model = SentenceTransformer("all-MiniLM-L6-v2")
model = SentenceTransformer("all-mpnet-base-v2")

sentences = trresults

# 2. Calculate embeddings by calling model.encode()
embeddings = model.encode(sentences)

with open("./capemb.pkl", 'wb') as file:  # 'wb' mode opens the file in binary write mode
    pickle.dump(embeddings, file, protocol=4)

## 2. CLIP visual embedding (example: Waymo dataset)

In [None]:
# only embedding training images

import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import os

from tqdm import tqdm

# Load the pre-trained CLIP model and processor
model_name = "openai/clip-vit-large-patch14"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Load and process images
embeddings = []
for filename in tqdm(trfiles):
    image = Image.open(filename).convert('RGB')

    # Preprocess image
    inputs = processor(images=image, return_tensors="pt").to(device)

    # Get image embeddings
    with torch.no_grad():
        image_features = model.get_image_features(**inputs)

    # Normalize embeddings
    embedding = image_features / image_features.norm(p=2, dim=-1, keepdim=True)

    # Store in dictionary
    embeddings.append(embedding.cpu().numpy())

    
with open("./trclip.pkl", 'wb') as file:  # 'wb' mode opens the file in binary write mode
    pickle.dump(embeddings, file, protocol=4)

embeddings.shape