In [1]:
# 第一格：环境设置和依赖安装
import os
import sys
import json
import random
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# 设置随机种子
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 检查GPU
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA是否可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU设备: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

set_seed(42)

PyTorch版本: 2.5.1+cu121
CUDA是否可用: True
GPU设备: NVIDIA GeForce RTX 3050 Laptop GPU


In [2]:
# 第二格：数据预处理类
from transformers import BertTokenizer, BertModel
import torchvision.transforms as T

class TextPreprocessor:
    def __init__(self, bert_path='./bert_model/', max_length=128):
        self.tokenizer = BertTokenizer.from_pretrained(bert_path)
        self.max_length = max_length

    def preprocess(self, text_path):
        # 直接从文件路径读取，使用 GBK 编码
        try:
            with open(text_path, 'r', encoding='gbk', errors='ignore') as f:
                text = f.read().strip()
        except:
            text = ""
            print(f"无法读取文本文件: {text_path}")

        # 文本清洗
        text = str(text).lower().strip()
        import re
        text = re.sub(r'[^\w\s#@]', '', text)
        text = re.sub(r'\s+', ' ', text)

        # BERT tokenize
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

class ImagePreprocessor:
    def __init__(self, img_size=224):
        self.transform = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
        ])

        # 数据增强（训练时使用）
        self.augmentation = T.Compose([
            T.Resize((256, 256)),
            T.RandomCrop((img_size, img_size)),
            T.RandomHorizontalFlip(p=0.5),
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
        ])

    def preprocess(self, image_path, augment=False):
        image = Image.open(image_path).convert('RGB')
        if augment:
            return self.augmentation(image)
        else:
            return self.transform(image)

In [3]:
# 第三格：数据集类
class MultiModalDataset(Dataset):
    def __init__(self, data_dir, label_file, text_processor, image_processor,
                 split='train', augment=False):
        self.data_dir = data_dir
        self.text_processor = text_processor
        self.image_processor = image_processor
        self.augment = augment and split == 'train'

        # 读取标签文件
        self.data = []
        label_mapping = {'positive': 0, 'neutral': 1, 'negative': 2}

        with open(label_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()[1:]  # 跳过标题行
            for line in lines:
                parts = line.strip().split(',')
                if len(parts) == 2:
                    guid, label = parts
                    if label != 'null':
                        text_path = os.path.join(data_dir, f'{guid}.txt')
                        image_path = os.path.join(data_dir, f'{guid}.jpg')

                        if os.path.exists(text_path) and os.path.exists(image_path):
                            self.data.append({
                                'guid': guid,
                                'label': label_mapping[label],
                                'text_path': text_path,
                                'image_path': image_path
                            })
                        else:
                            print(f"文件不存在: {text_path} 或 {image_path}")

        print(f"Loaded {len(self.data)} samples for {split} split")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        text_data = self.text_processor.preprocess(item['text_path'])


        # 处理图像
        try:
            image = self.image_processor.preprocess(item['image_path'], self.augment)
        except:
            # 如果图片不存在，使用黑色图像
            image = torch.zeros(3, 224, 224)

        return {
            'input_ids': text_data['input_ids'],
            'attention_mask': text_data['attention_mask'],
            'image': image,
            'label': torch.tensor(item['label'], dtype=torch.long),
            'guid': item['guid']
        }

In [4]:
# 第四格：多模态融合模型
class MultiModalModel(nn.Module):
    def __init__(self, bert_path='./bert_model/', num_classes=3, fusion_method='late'):
        super(MultiModalModel, self).__init__()
        self.fusion_method = fusion_method

        # 文本特征提取器 (BERT)
        self.text_encoder = BertModel.from_pretrained(bert_path)
        text_hidden_size = self.text_encoder.config.hidden_size

        # 图像特征提取器 (ResNet-50)
        self.image_encoder = models.resnet50(pretrained=True)
        self.image_encoder.fc = nn.Identity()  # 移除最后的分类层
        image_hidden_size = 2048  # ResNet-50的输出维度

        # 文本分类头
        self.text_classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(text_hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

        # 图像分类头
        self.image_classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(image_hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

        # 多模态融合分类头
        if fusion_method == 'early':
            # 早期融合：特征拼接
            fusion_dim = text_hidden_size + image_hidden_size
            self.fusion_classifier = nn.Sequential(
                nn.Dropout(0.3),
                nn.Linear(fusion_dim, 512),
                nn.ReLU(),
                nn.Linear(512, num_classes)
            )
        elif fusion_method == 'late':
            # 晚期融合：决策级融合（默认）
            self.fusion_classifier = nn.Linear(num_classes * 2, num_classes)
        elif fusion_method == 'middle':
            # 中期融合：注意力机制
            self.cross_attention = nn.MultiheadAttention(
                embed_dim=512, num_heads=8, batch_first=True
            )
            self.fusion_classifier = nn.Sequential(
                nn.Dropout(0.3),
                nn.Linear(1024, 512),
                nn.ReLU(),
                nn.Linear(512, num_classes)
            )

        # 投影层（用于统一特征维度）
        self.text_projection = nn.Linear(text_hidden_size, 512)
        self.image_projection = nn.Linear(image_hidden_size, 512)

    def forward(self, input_ids, attention_mask, image):
        # 提取文本特征
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.pooler_output  # [CLS] token的输出

        # 提取图像特征
        image_features = self.image_encoder(image)

        # 单模态预测
        text_logits = self.text_classifier(text_features)
        image_logits = self.image_classifier(image_features)

        # 多模态融合
        if self.fusion_method == 'early':
            # 特征拼接
            fused_features = torch.cat([text_features, image_features], dim=1)
            fused_logits = self.fusion_classifier(fused_features)

        elif self.fusion_method == 'late':
            # 决策级融合（加权平均）
            fused_logits = self.fusion_classifier(torch.cat([text_logits, image_logits], dim=1))

        elif self.fusion_method == 'middle':
            # 中期融合：注意力机制
            text_proj = self.text_projection(text_features).unsqueeze(1)
            image_proj = self.image_projection(image_features).unsqueeze(1)

            # 交叉注意力
            attended_features, _ = self.cross_attention(
                text_proj, image_proj, image_proj
            )
            fused_features = torch.cat([
                text_proj.squeeze(1),
                attended_features.squeeze(1)
            ], dim=1)
            fused_logits = self.fusion_classifier(fused_features)

        return fused_logits, text_logits, image_logits

In [5]:
# 第五格：训练函数
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc='Training')
    for batch in progress_bar:
        # 移动到设备
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)

        # 前向传播
        optimizer.zero_grad()
        fused_logits, text_logits, image_logits = model(input_ids, attention_mask, images)

        # 计算损失（多任务学习：主损失 + 辅助损失）
        loss_fused = criterion(fused_logits, labels)
        loss_text = criterion(text_logits, labels)
        loss_image = criterion(image_logits, labels)
        loss = loss_fused + 0.3 * loss_text + 0.3 * loss_image

        # 反向传播
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # 统计
        total_loss += loss.item()
        _, predicted = torch.max(fused_logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # 更新进度条
        progress_bar.set_postfix({
            'loss': loss.item(),
            'acc': 100 * correct / total
        })

    return total_loss / len(dataloader), 100 * correct / total

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validation'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            fused_logits, _, _ = model(input_ids, attention_mask, images)
            loss = criterion(fused_logits, labels)

            total_loss += loss.item()
            _, predicted = torch.max(fused_logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    from sklearn.metrics import classification_report, confusion_matrix
    print("\n分类报告:")
    print(classification_report(all_labels, all_preds,
                               target_names=['positive', 'neutral', 'negative']))

    return total_loss / len(dataloader), 100 * correct / total, all_preds, all_labels

In [6]:
# 第六格：主训练流程
def main_training():
    # 初始化处理器
    text_processor = TextPreprocessor(bert_path='./bert_model/', max_length=128)
    image_processor = ImagePreprocessor(img_size=224)

    # 加载数据

    # 创建完整数据集
    full_dataset = MultiModalDataset(
        data_dir='./data/data',
        label_file='./data/train.txt',
        text_processor=text_processor,
        image_processor=image_processor,
        split='train'
    )

    # 划分训练集和验证集 (80%训练, 20%验证)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size]
    )

    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset, batch_size=32, shuffle=True, num_workers=4
    )
    val_loader = DataLoader(
        val_dataset, batch_size=32, shuffle=False, num_workers=4
    )

    # 初始化模型
    model = MultiModalModel(
        bert_path='./bert_model/',
        num_classes=3,
        fusion_method='late'  # 可以尝试 'early' 或 'middle'
    ).to(device)

    # 损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=2e-5,
        weight_decay=0.01
    )

    # 学习率调度器
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, verbose=True
    )

    # 训练循环
    best_acc = 0
    patience = 5
    patience_counter = 0

    for epoch in range(20):
        print(f"\nEpoch {epoch+1}/20")
        print("-" * 50)

        # 训练
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )

        # 验证
        val_loss, val_acc, val_preds, val_labels = validate(
            model, val_loader, criterion, device
        )

        print(f"\n训练集 - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"验证集 - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")

        # 学习率调整
        scheduler.step(val_acc)

        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
            }, 'best_model.pth')
            print(f"保存新的最佳模型，验证准确率: {val_acc:.2f}%")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("早停触发")
                break

    return model

In [7]:
# 第七格：测试集预测
class TestDataset(Dataset):
    def __init__(self, test_file, data_dir, text_processor, image_processor):
        self.data_dir = data_dir
        self.text_processor = text_processor
        self.image_processor = image_processor

        # 读取测试文件
        self.guids = []
        with open(test_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()[1:]  # 跳过标题行
            for line in lines:
                guid = line.strip().split(',')[0]
                self.guids.append(guid)

    def __len__(self):
        return len(self.guids)
    def __getitem__(self, idx):
        guid = self.guids[idx]

        # 文本路径和图像路径
        text_path = os.path.join(self.data_dir, f'{guid}.txt')
        image_path = os.path.join(self.data_dir, f'{guid}.jpg')

        # 读取文本
        try:
            with open(text_path, 'r', encoding='utf-8') as f:
                text = f.read().strip()
        except:
            text = ""

        # 处理文本
        text_data = self.text_processor.preprocess(text)

        # 处理图像
        try:
            image = self.image_processor.preprocess(image_path, augment=False)
        except:
            image = torch.zeros(3, 224, 224)

        return {
            'input_ids': text_data['input_ids'],
            'attention_mask': text_data['attention_mask'],
            'image': image,
            'guid': guid
        }

def predict_test_set(model, test_file, output_file):
    model.eval()

    # 初始化处理器
    text_processor = TextPreprocessor(bert_path='./bert_model/', max_length=128)
    image_processor = ImagePreprocessor(img_size=224)

    # 创建测试数据集
    test_dataset = TestDataset(
        test_file=test_file,
        data_dir='./data',
        text_processor=text_processor,
        image_processor=image_processor
    )

    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

    # 预测
    predictions = {}
    label_mapping = {0: 'positive', 1: 'neutral', 2: 'negative'}

    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Predicting'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            guids = batch['guid']

            fused_logits, _, _ = model(input_ids, attention_mask, images)
            _, preds = torch.max(fused_logits, 1)

            for guid, pred in zip(guids, preds.cpu().numpy()):
                predictions[guid] = label_mapping[pred]

    # 写入结果文件
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write("guid,tag\n")
        with open(test_file, 'r', encoding='utf-8') as test_f:
            lines = test_f.readlines()[1:]
            for line in lines:
                guid = line.strip().split(',')[0]
                tag = predictions.get(guid, 'neutral')  # 默认neutral
                f.write(f"{guid},{tag}\n")

    print(f"预测结果已保存到: {output_file}")
    return predictions

In [8]:
# 第八格：消融实验
def ablation_study(model, dataloader, device):
    """对比不同模态的效果"""
    model.eval()

    # 只使用文本
    text_correct = 0
    # 只使用图像
    image_correct = 0
    # 多模态融合
    fusion_correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Ablation Study'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            fused_logits, text_logits, image_logits = model(
                input_ids, attention_mask, images
            )

            # 文本模态预测
            _, text_preds = torch.max(text_logits, 1)
            text_correct += (text_preds == labels).sum().item()

            # 图像模态预测
            _, image_preds = torch.max(image_logits, 1)
            image_correct += (image_preds == labels).sum().item()

            # 多模态融合预测
            _, fusion_preds = torch.max(fused_logits, 1)
            fusion_correct += (fusion_preds == labels).sum().item()

            total += labels.size(0)

    text_acc = 100 * text_correct / total
    image_acc = 100 * image_correct / total
    fusion_acc = 100 * fusion_correct / total

    print("\n消融实验结果:")
    print(f"仅文本模态准确率: {text_acc:.2f}%")
    print(f"仅图像模态准确率: {image_acc:.2f}%")
    print(f"多模态融合准确率: {fusion_acc:.2f}%")
    print(f"多模态相对提升: {fusion_acc - max(text_acc, image_acc):.2f}%")

    return {
        'text_only': text_acc,
        'image_only': image_acc,
        'multimodal': fusion_acc
    }

In [10]:
# 第九格：完整可运行的多模态情感分类实验
if __name__ == "__main__":
    import os
    import sys
    import time
    from tqdm import tqdm

    print("=" * 60)
    print("多模态情感分类实验 - 完整可运行版本")
    print("=" * 60)

    # ========== 1. 环境设置 ==========
    print("\n1. 环境设置...")

    # 禁用tokenizer的多线程警告
    os.environ['TOKENIZERS_PARALLELISM'] = 'false'

    # 设置随机种子
    def set_seed(seed=42):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    set_seed(42)

    # ========== 2. 修复BERT模型加载 ==========
    print("\n2. 初始化BERT模型...")

    class FixedTextPreprocessor:
        """修复的文本预处理器"""
        def __init__(self, max_length=128):
            self.max_length = max_length
            self.tokenizer = None

            # 尝试加载BERT tokenizer
            try:
                # 检查本地是否有BERT模型
                if os.path.exists('./bert_model/'):
                    print("  尝试从本地加载BERT tokenizer...")
                    self.tokenizer = BertTokenizer.from_pretrained('./bert_model/')
                    print("  ✅ 本地BERT tokenizer加载成功")
                else:
                    print("  本地BERT不存在，尝试在线加载...")
                    # 跳过SSL验证
                    import ssl
                    ssl._create_default_https_context = ssl._create_unverified_context

                    # 使用小模型避免下载失败
                    self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
                    print("  ✅ 在线BERT tokenizer加载成功")
            except Exception as e:
                print(f"  ⚠️ BERT tokenizer加载失败: {e}")
                print("  使用简单的字符级tokenizer作为备用")
                self.tokenizer = None

        def preprocess(self, text_path):
            """处理文本文件"""
            # 读取文本
            text = ""
            encodings = ['utf-8', 'gbk', 'latin-1', 'cp1252']

            for encoding in encodings:
                try:
                    with open(text_path, 'r', encoding=encoding) as f:
                        text = f.read().strip()
                    if text:
                        break
                except:
                    continue

            if not text:
                text = "empty text"

            # 简单的文本清洗
            import re
            text = str(text).lower().strip()
            text = re.sub(r'\s+', ' ', text)

            # 使用BERT tokenizer或备用tokenizer
            if self.tokenizer is not None:
                try:
                    encoding = self.tokenizer.encode_plus(
                        text,
                        add_special_tokens=True,
                        max_length=self.max_length,
                        padding='max_length',
                        truncation=True,
                        return_tensors='pt'
                    )
                    return {
                        'input_ids': encoding['input_ids'].flatten(),
                        'attention_mask': encoding['attention_mask'].flatten()
                    }
                except Exception as e:
                    print(f"  ⚠️ BERT tokenization失败: {e}")

            # 备用方案：简单的字符级tokenization
            print(f"  使用备用tokenizer处理: {os.path.basename(text_path)}")
            text_chars = list(text[:self.max_length-2])
            char_ids = [ord(c) % 1000 for c in text_chars]

            # 添加特殊token
            char_ids = [101] + char_ids + [102]  # [CLS]和[SEP]

            # Padding
            if len(char_ids) < self.max_length:
                char_ids = char_ids + [0] * (self.max_length - len(char_ids))
            else:
                char_ids = char_ids[:self.max_length]

            attention_mask = [1 if x != 0 else 0 for x in char_ids]

            return {
                'input_ids': torch.tensor(char_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
            }

    # ========== 3. 修复图像处理器 ==========
    print("\n3. 初始化图像处理器...")

    class FixedImagePreprocessor:
        """修复的图像预处理器"""
        def __init__(self, img_size=224):
            # 训练时的数据增强
            self.train_transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.RandomCrop((img_size, img_size)),
                transforms.RandomHorizontalFlip(p=0.3),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

            # 验证/测试时的转换
            self.val_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

        def preprocess(self, image_path, augment=False):
            try:
                image = Image.open(image_path).convert('RGB')
                if augment:
                    return self.train_transform(image)
                else:
                    return self.val_transform(image)
            except Exception as e:
                print(f"  图像加载失败 {image_path}: {e}")
                # 返回黑色图像
                return torch.zeros(3, 224, 224)

    # ========== 4. 修复数据集类 ==========
    print("\n4. 创建数据集...")

    class FixedMultiModalDataset(Dataset):
        """修复的数据集类"""
        def __init__(self, data_dir, label_file, text_processor, image_processor,
                     split='train', augment=False, max_samples=None):
            self.data_dir = data_dir
            self.text_processor = text_processor
            self.image_processor = image_processor
            self.augment = augment and split == 'train'

            # 读取标签
            self.data = []
            label_mapping = {'positive': 0, 'neutral': 1, 'negative': 2}

            try:
                with open(label_file, 'r', encoding='utf-8') as f:
                    lines = f.readlines()[1:]  # 跳过标题

                    # 限制样本数量用于快速测试
                    if max_samples:
                        lines = lines[:max_samples]

                    for i, line in enumerate(lines):
                        parts = line.strip().split(',')
                        if len(parts) >= 2:
                            guid = parts[0].strip()
                            label = parts[1].strip().lower()

                            if label in label_mapping:
                                text_path = os.path.join(data_dir, f'{guid}.txt')
                                image_path = os.path.join(data_dir, f'{guid}.jpg')

                                # 检查文件是否存在
                                text_exists = os.path.exists(text_path)
                                img_exists = os.path.exists(image_path)

                                if text_exists and img_exists:
                                    self.data.append({
                                        'guid': guid,
                                        'label': label_mapping[label],
                                        'text_path': text_path,
                                        'image_path': image_path
                                    })
                                else:
                                    if i < 5:  # 只显示前5个缺失文件
                                        print(f"  ⚠️ 文件缺失: guid={guid}, text={text_exists}, image={img_exists}")
            except Exception as e:
                print(f"  读取标签文件失败: {e}")

            print(f"  ✅ 加载了 {len(self.data)} 个样本")

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            item = self.data[idx]

            # 处理文本
            text_data = self.text_processor.preprocess(item['text_path'])

            # 处理图像
            image = self.image_processor.preprocess(item['image_path'], self.augment)

            return {
                'input_ids': text_data['input_ids'],
                'attention_mask': text_data['attention_mask'],
                'image': image,
                'label': torch.tensor(item['label'], dtype=torch.long),
                'guid': item['guid']
            }

    # ========== 5. 修复模型类 ==========
    print("\n5. 初始化模型...")

    class FixedMultiModalModel(nn.Module):
        """修复的多模态模型"""
        def __init__(self, num_classes=3, fusion_method='late'):
            super(FixedMultiModalModel, self).__init__()
            self.fusion_method = fusion_method

            # ===== 文本编码器 =====
            print("  初始化文本编码器...")
            try:
                if os.path.exists('./bert_model/'):
                    self.text_encoder = BertModel.from_pretrained('./bert_model/')
                else:
                    # 创建一个简化的BERT（避免下载）
                    from transformers import BertConfig
                    config = BertConfig(
                        hidden_size=768,
                        num_hidden_layers=4,
                        num_attention_heads=8,
                        intermediate_size=3072,
                        hidden_act="gelu"
                    )
                    self.text_encoder = BertModel(config)
                    print("  ⚠️ 使用随机初始化的BERT（本地模型不存在）")
            except Exception as e:
                print(f"  ⚠️ BERT加载失败，使用备用文本编码器: {e}")
                # 备用文本编码器
                self.text_encoder = nn.Sequential(
                    nn.Embedding(1000, 768),
                    nn.Linear(768, 768),
                    nn.Tanh()
                )

            text_hidden_size = 768

            # ===== 图像编码器 =====
            print("  初始化图像编码器...")
            try:
                # 使用预训练的ResNet-18（更小更快）
                self.image_encoder = models.resnet18(pretrained=True)
                self.image_encoder.fc = nn.Identity()
                image_hidden_size = 512
            except:
                print("  ⚠️ ResNet加载失败，使用备用图像编码器")
                # 备用图像编码器
                self.image_encoder = nn.Sequential(
                    nn.Conv2d(3, 64, 3, padding=1),
                    nn.ReLU(),
                    nn.MaxPool2d(2),
                    nn.Conv2d(64, 128, 3, padding=1),
                    nn.ReLU(),
                    nn.MaxPool2d(2),
                    nn.Conv2d(128, 256, 3, padding=1),
                    nn.ReLU(),
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Flatten(),
                    nn.Linear(256, 512)
                )
                image_hidden_size = 512

            # ===== 分类头 =====
            self.text_classifier = nn.Sequential(
                nn.Dropout(0.3),
                nn.Linear(text_hidden_size, 128),
                nn.ReLU(),
                nn.Linear(128, num_classes)
            )

            self.image_classifier = nn.Sequential(
                nn.Dropout(0.3),
                nn.Linear(image_hidden_size, 128),
                nn.ReLU(),
                nn.Linear(128, num_classes)
            )

            # ===== 融合分类器 =====
            if fusion_method == 'early':
                fusion_dim = text_hidden_size + image_hidden_size
                self.fusion_classifier = nn.Sequential(
                    nn.Dropout(0.3),
                    nn.Linear(fusion_dim, 256),
                    nn.ReLU(),
                    nn.Linear(256, num_classes)
                )
            elif fusion_method == 'late':
                self.fusion_classifier = nn.Linear(num_classes * 2, num_classes)
            else:  # middle fusion
                self.text_projection = nn.Linear(text_hidden_size, 256)
                self.image_projection = nn.Linear(image_hidden_size, 256)
                self.fusion_classifier = nn.Sequential(
                    nn.Dropout(0.3),
                    nn.Linear(512, 256),
                    nn.ReLU(),
                    nn.Linear(256, num_classes)
                )

            print(f"  ✅ 模型初始化完成，融合方法: {fusion_method}")

        def forward(self, input_ids, attention_mask, image):
            # 文本特征
            if isinstance(self.text_encoder, BertModel):
                try:
                    text_outputs = self.text_encoder(
                        input_ids=input_ids,
                        attention_mask=attention_mask
                    )
                    text_features = text_outputs.last_hidden_state[:, 0, :]  # [CLS] token
                except:
                    # 如果BERT失败，使用备用方法
                    batch_size = input_ids.size(0)
                    text_features = torch.randn(batch_size, 768).to(input_ids.device)
            else:
                # 备用文本编码器
                text_features = self.text_encoder(input_ids)

            # 图像特征
            image_features = self.image_encoder(image)

            # 单模态预测
            text_logits = self.text_classifier(text_features)
            image_logits = self.image_classifier(image_features)

            # 多模态融合
            if self.fusion_method == 'early':
                fused_features = torch.cat([text_features, image_features], dim=1)
                fused_logits = self.fusion_classifier(fused_features)
            elif self.fusion_method == 'late':
                fused_logits = self.fusion_classifier(
                    torch.cat([text_logits, image_logits], dim=1)
                )
            else:  # middle fusion
                text_proj = self.text_projection(text_features)
                image_proj = self.image_projection(image_features)
                fused_features = torch.cat([text_proj, image_proj], dim=1)
                fused_logits = self.fusion_classifier(fused_features)

            return fused_logits, text_logits, image_logits

    # ========== 6. 修复训练函数 ==========
    print("\n6. 设置训练参数...")

    def fixed_train_epoch(model, dataloader, criterion, optimizer, device):
        """修复的训练epoch函数"""
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        progress_bar = tqdm(dataloader, desc='训练')
        for batch_idx, batch in enumerate(progress_bar):
            # 移动到设备
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            # 前向传播
            optimizer.zero_grad()
            try:
                fused_logits, text_logits, image_logits = model(input_ids, attention_mask, images)

                # 计算损失
                loss_fused = criterion(fused_logits, labels)
                loss_text = criterion(text_logits, labels)
                loss_image = criterion(image_logits, labels)
                loss = loss_fused + 0.2 * loss_text + 0.2 * loss_image

                # 反向传播
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                # 统计
                total_loss += loss.item()
                _, predicted = torch.max(fused_logits, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                # 更新进度条
                progress_bar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{100 * correct / total:.1f}%'
                })
            except Exception as e:
                print(f"\n⚠️ 训练batch {batch_idx}失败: {e}")
                continue

        return total_loss / max(len(dataloader), 1), 100 * correct / max(total, 1)

    def fixed_validate(model, dataloader, criterion, device):
        """修复的验证函数"""
        model.eval()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(dataloader, desc='验证'):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                images = batch['image'].to(device)
                labels = batch['label'].to(device)

                try:
                    fused_logits, _, _ = model(input_ids, attention_mask, images)
                    loss = criterion(fused_logits, labels)

                    total_loss += loss.item()
                    _, predicted = torch.max(fused_logits, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                    all_preds.extend(predicted.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())
                except Exception as e:
                    print(f"验证batch失败: {e}")
                    continue

        if total > 0:
            # 计算分类报告
            from sklearn.metrics import classification_report
            print("\n分类报告:")
            print(classification_report(all_labels, all_preds,
                                       target_names=['positive', 'neutral', 'negative'],
                                       digits=4))

        return total_loss / max(len(dataloader), 1), 100 * correct / max(total, 1), all_preds, all_labels

    # ========== 7. 主训练流程 ==========
    print("\n7. 开始主训练流程...")

    # 创建处理器
    text_processor = FixedTextPreprocessor(max_length=128)
    image_processor = FixedImagePreprocessor(img_size=224)

    # 创建完整数据集（限制样本数量以加快速度）
    print("\n加载训练数据...")
    full_dataset = FixedMultiModalDataset(
        data_dir='./data/data',
        label_file='./data/train.txt',
        text_processor=text_processor,
        image_processor=image_processor,
        split='train',
        max_samples=1000  # 只用1000个样本，可以调大
    )

    # 划分训练集和验证集
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size]
    )

    print(f"训练集大小: {len(train_dataset)}")
    print(f"验证集大小: {len(val_dataset)}")

    # 创建数据加载器（使用单进程避免问题）
    train_loader = DataLoader(
        train_dataset,
        batch_size=16,  # 较小的batch size
        shuffle=True,
        num_workers=0,  # Windows下用0
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=16,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    # 初始化模型
    model = FixedMultiModalModel(
        num_classes=3,
        fusion_method='late'  # 可以使用 'early', 'middle', 'late'
    ).to(device)

    # 损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=2e-5,
        weight_decay=0.01
    )

    # 学习率调度器
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)

    # 训练循环
    best_acc = 0
    patience = 3
    patience_counter = 0

    print("\n" + "=" * 60)
    print("开始训练循环")
    print("=" * 60)

    for epoch in range(10):  # 训练10个epoch
        print(f"\nEpoch {epoch+1}/10")
        print("-" * 50)

        # 训练
        train_loss, train_acc = fixed_train_epoch(
            model, train_loader, criterion, optimizer, device
        )

        # 验证
        val_loss, val_acc, val_preds, val_labels = fixed_validate(
            model, val_loader, criterion, device
        )

        # 更新学习率
        scheduler.step()

        print(f"\n训练集 - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"验证集 - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")

        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            patience_counter = 0

            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
            }, 'fixed_best_model.pth')

            print(f"✅ 保存最佳模型，验证准确率: {val_acc:.2f}%")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"⚠️ 早停触发，连续{patience}个epoch验证准确率未提升")
                break

    # ========== 8. 消融实验 ==========
    print("\n" + "=" * 60)
    print("消融实验")
    print("=" * 60)

    # 加载最佳模型
    if os.path.exists('fixed_best_model.pth'):
        print("加载最佳模型进行消融实验...")
        checkpoint = torch.load('fixed_best_model.pth')
        model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    # 测试不同模态的效果
    text_correct = 0
    image_correct = 0
    fusion_correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc='消融实验'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            fused_logits, text_logits, image_logits = model(
                input_ids, attention_mask, images
            )

            # 文本模态
            _, text_preds = torch.max(text_logits, 1)
            text_correct += (text_preds == labels).sum().item()

            # 图像模态
            _, image_preds = torch.max(image_logits, 1)
            image_correct += (image_preds == labels).sum().item()

            # 多模态融合
            _, fusion_preds = torch.max(fused_logits, 1)
            fusion_correct += (fusion_preds == labels).sum().item()

            total += labels.size(0)

    print(f"\n消融实验结果:")
    print(f"仅文本模态准确率: {100 * text_correct / total:.2f}%")
    print(f"仅图像模态准确率: {100 * image_correct / total:.2f}%")
    print(f"多模态融合准确率: {100 * fusion_correct / total:.2f}%")
    print(f"多模态相对提升: {100 * fusion_correct / total - max(100 * text_correct / total, 100 * image_correct / total):.2f}%")

    # ========== 9. 测试集预测 ==========
    print("\n" + "=" * 60)
    print("测试集预测")
    print("=" * 60)

    class FixedTestDataset(Dataset):
        """修复的测试数据集类"""
        def __init__(self, test_file, data_dir, text_processor, image_processor):
            self.data_dir = data_dir
            self.text_processor = text_processor
            self.image_processor = image_processor

            # 读取测试文件
            self.guids = []
            try:
                with open(test_file, 'r', encoding='utf-8') as f:
                    lines = f.readlines()[1:]  # 跳过标题行
                    for line in lines:
                        guid = line.strip().split(',')[0]
                        self.guids.append(guid)
                print(f"  测试集GUID数量: {len(self.guids)}")
            except Exception as e:
                print(f"  读取测试文件失败: {e}")

        def __len__(self):
            return len(self.guids)

        def __getitem__(self, idx):
            guid = self.guids[idx]

            # 文本路径和图像路径
            text_path = os.path.join(self.data_dir, f'{guid}.txt')
            image_path = os.path.join(self.data_dir, f'{guid}.jpg')

            # 处理文本
            text_data = self.text_processor.preprocess(text_path)

            # 处理图像
            try:
                image = self.image_processor.preprocess(image_path, augment=False)
            except:
                image = torch.zeros(3, 224, 224)

            return {
                'input_ids': text_data['input_ids'],
                'attention_mask': text_data['attention_mask'],
                'image': image,
                'guid': guid
            }

    def fixed_predict_test_set(model, test_file, output_file):
        """修复的测试集预测函数"""
        model.eval()

        # 创建测试数据集
        test_dataset = FixedTestDataset(
            test_file=test_file,
            data_dir='./data/data',
            text_processor=text_processor,
            image_processor=image_processor
        )

        test_loader = DataLoader(
            test_dataset, batch_size=16, shuffle=False, num_workers=0
        )

        # 预测
        predictions = {}
        label_mapping = {0: 'positive', 1: 'neutral', 2: 'negative'}

        with torch.no_grad():
            for batch in tqdm(test_loader, desc='预测'):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                images = batch['image'].to(device)
                guids = batch['guid']

                fused_logits, _, _ = model(input_ids, attention_mask, images)
                _, preds = torch.max(fused_logits, 1)

                for guid, pred in zip(guids, preds.cpu().numpy()):
                    predictions[guid] = label_mapping.get(pred, 'neutral')

        # 写入结果文件
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                f.write("guid,tag\n")
                with open(test_file, 'r', encoding='utf-8') as test_f:
                    lines = test_f.readlines()[1:]
                    for line in lines:
                        guid = line.strip().split(',')[0]
                        tag = predictions.get(guid, 'neutral')
                        f.write(f"{guid},{tag}\n")

            print(f"✅ 预测结果已保存到: {output_file}")
            print(f"预测样本数量: {len(predictions)}")
        except Exception as e:
            print(f"写入结果文件失败: {e}")

        return predictions

    # 执行预测
    if os.path.exists('./data/test_without_label.txt'):
        print("\n开始预测测试集...")
        predictions = fixed_predict_test_set(
            model,
            test_file='./data/test_without_label.txt',
            output_file='./fixed_submission.txt'
        )
    else:
        print("\n⚠️ 测试文件不存在，跳过预测步骤")

    # ========== 10. 实验总结 ==========
    print("\n" + "=" * 60)
    print("实验总结")
    print("=" * 60)

    print(f"\n实验配置:")
    print(f"  - 训练样本: {len(train_dataset)}")
    print(f"  - 验证样本: {len(val_dataset)}")
    print(f"  - Batch Size: 16")
    print(f"  - 学习率: 2e-5")
    print(f"  - 融合方法: late")
    print(f"  - 训练轮数: {epoch + 1}")

    print(f"\n最佳结果:")
    print(f"  - 最佳验证准确率: {best_acc:.2f}%")

    print(f"\n生成文件:")
    print(f"  - 最佳模型: fixed_best_model.pth")
    print(f"  - 预测结果: fixed_submission.txt")

    print("\n✅ 实验完成！")
    print("=" * 60)

多模态情感分类实验 - 完整可运行版本

1. 环境设置...

2. 初始化BERT模型...

3. 初始化图像处理器...

4. 创建数据集...

5. 初始化模型...

6. 设置训练参数...

7. 开始主训练流程...
  尝试从本地加载BERT tokenizer...
  ✅ 本地BERT tokenizer加载成功

加载训练数据...
  ✅ 加载了 1000 个样本
训练集大小: 800
验证集大小: 200
  初始化文本编码器...
  初始化图像编码器...
  ✅ 模型初始化完成，融合方法: late

开始训练循环

Epoch 1/10
--------------------------------------------------


训练: 100%|██████████| 50/50 [00:28<00:00,  1.75it/s, loss=1.3449, acc=52.1%]
验证: 100%|██████████| 13/13 [00:02<00:00,  4.56it/s]



分类报告:
              precision    recall  f1-score   support

    positive     0.6224    0.9919    0.7649       123
     neutral     0.0000    0.0000    0.0000        21
    negative     0.5000    0.0357    0.0667        56

    accuracy                         0.6200       200
   macro avg     0.3741    0.3425    0.2772       200
weighted avg     0.5228    0.6200    0.4891       200


训练集 - Loss: 1.4091, Acc: 52.12%
验证集 - Loss: 0.8650, Acc: 62.00%
✅ 保存最佳模型，验证准确率: 62.00%

Epoch 2/10
--------------------------------------------------


训练: 100%|██████████| 50/50 [00:28<00:00,  1.76it/s, loss=1.4789, acc=66.6%]
验证: 100%|██████████| 13/13 [00:02<00:00,  4.72it/s]



分类报告:
              precision    recall  f1-score   support

    positive     0.6964    0.9512    0.8041       123
     neutral     0.0000    0.0000    0.0000        21
    negative     0.5312    0.3036    0.3864        56

    accuracy                         0.6700       200
   macro avg     0.4092    0.4183    0.3968       200
weighted avg     0.5771    0.6700    0.6027       200


训练集 - Loss: 1.1675, Acc: 66.62%
验证集 - Loss: 0.7902, Acc: 67.00%
✅ 保存最佳模型，验证准确率: 67.00%

Epoch 3/10
--------------------------------------------------


训练: 100%|██████████| 50/50 [00:27<00:00,  1.80it/s, loss=0.9718, acc=77.6%]
验证: 100%|██████████| 13/13 [00:02<00:00,  4.51it/s]



分类报告:
              precision    recall  f1-score   support

    positive     0.7591    0.8455    0.8000       123
     neutral     0.0000    0.0000    0.0000        21
    negative     0.5238    0.5893    0.5546        56

    accuracy                         0.6850       200
   macro avg     0.4276    0.4783    0.4515       200
weighted avg     0.6135    0.6850    0.6473       200


训练集 - Loss: 0.9458, Acc: 77.62%
验证集 - Loss: 0.7484, Acc: 68.50%
✅ 保存最佳模型，验证准确率: 68.50%

Epoch 4/10
--------------------------------------------------


训练: 100%|██████████| 50/50 [00:29<00:00,  1.69it/s, loss=0.7972, acc=86.2%]
验证: 100%|██████████| 13/13 [00:02<00:00,  4.54it/s]



分类报告:
              precision    recall  f1-score   support

    positive     0.7664    0.8537    0.8077       123
     neutral     0.6667    0.0952    0.1667        21
    negative     0.5333    0.5714    0.5517        56

    accuracy                         0.6950       200
   macro avg     0.6555    0.5068    0.5087       200
weighted avg     0.6907    0.6950    0.6687       200


训练集 - Loss: 0.7502, Acc: 86.25%
验证集 - Loss: 0.7529, Acc: 69.50%
✅ 保存最佳模型，验证准确率: 69.50%

Epoch 5/10
--------------------------------------------------


训练: 100%|██████████| 50/50 [00:27<00:00,  1.81it/s, loss=0.8536, acc=90.2%]
验证: 100%|██████████| 13/13 [00:02<00:00,  4.72it/s]



分类报告:
              precision    recall  f1-score   support

    positive     0.7669    0.8293    0.7969       123
     neutral     0.8000    0.1905    0.3077        21
    negative     0.5000    0.5536    0.5254        56

    accuracy                         0.6850       200
   macro avg     0.6890    0.5244    0.5433       200
weighted avg     0.6957    0.6850    0.6695       200


训练集 - Loss: 0.6298, Acc: 90.25%
验证集 - Loss: 0.7925, Acc: 68.50%

Epoch 6/10
--------------------------------------------------


训练: 100%|██████████| 50/50 [00:27<00:00,  1.81it/s, loss=0.5501, acc=92.6%]
验证: 100%|██████████| 13/13 [00:02<00:00,  4.72it/s]



分类报告:
              precision    recall  f1-score   support

    positive     0.7500    0.8537    0.7985       123
     neutral     0.6667    0.0952    0.1667        21
    negative     0.5263    0.5357    0.5310        56

    accuracy                         0.6850       200
   macro avg     0.6477    0.4949    0.4987       200
weighted avg     0.6786    0.6850    0.6572       200


训练集 - Loss: 0.5549, Acc: 92.62%
验证集 - Loss: 0.7902, Acc: 68.50%

Epoch 7/10
--------------------------------------------------


训练: 100%|██████████| 50/50 [00:27<00:00,  1.81it/s, loss=0.6914, acc=93.5%]
验证: 100%|██████████| 13/13 [00:02<00:00,  4.85it/s]



分类报告:
              precision    recall  f1-score   support

    positive     0.7630    0.8374    0.7984       123
     neutral     0.8333    0.2381    0.3704        21
    negative     0.5254    0.5536    0.5391        56

    accuracy                         0.6950       200
   macro avg     0.7072    0.5430    0.5693       200
weighted avg     0.7038    0.6950    0.6809       200


训练集 - Loss: 0.5031, Acc: 93.50%
验证集 - Loss: 0.7999, Acc: 69.50%
⚠️ 早停触发，连续3个epoch验证准确率未提升

消融实验
加载最佳模型进行消融实验...


消融实验: 100%|██████████| 13/13 [00:02<00:00,  4.75it/s]



消融实验结果:
仅文本模态准确率: 69.00%
仅图像模态准确率: 61.00%
多模态融合准确率: 69.50%
多模态相对提升: 0.50%

测试集预测

开始预测测试集...
  测试集GUID数量: 511


预测: 100%|██████████| 32/32 [00:07<00:00,  4.29it/s]

✅ 预测结果已保存到: ./fixed_submission.txt
预测样本数量: 511

实验总结

实验配置:
  - 训练样本: 800
  - 验证样本: 200
  - Batch Size: 16
  - 学习率: 2e-5
  - 融合方法: late
  - 训练轮数: 7

最佳结果:
  - 最佳验证准确率: 69.50%

生成文件:
  - 最佳模型: fixed_best_model.pth
  - 预测结果: fixed_submission.txt

✅ 实验完成！



