Skip to content

katsamapol/ResNets

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Scalable ResNet on CIFAR-10 using PyTorch

Requirements

  • Python 3.6+
  • PyTorch 1.0+
  • TorchVision 0.1+

How to train

# Easily start a new training, run: 
python project1_model.py

# You can manually assign parameters with: 
python project1_model.py --lr 0.01

# To list all configurable parameters use: 
python project1_model.py -h

How to load and view the saved model

# Start your python interactive shell and type these commands: 
import torch
from project1_model import project1_model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = project1_model().to(device)
model_path = './project1_model.pt' #full directory path to your saved model e.g., './model.pt'
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
model.eval()

load_model

How to re-test your saved model with our built-in function

# Start your python interactive shell and type these commands: 
import project1_model as p
model_path = './project1_model.pt' #full directory path to your saved model e.g., './model.pt'
p.test_model(model_path)

retest_model

Adjustable parameters

Description DType Arguments Default
Optimizer string o sgd
Learning rate float lr based on optimizer
Momentum float m based on optimizer
Weight decay float wd based on optimizer
Dataset full path string path ./CIFAR10/
Saved model full path string mp ./project1_model.pt
Number of epochs int e 5
Number of data loader workers int wk 2
Number of residual layers int n 4
Number of residual blocks in each of the residual layers int b 2 1 1 1
Number of channels in the first residual layer int c 64
Input layer convolutional kernel size int f0 3
Residual layer convolutional kernel size int f1 3
Skip connection kernel sizes int k 1
Input layer convolutional padding size int p0 1
Residual layer convolutional padding size int p1 1

Parameter setting explanation

explain_resnet

Name Learning rate Weight decay Momentum
SGD 0.1 0.0005 0.9
SGD /w Nesterov 0.1 0.0005 0.9
Adam 0.001 0.0005 None
Adadelta 1.0 0.0005 None
Adagrad 0.01 0.0005 None

References

Liu K., Train CIFAR10 with PyTorch (2017). https://github.com/kuangliu/pytorch-cifar.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages