Skip to content

Commit

Permalink
Patches for current PyTorch Release, fixes #3, #4, and #5
Browse files Browse the repository at this point in the history
  • Loading branch information
Kent Sommer committed Apr 18, 2018
1 parent 32d2037 commit 15fefd5
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 14 deletions.
110 changes: 110 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# npz
*.npz

# pth
*.pth
2 changes: 1 addition & 1 deletion dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _process(self, file, train):
"""Data format: A list, [train data, test data]
Each data sample: label, S1, S2, Images, in this order.
"""
with np.load(file) as f:
with np.load(file, mmap_mode='r') as f:
if train:
images = f['arr_0']
S1 = f['arr_1']
Expand Down
2 changes: 1 addition & 1 deletion dataset/make_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def make_data(dom_size, n_domains, max_obs,
return X_f, S1_f, S2_f, Labels_f


def main(dom_size=[8,8], n_domains=15000, max_obs=30, max_obs_size=None,
def main(dom_size=[28,28], n_domains=5000, max_obs=50, max_obs_size=2,
n_traj=7, state_batch_size=1):
# Get path to save dataset
save_path = "dataset/gridworld_{0}x{1}".format(dom_size[0], dom_size[1])
Expand Down
12 changes: 6 additions & 6 deletions download_weights_and_datasets.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
cd trained
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/vin_8x8.pth'
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/vin_16x16.pth'
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/vin_28x28.pth'
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_8x8.pth'
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_16x16.pth'
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_28x28.pth'
cd ../dataset
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/gridworld_8x8.npz'
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/gridworld_16x16.npz'
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.0/gridworld_28x28.npz'
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_8x8.npz'
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_16x16.npz'
wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_28x28.npz'
cd ..
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, config):
out_features=8,
bias=False)
self.w = Parameter(torch.zeros(config.l_q,1,3,3), requires_grad=True)
self.sm = nn.Softmax()
self.sm = nn.Softmax(dim=1)


def forward(self, X, S1, S2, config):
Expand Down
8 changes: 6 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ def main(config, n_domains=100, max_obs=30,
correct, total = 0.0, 0.0
# Automatic swith of GPU mode if available
use_GPU = torch.cuda.is_available()
vin = torch.load(config.weights)
# Instantiate a VIN model
vin = VIN(config)
# Load model parameters
vin.load_state_dict(torch.load(config.weights))
# Use GPU if available
if use_GPU:
vin = vin.cuda()
vin = vin.cuda()

for dom in range(n_domains):
# Randomly select goal position
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test(net, testloader, config):
# Unwrap autograd.Variable to Tensor
predicted = predicted.data
# Compute test accuracy
correct += (predicted == labels).sum()
correct += (torch.eq(torch.squeeze(predicted), labels)).sum()
total += labels.size()[0]
print('Test Accuracy: {:.2f}%'.format(100*(correct/total)))

Expand Down Expand Up @@ -147,5 +147,5 @@ def test(net, testloader, config):
train(net, trainloader, config, criterion, optimizer, use_GPU)
# Test accuracy
test(net, testloader, config)
# Save the trained model
torch.save(net, save_path)
# Save the trained model parameters
torch.save(net.state_dict(), save_path)

0 comments on commit 15fefd5

Please sign in to comment.