Skip to content

Commit

Permalink
Add Apple Silicon GPU Acceleration (#1985)
Browse files Browse the repository at this point in the history
* Add device selection logic based on availability

* fix formatting and slightly refactor

---------

Co-authored-by: Adeel Hassan <ahassan@element84.com>
  • Loading branch information
NripeshN and AdeelH committed Nov 20, 2023
1 parent a2d7337 commit 14d2d98
Showing 1 changed file with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,15 @@ def __init__(self,
self._tmp_dir = get_tmp_dir()
tmp_dir = self._tmp_dir.name
self.tmp_dir = tmp_dir
self.device = torch.device('cuda'
if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'

self.device = torch.device(device)

self.train_ds = train_ds
self.valid_ds = valid_ds
Expand Down

0 comments on commit 14d2d98

Please sign in to comment.