In [90]:
import os
import io
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import time
import torch, gc
import glob
import transformers
import tokenizers
import random
from torch.utils.data import Dataset
from PIL import Image, ImageFile
from datasets import load_dataset, concatenate_datasets
from pathlib import Path
from datasets.utils.logging import set_verbosity_info
from transformers import logging as tf_logging
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoProcessor
from functools import partial

In [91]:
@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})
    lazy_preprocess: bool = False
    is_multimodal: bool = False
    image_folder: Optional[str] = field(default=None)
    shortcaption_image_folder: Optional[str] = field(default=None)
    data_type: Optional[str] = field(default="mix")
    image_aspect_ratio: str = "square"

In [92]:
def process_sample(sample):
    try:
        metadata = sample["json"]
        return {
            "caption": metadata.get("caption"),
            "cot": metadata.get("cot"),
            "aspect_ratio": metadata.get("aspect_ratio"),
            "img_index": metadata.get("img_index")
        }
    except Exception as e:
        print(f"Error processing sample: {e}")
        return None

In [126]:
class LazySupervisedMixDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        data_path: str, #"/home/v-haodongli/Janus/tmp_script/laion_2b_aesthetic/{00042..00133}.tar"
        processor: AutoProcessor,
        # tokenizer: transformers.PreTrainedTokenizer,
        # data_args: DataArguments,
    ):
        super(LazySupervisedMixDataset, self).__init__()

        # self.data_args = data_args
        list_data_dict = []
        data_files = glob.glob(os.path.join(data_path, "*.tar"))
        train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=128)
        train_dataset = train_dataset.map(process_sample).filter(lambda x: x is not None)
        train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in (
            ["caption", "cot", "aspect_ratio", "img_index"])])
        list_data_dict.append(train_dataset)
        if len(list_data_dict) > 1:
            list_data_dict = concatenate_datasets(list_data_dict)
        else:
            list_data_dict = list_data_dict[0]
        list_data_dict = list_data_dict.shuffle(seed=42)
        
        self.processor = processor
        self.list_data_dict = list_data_dict
        print(self.list_data_dict)
    def __len__(self):
        return len(self.list_data_dict)

    @property
    def lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = 128 if "image" in sample else 0
            length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
        return length_list

    @property
    def modality_lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
            cur_len = cur_len if "image" in sample else -cur_len
            length_list.append(cur_len)
        return length_list

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]

        conversation = [
        {"role": "<|User|>", "content": sources['caption']},
        {"role": "<|Assistant|>", "content": f"{sources['cot']}<begin_of_image><end_of_image>"},
        ]
        system_prompt = "You are an assistant that creates images from descriptions. First, describe the image in detail, then generate it."
        prompt = self.processor.apply_sft_template_for_multi_turn_prompts(
        conversations=conversation,
        sft_format=self.processor.sft_format,
        system_prompt=system_prompt,
        )

        # tokenize prompt
        text_ids = self.processor.tokenizer.encode(prompt)
        all_ids = text_ids[:-2] + sources['img_index'] + text_ids[-2:]
        all_ids = torch.LongTensor(all_ids)

        # 构建图像 token 的 mask
        all_image_ids_mask = torch.zeros(all_ids.shape, dtype=torch.bool)
        all_image_ids_mask[:] = False
        all_image_ids_mask[-len(sources['img_index'])-2:-2] = True

        # 找到 Assistant 回答开始的位置
        try:
            assistant_start_token_id = self.processor.tokenizer.encode("<|Assistant|>")[0]
            assistant_start_index = text_ids.index(assistant_start_token_id)
        except (ValueError, IndexError):
            assistant_start_index = 0

        assistant_ids_mask = torch.zeros(all_ids.shape, dtype=torch.bool)
        assistant_ids_mask[assistant_start_index:] = True

        # 构造输入和标签
        input_ids = all_ids[:-1]
        text_ids_mask = all_image_ids_mask[:-1] == False
        image_ids_mask = all_image_ids_mask[:-1]
        label_ids = all_ids[1:]
        label_text_ids_mask = assistant_ids_mask[1:] & (all_image_ids_mask[1:] == False)
        label_image_ids_mask = assistant_ids_mask[1:] & all_image_ids_mask[1:]

        data_dict = {
            "input_ids": input_ids,
            "label_ids": label_ids,
            "text_ids_mask": text_ids_mask,
            "image_ids_mask": image_ids_mask,
            "label_text_ids_mask": label_text_ids_mask,
            "label_image_ids_mask": label_image_ids_mask,
        }
        return data_dict
            

In [160]:
from janus.models.processing_vlm import VLChatProcessor
processor: VLChatProcessor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-Pro-7B")
tokenizer = processor.tokenizer
padding_id = tokenizer.pad_token_id
print(f"Padding ID: {padding_id}")
train_dataset = LazySupervisedMixDataset(data_path="/home/v-haodongli/Janus/tmp_script/laion_2b_aesthetic", processor=processor)

Padding ID: 100015
Dataset({
    features: ['caption', 'cot', 'aspect_ratio', 'img_index'],
    num_rows: 202
})


In [187]:
len(train_dataset)

202

In [120]:
input_ids,_,_,_,_,_ = train_dataset[2]
input_ids.shape

torch.Size([839])

In [185]:
from dataclasses import dataclass
from typing import Sequence, Dict
import torch
import transformers
from torch.nn.utils.rnn import pad_sequence

@dataclass
class DataCollatorForSupervisedDataset:
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        # 提取所有字段
        input_ids, text_ids_mask, image_ids_mask, label_ids, label_text_ids_mask, label_image_ids_mask = (
            [instance[key] for instance in instances]
            for key in (
                "input_ids",
                "text_ids_mask",
                "image_ids_mask",
                "label_ids",
                "label_text_ids_mask",
                "label_image_ids_mask"
            )
        )

        # 转换为张量
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        label_ids = pad_sequence(label_ids, batch_first=True, padding_value=-100)  # IGNORE_INDEX

        # 对 mask 字段也进行 padding
        text_ids_mask = pad_sequence(text_ids_mask, batch_first=True, padding_value=0)
        image_ids_mask = pad_sequence(image_ids_mask, batch_first=True, padding_value=0)
        label_text_ids_mask = pad_sequence(label_text_ids_mask, batch_first=True, padding_value=0)
        label_image_ids_mask = pad_sequence(label_image_ids_mask, batch_first=True, padding_value=0)

        # 检查长度并截断
        if input_ids.shape[1] > self.tokenizer.model_max_length:
            print(f"Warning: input length {input_ids.shape[1]} exceeds model max length {self.tokenizer.model_max_length}")
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        label_ids = label_ids[:, :self.tokenizer.model_max_length]
        text_ids_mask = text_ids_mask[:, :self.tokenizer.model_max_length]
        image_ids_mask = image_ids_mask[:, :self.tokenizer.model_max_length]
        label_text_ids_mask = label_text_ids_mask[:, :self.tokenizer.model_max_length]
        label_image_ids_mask = label_image_ids_mask[:, :self.tokenizer.model_max_length]

        # 构建最终 batch
        batch = dict(
            input_ids=input_ids,
            label_ids=label_ids,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
            text_id_mask=text_ids_mask,
            image_id_mask=image_ids_mask,
            label_text_id_mask=label_text_ids_mask,
            label_image_id_mask=label_image_ids_mask,
        )

        return batch

In [186]:

data_collator = DataCollatorForSupervisedDataset(tokenizer) 
data_collator([train_dataset[1], train_dataset[2]])

{'input_ids': tensor([[100000,   2054,    418,  ..., 100015, 100015, 100015],
         [100000,   2054,    418,  ...,   1656,   3020, 100593]]),
 'label_ids': tensor([[  2054,    418,    274,  ...,   -100,   -100,   -100],
         [  2054,    418,    274,  ...,   3020, 100593, 100001]]),
 'attention_mask': tensor([[ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ...,  True,  True,  True]]),
 'text_id_mask': tensor([[ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False,  True]]),
 'image_id_mask': tensor([[False, False, False,  ..., False, False, False],
         [False, False, False,  ...,  True,  True, False]]),
 'label_text_id_mask': tensor([[ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False,  True,  True]]),
 'label_image_id_mask': tensor([[False, False, False,  ..., False, False, False],
         [False, False, False,  ...,  True, False, False]])}