@@ -27,7 +27,7 @@ static void iommufd_group_release(struct kref *kref)
2727 struct iommufd_group * igroup =
2828 container_of (kref , struct iommufd_group , ref );
2929
30- WARN_ON (igroup -> attach );
30+ WARN_ON (! xa_empty ( & igroup -> pasid_attach ) );
3131
3232 xa_cmpxchg (& igroup -> ictx -> groups , iommu_group_id (igroup -> group ), igroup ,
3333 NULL , GFP_KERNEL );
@@ -94,6 +94,7 @@ static struct iommufd_group *iommufd_get_group(struct iommufd_ctx *ictx,
9494
9595 kref_init (& new_igroup -> ref );
9696 mutex_init (& new_igroup -> lock );
97+ xa_init (& new_igroup -> pasid_attach );
9798 new_igroup -> sw_msi_start = PHYS_ADDR_MAX ;
9899 /* group reference moves into new_igroup */
99100 new_igroup -> group = group ;
@@ -297,16 +298,19 @@ u32 iommufd_device_to_id(struct iommufd_device *idev)
297298}
298299EXPORT_SYMBOL_NS_GPL (iommufd_device_to_id , "IOMMUFD" );
299300
300- static unsigned int iommufd_group_device_num (struct iommufd_group * igroup )
301+ static unsigned int iommufd_group_device_num (struct iommufd_group * igroup ,
302+ ioasid_t pasid )
301303{
304+ struct iommufd_attach * attach ;
302305 struct iommufd_device * idev ;
303306 unsigned int count = 0 ;
304307 unsigned long index ;
305308
306309 lockdep_assert_held (& igroup -> lock );
307310
308- if (igroup -> attach )
309- xa_for_each (& igroup -> attach -> device_array , index , idev )
311+ attach = xa_load (& igroup -> pasid_attach , pasid );
312+ if (attach )
313+ xa_for_each (& attach -> device_array , index , idev )
310314 count ++ ;
311315 return count ;
312316}
@@ -351,7 +355,7 @@ static bool
351355iommufd_group_first_attach (struct iommufd_group * igroup , ioasid_t pasid )
352356{
353357 lockdep_assert_held (& igroup -> lock );
354- return !igroup -> attach ;
358+ return !xa_load ( & igroup -> pasid_attach , pasid ) ;
355359}
356360
357361static int
@@ -382,10 +386,13 @@ iommufd_device_attach_reserved_iova(struct iommufd_device *idev,
382386
383387/* The device attach/detach/replace helpers for attach_handle */
384388
385- /* Check if idev is attached to igroup->hwpt */
386- static bool iommufd_device_is_attached ( struct iommufd_device * idev )
389+ static bool iommufd_device_is_attached ( struct iommufd_device * idev ,
390+ ioasid_t pasid )
387391{
388- return xa_load (& idev -> igroup -> attach -> device_array , idev -> obj .id );
392+ struct iommufd_attach * attach ;
393+
394+ attach = xa_load (& idev -> igroup -> pasid_attach , pasid );
395+ return xa_load (& attach -> device_array , idev -> obj .id );
389396}
390397
391398static int iommufd_hwpt_attach_device (struct iommufd_hw_pagetable * hwpt ,
@@ -512,12 +519,18 @@ int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
512519
513520 mutex_lock (& igroup -> lock );
514521
515- attach = igroup -> attach ;
522+ attach = xa_cmpxchg (& igroup -> pasid_attach , pasid , NULL ,
523+ XA_ZERO_ENTRY , GFP_KERNEL );
524+ if (xa_is_err (attach )) {
525+ rc = xa_err (attach );
526+ goto err_unlock ;
527+ }
528+
516529 if (!attach ) {
517530 attach = kzalloc (sizeof (* attach ), GFP_KERNEL );
518531 if (!attach ) {
519532 rc = - ENOMEM ;
520- goto err_unlock ;
533+ goto err_release_pasid ;
521534 }
522535 xa_init (& attach -> device_array );
523536 }
@@ -554,7 +567,8 @@ int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
554567 if (rc )
555568 goto err_unresv ;
556569 attach -> hwpt = hwpt ;
557- igroup -> attach = attach ;
570+ WARN_ON (xa_is_err (xa_store (& igroup -> pasid_attach , pasid , attach ,
571+ GFP_KERNEL )));
558572 }
559573 refcount_inc (& hwpt -> obj .users );
560574 WARN_ON (xa_is_err (xa_store (& attach -> device_array , idev -> obj .id ,
@@ -569,6 +583,9 @@ int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
569583err_free_attach :
570584 if (iommufd_group_first_attach (igroup , pasid ))
571585 kfree (attach );
586+ err_release_pasid :
587+ if (iommufd_group_first_attach (igroup , pasid ))
588+ xa_release (& igroup -> pasid_attach , pasid );
572589err_unlock :
573590 mutex_unlock (& igroup -> lock );
574591 return rc ;
@@ -583,14 +600,14 @@ iommufd_hw_pagetable_detach(struct iommufd_device *idev, ioasid_t pasid)
583600 struct iommufd_attach * attach ;
584601
585602 mutex_lock (& igroup -> lock );
586- attach = igroup -> attach ;
603+ attach = xa_load ( & igroup -> pasid_attach , pasid ) ;
587604 hwpt = attach -> hwpt ;
588605 hwpt_paging = find_hwpt_paging (hwpt );
589606
590607 xa_erase (& attach -> device_array , idev -> obj .id );
591608 if (xa_empty (& attach -> device_array )) {
592609 iommufd_hwpt_detach_device (hwpt , idev , pasid );
593- igroup -> attach = NULL ;
610+ xa_erase ( & igroup -> pasid_attach , pasid ) ;
594611 kfree (attach );
595612 }
596613 if (hwpt_paging && pasid == IOMMU_NO_PASID )
@@ -617,12 +634,14 @@ static void
617634iommufd_group_remove_reserved_iova (struct iommufd_group * igroup ,
618635 struct iommufd_hwpt_paging * hwpt_paging )
619636{
637+ struct iommufd_attach * attach ;
620638 struct iommufd_device * cur ;
621639 unsigned long index ;
622640
623641 lockdep_assert_held (& igroup -> lock );
624642
625- xa_for_each (& igroup -> attach -> device_array , index , cur )
643+ attach = xa_load (& igroup -> pasid_attach , IOMMU_NO_PASID );
644+ xa_for_each (& attach -> device_array , index , cur )
626645 iopt_remove_reserved_iova (& hwpt_paging -> ioas -> iopt , cur -> dev );
627646}
628647
@@ -631,15 +650,17 @@ iommufd_group_do_replace_reserved_iova(struct iommufd_group *igroup,
631650 struct iommufd_hwpt_paging * hwpt_paging )
632651{
633652 struct iommufd_hwpt_paging * old_hwpt_paging ;
653+ struct iommufd_attach * attach ;
634654 struct iommufd_device * cur ;
635655 unsigned long index ;
636656 int rc ;
637657
638658 lockdep_assert_held (& igroup -> lock );
639659
640- old_hwpt_paging = find_hwpt_paging (igroup -> attach -> hwpt );
660+ attach = xa_load (& igroup -> pasid_attach , IOMMU_NO_PASID );
661+ old_hwpt_paging = find_hwpt_paging (attach -> hwpt );
641662 if (!old_hwpt_paging || hwpt_paging -> ioas != old_hwpt_paging -> ioas ) {
642- xa_for_each (& igroup -> attach -> device_array , index , cur ) {
663+ xa_for_each (& attach -> device_array , index , cur ) {
643664 rc = iopt_table_enforce_dev_resv_regions (
644665 & hwpt_paging -> ioas -> iopt , cur -> dev , NULL );
645666 if (rc )
@@ -672,7 +693,7 @@ iommufd_device_do_replace(struct iommufd_device *idev, ioasid_t pasid,
672693
673694 mutex_lock (& igroup -> lock );
674695
675- attach = igroup -> attach ;
696+ attach = xa_load ( & igroup -> pasid_attach , pasid ) ;
676697 if (!attach ) {
677698 rc = - EINVAL ;
678699 goto err_unlock ;
@@ -682,7 +703,7 @@ iommufd_device_do_replace(struct iommufd_device *idev, ioasid_t pasid,
682703
683704 WARN_ON (!old_hwpt || xa_empty (& attach -> device_array ));
684705
685- if (!iommufd_device_is_attached (idev )) {
706+ if (!iommufd_device_is_attached (idev , pasid )) {
686707 rc = - EINVAL ;
687708 goto err_unlock ;
688709 }
@@ -709,7 +730,7 @@ iommufd_device_do_replace(struct iommufd_device *idev, ioasid_t pasid,
709730
710731 attach -> hwpt = hwpt ;
711732
712- num_devices = iommufd_group_device_num (igroup );
733+ num_devices = iommufd_group_device_num (igroup , pasid );
713734 /*
714735 * Move the refcounts held by the device_array to the new hwpt. Retain a
715736 * refcount for this thread as the caller will free it.
0 commit comments