Generalized Framework for PyTorch
Switch branches/tags
Nothing to show
Clone or download
Permalink
Failed to load latest commit information.
datasets adding new CNN model architectures, triplet loss, pep8 compliance and… Feb 25, 2018
evaluate adding new CNN model architectures, triplet loss, pep8 compliance and… Feb 25, 2018
losses some minor fixups Mar 16, 2018
models added two models: lbcresnet and pnnresnet, fixed JSON serialization e… Sep 19, 2018
plugins added two models: lbcresnet and pnnresnet, fixed JSON serialization e… Sep 19, 2018
.gitignore updated dataset root directory Feb 25, 2018
.gitmodules New visualization methods, multiline line plots, custom envronment, k… Feb 23, 2018
README.md adding new CNN model architectures, triplet loss, pep8 compliance and… Feb 25, 2018
args.txt added two models: lbcresnet and pnnresnet, fixed JSON serialization e… Sep 19, 2018
bad_grad_viz.py major update, some files still need to be updated to follow PEP8 Jan 19, 2018
checkpoints.py Update checkpoints.py Mar 1, 2018
config.py updating to pytorch 0.4 Jun 29, 2018
dataloader.py adding new CNN model architectures, triplet loss, pep8 compliance and… Feb 25, 2018
main.py added two models: lbcresnet and pnnresnet, fixed JSON serialization e… Sep 19, 2018
model.py updating to pytorch 0.4 Jun 29, 2018
requirements.txt Update requirements.txt Feb 25, 2018
test.py updating to pytorch 0.4 Jun 29, 2018
train.py added two models: lbcresnet and pnnresnet, fixed JSON serialization e… Sep 19, 2018
utils.py minor edits Jun 28, 2018
version.py New visualization methods, multiline line plots, custom envronment, k… Feb 23, 2018

README.md

Welcome to PyTorchNet!

PyTorchNet is a Machine Learning framework that is built on top of PyTorch. And, it uses Visdom and Plotly for visualization.

PyTorchNet is easy to be customized by creating the necessary classes:

  1. Data Loading: a dataset class is required to load the data.
  2. Model Design: a nn.Module class that represents the network model.
  3. Loss Method: an appropriate class for the loss, for example CrossEntropyLoss or MSELoss.
  4. Evaluation Metric: a class to measure the accuracy of the results.

Structure

PyTorchNet consists of the following packages:

Datasets

This is for loading and transforming datasets.

Models

Network models are kept in this package. It already includes ResNet, PreActResNet, Stacked Hourglass and SphereFace.

Losses

There are number of different choices available for Classification or Regression. New loss methods can be put here.

Evaluates

There are number of different choices available for Classification or Regression. New accuracy metrics can be put here.

Plugins

There are already three different plugins available:

  1. Monitor:
  2. Logger:
  3. Visualizer:

Root

  • main
  • dataloader
  • train
  • test

Setup

First, you need to download PyTorchNet by calling the following command:

git clone --recursive https://github.com/human-analysis/pytorchnet.git

Since PyTorchNet relies on several Python packages, you need to make sure that the requirements exist by executing the following command in the pytorchnet directory:

pip install -r requirements.txt

Notice

  • If you do not have Pytorch or it does not meet the requirements, please follow the instruction on the Pytorch website.

Congratulations!!! You are now ready to use PyTorchNet!

Usage

Before running PyTorchNet, Visdom must be up and running. This can be done by:

python -m visdom.server -p 8097

screenshot from 2018-02-24 19-10-44

PyTorchNet comes with a classification example in which a ResNet model is trained for the CIFAR10 dataset.

python main.py

screenshot from 2018-02-24 18-53-13

screenshot from 2018-02-24 18-58-03

Configuration

PyTorchNet loads its parameters at the beginning via a config file and/or the command line.

Config file

When PyTorchNet is being run, it will automatically load all parameters from args.txt by default, if it exists. In order to load a custom config file, the following parameter can be used:

python main.py --config custom_args.txt

args.txt

[Arguments]

port = 8097
env = main
same_env = Yes
log_type = traditional
save_results = No

# dataset options
dataroot = ./data
dataset_train = CIFAR10
dataset_test = CIFAR10
batch_size = 64

Command line

Parameters can also be set in the command line when invoking main.py. These parameters will precede the existing parameters in the configuration file.

python main.py --log-type progressbar