In [1]:
import os
from glob import glob
import torch
from PIL import Image
import torch.nn as nn
from scipy.cluster.hierarchy import weighted
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, MultiStepLR
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import pandas as pd
import numpy as np
import torch.nn.functional as F
import json
from torchvision.models.mobilenet import mobilenet_v2, MobileNet_V2_Weights
from torchvision.models.resnet import resnet50, resnet18, resnet101, ResNet18_Weights, ResNet50_Weights, \
    ResNet101_Weights
from torchsummary import summary
import random
from code.DigitsDataset import DigitsDataset
from code.jessie_utils import data_dir
from code.Config import Config

print("load config:")
print("batch_size: ", Config.batch_size)
print("start_epoch", Config.start_epoch)

load config:
batch_size:  64
start_epoch 0


In [2]:
#超参数设定 （注释：定义在独立文件中）
# class Config:
# ....


config = Config()

In [3]:
# # 定义数据 , Windows 可能会导致一些错误，需要将它封装在单独的文件中
# class DigitsDataset(Dataset):
#     """
#
#     DigitsDataset
#
#     Params:
#       data_dir(string): data directory
#
#       label_path(string): label path
#
#       aug(bool): wheather do image augmentation, default: True
#     """


In [4]:
# 自定义网络
# 以ResNet50 为主干网络
class DigitsResNet101(nn.Module):
    def __init__(self):
        super(DigitsResNet101, self).__init__()
        self.net = resnet101(weights=ResNet101_Weights.IMAGENET1K_V2)

        self.net = nn.Sequential(*list(self.net.children())[:-1])
        self.cnn = self.net

        # 定义4个隐藏全连接层（每个对应一个字符位置）
        self.hd_fc1 = nn.Linear(2048, 128)
        self.hd_fc2 = nn.Linear(2048, 128)
        self.hd_fc3 = nn.Linear(2048, 128)
        self.hd_fc4 = nn.Linear(2048, 128)

        # Dropout层防止过拟合
        self.dropout_1 = nn.Dropout(0.25)
        self.dropout_2 = nn.Dropout(0.25)
        self.dropout_3 = nn.Dropout(0.25)
        self.dropout_4 = nn.Dropout(0.25)

        # 最终分类层（每个位置输出11类：0-9数字 + 空白符）
        self.fc1 = nn.Linear(128, config.class_num)
        self.fc2 = nn.Linear(128, config.class_num)
        self.fc3 = nn.Linear(128, config.class_num)
        self.fc4 = nn.Linear(128, config.class_num)

    def forward(self, img):
        feat = self.cnn(img)
        feat = feat.view(feat.shape[0], -1)  #展平

        # 多任务分支（4个独立分类器）
        feat1 = self.dropout_1(self.hd_fc1(feat))  #降维+dropout
        feat2 = self.dropout_2(self.hd_fc2(feat))
        feat3 = self.dropout_3(self.hd_fc3(feat))
        feat4 = self.dropout_4(self.hd_fc4(feat))

        # 分类输出（4个位置的预测结果）
        c1 = self.fc1(feat1)  # 每个输出形状: (batch, 11)
        c2 = self.fc2(feat2)
        c3 = self.fc3(feat3)
        c4 = self.fc4(feat4)

        return c1, c2, c3, c4


class DigitsResNet50(nn.Module):
    def __init__(self):
        super().__init__()

        #resnet50
        #self.net = resnet50(pretrained=True)  # deprecated
        self.net = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        # 移除原始ResNet50的最后全连接层，保留特征提取层
        # children()[:-1] 表示取除最后一层外的所有层
        self.net = nn.Sequential(*list(self.net.children())[:-1])
        self.cnn = self.net

        # 定义4个隐藏全连接层（每个对应一个字符位置）
        self.hd_fc1 = nn.Linear(2048, 128)
        self.hd_fc2 = nn.Linear(2048, 128)
        self.hd_fc3 = nn.Linear(2048, 128)
        self.hd_fc4 = nn.Linear(2048, 128)

        # Dropout层防止过拟合
        self.dropout_1 = nn.Dropout(0.25)
        self.dropout_2 = nn.Dropout(0.25)
        self.dropout_3 = nn.Dropout(0.25)
        self.dropout_4 = nn.Dropout(0.25)
        # 最终分类层（每个位置输出11类：0-9数字 + 空白符）
        self.fc1 = nn.Linear(128, config.class_num)
        self.fc2 = nn.Linear(128, config.class_num)
        self.fc3 = nn.Linear(128, config.class_num)
        self.fc4 = nn.Linear(128, config.class_num)

    def forward(self, img):
        feat = self.cnn(img)
        feat = feat.view(feat.shape[0], -1)  #展平

        # 多任务分支（4个独立分类器）
        feat1 = self.dropout_1(self.hd_fc1(feat))  #降维+dropout
        feat2 = self.dropout_2(self.hd_fc2(feat))
        feat3 = self.dropout_3(self.hd_fc3(feat))
        feat4 = self.dropout_4(self.hd_fc4(feat))

        # 分类输出（4个位置的预测结果）
        c1 = self.fc1(feat1)  # 每个输出形状: (batch, 11)
        c2 = self.fc2(feat2)
        c3 = self.fc3(feat3)
        c4 = self.fc4(feat4)

        return c1, c2, c3, c4


class DigitsResNet18(nn.Module):
    def __init__(self):
        super().__init__()

        #resNet18
        self.net = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        # 移除原始全连接层（fc），替换为Identity层（即无操作）
        self.net.fc = nn.Identity()
        # 定义BatchNorm层
        self.bn = nn.BatchNorm1d(512)

        # 四个独立连接层
        self.fc1 = nn.Linear(512, config.class_num)
        self.fc2 = nn.Linear(512, config.class_num)
        self.fc3 = nn.Linear(512, config.class_num)
        self.fc4 = nn.Linear(512, config.class_num)

    def forward(self, img):
        # feature = self.net(img).squeeze()
        features = self.net(img)
        features = features.view(features.shape[0], -1)
        features = self.bn(features)  #添加BatchNorm

        fc1 = self.fc1(features)
        fc2 = self.fc2(features)
        fc3 = self.fc3(features)
        fc4 = self.fc4(features)

        return fc1, fc2, fc3, fc4


# modelNet_V2为主干网络
class DigitsMobileNet(nn.Module):
    def __init__(self, class_num=11):
        super().__init__()
        # 加载预训练MobileNetV2的特征层（去掉分类层）
        self.net = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1).features
        # 自适应池化层（输出形状：1x1）
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # 批量归一化
        self.bn = nn.BatchNorm1d(1280)

        # 四个独立全连接层（每个对应一个字符位置）
        self.fc1 = nn.Linear(1280, class_num)
        self.fc2 = nn.Linear(1280, class_num)
        self.fc3 = nn.Linear(1280, class_num)
        self.fc4 = nn.Linear(1280, class_num)

    def forward(self, img):
        # 特征提取与池化
        features = self.avgpool(self.net(img))  # 输出形状：(batch, 1280, 1, 1)
        features = features.view(-1, 1280)  # 展平为 (batch, 1280)
        features = self.bn(features)  # 批归一化

        # 四个分支的分类结果
        fc1 = self.fc1(features)
        fc2 = self.fc2(features)
        fc3 = self.fc3(features)
        fc4 = self.fc4(features)

        return fc1, fc2, fc3, fc4

In [5]:
net = DigitsResNet101()
# print(net)

In [6]:
# 标签平滑
# ----------------------------------- LabelSmoothEntropy ----------------------------------- #
class LabelSmoothEntropyLoss(nn.Module):
    def __init__(self, smooth=0.1, class_weights=None, size_average='mean'):
        super().__init__()
        self.size_average = size_average  # 损失平均方式（'mean'或'sum'）
        # 平滑参数，通常设置成0.1
        self.smooth = smooth
        self.class_weights = class_weights  # 类别权重（可选）

    def forward(self, preds, targets):
        """
        preds: 模型输出的原始logits（未归一化），形状为 (batch_size, num_classes)
        targets: 真实标签的索引，形状为 (batch_size,)
        """
        # 计算平滑后的标签分布
        num_classes = preds.shape[0]  # 获取类别总数
        lb_pos = 1 - self.smooth  # 真实类别的权重
        lb_neg = self.smooth / (num_classes - 1)  # 其他类别的均匀权重（原代码中的分母有误，应为类别数而非batch_size）

        # 创建平滑后的标签张量
        smoothed_lb = torch.zeros_like(preds)  # 初始化为全零
        smoothed_lb.fill_(lb_neg)  # 填充所有位置为lb_neg
        # 将真实标签位置的值设为lb_pos（scatter_的dim=1表示按列操作）
        smoothed_lb.scatter_(1, targets[:, None], lb_pos)

        # 计算log_softmax（数值稳定的概率对数）
        log_soft = F.log_softmax(preds, dim=1)

        # 应用类别权重（如果有）
        if self.class_weights is not None:
            # 扩展权重形状为(1, num_classes)，与logits相乘
            loss = -log_soft * smoothed_lb * self.class_weights[None, :]
        else:
            loss = -log_soft * smoothed_lb

        # 按样本求和（每个样本的损失是所有类别的加权和）
        loss = loss.sum(1)

        # 根据size_average参数返回结果
        if self.size_average == 'mean':
            return loss.mean()
        elif self.size_average == 'sum':
            return loss.sum()
        else:
            raise NotImplementedError("size_average must be 'mean' or 'sum'")

In [7]:

# 定义训练模型
class Trainer:
    def __init__(self, val=True):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.train_set = DigitsDataset(mode='train')
        # pin_memory提升CPU到GPU的传输速度; drop_last丢弃不完整 batch;collate_fn: 自定义 batch 合并逻辑动态调整裁剪宽度
        self.train_loader = DataLoader(self.train_set, batch_size=config.batch_size, shuffle=True, num_workers=6,
                                       pin_memory=True, drop_last=True, collate_fn=self.train_set.collect_fn)

        if val:
            self.val_loader = DataLoader(DigitsDataset(mode='val', aug=False), batch_size=config.batch_size,
                                         num_workers=6, pin_memory=True, drop_last=False)
        else:
            self.val_loader = None

        # 可以切换使用ResNet 和 mobilenet
        # self.model = DigitsResNet50().to(self.device)
        # self.model = DigitsMobileNet().to(self.device)
        # self.model = DigitsResNet18().to(self.device)
        self.model = DigitsResNet101().to(self.device)

        # 平滑标签
        self.criterion = LabelSmoothEntropyLoss().to(self.device)

        # optimizer可选择SGD或者Adam
        # 随机梯度下降
        # self.optimizer = SGD(self.model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weights_decay,nesterov=True)
        # Adam
        # betas,控制一阶矩（梯度的移动平均）和控制二阶矩（梯度平方的移动平均）的衰减率。
        # β₁=0.9：表示梯度的移动平均更关注近期梯度（衰减率高），有助于加速收敛。β₂=0.999：二阶矩的衰减率更高，能更稳定地估计梯度方差，适合噪声较大的数据。若数据噪声大，可适当降低β₂（如0.995），以更快响应梯度变化，若模型收敛过慢，可尝试增大β₁（如0.95）以加速一阶矩的更新。
        # eps（平滑项）当梯度平方的移动平均值接近零时，eps作为分母的最小值，避免梯度更新失效
        # amsgrad 是否启用AMSGrad变种 可改进Adam在后期阶段的收敛性，避免学习率过早衰减导致的停滞
        self.optimizer = Adam(self.model.parameters(), lr=config.lr, betas=(0.9, 0.999), eps=1e-08,
                              weight_decay=config.weights_decay, amsgrad=False)

        #学习率周期动态调整 余弦退火算法
        # self.lr_scheduler = CosineAnnealingWarmRestarts(self.optimizer, 10, 2, eta_min=10e-5),eta_min=10e-4,10e-5
        # `T_0`：初始周期长度（单位：epoch 或 step），T_mult：周期倍增因子（`T_mult=2`，则每个周期长度是前一个2倍），eta_min：学习率下限
        self.lr_scheduler = CosineAnnealingWarmRestarts(self.optimizer, T_0=10, T_mult=2, eta_min=0)
        # self.lr_scheduler = (self.optimizer, [10, 20, 30], 0.5)
        self.best_acc = 0

        # 是否载入预训练模型
        if config.pretrained is not None:
            self.load_model(config.pretrained)  #加载模型（自定义函数）
            # print('load model from %s' % config.pretrained)
            if self.val_loader is not None:
                acc = self.eval()
            self.best_acc = acc
            print('Load model from %s, Eval Acc: %.2f' % (config.pretrained, acc * 100))

    def train(self):  #训练主循环
        for epoch in range(config.start_epoch, config.epoches):
            print("================================Start Training at epoch %d ================================" % (
                        epoch + 1))
            acc = self.train_epoch(epoch)  #单轮训练细节
            # 周期性验证和保存模型
            if (epoch + 1) % config.eval_interval == 0:
                print("Start Evaluation at epoch %d" % (epoch + 1))
                if self.val_loader is not None:
                    acc = self.eval()  #模型验证
                #保存最优模型
                if acc > self.best_acc:
                    os.makedirs(config.checkpoints, exist_ok=True)
                    # 更换模型时记得更换路径名
                    save_path = config.checkpoints + 'epoch-mobilenet-%d-bn-acc-%.2f.pth' % (epoch + 1, acc * 100)
                    self.save_model(save_path)
                    print('%s saved successfully...' % save_path)
                    self.best_acc = acc

    # 单轮训练细节
    def train_epoch(self, epoch):
        total_loss = 0
        corrects = 0
        tbar = tqdm(self.train_loader)  #进度条
        self.model.train()
        for i, (img, label) in enumerate(tbar):
            img, label = img.to(self.device), label.to(self.device)
            # 梯度清零
            self.optimizer.zero_grad()
            # 前向传播输出，4个数字的预测结果
            pred = self.model(img)
            # 计算总损失（4个字符的损失求和）
            loss = self.criterion(pred[0], label[:, 0]) + \
                   self.criterion(pred[1], label[:, 1]) + \
                   self.criterion(pred[2], label[:, 2]) + \
                   self.criterion(pred[3], label[:, 3]) \

            total_loss += loss.item()
            # 反向传播
            loss.backward()  # 计算梯度
            self.optimizer.step()  # 更新模型参数

            # 计算准确率（所有4个字符均正确）
            temp = torch.stack([ \
                pred[0].argmax(1) == label[:, 0], \
                pred[1].argmax(1) == label[:, 1], \
                pred[2].argmax(1) == label[:, 2], \
                pred[3].argmax(1) == label[:, 3], ], dim=1)

            corrects += torch.all(temp, dim=1).sum().item()

            # 更新进度条信息
            tbar.set_description(
                'loss: %.3f, Train Acc: %.3f' % (total_loss / (i + 1), corrects * 100 / ((i + 1) * config.batch_size)))

            # 每50个batch更新学习率(余弦退火算法)
            if (i + 1) % config.print_interval == 0:
                self.lr_scheduler.step()

        return corrects * 100 / ((i + 1) * config.batch_size)

    # 模型验证
    def eval(self):
        self.model.eval()  # 切换为评估模式
        corrects = 0
        total_loss = 0
        with torch.no_grad():  # 关闭梯度计算
            tbar = tqdm(self.val_loader)
            for i, (img, label) in enumerate(tbar):
                img, label = img.to(self.device), label.to(self.device)
                pred = self.model(img)

                loss = self.criterion(pred[0], label[:, 0]) + \
                       self.criterion(pred[1], label[:, 1]) + \
                       self.criterion(pred[2], label[:, 2]) + \
                       self.criterion(pred[3], label[:, 3]) \

                total_loss += loss.item()

                temp = torch.stack([
                    pred[0].argmax(1) == label[:, 0], \
                    pred[1].argmax(1) == label[:, 1], \
                    pred[2].argmax(1) == label[:, 2], \
                    pred[3].argmax(1) == label[:, 3], \
                    ], dim=1)

                corrects += torch.all(temp, dim=1).sum().item()
                tbar.set_description('loss: %.3f, Val Acc: %.2f' % (
                total_loss / (i + 1), corrects * 100 / ((i + 1) * config.batch_size)))
        self.model.train()  # 切换为训练模式
        return corrects / (len(self.val_loader) * config.batch_size)

    def save_model(self, save_path, save_opt=False, save_config=False):
        """保存模型、优化器和配置"""
        dicts = {}
        dicts['model'] = self.model.state_dict()
        if save_opt:
            dicts['opt'] = self.optimizer.state_dict()
        if save_config:  #将超参数保存在模型文件中
            dicts['config'] = {s: config.__getattribute__(s) for s in dir(config) if not s.startswith('_')}

        torch.save(dicts, save_path)

    def load_model(self, load_path, changed=False, save_opt=False, save_config=False):
        """加载模型、优化器和配置"""
        dicts = torch.load(load_path)
        if not changed:  # 模型结构未发生变化
            self.model.load_state_dict(dicts['model'])  #直接加载模型权重

        else:  #兼容性加载
            dicts = torch.load(load_path)['model']

            keys = list(net.state_dict().keys())  #获取当前模型的所有参数名（按当前结构）。
            values = list(dicts.values())  #获取加载的权重值,（按保存时的结构)

            new_dicts = {k: v for k, v in zip(keys, values)}  #将当前模型的参数名与加载的权重值按顺序一一对应
            self.model.load_state_dict(new_dicts)  #加载新字典

        if save_opt:
            self.optimizer.load_state_dict(dicts['opt'])  #恢复优化器的学习率、动量等状态，用于断点续训

        if save_config:  #配置参数恢复
            for k, v in dicts['config'].items():
                config.__setattr__(k, v)  #加载的配置会覆盖当前 config 的值，需确保兼容性

**执行训练和验证**


In [None]:
trainer = Trainer()
trainer.train()


**单模型预测**

In [8]:
# 定义预测模型
def predict(model_path):
    test_loader = DataLoader(DigitsDataset(mode='test'), batch_size=config.batch_size, shuffle=False, num_workers=8,
                             pin_memory=True, drop_last=False)
    results = []
    # 单个模型预测可以把另外的模型注释掉
    # res50_path = config.checkpoints + 'epoch-resnet50-64-bn-acc-76.55.pth'
    # res50_net = DigitsResNet50().cuda()
    # res50_net.load_state_dict(torch.load(res50_path)['model'])

    # res_net101 = DigitsResNet101().cuda()
    # res101_path=config.checkpoints+'epoch-DigitsResnet101-47-bn-acc-76.16.pth'
    # res_net101.load_state_dict(torch.load(res101_path)['model'])

    mb_path = config.checkpoints + 'epoch-mobilenet-33-bn-acc-73.56.pth'
    mb_net = DigitsMobileNet().cuda()
    mb_net.load_state_dict(torch.load(mb_path)['model'])

    print('Load model from %s successfully'%model_path)
    tbar = tqdm(test_loader)
    mb_net.eval() #切换模型要修改
    with torch.no_grad():
        for i, (img, img_names) in enumerate(tbar):
            img = img.cuda()
            pred = mb_net(img)

            results += [[name, code] for name, code in zip(img_names, parse2class(pred))]

    results = sorted(results, key=lambda x: x[0])

    write2csv(results)
    return results


def parse2class(prediction):
    """

     Params:
     prediction(tuple of tensor):

    """
    ch1, ch2, ch3, ch4 = prediction
    char_list = [str(i) for i in range(10)]   # ['0', '1', ..., '9']
    char_list.append('')
    ch1, ch2, ch3, ch4 = ch1.argmax(1), ch2.argmax(1), ch3.argmax(1), ch4.argmax(1)
    ch1, ch2, ch3, ch4 = [char_list[i.item()] for i in ch1], [char_list[i.item()] for i in ch2], [char_list[i.item()]
                                                                                                  for i in ch3], [
        char_list[i.item()] for i in ch4]

    res = [c1 + c2 + c3 + c4 for c1, c2, c3, c4 in zip(ch1, ch2, ch3, ch4)]

    return res


def write2csv(results):
    """

    results(list):

    """
    #定义输出文件
    df = pd.DataFrame(results, columns=['file_name', 'file_code'])
    df['file_name'] = df['file_name'].apply(lambda x: x.split('\\')[-1])
    save_name = './prediction_result/results-resnet101.csv'
    df.to_csv(save_name, sep=',', index=None)
    print('Results.saved to %s' % save_name)


In [14]:
predict('./user_data/model_data/checkpoints/epoch-mobilenet-33-bn-acc-73.56.pth')
torch.cuda.empty_cache()

Load model from ./user_data/model_data/checkpoints/epoch-mobilenet-33-bn-acc-73.56.pth successfully


  0%|          | 0/625 [00:39<?, ?it/s]

Results.saved to ./prediction_result/results-resnet101.csv


[['./tcdata/mchar_test_a\\000000.png', '199'],
 ['./tcdata/mchar_test_a\\000001.png', '290'],
 ['./tcdata/mchar_test_a\\000002.png', '113'],
 ['./tcdata/mchar_test_a\\000003.png', '97'],
 ['./tcdata/mchar_test_a\\000004.png', '63'],
 ['./tcdata/mchar_test_a\\000005.png', '639'],
 ['./tcdata/mchar_test_a\\000006.png', '126'],
 ['./tcdata/mchar_test_a\\000007.png', '1475'],
 ['./tcdata/mchar_test_a\\000008.png', '48'],
 ['./tcdata/mchar_test_a\\000009.png', '118'],
 ['./tcdata/mchar_test_a\\000010.png', '281'],
 ['./tcdata/mchar_test_a\\000011.png', '610'],
 ['./tcdata/mchar_test_a\\000012.png', '60'],
 ['./tcdata/mchar_test_a\\000013.png', '772'],
 ['./tcdata/mchar_test_a\\000014.png', '836'],
 ['./tcdata/mchar_test_a\\000015.png', '40'],
 ['./tcdata/mchar_test_a\\000016.png', '29'],
 ['./tcdata/mchar_test_a\\000017.png', '60'],
 ['./tcdata/mchar_test_a\\000018.png', '15'],
 ['./tcdata/mchar_test_a\\000019.png', '204'],
 ['./tcdata/mchar_test_a\\000020.png', '284'],
 ['./tcdata/mchar_te

## 模型融合


In [9]:
def stack_eval(mb_path,res_path):
    mb_net = DigitsMobileNet().cuda()
    mb_net.load_state_dict(torch.load(mb_path)['model'])

    res_net = DigitsResNet101().cuda()
    res_net.load_state_dict(torch.load(res_path)['model'])

    mb_net.eval()
    res_net.eval()

    dataset = DigitsDataset(mode='val',aug=False)

    val_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False, num_workers=8,pin_memory=True, drop_last=False)
    corrects = 0

    with torch.no_grad():  # 关闭梯度计算
        tbar = tqdm(val_loader)
        for i, (img, label) in enumerate(tbar):
            img, label = img.cuda(), label.cuda()
            pred = [0.4*a+0.6*b for a, b in zip(mb_net(img), res_net(img))] # 对预测结果加权平均求和



            temp = torch.stack([
                pred[0].argmax(1) == label[:, 0], \
                pred[1].argmax(1) == label[:, 1], \
                pred[2].argmax(1) == label[:, 2], \
                pred[3].argmax(1) == label[:, 3], \
                ], dim=1)

            corrects += torch.all(temp, dim=1).sum().item()
            tbar.set_description('Val Acc: %.2f' % (corrects * 100 / ((i + 1) * config.batch_size)))
    res_net.train()  # 切换为训练模式
    mb_net.train()
    return corrects / (len(val_loader) * config.batch_size)

# resnet101和mobilenet融合
def stack_predict(mb_path,res_path):
    test_loader = DataLoader(DigitsDataset(mode='test'), batch_size=config.batch_size, shuffle=False, num_workers=8,pin_memory=True, drop_last=False)

    results = []

    mb_net = DigitsMobileNet().cuda()
    mb_net.load_state_dict(torch.load(mb_path)['model'])

    res_net = DigitsResNet101().cuda()
    res_net.load_state_dict(torch.load(res_path)['model'])

    mb_net.eval()
    res_net.eval()

    tbar = tqdm(test_loader)
    with torch.no_grad():
        for i, (img, label) in enumerate(tbar):
            img = img.cuda()
            pred = [0.4*a+0.6*b for a, b in zip(mb_net(img), res_net(img))]
            results += [[name, code] for name, code in zip(label, parse2class(pred))]


    results = sorted(results,key=lambda x: x[0])

    write2csv_stack(results)
    return results


def write2csv_stack(results):
    #定义输出文件
    df = pd.DataFrame(results, columns=['file_name', 'file_code'])
    df['file_name'] = df['file_name'].apply(lambda x: x.split('\\')[-1])
    save_name = './prediction_result/results-stack-res101_mb.csv'
    df.to_csv(save_name, sep=',', index=None)
    print('Results.saved to %s' % save_name)

In [10]:
# stack_eval('./user_data/model_data/checkpoints/epoch-mobilenet-33-bn-acc-73.56.pth','./user_data/model_data/checkpoints/epoch-DigitsResnet101-47-bn-acc-76.16.pth')

In [11]:
stack_predict('./user_data/model_data/checkpoints/epoch-mobilenet-33-bn-acc-73.56.pth','./user_data/model_data/checkpoints/epoch-DigitsResnet101-47-bn-acc-76.16.pth')
torch.cuda.empty_cache()


  0%|          | 0/625 [00:41<?, ?it/s]

Results.saved to ./prediction_result/results-stack-res101+mb.csv
