In [None]:
TRAIN_BATCH_SIZE = 1
NUM_WORKERS = 8
MAX_TOKENS = 1300

In [None]:
import json
import os
import re
from IPython.display import display
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

import datasets
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig

from CodaDatasets import CodaDataset

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
device = 'cuda'
model_id = 'llava-hf/llava-1.5-7b-hf'
prompt_template = 'USER: {} ASSISTANT:'

In [None]:
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
)

In [None]:
model.load_adapter('models/lora_r64_5e-5_ep1')

In [None]:
model_pretrained = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
)

In [None]:
hf_dataset = {
    'train': datasets.load_dataset('ntudlcv/dlcv_2024_final1', split='train')
}

In [None]:
dataset = {
    'train': CodaDataset(hf_dataset['train'], has_answer=True)
}

In [None]:
def custom_collate_fn(batch):
    return zip(*batch)

In [None]:
dataloader = {
    'train': DataLoader(dataset['train'], batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=custom_collate_fn)
}

In [None]:
batch = next(iter(dataloader['train']))
data_id, question_type, image, question, answer = list(zip(*batch))[0]

print('data_id:', repr(data_id))
print('question_type:', repr(question_type))
display(image.resize((200, 200)))
print('question:', repr(question))
print('answer:', repr(answer))

In [None]:
gen_qa_prompt_template = "Given a reference text and a predicted text from an autonomous driving AI assistant, perform the following steps:\n\n1. Evaluate the semantic similarity and relevance of the predicted text compared to the reference text.\n2. Identify key details present in the reference text but missing or inaccurately represented in the predicted text.\n3. Formulate a question-answer pair starting with 'Q:' for the question and 'A:' for the answer, addressing the missing or misrepresented details. This question-answer pair will be used as additional training data for the AI assistant to improve its reasoning and contextual accuracy.\n\nReference text:\n{}\n\nPredicted text:\n{}"
print(gen_qa_prompt_template)

In [None]:
all_qa_pairs = []

pbar = tqdm(dataloader['train'])
for data in pbar:
    for data_id, question_type, image, question, answer in zip(*data):
        prompt = prompt_template.format(question)
        inputs = processor(images=image, text=prompt, return_tensors='pt').to(device)
        outputs = model.generate(**inputs, max_new_tokens=MAX_TOKENS, do_sample=False)
        
        generated_answer = processor.decode(outputs[0], skip_special_tokens=True).split('ASSISTANT: ')[1]
        new_prompt = prompt_template.format(gen_qa_prompt_template.format(repr(answer), repr(generated_answer)))
        print(repr(new_prompt))
        
        inputs = processor(text=new_prompt, return_tensors='pt').to(device)
        input_embeds = model_pretrained.get_input_embeddings()(inputs['input_ids'])
        outputs = model_pretrained.generate(inputs_embeds=input_embeds, attention_mask=inputs['attention_mask'], max_new_tokens=MAX_TOKENS, do_sample=False)
        generated_qa = processor.decode(outputs[0], skip_special_tokens=True)
        print(generated_qa)
        
        pattern = r'Q:\s*(.*?)\s*A:\s*(.*?)(?=Q:|\Z)'
        qa_pairs = re.findall(pattern, generated_qa, flags=re.DOTALL)
        qa_pairs = [(data_id, q.strip(), a.strip()) for q, a in qa_pairs]
        print(qa_pairs)
        
        all_qa_pairs.extend(qa_pairs)
        pbar.set_postfix_str(f'completed={len(all_qa_pairs)}')

        with open('qa_pairs.json', 'w') as f:
            json.dump(all_qa_pairs, f, indent=4)