Skip to content

度量学习

Mr.Li edited this page Jun 20, 2022 · 7 revisions

一. 安装依赖

pip install -r ./Package/requirements.txt 

二. 准备数据

根目录已包含CatDog数据集,且Config/已生成dataset.txt

格式一 : 统计精确率

  • 原理:训练集的某类所有图片特征的均值作为该类的特征中心,测试样本与某类的类中心距离最近即判定为该类。
  • txt格式为[类型, 类别名, 图像路径],
train, dog,  /xxxx/img1.jpg
val,   cat,  /xxxx/img2.jpg
test,  cat,  /xxxx/img3.jpg

格式二 : 统计误识率FPR下的通过率TPR

  • 训练集:[类型, 类别名, 图像路径],
  • 验证集/测试集:[类型,是否为同类,图片1,图片2] 样本对格式
train, dog,   img0.jpg
val,   true,  img1.jpg,  img2.jpg
val,   false, img3.jpg,  img4.jpg
test,  true,  img5.jpg,  img6.jpg
test,  false, img7.jpg,  img8.jpg

三. 训练

一. 训练

  1. Config/config.py (1)注释常规分类(2)打开度量学习

    # 常规分类
    # Backbone="resnet18"   # 主干网络
    # Loss="CrossEntropy"    # 损失函数
    
    # 度量学习
    Backbone = "resnet18"
    Loss = "ArcFace"
    Feature_dim = 128
  2. 开始训练

    假设单机2卡,执行如下命令 访问更多详情

    torchrun --nproc_per_node 2 metric_train.py
  3. 测试

    查看metric_test.py