Skip to content

oldfemalepig/resnet_exercise

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

resnet_exercise

文件介绍

flower_data:放置数据集,其中train和val是划分数据集后产生的
test:放置测试的图片
model.py:模型文件
predict.py:进行预测
split_data.py:对数据集进行划分
train.py:训练网络
class_data.json:编号-种类的对应

环境介绍

pytorch:1.6.0
torchvision:0.7.0

过程

一、下载数据集

1、 打开/flower_data/flower_link.txt中的链接,下载数据集
2、随后运行split_data.py,划分训练集和验证集

二、搭建模型

1、构建BasicBlock
先定义结构,由两块组成;再定义前向传播函数
2、构建Bottleneck
同样先定义结构,由三块组成(降维、卷积、升维);再定义前向传播函数
3、构建ResNet类
1)先定义第一阶段,为一个7x7的卷积处理,stride为2,然后经过池化处理
2)定义4个block,这里采用了_make_layer()函数来产生
3)最后采用平均池化,定义全连接层。
4)定义前向传播函数

三、训练模型

1、定义数据处理
训练集:先随机切割再resize到224大小,水平翻转,转化为Tensor,最后正则化
验证集:先resize到256,再进行中心切割,转化为Tensor,最后正则化
2、加载图片
采用datasets.ImageFolder从路径中读取图像数据,并经过transforms变换
构建种类-编号的dict,并写入json文件中
采用torch.utils.data.DataLoader,按照一个bacth_size,分批次加载数据为Tensor类型
3、获取网络
定义Loss和优化器
开始迭代训练

四、预测

可更改预测图片路径,来完成对图片的预测输出

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages