数据加载与预处理

In [None]:
import os
import cv2
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

class DoorplateDataset(Dataset):
    def __init__(self, data_dir, json_file, transform=None, max_seq_length=10, is_test=False):
        self.data_dir = data_dir
        self.transform = transform
        self.max_seq_length = max_seq_length
        self.is_test = is_test
        
        # 读取标注文件
        if not self.is_test:
            with open(json_file, 'r') as f:
                self.annotations = json.load(f)
            self.image_files = list(self.annotations.keys())
        else:
            self.image_files = [f for f in os.listdir(data_dir) if f.endswith('.png') or f.endswith('.jpg')]
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.data_dir, img_name)
        
        # 读取图像
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if not self.is_test:
            # 获取标注信息
            ann = self.annotations[img_name]
            
            # 构建目标数据
            labels = ann['label']
            lefts = ann['left']
            tops = ann['top']
            widths = ann['width']
            heights = ann['height']
            
            # 创建目标序列和边界框
            seq = np.zeros(self.max_seq_length, dtype=np.int64)
            seq_len = min(len(labels), self.max_seq_length)
            
            # 填充序列
            for i in range(seq_len):
                seq[i] = labels[i]
            
            # 归一化边界框坐标
            h, w = image.shape[:2]
            bbox = np.array([lefts[0]/w, tops[0]/h, widths[0]/w, heights[0]/h], dtype=np.float32)
                
            # 应用变换
            if self.transform:
                transformed = self.transform(image=image)
                image = transformed['image']
            
            return image, seq, seq_len, bbox
        
        else:
            # 测试集只返回图像
            if self.transform:
                transformed = self.transform(image=image)
                image = transformed['image']
            
            return image, img_name

def get_transforms(is_train=True):
    if is_train:
        return A.Compose([
            A.Resize(height=128, width=384),
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1),
                A.RandomGamma(),
                A.GaussNoise(),
            ], p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(height=128, width=384),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

def get_dataloader(data_dir, json_file, batch_size, is_train=True, is_test=False):
    transform = get_transforms(is_train)
    dataset = DoorplateDataset(data_dir, json_file, transform, is_test=is_test)
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_train,
        num_workers=4,
        pin_memory=True
    )

工具函数

In [None]:
import os
import time
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total_samples = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for batch in pbar:
        images, targets, target_lengths, bbox_targets = batch
        images = images.to(device)
        targets = targets.to(device)
        bbox_targets = bbox_targets.to(device)
        
        # 前向传播
        logits, bbox_pred = model(images)
        
        # 计算序列损失
        batch_size, seq_len, num_classes = logits.size()
        input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long, device=device)
        
        # CTC损失
        log_probs = F.log_softmax(logits, dim=2).transpose(0, 1)  # (seq_len, batch, num_classes)
        ctc_loss = criterion(log_probs, targets, input_lengths, target_lengths)
        
        # 定位损失 (使用MSE)
        bbox_loss = F.mse_loss(bbox_pred, bbox_targets)
        
        # 总损失
        loss = ctc_loss + bbox_loss
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # 计算准确率
        _, preds = logits.max(2)
        for i in range(batch_size):
            pred_str = ''.join([str(p.item()) for p in preds[i] if p != 0])
            target_str = ''.join([str(targets[i][j].item()) for j in range(target_lengths[i].item())])
            if pred_str == target_str:
                correct += 1
        total_samples += batch_size
        
        pbar.set_postfix({"loss": loss.item(), "accuracy": correct / total_samples})
    
    return total_loss / len(dataloader), correct / total_samples

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total_samples = 0
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Evaluating")
        for batch in pbar:
            images, targets, target_lengths, bbox_targets = batch
            images = images.to(device)
            targets = targets.to(device)
            bbox_targets = bbox_targets.to(device)
            
            # 前向传播
            logits, bbox_pred = model(images)
            
            # 计算序列损失
            batch_size, seq_len, num_classes = logits.size()
            input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long, device=device)
            
            # CTC损失
            log_probs = F.log_softmax(logits, dim=2).transpose(0, 1)
            ctc_loss = criterion(log_probs, targets, input_lengths, target_lengths)
            
            # 定位损失
            bbox_loss = F.mse_loss(bbox_pred, bbox_targets)
            
            # 总损失
            loss = ctc_loss + bbox_loss
            
            total_loss += loss.item()
            
            # 计算准确率
            _, preds = logits.max(2)
            for i in range(batch_size):
                pred_str = ''.join([str(p.item()) for p in preds[i] if p != 0])
                target_str = ''.join([str(targets[i][j].item()) for j in range(target_lengths[i].item())])
                if pred_str == target_str:
                    correct += 1
            total_samples += batch_size
            
            pbar.set_postfix({"val_loss": loss.item(), "val_accuracy": correct / total_samples})
    
    return total_loss / len(dataloader), correct / total_samples

def decode_predictions(logits, image_size):
    """从模型输出解码得到字符和位置"""
    batch_size, seq_len, num_classes = logits.size()
    
    # 解码序列
    _, preds = logits.max(2)
    
    # 移除重复字符和空白符（CTC解码）
    decoded_preds = []
    for i in range(batch_size):
        pred = preds[i].cpu().numpy()
        decoded = []
        prev_char = -1
        for p in pred:
            if p != 0 and p != prev_char:  # 0通常是空白符
                decoded.append(p)
            prev_char = p
        decoded_preds.append(decoded)
    
    return decoded_preds

def predict(model, dataloader, device, output_dir):
    model.eval()
    predictions = {}
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Predicting"):
            images, img_names = batch
            images = images.to(device)
            
            # 前向传播
            logits, bbox_pred = model(images)
            
            # 解码预测
            decoded = decode_predictions(logits, images.shape[-2:])
            
            # 将预测结果保存
            for i, img_name in enumerate(img_names):
                h, w = 128, 384  # 模型输入的图像大小
                
                # 获取原始图像大小
                orig_img = cv2.imread(os.path.join(dataloader.dataset.data_dir, img_name))
                orig_h, orig_w = orig_img.shape[:2]
                
                # 缩放边界框到原始图像大小
                x, y, width, height = bbox_pred[i].cpu().numpy()
                x = int(x * orig_w)
                y = int(y * orig_h)
                width = int(width * orig_w)
                height = int(height * orig_h)
                
                # 保存预测结果
                label = [int(c) for c in decoded[i]]
                
                if img_name not in predictions:
                    predictions[img_name] = {
                        "label": label,
                        "top": [y],
                        "left": [x],
                        "height": [height],
                        "width": [width]
                    }
                else:
                    predictions[img_name]["label"].extend(label)
                    predictions[img_name]["top"].extend([y])
                    predictions[img_name]["left"].extend([x])
                    predictions[img_name]["height"].extend([height])
                    predictions[img_name]["width"].extend([width])
    
    # 保存预测结果
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, "predictions.json"), "w") as f:
        json.dump(predictions, f)
    
    return predictions

模型定义

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class CRNN(nn.Module):
    def __init__(self, num_classes=10, rnn_hidden_size=256):
        super(CRNN, self).__init__()
        
        # 使用ResNet18作为骨干网络，提取特征
        resnet = torchvision.models.resnet18(pretrained=True)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])  # 移除最后的池化和全连接层
        
        # 定位网络
        self.localization = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(128 * 1 * 3, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 4)  # 输出4个坐标值：x, y, width, height
        )
        
        # 序列化特征
        self.avg_pool = nn.AdaptiveAvgPool2d((1, None))
        
        # LSTM层
        self.rnn = nn.LSTM(512, rnn_hidden_size, bidirectional=True, batch_first=True)
        
        # 分类层
        self.classifier = nn.Linear(rnn_hidden_size * 2, num_classes)
        
    def forward(self, x):
        # 特征提取
        features = self.backbone(x)
        
        # 定位
        bbox = self.localization(features)
        bbox = torch.sigmoid(bbox)  # 将输出限制在0-1范围内
        
        # 序列化特征
        seq_features = self.avg_pool(features)
        seq_features = seq_features.squeeze(2)
        seq_features = seq_features.permute(0, 2, 1)  # [batch, width, channels]
        
        # RNN处理
        rnn_out, _ = self.rnn(seq_features)
        
        # 分类
        logits = self.classifier(rnn_out)
        
        return logits, bbox

模型训练

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from model import CRNN
from data_loader import get_dataloader
from utils import train_epoch, evaluate

def main():
    # 配置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_epochs = 100
    batch_size = 32  # 增加批量大小
    learning_rate = 0.0005  # 降低学习率
    early_stop_patience = 15  # 增加早停耐心值
    
    # 数据集路径
    base_dir = '../tcdata'
    train_data_dir = os.path.join(base_dir, 'mchar_train')
    train_json = os.path.join(base_dir, 'train.json')
    val_data_dir = os.path.join(base_dir, 'mchar_val')
    val_json = os.path.join(base_dir, 'mchar_val.json')
    
    # 模型保存路径
    model_save_dir = '../user_data/model_data'
    os.makedirs(model_save_dir, exist_ok=True)
    
    # 获取数据加载器
    train_loader = get_dataloader(train_data_dir, train_json, batch_size, is_train=True)
    val_loader = get_dataloader(val_data_dir, val_json, batch_size, is_train=False)
    
    # 初始化模型
    model = CRNN(num_classes=11)  # 包括0-9和空白符
    model = model.to(device)
    
    # 损失函数和优化器
    criterion = nn.CTCLoss(blank=10, reduction='mean')
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)  # 添加权重衰减
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)  # 调整学习率调度
    
    # 训练循环
    best_val_loss = float('inf')
    no_improve_epochs = 0  # 记录验证损失没有改善的轮数
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # 训练
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        
        # 验证
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        # 学习率调整
        scheduler.step(val_loss)
        
        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve_epochs = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc
            }, os.path.join(model_save_dir, 'best_model.pth'))
            print("保存最佳模型！")
        else:
            no_improve_epochs += 1
        
        # 保存最后模型
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc
        }, os.path.join(model_save_dir, 'last_model.pth'))
        
        # 早停检查
        if no_improve_epochs >= early_stop_patience:
            print(f"\n早停：验证损失连续{early_stop_patience}轮没有改善")
            break
    
    print("\n训练完成！")
    print(f"最佳验证损失: {best_val_loss:.4f}")
    print(f"最后模型保存在: {os.path.join(model_save_dir, 'last_model.pth')}")
    print(f"最佳模型保存在: {os.path.join(model_save_dir, 'best_model.pth')}")

if __name__ == "__main__":
    main()

模型推理与测试

In [None]:
import os
import json
import cv2
import torch
import numpy as np
from model import CRNN
from data_loader import get_transforms
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

def check_image_quality(image):
    """
    检查图像质量
    返回: (bool, str) - (是否通过检查, 失败原因)
    """
    # 检查图像是否为空
    if image is None:
        return False, "图像读取失败"
    
    # 检查图像尺寸
    h, w = image.shape[:2]
    if h < 16 or w < 16:  # 降低最小尺寸要求
        return False, f"图像尺寸过小: {h}x{w}"
    
    # 检查图像是否模糊
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    laplacian = cv2.Laplacian(gray, cv2.CV_64F).var()
    if laplacian < 30:  # 降低模糊检测阈值
        return False, f"图像模糊: {laplacian:.2f}"
    
    # 检查图像亮度
    brightness = np.mean(gray)
    if brightness < 20:  # 降低暗度阈值
        return False, f"图像过暗: {brightness:.2f}"
    if brightness > 230:  # 提高亮度阈值
        return False, f"图像过亮: {brightness:.2f}"
    
    # 检查图像对比度
    contrast = np.std(gray)
    if contrast < 20:  # 添加对比度检查
        return False, f"图像对比度过低: {contrast:.2f}"
    
    return True, "图像质量正常"

def preprocess_image(image):
    """
    预处理图像
    返回: 预处理后的图像
    """
    if image is None:
        return None
    
    # 调整图像大小
    h, w = image.shape[:2]
    if h < 16 or w < 16:
        scale = max(16/h, 16/w)
        image = cv2.resize(image, (int(w*scale), int(h*scale)))
    
    # 调整亮度和对比度
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    brightness = np.mean(gray)
    contrast = np.std(gray)
    
    # 如果图像过暗
    if brightness < 20:
        alpha = 1.5  # 增加亮度
        beta = 30    # 增加偏移
        image = cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
    # 如果图像过亮
    elif brightness > 230:
        alpha = 0.8  # 降低亮度
        image = cv2.convertScaleAbs(image, alpha=alpha, beta=0)
    
    # 如果对比度过低
    if contrast < 20:
        # 使用直方图均衡化
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        gray = cv2.equalizeHist(gray)
        image = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
    
    return image

class TestDataset(Dataset):
    def __init__(self, data_dir, transform=None, max_samples=None):
        self.data_dir = data_dir
        self.transform = transform
        # 获取所有图片文件并按数字顺序排序
        self.image_files = []
        for f in os.listdir(data_dir):
            if f.endswith('.png') or f.endswith('.jpg'):
                try:
                    # 验证文件名格式
                    num = int(f.split('.')[0])
                    self.image_files.append(f)
                except ValueError:
                    print(f"警告：跳过无效的文件名 {f}")
                    continue
        self.image_files.sort(key=lambda x: int(x.split('.')[0]))
        
        if max_samples is not None:
            self.image_files = self.image_files[:max_samples]
        
        # 创建debug目录
        self.debug_dir = os.path.join(os.path.dirname(data_dir), 'debug_images')
        os.makedirs(self.debug_dir, exist_ok=True)
        
        # 打印加载的图片数量
        print(f"成功加载 {len(self.image_files)} 张图片")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.data_dir, img_name)
        
        try:
            # 读取图像
            image = cv2.imread(img_path)
            if image is None:
                print(f"警告：无法读取图像 {img_name}")
                # 返回一个占位图像
                image = np.zeros((128, 384, 3), dtype=np.uint8)
                return image, img_name, (128, 384), False
            
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # 预处理图像
            image = preprocess_image(image)
            if image is None:
                print(f"警告：图像 {img_name} 预处理失败")
                image = np.zeros((128, 384, 3), dtype=np.uint8)
                return image, img_name, (128, 384), False
            
            # 保存原始图像尺寸
            orig_h, orig_w = image.shape[:2]
            
            # 检查图像质量
            is_valid, reason = check_image_quality(image)
            if not is_valid:
                print(f"警告：图像 {img_name} 质量检查未通过 - {reason}")
                # 保存问题图像用于调试
                debug_path = os.path.join(self.debug_dir, f"invalid_{img_name}")
                cv2.imwrite(debug_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
                # 尝试再次预处理
                image = preprocess_image(image)
                if image is not None:
                    is_valid, _ = check_image_quality(image)
            
            # 应用变换
            if self.transform:
                transformed = self.transform(image=image)
                image = transformed['image']
            
            return image, img_name, (orig_h, orig_w), is_valid
            
        except Exception as e:
            print(f"处理图像 {img_name} 时出错: {str(e)}")
            # 返回一个占位图像
            image = np.zeros((128, 384, 3), dtype=np.uint8)
            return image, img_name, (128, 384), False

def decode_predictions(pred, blank=10):
    """从CTC预测中解码标签"""
    # 找到最可能的类别
    _, max_indices = pred.max(2)
    
    # 将预测转换为numpy数组
    max_indices = max_indices.cpu().numpy()
    
    batch_results = []
    for indices in max_indices:
        # 移除重复的元素
        collapsed = []
        previous = None
        for idx in indices:
            if idx != previous and idx != blank:
                collapsed.append(int(idx))
            previous = idx
        batch_results.append(collapsed)
    
    return batch_results

def main():
    # 配置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    test_data_dir = '../tcdata/mchar_test_a'
    model_path = '../user_data/model_data/best_model.pth'
    output_dir = '../prediction_result'
    os.makedirs(output_dir, exist_ok=True)
    
    # 获取所有测试图片
    all_image_files = sorted(
        [f for f in os.listdir(test_data_dir) if f.endswith('.png')],
        key=lambda x: int(x.split('.')[0])
    )
    print(f"找到 {len(all_image_files)} 张测试图片")
    
    # 检查模型文件
    if not os.path.exists(model_path):
        print(f"错误：模型文件 {model_path} 不存在")
        return
    
    # 转换
    transform = A.Compose([
        A.Resize(height=128, width=384),
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
            A.RandomGamma(gamma_limit=(80, 120)),
            A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8)),
        ], p=0.5),
        A.OneOf([
            A.GaussNoise(var_limit=(10.0, 50.0)),
            A.GaussianBlur(blur_limit=3),
            A.MedianBlur(blur_limit=3),
        ], p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    # 加载测试数据
    print("加载测试数据...")
    test_dataset = TestDataset(test_data_dir, transform)
    test_loader = DataLoader(
        test_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=0,  # 修改为0以避免多进程问题
        pin_memory=True,
        drop_last=False,  # 确保不丢弃最后一批
        collate_fn=lambda x: (torch.stack([item[0] for item in x]),  # 图像
                            [item[1] for item in x],  # 文件名
                            [item[2] for item in x],  # 原始尺寸
                            torch.tensor([item[3] for item in x]))  # 有效性标记
    )
    
    # 加载模型
    print("加载模型...")
    model = CRNN(num_classes=11)  # 包括0-9和空白符
    try:
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(device)
        model.eval()
        print(f"成功加载最佳模型，验证损失: {checkpoint['val_loss']:.4f}, 验证准确率: {checkpoint['val_acc']:.4f}")
    except Exception as e:
        print(f"加载模型时出错: {str(e)}")
        return
    
    # 预测结果
    predictions = {}
    processed_files = set()
    invalid_files = set()
    error_files = set()
    
    print("开始预测...")
    with torch.no_grad():
        for batch_idx, (images, img_names, orig_sizes, is_valid) in enumerate(tqdm(test_loader, desc="Predicting")):
            try:
                images = images.to(device)
                
                # 前向传播
                logits, bbox_preds = model(images)
                
                # 确保模型输出与输入批次大小一致
                batch_size = images.size(0)
                if logits.size(0) != batch_size or bbox_preds.size(0) != batch_size:
                    print(f"警告：模型输出批次大小不一致，输入: {batch_size}, logits: {logits.size(0)}, bbox: {bbox_preds.size(0)}")
                    continue
                
                # 解码预测
                batch_labels = decode_predictions(logits)
                
                # 确保所有数据长度一致
                if len(batch_labels) != batch_size:
                    print(f"警告：解码后的标签数量不一致，预期: {batch_size}, 实际: {len(batch_labels)}")
                    continue
                
                # 处理每个图像的预测
                for i in range(batch_size):
                    try:
                        img_name = img_names[i]
                        labels = batch_labels[i]
                        is_img_valid = is_valid[i]
                        orig_size = orig_sizes[i]
                        
                        # 获取原始图像尺寸
                        if isinstance(orig_size, tuple):
                            orig_h, orig_w = orig_size
                        else:
                            orig_h, orig_w = orig_size[0], orig_size[1]
                        
                        # 预测的边界框
                        bbox = bbox_preds[i].cpu().numpy()
                        x = int(bbox[0] * orig_w)
                        y = int(bbox[1] * orig_h)
                        width = int(bbox[2] * orig_w)
                        height = int(bbox[3] * orig_h)
                        
                        # 保存预测结果
                        predictions[img_name] = {
                            "label": labels,
                            "top": [y],
                            "left": [x],
                            "height": [height],
                            "width": [width]
                        }
                        
                        if is_img_valid:
                            processed_files.add(img_name)
                        else:
                            invalid_files.add(img_name)
                            print(f"图像 {img_name} 质量检查未通过")
                            
                    except Exception as e:
                        print(f"处理图像 {img_name} 时出错: {str(e)}")
                        error_files.add(img_name)
                        # 为错误图像添加空预测
                        predictions[img_name] = {
                            "label": [],
                            "top": [0],
                            "left": [0],
                            "height": [0],
                            "width": [0]
                        }
            
            except Exception as e:
                print(f"批次 {batch_idx} 处理出错: {str(e)}")
                for img_name in img_names:
                    error_files.add(img_name)
                    predictions[img_name] = {
                        "label": [],
                        "top": [0],
                        "left": [0],
                        "height": [0],
                        "width": [0]
                    }
    
    # 检查是否有遗漏的图片
    missing_files = set(all_image_files) - processed_files - invalid_files - error_files
    if missing_files:
        print(f"\n警告：有 {len(missing_files)} 张图片未被处理：")
        for f in sorted(missing_files, key=lambda x: int(x.split('.')[0])):
            print(f"  {f}")
            # 为未处理的图片添加空预测
            predictions[f] = {
                "label": [],
                "top": [0],
                "left": [0],
                "height": [0],
                "width": [0]
            }
    
    # 保存预测结果
    print("保存预测结果...")
    with open(os.path.join(output_dir, "predictions.json"), "w") as f:
        json.dump(predictions, f)
    
    # 打印统计信息
    print("\n测试完成！统计信息：")
    print(f"总图片数: {len(all_image_files)}")
    print(f"成功预测: {len(processed_files)}")
    print(f"质量不合格: {len(invalid_files)}")
    print(f"处理出错: {len(error_files)}")
    print(f"未处理: {len(missing_files)}")
    
    # 保存统计信息
    stats = {
        "total_images": len(all_image_files),
        "successful_predictions": len(processed_files),
        "invalid_quality": len(invalid_files),
        "processing_errors": len(error_files),
        "missing_files": len(missing_files),
        "invalid_files": list(invalid_files),
        "error_files": list(error_files),
        "missing_files_list": list(missing_files)
    }
    with open(os.path.join(output_dir, "stats.json"), "w") as f:
        json.dump(stats, f, indent=4)

if __name__ == "__main__":
    main()

生成csv

In [None]:
import json
import csv
import os

def generate_csv(pred_file, test_dir, output_file):
    # 读取预测结果
    with open(pred_file, 'r') as f:
        predictions = json.load(f)
    
    # 获取所有测试图片并排序
    all_images = sorted(
        [f for f in os.listdir(test_dir) if f.endswith('.png')],
        key=lambda x: int(x.split('.')[0])
    )
    
    # 创建输出目录（如果不存在）
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # 写入CSV文件
    with open(output_file, 'w', newline='') as f:
        writer = csv.writer(f)
        # 写入表头
        writer.writerow(['file_name', 'file_code'])
        
        # 写入每一行数据
        processed_count = 0
        empty_count = 0
        
        for img_name in all_images:
            if img_name in predictions:
                # 将标签列表转换为字符串
                code = ''.join([str(l) for l in predictions[img_name]['label']]) if predictions[img_name]['label'] else ''
                processed_count += 1
            else:
                code = ''
                empty_count += 1
            writer.writerow([img_name, code])
    
    print(f'CSV文件已生成：{output_file}')
    print(f'总图片数：{len(all_images)}')
    print(f'有预测结果：{processed_count} 张')
    print(f'无预测结果：{empty_count} 张')

if __name__ == '__main__':
    pred_file = '../prediction_result/predictions.json'
    test_dir = '../tcdata/mchar_test_a'
    output_file = '../prediction_result/result.csv'
    
    generate_csv(pred_file, test_dir, output_file) 

预测结果检查

In [None]:
import json

# 读取预测结果
with open('../prediction_result/predictions.json', 'r') as f:
    predictions = json.load(f)

# 显示第一个预测结果的详细信息
first_img = list(predictions.keys())[0]
print(f"第一张图片 {first_img} 的预测结果：")
print(json.dumps(predictions[first_img], indent=2, ensure_ascii=False))

# 显示前5张图片的边界框数量
print("\n前5张图片的边界框数量：")
for i, (img_name, pred) in enumerate(list(predictions.items())[:5]):
    print(f"{img_name}: {len(pred['left'])} 个边界框，标签：{pred['label']}") 

可视化预测结果

In [None]:
import os
import cv2
import json
import numpy as np
from pathlib import Path

def visualize_predictions(test_dir, pred_file, output_dir, num_samples=10):
    # 读取预测结果
    with open(pred_file, 'r') as f:
        predictions = json.load(f)
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 设置颜色
    box_color = (0, 255, 0)  # 绿色边界框
    text_color = (255, 255, 255)  # 白色文字
    
    # 获取前num_samples个预测结果
    for i, (img_name, pred) in enumerate(list(predictions.items())[:num_samples]):
        # 读取原始图像
        img_path = os.path.join(test_dir, img_name)
        image = cv2.imread(img_path)
        
        # 获取边界框和标签
        x = pred['left'][0]
        y = pred['top'][0]
        w = pred['width'][0]
        h = pred['height'][0]
        labels = pred['label']
        
        # 绘制边界框
        cv2.rectangle(image, (x, y), (x + w, y + h), box_color, 2)
        
        # 将标签转换为字符串
        label_str = ''.join([str(l) for l in labels])
        
        # 计算文本大小以优化显示位置
        (text_w, text_h), baseline = cv2.getTextSize(
            label_str, cv2.FONT_HERSHEY_SIMPLEX, 0.9, 2)
        
        # 添加文本背景
        cv2.rectangle(image, 
                     (x, y - text_h - 10), 
                     (x + text_w, y),
                     box_color, -1)  # -1 表示填充矩形
        
        # 添加标签文本
        cv2.putText(image, label_str, 
                    (x, y - 5), 
                    cv2.FONT_HERSHEY_SIMPLEX, 
                    0.9, text_color, 2)
        
        # 保存结果
        output_path = os.path.join(output_dir, f'vis_{img_name}')
        cv2.imwrite(output_path, image)
        
        print(f'已处理图片 {i+1}/{num_samples}: {img_name}，检测到数字：{label_str}')

if __name__ == '__main__':
    test_dir = '../tcdata/mchar_test_a'
    pred_file = '../prediction_result/predictions.json'
    output_dir = '../visualization_result'
    
    visualize_predictions(test_dir, pred_file, output_dir, num_samples=10)
    print('可视化完成！结果保存在:', output_dir) 