-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training with GPU on M1 error #563
Comments
@lydiakatsis since the not implemented error occurs during torch.logit() activation layer, which is essentially just post-processing the outputs, we might be able to avoid it by avoiding torch.logit during the validation step. I can give this a try later this week. |
I pushed a patch for this to |
From looking at the post above, most of these are still not implemented. However, the CPU |
I added a patch that will allow CPU alternative if logit fails (see #626 ). This will be merged into master with the next release. |
I get the following error when trying to train a model using the M1 GPU:
`---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Input In [43], in <cell line: 3>()
1 start_time = time.time()
----> 3 model.train(
4 train_df=train_df,
5 validation_df=validation_df,
6 save_path= output +'/binary_train/', #where to save the trained model
7 epochs=50,
8 batch_size=64,
9 save_interval=10, #save model every 5 epochs (the best model is always saved in addition)
10 num_workers=4, #specify 4 if you have 4 CPU processes, eg; 0 means only the root process
11 )
13 end_time = time.time()
14 print("Finished model training " + str(int((end_time - start_time)/60)) + ' minutes!\n\n')
File ~/opt/anaconda3/envs/opso7.1/lib/python3.9/site-packages/opensoundscape/torch/models/cnn.py:463, in CNN.train(self, train_df, validation_df, epochs, batch_size, num_workers, save_path, save_interval, log_interval, validation_interval, unsafe_samples_log)
461 if validation_df is not None:
462 self._log("\nValidation.")
--> 463 validation_scores, _, unsafe_val_samples = self.predict(
464 validation_df,
465 batch_size=batch_size,
466 num_workers=num_workers,
467 activation_layer="softmax_and_logit"
468 if self.single_target
469 else None,
470 split_files_into_clips=False,
471 )
472 validation_targets = validation_df.values
473 validation_scores = validation_scores.values
File ~/opt/anaconda3/envs/opso7.1/lib/python3.9/site-packages/opensoundscape/torch/models/cnn.py:769, in CNN.predict(self, samples, batch_size, num_workers, activation_layer, binary_preds, threshold, split_files_into_clips, overlap_fraction, final_clip, bypass_augmentations, unsafe_samples_log)
766 logits = self.network.forward(batch_tensors)
768 ### Activation layer ###
--> 769 scores = apply_activation_layer(logits, activation_layer)
771 ### Binary predictions ###
772 batch_preds = tensor_binary_predictions(
773 scores=scores, mode=binary_preds, threshold=threshold
774 )
File ~/opt/anaconda3/envs/opso7.1/lib/python3.9/site-packages/opensoundscape/torch/models/utils.py:180, in apply_activation_layer(x, activation_layer)
177 x = torch.sigmoid(x)
178 elif activation_layer == "softmax_and_logit":
179 # softmax, then remap scores from [0,1] to [-inf,inf]
--> 180 x = torch.logit(softmax(x, 1))
181 else:
182 raise ValueError(f"invalid option for activation_layer: {activation_layer}")
NotImplementedError: The operator 'aten::logit' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on pytorch/pytorch#77764. As a temporary fix, you can set the environment variable
PYTORCH_ENABLE_MPS_FALLBACK=1
to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.`
Note code used was:
model = load_model('model.model')
model.device = 'mps'
model.train(
train_df=train_df,
validation_df=validation_df,
save_path= output +'/binary_train/', #where to save the trained model
epochs=50,
batch_size=64,
save_interval=10, #save model every 5 epochs (the best model is always saved in addition)
num_workers=4, #specify 4 if you have 4 CPU processes, eg; 0 means only the root process
)
The text was updated successfully, but these errors were encountered: