Skip to content

Commit

Permalink
update PyTorch version to 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
jakesnell committed Jun 11, 2018
1 parent c4dc41a commit b4ff043
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -17,7 +17,7 @@ If you use this code, please cite our paper:

### Install dependencies

* This code has been tested on Ubuntu 16.04 with Python 2.7 and 3.5. If you're using [conda](https://conda.io/docs/), you can create a Python environment by running `conda create -n protonets python=2.7` or `conda create -n protonets python=3.5` and then activate it by running `source activate protonets`.
* This code has been tested on Ubuntu 16.04 with Python 3.6 and PyTorch 0.4. If you're using [conda](https://conda.io/docs/), you can create a Python environment by running `conda create -n protonets python=2.7` or `conda create -n protonets python=3.5` and then activate it by running `source activate protonets`.
* Install [PyTorch and torchvision](http://pytorch.org/).
* Install [torchnet](https://github.com/pytorch/tnt) by running `pip install git+https://github.com/pytorch/tnt.git@master`.
* Install the protonets package by running `python setup.py install` or `python setup.py develop`.
Expand Down
6 changes: 3 additions & 3 deletions protonets/models/few_shot.py
Expand Up @@ -47,16 +47,16 @@ def loss(self, sample):

dists = euclidean_dist(zq, z_proto)

log_p_y = F.log_softmax(-dists).view(n_class, n_query, -1)
log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)

loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()

_, y_hat = log_p_y.max(2)
acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()

return loss_val, {
'loss': loss_val.data[0],
'acc': acc_val.data[0]
'loss': loss_val.item(),
'acc': acc_val.item()
}

@register_model('protonet_conv')
Expand Down

0 comments on commit b4ff043

Please sign in to comment.