Skip to content

Commit

Permalink
fixed and optimized GPU-enabled codepath. Now accuracy is computed on…
Browse files Browse the repository at this point in the history
… GPU and is transferred back to CPU only for printing. Execution became somewhat faster.
  • Loading branch information
datamove committed Oct 1, 2018
1 parent 2e0bf48 commit 4dd0c48
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions recitation-2/pytorch_example.ipynb
Expand Up @@ -163,18 +163,16 @@
" if i%100==0:\n",
" print(\"At iteration\",i)\n",
" # compute the accuracy of the prediction\n",
" train_prediction = train_output.cpu().detach().argmax(dim=1)\n",
" train_accuracy = (train_prediction.numpy()==train_labels.numpy()).mean() \n",
" train_accuracy = (train_labels.eq(train_output.argmax(dim=1))).float().mean() \n",
" # Now for the validation set\n",
" val_output = net(val_data)\n",
" val_loss = criterion(val_output,val_labels)\n",
" # compute the accuracy of the prediction\n",
" val_prediction = val_output.cpu().detach().argmax(dim=1)\n",
" val_accuracy = (val_prediction.numpy()==val_labels.numpy()).mean() \n",
" val_accuracy = (val_labels.eq(val_output.argmax(dim=1))).float().mean() \n",
" print(\"Training loss :\",train_loss.cpu().detach().numpy())\n",
" print(\"Training accuracy :\",train_accuracy)\n",
" print(\"Training accuracy :\",train_accuracy.cpu().detach().numpy())\n",
" print(\"Validation loss :\",val_loss.cpu().detach().numpy())\n",
" print(\"Validation accuracy :\",val_accuracy)\n",
" print(\"Validation accuracy :\",val_accuracy.cpu().detach().numpy())\n",
" \n",
" net = net.cpu()"
]
Expand Down

0 comments on commit 4dd0c48

Please sign in to comment.