-
Notifications
You must be signed in to change notification settings - Fork 5
度量学习
Mr.Li edited this page Jun 20, 2022
·
7 revisions
pip install -r ./Package/requirements.txt
根目录已包含
CatDog
数据集,且Config/
已生成dataset.txt
。
格式一 : 统计精确率
- 原理:训练集的某类所有图片特征的均值作为该类的特征中心,测试样本与某类的类中心距离最近即判定为该类。
- txt格式为[类型, 类别名, 图像路径],
train, dog, /xxxx/img1.jpg
val, cat, /xxxx/img2.jpg
test, cat, /xxxx/img3.jpg
格式二 : 统计误识率FPR下的通过率TPR
- 训练集:[类型, 类别名, 图像路径],
- 验证集/测试集:[类型,是否为同类,图片1,图片2] 样本对格式
train, dog, img0.jpg
val, true, img1.jpg, img2.jpg
val, false, img3.jpg, img4.jpg
test, true, img5.jpg, img6.jpg
test, false, img7.jpg, img8.jpg
一. 训练
-
Config/config.py
(1)注释常规分类(2)打开度量学习# 常规分类 # Backbone="resnet18" # 主干网络 # Loss="CrossEntropy" # 损失函数 # 度量学习 Backbone = "resnet18" Loss = "ArcFace" Feature_dim = 128
-
开始训练
假设单机2卡,执行如下命令 访问更多详情
torchrun --nproc_per_node 2 metric_train.py
-
测试
查看metric_test.py