In [1]:
from data.imagenet_c import get_imagenet_c_loader
from torchvision import transforms
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from torch.utils.data import DataLoader

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

MLLM_transform = build_transform(224)

In [2]:
from data.dataloader import TTALoRADataset
dataset =  TTALoRADataset(dataset_folder=r'D:\code\dataset\mini-imagenet-c', download_dataset=False, target_model='MLLM', transform=MLLM_transform)


从D:\code\dataset\mini-imagenet-c\brightness加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\contrast加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\defocus_blur加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\elastic_transform加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\fog加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\frost加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\gaussian_noise加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\glass_blur加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\impulse_noise加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\jpeg_compression加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\motion_blur加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\pixelate加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\shot_noise加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\snow加载了 50000 张图片
从D:\code\dataset\mini-imagenet-c\zoom_blur加载了 50000 张图片


In [9]:
import torch
import numpy as np
import random

def set_seed(seed=7600):
    random.seed(seed)     
    np.random.seed(seed)     
    torch.manual_seed(seed)   
    torch.cuda.manual_seed(seed)      
    torch.cuda.manual_seed_all(seed)   
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def worker_init_fn(worker_id):
    seed = 7600
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

set_seed()

In [None]:
# model_name  = 'OpenGVLab/InternVL3_5-1B'
model_name  = 'OpenGVLab/InternVL3-1B'

In [10]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, worker_init_fn=worker_init_fn)

In [11]:
# test dataloader
for i, (images, labels) in enumerate(dataloader):
    print(f"Batch {i+1}:")
    print(f"  Images shape: {images.shape}")
    print(f"  Images dtype: {images.dtype}")  
    print(f"  Labels: {labels[:3]}")  # 显示前3个标签 [corruption_type，class_index]
    print(f"  Number of samples in batch: {len(labels)}")
    break


Batch 1:
  Images shape: torch.Size([32, 3, 224, 224])
  Images dtype: torch.float32
  Labels: [tensor([12,  1, 13, 14,  0,  9,  9,  5,  8, 11,  6, 13,  1,  8,  2, 10,  8, 11,
        13,  8, 11, 12, 11,  2, 12,  3,  0,  2,  7, 13,  1, 10]), tensor([ 69, 417, 671, 236, 911, 703, 118, 236, 434, 148,  51, 559, 630, 338,
        949, 790, 267, 793, 120, 886, 192, 100, 460, 657, 339, 218, 802, 863,
        328, 939, 119, 786])]
  Number of samples in batch: 2


In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
# path = "OpenGVLab/InternVL3-1B"
path = model_name
model = AutoModel.from_pretrained(
    path,
    dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True).eval().cuda()

In [13]:
import json
class_index = dataset.class_index
with open(class_index, 'r') as f:
    classes = json.load(f)

# 将类别信息转换为列表格式
class_names = []
for key, value in classes.items():
    class_names.append(value[1])  # 只取类别名称，如 'tench'

print(class_names)

['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'robin', 'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture', 'great grey owl', 'European fire salamander', 'common newt', 'eft', 'spotted salamander', 'axolotl', 'bullfrog', 'tree frog', 'tailed frog', 'loggerhead', 'leatherback turtle', 'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'common iguana', 'American chameleon', 'whiptail', 'agama', 'frilled lizard', 'alligator lizard', 'Gila monster', 'green lizard', 'African chameleon', 'Komodo dragon', 'African crocodile', 'American alligator', 'triceratops', 'thunder snake', 'ringneck snake', 'hognose snake', 'green snake', 'king snake', 'garter snake', 'water snake', 'vine snake', 'night snake', 'boa constrictor', 'rock python', 'Indian cobra', 'green mamba', 'sea snake', 'horned viper', 'diamondback', 

In [14]:
corruption_types = []
corruption_types_index_path = dataset.corruption_index
with open(corruption_types_index_path, 'r') as f:
    corruption_types = json.load(f)
print(corruption_types)

{'brightness': 0, 'contrast': 1, 'defocus_blur': 2, 'elastic_transform': 3, 'fog': 4, 'frost': 5, 'gaussian_noise': 6, 'glass_blur': 7, 'impulse_noise': 8, 'jpeg_compression': 9, 'motion_blur': 10, 'pixelate': 11, 'shot_noise': 12, 'snow': 13, 'zoom_blur': 14}


In [None]:
# 设置参与测试的数据数量
n = 5000

import math
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoTokenizer
import requests
from tqdm import tqdm

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # 计算现有图像的宽高比
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # 找到最接近目标的宽高比
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # 计算目标宽度和高度
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # 调整图像大小
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # 分割图像
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    if isinstance(image_file, str):
        if image_file.startswith('http'):
            image = Image.open(requests.get(image_file, stream=True).raw).convert('RGB')
        else:
            image = Image.open(image_file).convert('RGB')
    else:
        image = image_file.convert('RGB')
    
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

# 加载tokenizer并设置pad_token_id
# tokenizer = AutoTokenizer.from_pretrained("OpenGVLab/InternVL3-1B", trust_remote_code=True, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
# 设置pad_token_id为eos_token_id以避免警告
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# 确保模型的所有参数都在GPU上
model = model.cuda()

# 构建包含所有类别名称的提示词
def create_class_prompt(class_names):
    """创建包含所有类别名称的提示词"""
    # 将类别名称列表转换为字符串格式
    class_list_str = ", ".join(class_names)
    prompt = f"<image>\n请从以下1000个ImageNet类别中选择一个最符合图片内容的词汇：[{class_list_str}]。请只回答一个词汇，确保词汇在提供的词汇中，不要添加任何解释。"
    return prompt

# 批量预测函数 - 一次性处理n张图片
def batch_predict_n(images_batch, class_names, question_template=None):
    """
    对一个batch的图像进行批量预测，从class_names中选择最相关的类别
    Args:
        images_batch: torch.Tensor, shape为[batch_size, 3, 224, 224]的图像张量
        class_names: list, 包含1000个ImageNet类别名称的列表
        question_template: str, 问题模板（可选）
    Returns:
        list: 每张图像的预测结果
    """
    results = []
    # 设置生成配置，包括pad_token_id
    generation_config = dict(
        max_new_tokens=50, 
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id  # 明确设置pad_token_id
    )
    
    # 创建包含所有类别的提示词
    if question_template is None:
        question = create_class_prompt(class_names)
    else:
        class_list_str = ", ".join(class_names)
        question = question_template.format(class_list=class_list_str)
    
    # 准备所有图像的pixel_values
    all_pixel_values = []
    
    print(f"准备处理 {images_batch.shape[0]} 张图像...")
    print(f"将从 {len(class_names)} 个ImageNet类别中进行选择")
    
    # 预处理所有图像
    for i in tqdm(range(images_batch.shape[0]), desc="预处理图像"):
        # 获取单张图像
        single_image = images_batch[i]  # shape: [3, 224, 224]
        
        # 反归一化图像以转换回PIL格式
        # 使用ImageNet的均值和标准差进行反归一化
        mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
        std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
        
        # 反归一化
        denorm_image = single_image * std + mean
        denorm_image = torch.clamp(denorm_image, 0, 1)
        
        # 转换为PIL图像
        pil_image = T.ToPILImage()(denorm_image)
        
        # 使用InternVL的预处理方法
        pixel_values = load_image(pil_image, max_num=12).to(torch.bfloat16).cuda()
        all_pixel_values.append(pixel_values)
    
    print("开始批量推理...")
    
    # 批量推理 - 一次性处理所有图像
    with torch.no_grad():
        try:
            # 由于InternVL模型的限制，我们需要分批处理，但可以并行处理多个
            batch_size = min(n, len(all_pixel_values))  # 最多n张图片一批
            
            for batch_start in range(0, len(all_pixel_values), batch_size):
                batch_end = min(batch_start + batch_size, len(all_pixel_values))
                current_batch_size = batch_end - batch_start
                
                # print(f"处理第 {batch_start//batch_size + 1} 批，包含 {current_batch_size} 张图像")
                
                # 并行处理当前批次的所有图像
                batch_results = []
                for i in tqdm(range(batch_start, batch_end), desc=f"推理第 {batch_start//batch_size + 1} 批"):
                    try:
                        response = model.chat(tokenizer, all_pixel_values[i], question, generation_config)
                        # 清理响应，只保留预测的类别名称
                        cleaned_response = response.strip()
                        batch_results.append(cleaned_response)
                        # print(f"图像 {i+1}: {cleaned_response}")
                    except Exception as e:
                        print(f"预测图像 {i+1} 时出错: {e}")
                        batch_results.append(f"预测失败: {str(e)}")
                
                results.extend(batch_results)
                
        except Exception as e:
            print(f"批量预测时出错: {e}")
            # 如果批量预测失败，回退到逐个预测
            print("回退到逐个预测模式...")
            results = []
            for i, pixel_values in enumerate(tqdm(all_pixel_values, desc="逐个预测")):
                try:
                    response = model.chat(tokenizer, pixel_values, question, generation_config)
                    cleaned_response = response.strip()
                    results.append(cleaned_response)
                    # print(f"图像 {i+1}: {cleaned_response}")
                except Exception as e:
                    print(f"预测图像 {i+1} 时出错: {e}")
                    results.append(f"预测失败: {str(e)}")
    
    return results

# 从dataloader获取多个batch并进行预测
print(f"开始批量预测{n}张图片...")
all_images = []
all_labels = []

# 收集n张图片
for i, (images, labels) in enumerate(tqdm(dataloader, desc="收集图像数据")):
    all_images.append(images)
    all_labels.append(labels)
    
    # 计算已收集的图片总数
    total_images = sum(img.shape[0] for img in all_images)
    # print(f"已收集 {total_images} 张图片")
    
    if total_images >= n:
        break

# 合并所有图片
if all_images:
    combined_images = torch.cat(all_images, dim=0)[:n]  # 取前n张
    combined_labels = [torch.cat([label[j] for label in all_labels], dim=0)[:n] for j in range(len(all_labels[0]))]
    
    print(f"最终处理 {combined_images.shape[0]} 张图像")
    print(f"图像形状: {combined_images.shape}")
    print(f"标签信息: corruption_types={combined_labels[0][:5]}, class_indices={combined_labels[1][:5]}")
    
    # 进行批量预测，使用class_names列表
    predictions = batch_predict_n(combined_images, class_names)
    
    print(f"\n预测完成，共处理了 {len(predictions)} 张图像")
    
    # 显示预测结果和真实标签的对比
    print("\n预测结果对比:")
    correct_count = 0
    total_count = 0
    
    for i in range(min(10, len(predictions))):  # 显示前10个结果
        true_class_idx = combined_labels[1][i].item()
        true_class_name = class_names[true_class_idx] if true_class_idx < len(class_names) else "未知类别"
        predicted_class = predictions[i]
        corruption_type = combined_labels[0][i].item()
        
        # 去除空格后比较字符串
        true_class_clean = true_class_name.replace(" ", "").lower()
        predicted_class_clean = predicted_class.replace(" ", "").lower()
        is_correct = true_class_clean == predicted_class_clean
        
        if is_correct:
            correct_count += 1
        total_count += 1
        
        print(f"图像 {i+1}:")
        print(f"  真实类别: {true_class_name} (索引: {true_class_idx})")
        print(f"  预测类别: {predicted_class}")
        print(f"  损坏类型: {corruption_type}")
        print(f"  预测正确: {'是' if is_correct else '否'}")
        print()
    
    # 计算所有n张图片的准确率
    total_correct = 0
    for i in range(len(predictions)):
        true_class_idx = combined_labels[1][i].item()
        true_class_name = class_names[true_class_idx] if true_class_idx < len(class_names) else "未知类别"
        predicted_class = predictions[i]
        
        # 去除空格后比较字符串
        true_class_clean = true_class_name.replace(" ", "").lower()
        predicted_class_clean = predicted_class.replace(" ", "").lower()
        
        if true_class_clean == predicted_class_clean:
            total_correct += 1
    
    accuracy = total_correct / len(predictions) * 100
    print(f"总体准确率: {total_correct}/{len(predictions)} = {accuracy:.2f}%")
        
else:
    print("没有找到足够的图像数据")



开始批量预测5000张图片...


收集图像数据:   1%|          | 156/23438 [25:37<63:43:47,  9.85s/it]


最终处理 5000 张图像
图像形状: torch.Size([5000, 3, 224, 224])
标签信息: corruption_types=tensor([6, 1, 8, 4, 8]), class_indices=tensor([620, 641, 306,  50, 430])
准备处理 5000 张图像...
将从 1000 个ImageNet类别中进行选择


预处理图像: 100%|██████████| 5000/5000 [00:12<00:00, 398.20it/s]


开始批量推理...


推理第 1 批: 100%|██████████| 5000/5000 [32:49<00:00,  2.54it/s]


预测完成，共处理了 5000 张图像

预测结果对比:
图像 1:
  真实类别: laptop (索引: 620)
  预测类别: box turtle
  损坏类型: 6
  预测正确: 否

图像 2:
  真实类别: maraca (索引: 641)
  预测类别: axolotl
  损坏类型: 1
  预测正确: 否

图像 3:
  真实类别: rhinoceros beetle (索引: 306)
  预测类别: hammerhead
  损坏类型: 8
  预测正确: 否

图像 4:
  真实类别: American alligator (索引: 50)
  预测类别: hammerhead
  损坏类型: 4
  预测正确: 否

图像 5:
  真实类别: basketball (索引: 430)
  预测类别: box turtle
  损坏类型: 8
  预测正确: 否

图像 6:
  真实类别: ptarmigan (索引: 81)
  预测类别: sparrow
  损坏类型: 4
  预测正确: 否

图像 7:
  真实类别: suit (索引: 834)
  预测类别: house finch
  损坏类型: 8
  预测正确: 否

图像 8:
  真实类别: langur (索引: 374)
  预测类别: axolotl
  损坏类型: 7
  预测正确: 否

图像 9:
  真实类别: cellular telephone (索引: 487)
  预测类别: box turtle
  损坏类型: 9
  预测正确: 否

图像 10:
  真实类别: cowboy hat (索引: 515)
  预测类别: house finch
  损坏类型: 6
  预测正确: 否

总体准确率: 59/5000 = 1.18%





In [None]:
# 设置参与测试的数据数量
n = 5000

import math
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoTokenizer
import requests
import json
from tqdm import tqdm

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # 计算现有图像的宽高比
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # 找到最接近目标的宽高比
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # 计算目标宽度和高度
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # 调整图像大小
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # 分割图像
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    if isinstance(image_file, str):
        if image_file.startswith('http'):
            image = Image.open(requests.get(image_file, stream=True).raw).convert('RGB')
        else:
            image = Image.open(image_file).convert('RGB')
    else:
        image = image_file.convert('RGB')
    
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

# 加载tokenizer并设置pad_token_id
# tokenizer = AutoTokenizer.from_pretrained("OpenGVLab/InternVL3-1B", trust_remote_code=True, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
# 设置pad_token_id为eos_token_id以避免警告
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# 确保模型的所有参数都在GPU上
model = model.cuda()

# 加载corruption类型映射
with open('data/corruption_index.json', 'r') as f:
    corruption_index_data = json.load(f)

# 获取所有corruption类型名称
corruption_names = list(corruption_index_data.keys())
print(f"可用的corruption类型: {corruption_names}")

# 构建包含所有corruption类型的提示词
def create_corruption_prompt(corruption_names):
    """创建包含所有corruption类型的提示词"""
    # 将corruption类型列表转换为字符串格式
    corruption_list_str = ", ".join(corruption_names)
    prompt = f"<image>\n请从以下corruption类型中选择一个最符合图片污染的类型：[{corruption_list_str}]。请只回答一个corruption类型名称，确保名称在提供的列表中，不要添加任何解释。"
    return prompt

# 批量预测函数 - 一次性处理n张图片
def batch_predict_corruption_n(images_batch, corruption_names, question_template=None):
    """
    对一个batch的图像进行批量预测，从corruption_names中选择最相关的corruption类型
    Args:
        images_batch: torch.Tensor, shape为[batch_size, 3, 224, 224]的图像张量
        corruption_names: list, 包含所有corruption类型名称的列表
        question_template: str, 问题模板（可选）
    Returns:
        list: 每张图像的预测结果
    """
    results = []
    # 设置生成配置，包括pad_token_id
    generation_config = dict(
        max_new_tokens=50, 
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id  # 明确设置pad_token_id
    )
    
    # 创建包含所有corruption类型的提示词
    if question_template is None:
        question = create_corruption_prompt(corruption_names)
    else:
        corruption_list_str = ", ".join(corruption_names)
        question = question_template.format(corruption_list=corruption_list_str)
    
    # 准备所有图像的pixel_values
    all_pixel_values = []
    
    print(f"准备处理 {images_batch.shape[0]} 张图像...")
    print(f"将从 {len(corruption_names)} 个corruption类型中进行选择")
    
    # 预处理所有图像
    for i in tqdm(range(images_batch.shape[0]), desc="预处理图像"):
        # 获取单张图像
        single_image = images_batch[i]  # shape: [3, 224, 224]
        
        # 反归一化图像以转换回PIL格式
        # 使用ImageNet的均值和标准差进行反归一化
        mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
        std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
        
        # 反归一化
        denorm_image = single_image * std + mean
        denorm_image = torch.clamp(denorm_image, 0, 1)
        
        # 转换为PIL图像
        pil_image = T.ToPILImage()(denorm_image)
        
        # 使用InternVL的预处理方法
        pixel_values = load_image(pil_image, max_num=12).to(torch.bfloat16).cuda()
        all_pixel_values.append(pixel_values)
    
    print("开始批量推理...")
    
    # 批量推理 - 一次性处理所有图像
    with torch.no_grad():
        try:
            # 由于InternVL模型的限制，我们需要分批处理，但可以并行处理多个
            batch_size = min(n, len(all_pixel_values))  # 最多n张图片一批
            
            for batch_start in range(0, len(all_pixel_values), batch_size):
                batch_end = min(batch_start + batch_size, len(all_pixel_values))
                current_batch_size = batch_end - batch_start
                
                print(f"处理第 {batch_start//batch_size + 1} 批，包含 {current_batch_size} 张图像")
                
                # 并行处理当前批次的所有图像
                batch_results = []
                for i in tqdm(range(batch_start, batch_end), desc=f"推理第 {batch_start//batch_size + 1} 批"):
                    try:
                        response = model.chat(tokenizer, all_pixel_values[i], question, generation_config)
                        # 清理响应，只保留预测的corruption类型名称
                        cleaned_response = response.strip()
                        batch_results.append(cleaned_response)
                    except Exception as e:
                        batch_results.append(f"预测失败: {str(e)}")
                
                results.extend(batch_results)
                
        except Exception as e:
            print(f"批量预测时出错: {e}")
            # 如果批量预测失败，回退到逐个预测
            print("回退到逐个预测模式...")
            results = []
            for i, pixel_values in enumerate(tqdm(all_pixel_values, desc="逐个预测")):
                try:
                    response = model.chat(tokenizer, pixel_values, question, generation_config)
                    cleaned_response = response.strip()
                    results.append(cleaned_response)
                except Exception as e:
                    results.append(f"预测失败: {str(e)}")
    
    return results

# 从dataloader获取多个batch并进行预测
print(f"开始批量预测{n}张图片的corruption类型...")
all_images = []
all_labels = []

# 收集n张图片
total_images = 0
for i, (images, labels) in enumerate(tqdm(dataloader, desc="收集图像数据")):
    all_images.append(images)
    all_labels.append(labels)
    
    # 计算已收集的图片总数
    total_images += images.shape[0]
    
    if total_images >= n:
        break

# 合并所有图片
if all_images:
    combined_images = torch.cat(all_images, dim=0)[:n]  # 取前n张
    combined_labels = [torch.cat([label[j] for label in all_labels], dim=0)[:n] for j in range(len(all_labels[0]))]
    
    print(f"最终处理 {combined_images.shape[0]} 张图像")
    print(f"图像形状: {combined_images.shape}")
    print(f"标签信息: corruption_types={combined_labels[0][:5]}, class_indices={combined_labels[1][:5]}")
    
    # 进行批量预测，使用corruption_names列表
    predictions = batch_predict_corruption_n(combined_images, corruption_names)
    
    print(f"\n预测完成，共处理了 {len(predictions)} 张图像")
    
    # 显示预测结果和真实标签的对比
    print("\nCorruption类型预测结果对比:")
    correct_count = 0
    total_count = 0
    
    # 创建corruption索引到名称的映射
    corruption_idx_to_name = {v: k for k, v in corruption_index_data.items()}
    
    for i in range(min(10, len(predictions))):  # 显示前10个结果
        true_corruption_idx = combined_labels[0][i].item()
        true_corruption_name = corruption_idx_to_name.get(true_corruption_idx, "未知corruption类型")
        predicted_corruption = predictions[i]
        class_idx = combined_labels[1][i].item()
        
        # 去除空格后比较字符串
        true_corruption_clean = true_corruption_name.replace(" ", "").lower()
        predicted_corruption_clean = predicted_corruption.replace(" ", "").lower()
        is_correct = true_corruption_clean == predicted_corruption_clean
        
        if is_correct:
            correct_count += 1
        total_count += 1
        
        print(f"图像 {i+1}:")
        print(f"  真实corruption类型: {true_corruption_name} (索引: {true_corruption_idx})")
        print(f"  预测corruption类型: {predicted_corruption}")
        print(f"  图像类别索引: {class_idx}")
        print(f"  预测正确: {'是' if is_correct else '否'}")
        print()
    
    # 计算所有n张图片的准确率
    total_correct = 0
    for i in range(len(predictions)):
        true_corruption_idx = combined_labels[0][i].item()
        true_corruption_name = corruption_idx_to_name.get(true_corruption_idx, "未知corruption类型")
        predicted_corruption = predictions[i]
        
        # 去除空格后比较字符串
        true_corruption_clean = true_corruption_name.replace(" ", "").lower()
        predicted_corruption_clean = predicted_corruption.replace(" ", "").lower()
        
        if true_corruption_clean == predicted_corruption_clean:
            total_correct += 1
    
    accuracy = total_correct / len(predictions) * 100
    print(f"Corruption类型预测总体准确率: {total_correct}/{len(predictions)} = {accuracy:.2f}%")
        
else:
    print("没有找到足够的图像数据")


可用的corruption类型: ['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate', 'shot_noise', 'snow', 'zoom_blur']
开始批量预测5000张图片的corruption类型...


收集图像数据:   1%|          | 156/23438 [24:13<60:16:20,  9.32s/it]


最终处理 5000 张图像
图像形状: torch.Size([5000, 3, 224, 224])
标签信息: corruption_types=tensor([14,  6,  9,  1,  2]), class_indices=tensor([988, 755, 515, 806,  77])
准备处理 5000 张图像...
将从 15 个corruption类型中进行选择


预处理图像: 100%|██████████| 5000/5000 [00:12<00:00, 392.34it/s]


开始批量推理...
处理第 1 批，包含 5000 张图像


推理第 1 批: 100%|██████████| 5000/5000 [06:14<00:00, 13.36it/s]


预测完成，共处理了 5000 张图像

Corruption类型预测结果对比:
图像 1:
  真实corruption类型: zoom_blur (索引: 14)
  预测corruption类型: defocus_blur
  图像类别索引: 988
  预测正确: 否

图像 2:
  真实corruption类型: gaussian_noise (索引: 6)
  预测corruption类型: fog
  图像类别索引: 755
  预测正确: 否

图像 3:
  真实corruption类型: jpeg_compression (索引: 9)
  预测corruption类型: fog
  图像类别索引: 515
  预测正确: 否

图像 4:
  真实corruption类型: contrast (索引: 1)
  预测corruption类型: frost
  图像类别索引: 806
  预测正确: 否

图像 5:
  真实corruption类型: defocus_blur (索引: 2)
  预测corruption类型: frost
  图像类别索引: 77
  预测正确: 否

图像 6:
  真实corruption类型: zoom_blur (索引: 14)
  预测corruption类型: motion_blur
  图像类别索引: 234
  预测正确: 否

图像 7:
  真实corruption类型: shot_noise (索引: 12)
  预测corruption类型: fog
  图像类别索引: 809
  预测正确: 否

图像 8:
  真实corruption类型: shot_noise (索引: 12)
  预测corruption类型: fog
  图像类别索引: 901
  预测正确: 否

图像 9:
  真实corruption类型: contrast (索引: 1)
  预测corruption类型: fog
  图像类别索引: 31
  预测正确: 否

图像 10:
  真实corruption类型: impulse_noise (索引: 8)
  预测corruption类型: fog
  图像类别索引: 546
  预测正确: 否

Corruption类型预测总体准确率: 1072/


