In [None]:
# ! pip uninstall -y pyarrow
# ! pip install pyarrow==9.0.0
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List

import torch

import transformers

from torch.utils.data import Dataset

from PIL import Image

from datasets import load_dataset

import matplotlib.pyplot as plt


import sys
sys.path.append("/mnt/workdisk/jasmine/LLaVA")
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava import conversation as conversation_lib


In [None]:
def preprocess_multimodal(
    sources: Sequence[str],
    data_args: DataArguments
) -> Dict:
    is_multimodal = data_args.is_multimodal
    if not is_multimodal:
        return sources

    for source in sources:
        for sentence in source:
            if DEFAULT_IMAGE_TOKEN in sentence['value']:
                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
                sentence['value'] = sentence['value'].strip()
                if "mmtag" in conversation_lib.default_conversation.version:
                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
            replace_token = DEFAULT_IMAGE_TOKEN
            if data_args.mm_use_im_start_end:
                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
            sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)

    return sources

def preprocess_mpt(
    sources,
    # tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    conv = conversation_lib.default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    # Tokenize conversations
    input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
    targets = input_ids.clone()
    assert conv.sep_style == conversation_lib.SeparatorStyle.MPT

    # Mask targets
    sep = conv.sep + conv.roles[1]
    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        rounds = conversation.split(conv.sep)
        re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
        for conv_idx in range(3, len(rounds), 2):
            re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))    # user + gpt
        cur_len = 0
        target[:cur_len] = IGNORE_INDEX
        for i, rou in enumerate(re_rounds):
            if rou == "":
                break

            parts = rou.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep
            round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
            instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

            cur_len += round_len
        target[cur_len:] = IGNORE_INDEX

        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_INDEX
                print(
                    f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                    f" (ignored)"
                )

    return dict(
        input_ids=input_ids,
        labels=targets,
    )

@dataclass
class DataArguments:
    data_path: str = field(default=None,
                           metadata={"help": "Path to the training data."})
    lazy_preprocess: bool = False
    is_multimodal: bool = False
    image_folder: Optional[str] = field(default=None)
    image_aspect_ratio: str = 'square'
    
class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str,
                #  tokenizer: transformers.PreTrainedTokenizer,
                 data_args: DataArguments):
        super(LazySupervisedDataset, self).__init__()
        list_data_dict = json.load(open(data_path, "r"))

        print("Formatting inputs...Skip in lazy mode")
        # self.tokenizer = tokenizer
        self.list_data_dict = list_data_dict
        self.data_args = data_args

    def __len__(self):
        return len(self.list_data_dict)

    @property
    def lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = 128 if 'image' in sample else 0
            length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
        return length_list

    @property
    def modality_lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
            cur_len = cur_len if 'image' in sample else -cur_len
            length_list.append(cur_len)
        return length_list

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if 'image' in sources[0]:
            image_file = self.list_data_dict[i]['image']
            image_folder = self.data_args.image_folder
            # processor = self.data_args.image_processor
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            if self.data_args.image_aspect_ratio == 'pad':
                def expand2square(pil_img, background_color):
                    width, height = pil_img.size
                    if width == height:
                        return pil_img
                    elif width > height:
                        result = Image.new(pil_img.mode, (width, width), background_color)
                        result.paste(pil_img, (0, (width - height) // 2))
                        return result
                    else:
                        result = Image.new(pil_img.mode, (height, height), background_color)
                        result.paste(pil_img, ((height - width) // 2, 0))
                        return result
                # image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
                image = expand2square(image, (255, 255, 255))
                # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            else:
                pass
                # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args)
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])

        # data_dict = preprocess(
        #     sources,
        #     self.tokenizer,
        #     has_image=('image' in self.list_data_dict[i]))

        data_dict = preprocess_mpt(sources)
        
        if isinstance(i, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0])

        data_dict['sources'] = sources

        # image exist in the data
        if 'image' in self.list_data_dict[i]:
            data_dict['image'] = image
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            crop_size = 128 #self.data_args.image_processor.crop_size
            data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
        return data_dict

In [None]:
data_path = "/mnt/workdisk/jasmine/data/llava/LLaVA-Instruct-150K/v1-ft/llava_instruct_80k.json"
image_folder = '/mnt/workdisk/jasmine/data/llava/data/train2017'
data_args = DataArguments(data_path = data_path, image_folder = image_folder, image_aspect_ratio='pad')


dataset = LazySupervisedDataset(data_path, data_args=data_args)

In [None]:
# Load the dataset
# dataset = load_dataset("liuhaotian/LLaVA-CC3M-Pretrain-595K")
# dataset = load_dataset("/mnt/workdisk/jasmine/data/llava/LLaVA-Instruct-150K/tmp")
dataset = load_dataset("/mnt/workdisk/jasmine/data/llava/LLaVA-Instruct-150K/v1-ft")

# Print the first example
print(dataset["train"][0])

In [None]:
# for i in range(50): #len(dataset)):

#     # image = os.path.join(image_folder, dataset['train'][i]['image'])
#     # img = Image.open(image)
#     # img.show()
#     convos = dataset["train"][i]['conversations']
#     print(len(convos))
#     # print(convos[0])
#     # print(convos[1])
#     # print(len(convos))

In [None]:
for i in range(3):
    image = os.path.join(image_folder, dataset['train'][i]['image'])
    img = Image.open(image)
    img.show()
    convos = dataset["train"][i]['conversations']
    for convo in convos:
        print(convo)

    print(len(convos))

In [None]:
sources[0]

In [None]:
# conv = conversation_lib.default_conversation.copy()
# roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# # print(roles)

# sources = [dataset["train"][0]['conversations'], dataset["train"][1]['conversations'], dataset["train"][2]['conversations']]
# # print(sources)

# # Apply prompt templates
# conversations = []
# for i, source in enumerate(sources):
#     if roles[source[0]["from"]] != conv.roles[0]:
#         # Skip the first one if it is not from human
#         source = source[1:]

#     conv.messages = []
#     for j, sentence in enumerate(source):
#         role = roles[sentence["from"]]
#         assert role == conv.roles[j % 2], f"{i}"
#         conv.append_message(role, sentence["value"])
#     conversations.append(conv.get_prompt())


# for convo in conversations:
#     print(convo)
#     print('------------------')

# # Tokenize conversations
# input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
# targets = input_ids.clone()
# assert conv.sep_style == conversation_lib.SeparatorStyle.MPT

# # Mask targets
# sep = conv.sep + conv.roles[1]
# for conversation, target in zip(conversations, targets):
#     total_len = int(target.ne(tokenizer.pad_token_id).sum())

#     rounds = conversation.split(conv.sep)
#     re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
#     for conv_idx in range(3, len(rounds), 2):
#         re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))    # user + gpt
#     cur_len = 0
#     target[:cur_len] = IGNORE_INDEX
#     for i, rou in enumerate(re_rounds):
#         if rou == "":
#             break

#         parts = rou.split(sep)
#         if len(parts) != 2:
#             break
#         parts[0] += sep
#         round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
#         instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
#         target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

#         cur_len += round_len
#     target[cur_len:] = IGNORE_INDEX

#     if cur_len < tokenizer.model_max_length:
#         if cur_len != total_len:
#             target[:] = IGNORE_INDEX
#             print(
#                 f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
#                 f" (ignored)"
#             )

# return dict(
#     input_ids=input_ids,
#     labels=targets,
# )

In [None]:
SYSTEM = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.'
tok = transformers.AutoTokenizer.from_pretrained('rajammanabrolu/gpt-4-chat', trust_remote_code=True)


# convo = [{'role': 'system', 'content': SYSTEM}, 
#          {'role': 'user', 'content': 'hello'}, {'role': 'assistant', 'content': 'hi there'}, {'role': 'user', 'content': 'tell me a joke'}, {'role': 'assistant', 'content': 'knock knock...'}]

# d1 = {'prompt': tok.apply_chat_template(convo[:2], tokenize=False, add_generation_prompt=True), 'response': convo[2]['content']}
# d2 = {'prompt': tok.apply_chat_template(convo[:4], tokenize=False, add_generation_prompt=True), 'response': convo[4]['content']}

# format multi-turn data
training_data = []
for i in range(dataset.num_rows['train']):
    # image = os.path.join(image_folder, dataset['train'][i]['image'])
    # img = Image.open(image)
    # img.show()
    convo = dataset["train"][i]['conversations']
    # conv = conversation_lib.default_conversation.copy()
    # if roles[source[0]["from"]] != conv.roles[0]:
    if convo[0]['from'] != 'human':
        convo = convo[1:]
        
    formatted_convo = [{'role': 'system', 'content': SYSTEM}]
    for line in convo:
        if line['from'] == 'human':
            formatted_convo.append({'role': 'user', 'content': line['value']})
        elif line['from'] == 'gpt':
            formatted_convo.append({'role': 'assistant', 'content': line['value']})
        else:
            assert False, 'unrecognized from: %s'%line['from']

    # print(formatted_convo)
    for j in range(len(formatted_convo)//2):
        data_pt = {'prompt': tok.apply_chat_template(formatted_convo[:2*(j+1)], tokenize=False, add_generation_prompt=True), 
                   'response': formatted_convo[2*(j+1)]['content']}
        training_data.append(data_pt)
        # print(data_pt)

print(len(training_data))

In [None]:
for i in range(20):
    print(training_data[i])

In [None]:
# tok = transformers.AutoTokenizer.from_pretrained('rajammanabrolu/gpt-4-chat', trust_remote_code=True)

# def preprocess_prompt(prompt: str, system=None):
#     if system is None:
#         system = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.'
#     s = tok.apply_chat_template(
#         [
#             {'role': 'system', 'content': system},
#             {'role': 'user', 'content': prompt},
#         ],
#         tokenize=False,
#         add_generation_prompt=True,
#     )
#     return s

In [None]:
SYSTEM = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.'

tok = transformers.AutoTokenizer.from_pretrained('rajammanabrolu/gpt-4-chat', trust_remote_code=True)

convo = [{'role': 'system', 'content': SYSTEM}, {'role': 'user', 'content': 'hello'}, {'role': 'assistant', 'content': 'hi there'}, {'role': 'user', 'content': 'tell me a joke'}, {'role': 'assistant', 'content': 'knock knock...'}]


# for line in convo:
    # print(preprocess_prompt(line['content']))

d1 = {'prompt': tok.apply_chat_template(convo[:2], tokenize=False, add_generation_prompt=True), 'response': convo[2]['content']}
d2 = {'prompt': tok.apply_chat_template(convo[:4], tokenize=False, add_generation_prompt=True), 'response': convo[4]['content']}

print(d1)
print(d2)

In [None]:
for i in range(3):
    image = os.path.join(image_folder, dataset['train'][i]['image'])
    img = Image.open(image)
    img.show()
    convos = dataset["train"][i]['conversations']
    for convo in convos:
        print(convo)

    print(len(convos))