In [None]:
import os
import io
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import time
import torch, gc
import glob
import transformers
import tokenizers
import random
from torch.utils.data import Dataset
from PIL import Image, ImageFile
from datasets import load_dataset, concatenate_datasets
from pathlib import Path
from datasets.utils.logging import set_verbosity_info
from transformers import logging as tf_logging
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoProcessor
from functools import partial
import webdataset as wds

In [None]:
input_data_url = "https://facevcstandard.blob.core.windows.net/doch/data/laion2B-en-aesthetic/00000.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D"

bad_indices = []
threshold_bytes = 2000
index = 0

def url_logger(urls):
    for url in urls:
        print(f"Processing tar file: {url}")
        yield url

dataset = wds.WebDataset(input_data_url).map_dict(jpg=lambda x: x if x else None, png=lambda x: x if x else None)

for sample in dataset:
    print(sample)
    image_data = None
    if "jpg" in sample:
        image_data = sample["jpg"]
    elif "png" in sample:
        image_data = sample["png"]
    else:
        index += 1
        continue

    try:
        # 获取 caption 
        caption = "No caption"
        if "json" in sample:
            meta = json.loads(sample["json"])
            caption = meta.get("caption", "No caption")

        file_size = len(image_data)
        img = Image.open(io.BytesIO(image_data))
        width, height = img.size

        print(f"[{index}] Caption: {caption}")
        print(f"File size: {file_size} bytes | Image size: {width}x{height}")

        if file_size < threshold_bytes:
            bad_indices.append(index)
            print("⚠️ Warning: This image may be corrupted (too small).\n")
    except Exception as e:
        print(f"[{index}] ❌ Failed to open image: {e}\n")
        bad_indices.append(index)

    index += 1

print("Bad indices:", bad_indices)

In [None]:
input_data_url = "https://facevcstandard.blob.core.windows.net/doch/data/laion2B-en-aesthetic/00000.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D"

dataset = wds.WebDataset(input_data_url)

bad_indices = []
threshold_bytes = 2000
index = 0

for sample in dataset:
    print(sample)
    image_data = None
    if "jpg" in sample:
        image_data = sample["jpg"]
    elif "png" in sample:
        image_data = sample["png"]
    else:
        index += 1
        continue

    # 解析 JSON 元数据 
    caption = "No caption"
    if "json" in sample:
        try:
            meta = json.loads(sample["json"])  # 关键修复点：把 bytes 转为 dict
            caption = meta.get("caption", "No caption")
        except json.JSONDecodeError:
            caption = "Invalid JSON"

    file_size = len(image_data)

    try:
        img = Image.open(io.BytesIO(image_data))
        width, height = img.size

        print(f"[{index}] Caption: {caption}")
        print(f"File size: {file_size} bytes | Image size: {width}x{height}")

        if file_size < threshold_bytes:
            bad_indices.append(index)
            print("⚠️ Warning: This image may be corrupted (too small).\n")
    except Exception as e:
        print(f"[{index}] ❌ Failed to open image: {e}\n")
        bad_indices.append(index)

    index += 1

print("\n--- Summary ---")
print(f"Total images processed: {index}")
print(f"Images with file size < {threshold_bytes} bytes: {len(bad_indices)}")
print(f"Bad indices: {bad_indices}")

In [None]:
import io
from PIL import Image
import webdataset as wds
from tqdm import tqdm

# 替换为你自己的 tar 地址
input_data_url = "https://facevcstandard.blob.core.windows.net/doch/data/laion2B-en-aesthetic/00000.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D"

# 创建 WebDataset 
dataset = wds.WebDataset("/home/v-haodongli/t2isft/cleaned_dataset.tar")

threshold_bytes = 4000

# 使用 tqdm 显示进度条
for i, sample in enumerate(tqdm(dataset, desc="Processing dataset")):
    # 提取图像字段（jpg 或 png）
    image_data = None
    if "jpg" in sample:
        image_data = sample["jpg"]
    elif "png" in sample:
        image_data = sample["png"]

    # 提取文本字段（通常是 txt）
    text_data = sample.get("txt", None)
    
    # 如果没有图像或文本，跳过
    if image_data is None or text_data is None:
        continue

    file_size = len(image_data)

    # 如果文件太小，尝试打开并显示图像及对应的文本
    if file_size < threshold_bytes:
        print(f"\n[Found small image] Index: {i}, File size: {file_size} bytes")

        # 打印 prompt（假设是 utf-8 编码）
        try:
            prompt = text_data.decode("utf-8")
            print("Prompt (Text):", prompt)
        except Exception as e:
            print("Failed to decode text:", str(e))

        # 尝试显示图像
        try:
            img = Image.open(io.BytesIO(image_data))
            print("Image size:", img.size)
            print("Image format:", img.format)
            print("Showing the image...")
            img.show()  # 调用系统默认图片查看器
        except Exception as e:
            print("Failed to open image:", str(e))

        # 可选：只处理第一个异常图像后退出循环
        # break

In [None]:
import io
from PIL import Image
import webdataset as wds
from tqdm import tqdm

# 输入 tar 文件路径
input_data_url = "/home/v-haodongli/mnt/v-haodongli-container/cot_output_test"

# 创建 WebDataset 
dataset = wds.WebDataset(input_data_url)

threshold_bytes = 4000
bad_keys = []

# 使用 tqdm 显示进度条
for i, sample in enumerate(tqdm(dataset, desc="Detecting corrupted or small images")):
    # 提取图像字段
    image_data = None
    if "jpg" in sample:
        image_data = sample["jpg"]
    elif "png" in sample:
        image_data = sample["png"]

    text_data = sample.get("txt", None)

    if image_data is None or text_data is None:
        bad_keys.append(sample["__key__"])
        continue

    file_size = len(image_data)
    if file_size < threshold_bytes:
        bad_keys.append(sample["__key__"])

print(f"\nFound {len(bad_keys)} samples to remove.")

import io
import webdataset as wds
from tqdm import tqdm

# 下载输入的 tar 文件流并过滤
input_url = input_data_url
output_file = "/home/v-haodongli/mnt/v-haodongli-container/cot_output_test"

bad_keys_set = set(bad_keys)  # 来自上一步的结果

with wds.TarWriter(output_file) as sink:
    with wds.WebDataset(input_url) as dataset:
        for sample in tqdm(dataset, desc="Filtering and saving clean dataset"):
            if sample["__key__"] not in bad_keys_set:
                sink.write(sample)

print(f"Cleaned dataset saved to {output_file}")

In [None]:
def process_sample(sample):
    try:
        metadata = sample["json"]
        return {
            "caption": metadata.get("caption"),
            "cot": metadata.get("cot"),
            "aspect_ratio": metadata.get("aspect_ratio"),
            "img_index": metadata.get("img_index")
        }
    except Exception as e:
        print(f"Error processing sample: {e}")
        return None

In [None]:
from torch.utils.data import Dataset
from datasets import concatenate_datasets
import glob
from datasets import load_dataset
import torch
import json
import os

class LazySupervisedMixDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        processor: AutoProcessor,
    ):
        super().__init__()
        self.data_files = glob.glob(os.path.join(data_path, "*.tar"))  # 保存 tar 文件列表
        train_datasets = []
        self.offsets = [0]  # 记录每个 tar 文件的起始索引

        # 逐个加载并处理 tar 文件
        for data_file in self.data_files:
            raw_dataset = load_dataset("webdataset", data_files=[data_file], split="train", num_proc=128)
            train_dataset = raw_dataset.map(process_sample).filter(lambda x: x is not None)
            train_datasets.append(train_dataset)
            self.offsets.append(self.offsets[-1] + len(train_dataset))  # 累积样本数

        # 合并数据集
        if len(train_datasets) > 1:
            self.list_data_dict = concatenate_datasets(train_datasets)
        else:
            self.list_data_dict = train_datasets[0]

        self.processor = processor

    # def process_sample(self, sample):
    #     # 示例处理逻辑，确保返回有效样本
    #     if 'caption' not in sample or 'cot' not in sample:
    #         return None
    #     return sample

    def get_tar_info(self, index: int):
        """根据全局索引定位到 tar 文件和文件内偏移量"""
        for i in range(len(self.offsets) - 1):
            if self.offsets[i] <= index < self.offsets[i + 1]:
                tar_index = i  # tar 文件索引
                file_index = index - self.offsets[i]  # 文件内偏移量
                return self.data_files[tar_index], file_index
        raise IndexError(f"Index {index} out of range")

    def __len__(self):
        return len(self.list_data_dict)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]

        conversation = [
            {"role": "<|User|>", "content": sources['caption']},
            {"role": "<|Assistant|>", "content": f"{sources['cot']}<begin_of_image><end_of_image>"},
        ]
        system_prompt = "You are an assistant that creates images from descriptions. First, describe the image in detail, then generate it."
        prompt = self.processor.apply_sft_template_for_multi_turn_prompts(
            conversations=conversation,
            sft_format=self.processor.sft_format,
            system_prompt=system_prompt,
        )

        # Tokenize prompt
        text_ids = self.processor.tokenizer.encode(prompt)
        all_ids = text_ids[:-2] + sources['img_index'] + text_ids[-2:]
        all_ids = torch.LongTensor(all_ids)

        # 构建图像 token 的 mask
        all_image_ids_mask = torch.zeros(all_ids.shape, dtype=torch.bool)
        all_image_ids_mask[-len(sources['img_index'])-2:-2] = True

        # 找到 Assistant 回答开始的位置
        try:
            assistant_start_token_id = self.processor.tokenizer.encode("<|Assistant|>")[0]
            assistant_start_index = text_ids.index(assistant_start_token_id)
        except (ValueError, IndexError):
            assistant_start_index = 0

        assistant_ids_mask = torch.zeros(all_ids.shape, dtype=torch.bool)
        assistant_ids_mask[assistant_start_index:] = True

        # 构造输入和标签
        input_ids = all_ids[:-1]
        text_ids_mask = (all_image_ids_mask[:-1] == False)
        image_ids_mask = all_image_ids_mask[:-1]
        label_ids = all_ids[1:]
        label_text_ids_mask = assistant_ids_mask[1:] & (all_image_ids_mask[1:] == False)
        label_image_ids_mask = assistant_ids_mask[1:] & all_image_ids_mask[1:]

        return {
            "input_ids": input_ids,
            "label_ids": label_ids,
            "text_ids_mask": text_ids_mask,
            "image_ids_mask": image_ids_mask,
            "label_text_ids_mask": label_text_ids_mask,
            "label_image_ids_mask": label_image_ids_mask,
        }

In [None]:
from janus.models.processing_vlm import VLChatProcessor
processor: VLChatProcessor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-Pro-7B")
tokenizer = processor.tokenizer
print(tokenizer.model_max_length)
padding_id = tokenizer.pad_token_id
print(f"Padding ID: {padding_id}")


In [None]:
data_files = glob.glob(os.path.join("/home/v-haodongli/mnt/v-haodongli-container/cot_output_test_train", "*.tar"))
# train_dataset = load_dataset("webdataset", data_files=data_files, split="train", streaming=True ,num_proc=8)
train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=8)

In [None]:
from dataclasses import dataclass
from typing import Sequence, Dict, Any
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizer

@dataclass
class DataCollatorForSupervisedDataset:
    """Collate examples for supervised fine-tuning."""
    tokenizer: PreTrainedTokenizer
    processor: Any  # 替换为你的具体 processor 类型（如 VLMProcessor）
    max_length: int = 1024
    IGNORE_INDEX: int = -100

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids_list = []
        labels_list = []
        text_ids_mask_list = []
        image_ids_mask_list = []
        label_text_ids_mask_list = []
        label_image_ids_mask_list = []

        for instance in instances:
            # 提取 caption 和 img_index
            try:
                json_data = instance['json']
                caption = json_data['caption']
                cot = json_data['cot']  # 注意这里新增了 cot 字段
                img_index = json_data['img_index']  # list of int 或者 tensor
            except KeyError as e:
                raise ValueError(f"Missing key in instance: {e}")

            # 构造 conversation
            conversation = [
                {"role": "<|User|>", "content": caption},
                {"role": "<|Assistant|>", "content": f"{cot}<begin_of_image><end_of_image>"},
            ]
            system_prompt = "You are an assistant that creates images from descriptions. First, describe the image in detail, then generate it."

            # 使用 self.processor 来生成 prompt
            prompt = self.processor.apply_sft_template_for_multi_turn_prompts(
                conversations=conversation,
                sft_format=self.processor.sft_format,
                system_prompt=system_prompt,
            )

            # Tokenize prompt
            text_ids = self.tokenizer.encode(prompt)

            # 插入图像 token ID
            all_ids = text_ids[:-2] + img_index + text_ids[-2:]
            all_ids = torch.LongTensor(all_ids)

            # 构建图像 token 的 mask
            all_image_ids_mask = torch.zeros(len(all_ids), dtype=torch.bool)
            all_image_ids_mask[-len(img_index)-2:-2] = True

            # 找到 Assistant 回答开始的位置
            try:
                assistant_start_token_id = self.tokenizer.encode("<|Assistant|>")[0]
                assistant_start_index = (all_ids == assistant_start_token_id).nonzero(as_tuple=True)[0][0].item()
            except Exception:
                assistant_start_index = 0

            # 构造各类 mask
            assistant_mask = torch.zeros(len(all_ids), dtype=torch.bool)
            assistant_mask[assistant_start_index:] = True

            # 构造 input 和 label
            input_ids = all_ids[:-1]
            label_ids = all_ids[1:]

            text_mask = (all_image_ids_mask[:-1] == False)
            image_mask = all_image_ids_mask[:-1]

            label_text_mask = assistant_mask[1:] & (all_image_ids_mask[1:] == False)
            label_image_mask = assistant_mask[1:] & all_image_ids_mask[1:]

            # 只保留 label 中需要的部分，其他设为 IGNORE_INDEX
            label_ids[~label_text_mask] = self.IGNORE_INDEX

            # 添加进列表
            input_ids_list.append(input_ids)
            labels_list.append(label_ids)
            text_ids_mask_list.append(text_mask)
            image_ids_mask_list.append(image_mask)
            label_text_ids_mask_list.append(label_text_mask)
            label_image_ids_mask_list.append(label_image_mask)

        # Padding 处理
        input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        labels = pad_sequence(labels_list, batch_first=True, padding_value=self.IGNORE_INDEX)
        text_ids_mask = pad_sequence(text_ids_mask_list, batch_first=True, padding_value=False)
        image_ids_mask = pad_sequence(image_ids_mask_list, batch_first=True, padding_value=False)
        label_text_ids_mask = pad_sequence(label_text_ids_mask_list, batch_first=True, padding_value=False)
        label_image_ids_mask = pad_sequence(label_image_ids_mask_list, batch_first=True, padding_value=False)

        # 截断处理
        if input_ids.size(1) > self.max_length:
            input_ids = input_ids[:, :self.max_length]
            labels = labels[:, :self.max_length]
            text_ids_mask = text_ids_mask[:, :self.max_length]
            image_ids_mask = image_ids_mask[:, :self.max_length]
            label_text_ids_mask = label_text_ids_mask[:, :self.max_length]
            label_image_ids_mask = label_image_ids_mask[:, :self.max_length]

        return dict(
            input_ids=input_ids,
            label_ids=labels,
            attention_mask=(input_ids != self.tokenizer.pad_token_id),
            text_id_mask=text_ids_mask,
            image_id_mask=image_ids_mask,
            label_text_id_mask=label_text_ids_mask,
            label_image_id_mask=label_image_ids_mask,
        )

In [None]:
# 假设你已经有一个 processor 实例
collator = DataCollatorForSupervisedDataset(
    tokenizer=processor.tokenizer,
    processor=processor,
    max_length=2048
)

# 测试一下
batch = collator([train_dataset[1], train_dataset[2], train_dataset[3]])
batch["input_ids"]
batch["text_id_mask"]
batch["image_id_mask"]
batch["label_ids"] 
batch["label_text_id_mask"]
batch["label_image_id_mask"]

In [None]:
batches_per_epoch = len(train_dataset) // 8
batches_per_epoch

In [None]:
tar_file, file_index = train_dataset.get_tar_info(2132052)
print(f"索引 2132052 的数据位于 tar 文件: {tar_file}，文件内第 {file_index} 个样本")

In [None]:
import webdataset as wds

shard_path = '/mnt/v-haodongli/cot_output_test_train/02166.tar'
target_key = "02166_00028"  # 想看的 key

dataset = wds.WebDataset(shard_path).decode().to_tuple("__key__", "json")

for key, label in dataset:
    if key == target_key:
        print("Key:", key)
        print("Label:", label)
        break  # 找到后退出循环   

In [None]:
processor: VLChatProcessor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-Pro-7B")

In [None]:
@dataclass
class DataCollatorForSupervisedDataset:
    """Collate examples for supervised fine-tuning."""
    tokenizer: PreTrainedTokenizer
    processor: Any  # 替换为你的具体 processor 类型（如 VLMProcessor）
    max_length: int = 1024
    IGNORE_INDEX: int = -100

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids_list = []
        labels_list = []
        text_ids_mask_list = []
        image_ids_mask_list = []
        label_text_ids_mask_list = []
        label_image_ids_mask_list = []

        for instance in instances:
            # 提取 caption 和 img_index
            try:
                json_data = instance['json']
                caption = json_data['caption']
                cot = json_data['cot']  # 注意这里新增了 cot 字段
                img_index = json_data['img_index']  # list of int 或者 tensor
            except KeyError as e:
                raise ValueError(f"Missing key in instance: {e}")

            # 构造 conversation
            conversation = [
                {"role": "<|User|>", "content": caption},
                {"role": "<|Assistant|>", "content": f"{cot}<begin_of_image><end_of_image>"},
            ]
            system_prompt = "You are an assistant that creates images from descriptions. First, describe the image in detail, then generate it."

            # 使用 self.processor 来生成 prompt
            prompt = self.processor.apply_sft_template_for_multi_turn_prompts(
                conversations=conversation,
                sft_format=self.processor.sft_format,
                system_prompt=system_prompt,
            )

            # Tokenize prompt
            text_ids = self.tokenizer.encode(prompt)

            # 插入图像 token ID
            all_ids = text_ids[:-2] + img_index + text_ids[-2:]
            all_ids = torch.LongTensor(all_ids)

            # 构建图像 token 的 mask
            all_image_ids_mask = torch.zeros(len(all_ids), dtype=torch.bool)
            all_image_ids_mask[-len(img_index)-2:-2] = True

            # 找到 Assistant 回答开始的位置
            try:
                assistant_start_token_id = self.tokenizer.encode("<|Assistant|>")[0]
                assistant_start_index = (all_ids == assistant_start_token_id).nonzero(as_tuple=True)[0][0].item()
            except Exception:
                assistant_start_index = 0

            # 构造各类 mask
            assistant_mask = torch.zeros(len(all_ids), dtype=torch.bool)
            assistant_mask[assistant_start_index:] = True

            # 构造 input 和 label
            input_ids = all_ids[:-1]
            label_ids = all_ids[1:]

            text_mask = (all_image_ids_mask[:-1] == False)
            image_mask = all_image_ids_mask[:-1]

            label_text_mask = assistant_mask[1:] & (all_image_ids_mask[1:] == False)
            label_image_mask = assistant_mask[1:] & all_image_ids_mask[1:]

            # 只保留 label 中需要的部分，其他设为 IGNORE_INDEX
            label_ids[~label_text_mask] = self.IGNORE_INDEX

            # 添加进列表
            input_ids_list.append(input_ids)
            labels_list.append(label_ids)
            text_ids_mask_list.append(text_mask)
            image_ids_mask_list.append(image_mask)
            label_text_ids_mask_list.append(label_text_mask)
            label_image_ids_mask_list.append(label_image_mask)

        # Padding 处理
        input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        labels = pad_sequence(labels_list, batch_first=True, padding_value=self.IGNORE_INDEX)
        text_ids_mask = pad_sequence(text_ids_mask_list, batch_first=True, padding_value=False)
        image_ids_mask = pad_sequence(image_ids_mask_list, batch_first=True, padding_value=False)
        label_text_ids_mask = pad_sequence(label_text_ids_mask_list, batch_first=True, padding_value=False)
        label_image_ids_mask = pad_sequence(label_image_ids_mask_list, batch_first=True, padding_value=False)

        # 截断处理
        if input_ids.size(1) > self.max_length:
            input_ids = input_ids[:, :self.max_length]
            labels = labels[:, :self.max_length]
            text_ids_mask = text_ids_mask[:, :self.max_length]
            image_ids_mask = image_ids_mask[:, :self.max_length]
            label_text_ids_mask = label_text_ids_mask[:, :self.max_length]
            label_image_ids_mask = label_image_ids_mask[:, :self.max_length]

        return dict(
            input_ids=input_ids,
            label_ids=labels,
            attention_mask=(input_ids != self.tokenizer.pad_token_id),
            text_id_mask=text_ids_mask,
            image_id_mask=image_ids_mask,
            label_text_id_mask=label_text_ids_mask,
            label_image_id_mask=label_image_ids_mask,
        )

In [None]:
dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    collate_fn=DataCollatorForSupervisedDataset(),
    num_workers=8,
    pin_memory=True
)

In [None]:
for i in range(len(train_dataset)):
    try:
        sample = train_dataset[i]
        if len(sample['input_ids'])>2000:
            print(f"Sample index {i}: input_ids length = {len(sample['input_ids'])}")
    except Exception as e:
        print(f"Error at index {i}: {e}")
        # 打印原始数据源信息
        print("Raw data:", train_dataset.data[i])
        continue

In [None]:
start_index = 2010347
for i in range(start_index, len(train_dataset)):
    try:
        sample = train_dataset[i]
        if len(sample['input_ids']) > 2000:  # 检查 input_ids 长度
            print(f"Sample index {i}: input_ids length = {len(sample['input_ids'])}")
    except Exception as e:
        print(f"Error at index {i}: {e}")  # 打印错误信息
        print("Raw data:", train_dataset.data[i])  # 打印原始数据（需确保 data 属性存在）
        continue  # 跳过当前错误样本，继续循环

In [None]:
start_index = 2132053
for i in range(start_index, len(train_dataset)):
    try:
        sample = train_dataset[i]
        if len(sample['input_ids']) > 2000:  # 检查 input_ids 长度
            print(f"Sample index {i}: input_ids length = {len(sample['input_ids'])}")
    except Exception as e:
        print(f"Error at index {i}: {e}")  # 打印错误信息
        print("Raw data:", train_dataset.data[i])  # 打印原始数据（需确保 data 属性存在）
        continue  # 跳过当前错误样本，继续循环

In [None]:
import webdataset as wds
import os

# 输入路径（同时也是目标输出路径）
input_shard = '/mnt/v-haodongli/cot_output_test_train/02166.tar'

# 临时中间文件路径
temp_shard = input_shard + ".tmp"

# 想要删除的 key
target_key_to_remove = "02166_00027"

# 第一步：读取原始 tar，过滤后写入临时文件
with wds.TarWriter(temp_shard) as sink:
    with wds.WebDataset(input_shard) as dataset:
        for sample in dataset:
            if sample["__key__"] == target_key_to_remove:
                print(f"Skipping key: {target_key_to_remove}")
                continue
            sink.write(sample)

# 第二步：将临时文件替换回原文件名（覆盖原文件）
os.replace(temp_shard, input_shard)

print(f"Done. Removed key '{target_key_to_remove}' and overwritten the original file.")

In [None]:
import torch

# 文件路径
file_path = "/scratch/amlt_code/debug_batch_step_327_rank_3.pt"

# 加载文件
data = torch.load(file_path, map_location='cpu')  # 建议先加载到 CPU 上

# 打印所有 key
print("Keys in the saved file:")
print(data.keys())

# 打印每个 key 对应的数据形状或内容
print("\nData details:")
for key, value in data.items():
    if isinstance(value, torch.Tensor):
        print(f"{key}: {value.shape} | dtype: {value.dtype}")
    else:
        print(f"{key}: {type(value)}")
        if isinstance(value, dict):
            for k, v in value.items():
                if isinstance(v, torch.Tensor):
                    print(f"  {k}: {v.shape} | dtype: {v.dtype}")
                else:
                    print(f"  {k}: {type(v)}")
        else:
            print(f"  Value: {value}")

In [None]:
label_ids = data["label_ids"]
label_text_ids_mask = data["label_text_id_mask"]
input_ids = data["input_ids"]
image_ids_mask = data["image_id_mask"]


In [None]:
text = processor.tokenizer.decode(label_ids[label_text_ids_mask], skip_special_tokens=False)

In [None]:
text

In [None]:
import numpy as np
def decode_to_pil(vq_list, vl_gpt, shape=(1, 8, 24, 24)):
    # 将列表转为张量并移动到GPU
    vq_tensor = torch.tensor(vq_list, dtype=torch.int, device="cuda")
    print(vq_tensor.shape)
    # 解码图像数据（假设vl_gpt已加载）
    with torch.no_grad():
        dec = vl_gpt.gen_vision_model.decode_code(vq_tensor, shape=shape)
    
    # 后处理：张量转图像
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
    dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)
    return Image.fromarray(dec[0])

In [None]:
from janus.models.modeling_vlm import MultiModalityCausalLM
model: MultiModalityCausalLM = MultiModalityCausalLM.from_pretrained(
        "deepseek-ai/Janus-Pro-7B",
        trust_remote_code=True).to("cuda")
image = decode_to_pil(input_ids[image_ids_mask].tolist(), model)

In [None]:
import torch

# 文件路径
file_path = "/scratch/amlt_code/debug_batch_step_2_rank_3.pt"

# 加载文件
data = torch.load(file_path, map_location='cpu')  # 建议先加载到 CPU 上

# 打印所有 key
print("Keys in the saved file:")
print(data.keys())

# 打印每个 key 对应的数据形状或内容
print("\nData details:")
for key, value in data.items():
    if isinstance(value, torch.Tensor):
        print(f"{key}: {value.shape} | dtype: {value.dtype}")
    else:
        print(f"{key}: {type(value)}")
        if isinstance(value, dict):
            for k, v in value.items():
                if isinstance(v, torch.Tensor):
                    print(f"  {k}: {v.shape} | dtype: {v.dtype}")
                else:
                    print(f"  {k}: {type(v)}")
        else:
            print(f"  Value: {value}")

In [None]:
label_ids = data["label_ids"]
label_text_ids_mask = data["label_text_id_mask"]
input_ids = data["input_ids"]
image_ids_mask = data["image_id_mask"]


In [None]:
from janus.models.modeling_vlm import MultiModalityCausalLM
model: MultiModalityCausalLM = MultiModalityCausalLM.from_pretrained(
        "deepseek-ai/Janus-Pro-7B",
        trust_remote_code=True).to("cuda")
image = decode_to_pil(input_ids[image_ids_mask].tolist(), model)

In [None]:
from janus.models.processing_vlm import VLChatProcessor
processor: VLChatProcessor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-Pro-7B")
tokenizer = processor.tokenizer

In [None]:
import webdataset as wds
import json
from PIL import Image
import io

input_data_url = "https://facevcstandard.blob.core.windows.net/doch/data/laion2B-en-aesthetic/{00000..05247}.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D"

bad_indices = []
threshold_bytes = 2000
index = 0

def url_logger(urls):
    for url in urls:
        print(f"Processing tar file: {url}")
        yield url

dataset = wds.WebDataset(input_data_url).map_dict(jpg=lambda x: x if x else None, png=lambda x: x if x else None)

for sample in dataset:
    print(sample.keys())    
    image_data = None
    if "jpg" in sample:
        image_data = sample["jpg"]
    elif "png" in sample:
        image_data = sample["png"]
    else:
        index += 1
        continue

    try:
        # 获取 caption 
        caption = "No caption"
        if "json" in sample:
            meta = json.loads(sample["json"])
            caption = meta.get("caption", "No caption")

        file_size = len(image_data)
        img = Image.open(io.BytesIO(image_data))
        width, height = img.size

        print(f"[{index}] Caption: {caption}")
        print(f"File size: {file_size} bytes | Image size: {width}x{height}")

        if file_size < threshold_bytes:
            bad_indices.append(index)
            print("⚠️ Warning: This image may be corrupted (too small).\n")
    except Exception as e:
        print(f"[{index}] ❌ Failed to open image: {e}\n")
        bad_indices.append(index)

    index += 1

print("Bad indices:", bad_indices)

In [None]:
import io
from PIL import Image
import webdataset as wds
from tqdm import tqdm
import os
from braceexpand import braceexpand

# 原始 URL 模板（含 braceexpand 语法）
input_data_url_template = "https://facevcstandard.blob.core.windows.net/doch/data/laion2B-en-aesthetic/{00000..05247}.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D"

# 使用 braceexpand 展开为完整的 URL 列表 
input_data_urls = list(braceexpand(input_data_url_template))

threshold_bytes = 4000  # 图像大小阈值  

# 用于保存异常数据的索引
bad_samples = []  # 格式: (tar_file_name, local_index_in_tar)

# 遍历所有 tar 文件
for url in input_data_urls:
    print(f"🔍 Processing {url}...")

    try:
        # 创建 WebDataset 并解包
        dataset = wds.WebDataset(url)

        # 使用 enumerate 加上计数器来记录当前是第几个样本
        for local_index, sample in enumerate(tqdm(dataset, desc=f"Scanning {os.path.basename(url)}")):
            # 提取图像字段（jpg 或 png）
            image_data = None
            if "jpg" in sample:
                image_data = sample["jpg"]
            elif "png" in sample:
                image_data = sample["png"]

            # 提取文本字段（通常是 txt）
            text_data = sample.get("txt", None)

            # 如果没有图像或文本，跳过
            if image_data is None or text_data is None:
                continue

            file_size = len(image_data)

            # 判断是否小于阈值
            if file_size < threshold_bytes:
                bad_samples.append((url, local_index))
                
                # 🔥 实时打印异常信息
                print(f"\n⚠️ Found small image:")
                print(f"   Tar File: {os.path.basename(url)}")
                print(f"   Sample Index: {local_index}")
                print(f"   Image Size: {file_size} bytes")

    except Exception as e:
        print(f"❌ Error processing {url}: {e}")
        continue

# 最后输出所有异常样本的位置（可选）
print("\n🔚 Bad samples summary:")
for tar_file, idx in bad_samples:
    print(f"[{os.path.basename(tar_file)}] Sample index: {idx}")

In [7]:
import os
import io
import json
import base64
from PIL import Image
import webdataset as wds
from urllib.parse import urlparse

# 假设你已经有了 bad_samples
bad_samples = [
    (
        "https://facevcstandard.blob.core.windows.net/doch/data/laion2B-en-aesthetic/00001.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D",
        493
    ),
    (
        "https://facevcstandard.blob.core.windows.net/doch/data/laion2B-en-aesthetic/00001.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D",
        612
    )
]

# 本地目录    
local_base_path = "/home/v-haodongli/mnt/v-haodongli-container/cot_output_test"

# 先测试 00001.tar 这个文件
test_tar_name = "00001.tar"

# 过滤出你要测试的样本
test_samples = []
for url, index in bad_samples:
    filename = os.path.basename(urlparse(url).path)
    if filename == test_tar_name:
        test_samples.append(index)

if not test_samples:
    print(f"No bad samples found for {test_tar_name}")
    exit()

local_tar_path = os.path.join(local_base_path, test_tar_name)
if not os.path.exists(local_tar_path):
    print(f"Local file not found: {local_tar_path}")
    exit()

print(f"Processing local file: {local_tar_path}")

# 创建 WebDataset
dataset = wds.WebDataset(local_tar_path)

# 遍历数据
for local_index, sample in enumerate(dataset):
    if local_index in test_samples:
        print(f"\n[Found small image] Index: {local_index}")
        print(f"[DEBUG] Sample keys: {sample.keys()}")

        # 提取 JSON 数据
        json_data = sample.get("json")
        if not json_data:
            print("⚠️ No 'json' field found.")
            continue

        try:
            data = json.loads(json_data)
        except Exception as e:
            print("Failed to parse JSON:", str(e))
            continue

        # 假设图像字段是 base64 编码的字符串，存在 "image" 或 "img" 键中
        image_b64 = data.get("image") or data.get("img")
        if not image_b64:
            print("⚠️ No image data found in JSON.")
            continue

        # 解码 base64 图像
        try:
            image_data = base64.b64decode(image_b64)
            img = Image.open(io.BytesIO(image_data))
            print("Image size:", img.size)
            print("Image format:", img.format)
            print("Showing the image...")
            img.show()
        except Exception as e:
            print("Failed to decode or open image:", str(e))

        # 提取 prompt（假设在 JSON 的 "caption" 或 "text" 字段中）
        prompt = data.get("caption") or data.get("text") or data.get("prompt")
        if prompt:
            print("Prompt (Text):", prompt)
        else:
            print("⚠️ No prompt found in JSON.")

Processing local file: /home/v-haodongli/mnt/v-haodongli-container/cot_output_test/00001.tar

[Found small image] Index: 493
[DEBUG] Sample keys: dict_keys(['__key__', '__url__', 'json', '__local_path__'])
⚠️ No image data found in JSON.

[Found small image] Index: 612
[DEBUG] Sample keys: dict_keys(['__key__', '__url__', 'json', '__local_path__'])
⚠️ No image data found in JSON.


In [8]:
import os
import io
import json
import base64
from PIL import Image
import webdataset as wds
from urllib.parse import urlparse

# 假设你已经有了 bad_samples
bad_samples = [
    (
        "https://facevcstandard.blob.core.windows.net/doch/data/laion2B-en-aesthetic/00001.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D",
        493
    ),
    (
        "https://facevcstandard.blob.core.windows.net/doch/data/laion2B-en-aesthetic/00001.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D",
        612
    )
]

# 本地目录    
local_base_path = "/home/v-haodongli/mnt/v-haodongli-container/cot_output_test"

# 先测试 00001.tar 这个文件
test_tar_name = "00001.tar"

# 过滤出你要测试的样本
test_samples = []
for url, index in bad_samples:
    filename = os.path.basename(urlparse(url).path)
    if filename == test_tar_name:
        test_samples.append(index)

if not test_samples:
    print(f"No bad samples found for {test_tar_name}")
    exit()

local_tar_path = os.path.join(local_base_path, test_tar_name)
if not os.path.exists(local_tar_path):
    print(f"Local file not found: {local_tar_path}")
    exit()

print(f"Processing local file: {local_tar_path}")

# 创建 WebDataset
dataset = wds.WebDataset(local_tar_path)

# 遍历数据
for local_index, sample in enumerate(dataset):
    if local_index in test_samples:
        print(f"\n[Found small image] Index: {local_index}")
        print(f"[DEBUG] Sample keys: {sample.keys()}")

        # 提取 JSON 数据
        json_data = sample.get("json")
        if not json_data:
            print("⚠️ No 'json' field found.")
            continue

        try:
            data = json.loads(json_data)
            print("\n📄 Full JSON content:")
            print(json.dumps(data, indent=2))  # 打印完整 JSON 内容便于分析
        except Exception as e:
            print("Failed to parse JSON:", str(e))
            continue

Processing local file: /home/v-haodongli/mnt/v-haodongli-container/cot_output_test/00001.tar

[Found small image] Index: 493
[DEBUG] Sample keys: dict_keys(['__key__', '__url__', 'json', '__local_path__'])

📄 Full JSON content:
{
  "caption": "NOTE-WORTHY CHRISTMAS - Enesco miniature ornament - beaver playing xylophone",
  "cot": "The image shows a miniature Christmas ornament by Enesco. The ornament features a beaver playing a xylophone. The beaver is brown with a white belly, and the xylophone has multiple keys. The beaver is positioned in the center of the ornament, actively playing the xylophone. The ornament is a small, detailed figurine, likely made of ceramic or a similar material, with a glossy finish. The background is white, emphasizing the ornament. The style is festive and whimsical, typical of holiday decorations.",
  "aspect_ratio": "1:1",
  "img_index": [
    922,
    1653,
    1653,
    307,
    307,
    1085,
    8053,
    1653,
    4514,
    1653,
    1653,
    1653,


In [2]:
import io
from PIL import Image
import webdataset as wds
from tqdm import tqdm
import os
from braceexpand import braceexpand

# 原始 URL 模板（含 braceexpand 语法）
input_data_url_template = "https://facevcstandard.blob.core.windows.net/doch/data/laion2B-en-aesthetic/{00000..05247}.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D"

# 使用 braceexpand 展开为完整的 URL 列表 
input_data_urls = list(braceexpand(input_data_url_template))

threshold_bytes = 4000  # 图像大小阈值  

# 用于保存异常数据的索引
bad_samples = []  # 格式: (tar_file_name, local_index_in_tar)

# 遍历所有 tar 文件
for url in input_data_urls:

    try:
        # 创建 WebDataset 并解包
        dataset = wds.WebDataset(url)

        # 使用 enumerate 加上计数器来记录当前是第几个样本
        for local_index, sample in enumerate(tqdm(dataset, desc=f"Scanning {os.path.basename(url)}")):
            # 提取图像字段（jpg 或 png）
            image_data = None
            if "jpg" in sample:
                image_data = sample["jpg"]
            elif "png" in sample:
                image_data = sample["png"]

            # 提取文本字段（通常是 txt）
            text_data = sample.get("txt", None)

            # 如果没有图像或文本，跳过
            if image_data is None or text_data is None:
                continue

            file_size = len(image_data)

            # 判断是否小于阈值
            if file_size < threshold_bytes:
                bad_samples.append((url, local_index))
                
                # 🔥 实时

    except Exception as e:
        print(f"❌ Error processing {url}: {e}")
        continue

# 最后输出所有异常样本的位置（可选）
for tar_file, idx in bad_samples:
    print(f"[{os.path.basename(tar_file)}] Sample index: {idx}")

Scanning 00000.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D: 0it [00:00, ?it/s]

Scanning 00000.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D: 4688it [00:16, 292.45it/s]
Scanning 00001.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoXHZptZzaCg29hwFFdwrAlA%3D: 4557it [00:12, 370.11it/s]
Scanning 00002.tar?sv=2023-01-03&st=2025-06-05T05%3A41%3A53Z&se=2025-06-12T05%3A41%3A00Z&skoid=1ff6eda0-bcb1-4b77-9ff2-64bae2665820&sktid=72f988bf-86f1-41af-91ab-2d7cd011db47&skt=2025-06-05T05%3A41%3A53Z&ske=2025-06-12T05%3A41%3A00Z&sks=b&skv=2023-01-03&sr=c&sp=racwdxltf&sig=C31qRQqAV3QeH26BHrCzoX

KeyboardInterrupt: 