Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
add .to(device) function for trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
Steffen-Wolf committed Oct 1, 2018
1 parent 5aaa63f commit 646f9ce
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions inferno/trainers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,22 @@ def get_current_learning_rate(self):
for _learning_rate in learning_rate]
return pyu.from_iterable(learning_rate)

def to(self, device):
"""
Send trainer to device
----------
device : string or torch.device
Target device where trainer/model should be send to
"""
if device == 'cuda':
return self.cuda()
elif device == 'cpu':
return self.cpu()
elif isinstance(device, torch.torch.device):
self.to(device.type)
else:
raise NotImplementedError("Can not send trainer to device", device)

def cuda(self, devices=None, base_device=None):
"""
Train on the GPU.
Expand Down

0 comments on commit 646f9ce

Please sign in to comment.