learning about how to use pytorch
PyTorch is a deep learning framework for fast, flexible experimentation
English ver: https://pytorch.org/docs/stable/index.html
chinese ver: https://pytorch-cn.readthedocs.io/zh/latest/
https://ptorch.com/docs/1/
index
Base on python3
-
Windows
-
Step1: Install ANACONDA: https://www.anaconda.com/download/
Add anaconda to environment path: installpath\Anaconda3 installpath\Anaconda3\Scripts installpath\Anaconda3\Library\bin
-
Step2: Install CUDA: https://developer.nvidia.com/cuda-downloads
-
Step3: Create a virtural environment for pytorch
conda create -n pytorch python=3.6
-
Step4: Activate your virtural environment
activate pytorch
PS: if you finish your work in the virtural environment and want to quit the ve, use the flowing command:
deactivate pytorch
-
Step5: In the virtural environment:
# for cuda8.0 conda install pytorch cuda80 -c pytorch # for cuda9.0 conda install pytorch cuda90 -c pytorch
-
Step6: use the flowing python code to test:
import torch print(torch.__version__)
-
Step7: install torchvision
pip install torchvision
-
-
Linux
-
Step1: install Anaconda
-
Step2: Add Tsinghua Open Source Mirror
conda config
Use the above command to generate a configuration file
.condarc
. Thenconda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ conda config --set show_channel_urls yes
delete '-defaults', the content of .condarc just like:
channels: - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/ show_channel_urls: true
-
Step3: install pytorch and torchvision:
# for cuda8.0 conda install pytorch torchvision -c pytorch # for cuda9.0 conda install pytorch torchvision cuda90 -c pytorch
-
The directory models/
has common deep learnning network models that were implemented by pytorch:
alexnet
VGG
ResNet
DenseNet
BNinception
caffe_resnet
fb_resnet
inception_resnetv2
inceptionv4
nasnet
nasnet_mobile
pnasnet
polynet
vggm
wide_resnet
xception
.
ps:由于模型最后没有经过softmax,所以多分类下loss function请使用CrossEntropyLoss。假如修改了代码让模型最后的输出经过了softmax,那么loss function使用NLLLoss(CrossEntropyLoss 等价于 softmax + NLLLoss)
关于读入的数据集可以有三种方法制作,详情点这里
关于如何训练,如何测试,如何上GPU,点这里
关于如何保存训练好的模型、加载模型 以及 fine-tune 点这里
The direcoty examples/
has simple examples for using pytorch
mnist.py
: A simple cnn to train and test mnist datasets
FashionAI
: 阿里天池全球挑战赛-服饰属性标签识别
To be add