@@ -228,13 +228,57 @@ static LIST_HEAD(global_svm_list);
228228 list_for_each_entry((sdev), &(svm)->devs, list) \
229229 if ((d) != (sdev)->dev) {} else
230230
231+ static int pasid_to_svm_sdev (struct device * dev , unsigned int pasid ,
232+ struct intel_svm * * rsvm ,
233+ struct intel_svm_dev * * rsdev )
234+ {
235+ struct intel_svm_dev * d , * sdev = NULL ;
236+ struct intel_svm * svm ;
237+
238+ /* The caller should hold the pasid_mutex lock */
239+ if (WARN_ON (!mutex_is_locked (& pasid_mutex )))
240+ return - EINVAL ;
241+
242+ if (pasid == INVALID_IOASID || pasid >= PASID_MAX )
243+ return - EINVAL ;
244+
245+ svm = ioasid_find (NULL , pasid , NULL );
246+ if (IS_ERR (svm ))
247+ return PTR_ERR (svm );
248+
249+ if (!svm )
250+ goto out ;
251+
252+ /*
253+ * If we found svm for the PASID, there must be at least one device
254+ * bond.
255+ */
256+ if (WARN_ON (list_empty (& svm -> devs )))
257+ return - EINVAL ;
258+
259+ rcu_read_lock ();
260+ list_for_each_entry_rcu (d , & svm -> devs , list ) {
261+ if (d -> dev == dev ) {
262+ sdev = d ;
263+ break ;
264+ }
265+ }
266+ rcu_read_unlock ();
267+
268+ out :
269+ * rsvm = svm ;
270+ * rsdev = sdev ;
271+
272+ return 0 ;
273+ }
274+
231275int intel_svm_bind_gpasid (struct iommu_domain * domain , struct device * dev ,
232276 struct iommu_gpasid_bind_data * data )
233277{
234278 struct intel_iommu * iommu = device_to_iommu (dev , NULL , NULL );
279+ struct intel_svm_dev * sdev = NULL ;
235280 struct dmar_domain * dmar_domain ;
236- struct intel_svm_dev * sdev ;
237- struct intel_svm * svm ;
281+ struct intel_svm * svm = NULL ;
238282 int ret = 0 ;
239283
240284 if (WARN_ON (!iommu ) || !data )
@@ -261,35 +305,23 @@ int intel_svm_bind_gpasid(struct iommu_domain *domain, struct device *dev,
261305 dmar_domain = to_dmar_domain (domain );
262306
263307 mutex_lock (& pasid_mutex );
264- svm = ioasid_find (NULL , data -> hpasid , NULL );
265- if (IS_ERR (svm )) {
266- ret = PTR_ERR (svm );
308+ ret = pasid_to_svm_sdev (dev , data -> hpasid , & svm , & sdev );
309+ if (ret )
267310 goto out ;
268- }
269-
270- if (svm ) {
271- /*
272- * If we found svm for the PASID, there must be at
273- * least one device bond, otherwise svm should be freed.
274- */
275- if (WARN_ON (list_empty (& svm -> devs ))) {
276- ret = - EINVAL ;
277- goto out ;
278- }
279311
312+ if (sdev ) {
280313 /*
281314 * Do not allow multiple bindings of the same device-PASID since
282315 * there is only one SL page tables per PASID. We may revisit
283316 * once sharing PGD across domains are supported.
284317 */
285- for_each_svm_dev (sdev , svm , dev ) {
286- dev_warn_ratelimited (dev ,
287- "Already bound with PASID %u\n" ,
288- svm -> pasid );
289- ret = - EBUSY ;
290- goto out ;
291- }
292- } else {
318+ dev_warn_ratelimited (dev , "Already bound with PASID %u\n" ,
319+ svm -> pasid );
320+ ret = - EBUSY ;
321+ goto out ;
322+ }
323+
324+ if (!svm ) {
293325 /* We come here when PASID has never been bond to a device. */
294326 svm = kzalloc (sizeof (* svm ), GFP_KERNEL );
295327 if (!svm ) {
@@ -372,25 +404,17 @@ int intel_svm_unbind_gpasid(struct device *dev, int pasid)
372404 struct intel_iommu * iommu = device_to_iommu (dev , NULL , NULL );
373405 struct intel_svm_dev * sdev ;
374406 struct intel_svm * svm ;
375- int ret = - EINVAL ;
407+ int ret ;
376408
377409 if (WARN_ON (!iommu ))
378410 return - EINVAL ;
379411
380412 mutex_lock (& pasid_mutex );
381- svm = ioasid_find (NULL , pasid , NULL );
382- if (!svm ) {
383- ret = - EINVAL ;
384- goto out ;
385- }
386-
387- if (IS_ERR (svm )) {
388- ret = PTR_ERR (svm );
413+ ret = pasid_to_svm_sdev (dev , pasid , & svm , & sdev );
414+ if (ret )
389415 goto out ;
390- }
391416
392- for_each_svm_dev (sdev , svm , dev ) {
393- ret = 0 ;
417+ if (sdev ) {
394418 if (iommu_dev_feature_enabled (dev , IOMMU_DEV_FEAT_AUX ))
395419 sdev -> users -- ;
396420 if (!sdev -> users ) {
@@ -414,7 +438,6 @@ int intel_svm_unbind_gpasid(struct device *dev, int pasid)
414438 kfree (svm );
415439 }
416440 }
417- break ;
418441 }
419442out :
420443 mutex_unlock (& pasid_mutex );
@@ -592,7 +615,7 @@ intel_svm_bind_mm(struct device *dev, int flags, struct svm_dev_ops *ops,
592615 if (sd )
593616 * sd = sdev ;
594617 ret = 0 ;
595- out :
618+ out :
596619 return ret ;
597620}
598621
@@ -608,17 +631,11 @@ static int intel_svm_unbind_mm(struct device *dev, int pasid)
608631 if (!iommu )
609632 goto out ;
610633
611- svm = ioasid_find (NULL , pasid , NULL );
612- if (!svm )
613- goto out ;
614-
615- if (IS_ERR (svm )) {
616- ret = PTR_ERR (svm );
634+ ret = pasid_to_svm_sdev (dev , pasid , & svm , & sdev );
635+ if (ret )
617636 goto out ;
618- }
619637
620- for_each_svm_dev (sdev , svm , dev ) {
621- ret = 0 ;
638+ if (sdev ) {
622639 sdev -> users -- ;
623640 if (!sdev -> users ) {
624641 list_del_rcu (& sdev -> list );
@@ -647,10 +664,8 @@ static int intel_svm_unbind_mm(struct device *dev, int pasid)
647664 kfree (svm );
648665 }
649666 }
650- break ;
651667 }
652- out :
653-
668+ out :
654669 return ret ;
655670}
656671
0 commit comments