diff --git a/virt/kvm/iommu.c b/virt/kvm/iommu.c index fec1723de9b4eb..e9fff9830bf0bf 100644 --- a/virt/kvm/iommu.c +++ b/virt/kvm/iommu.c @@ -240,9 +240,13 @@ int kvm_iommu_map_guest(struct kvm *kvm) return -ENODEV; } + mutex_lock(&kvm->slots_lock); + kvm->arch.iommu_domain = iommu_domain_alloc(&pci_bus_type); - if (!kvm->arch.iommu_domain) - return -ENOMEM; + if (!kvm->arch.iommu_domain) { + r = -ENOMEM; + goto out_unlock; + } if (!allow_unsafe_assigned_interrupts && !iommu_domain_has_cap(kvm->arch.iommu_domain, @@ -253,17 +257,16 @@ int kvm_iommu_map_guest(struct kvm *kvm) " module option.\n", __func__); iommu_domain_free(kvm->arch.iommu_domain); kvm->arch.iommu_domain = NULL; - return -EPERM; + r = -EPERM; + goto out_unlock; } r = kvm_iommu_map_memslots(kvm); if (r) - goto out_unmap; - - return 0; + kvm_iommu_unmap_memslots(kvm); -out_unmap: - kvm_iommu_unmap_memslots(kvm); +out_unlock: + mutex_unlock(&kvm->slots_lock); return r; } @@ -340,7 +343,11 @@ int kvm_iommu_unmap_guest(struct kvm *kvm) if (!domain) return 0; + mutex_lock(&kvm->slots_lock); kvm_iommu_unmap_memslots(kvm); + kvm->arch.iommu_domain = NULL; + mutex_unlock(&kvm->slots_lock); + iommu_domain_free(domain); return 0; }