Skip to content

Commit

Permalink
mm/hmm: do not write the output pfn array when faulting
Browse files Browse the repository at this point in the history
The pfn array is an input/output value. If it is written to then
hmm_vma_walk->last must be updated so that the next go around the loop
does not read a value that is output.

hmm_vma_walk_hole_() has a confusing dual purpose, on some of its flows it
will destroy the input pfns and return EBUSY.

The only case where it is called with a potentially 0 required_fault is in
hmm_vma_walk_hole(), so move the fill of HMM_PFN_NONE directly there, and
simplify hmm_vma_walk_hole_() into a function with clear purpose:
hmm_vma_fault() which always tries to fault and always causes an EBUSY
return back to the main loop.

The call tree for hmm_vma_fault() now always has a vma (as we can't fault
without one), so simplify the flow. All the callbacks besides
hmm_vma_walk_hole() have the vma guaranteed by the page walker code.

Fixes: 2aee09d ("mm/hmm: change hmm_vma_fault() to allow write fault on page basis")
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
  • Loading branch information
jgunthorpe committed Mar 11, 2020
1 parent 4048f7f commit 078e10c
Showing 1 changed file with 48 additions and 42 deletions.
90 changes: 48 additions & 42 deletions mm/hmm.c
Expand Up @@ -59,9 +59,6 @@ static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
struct vm_area_struct *vma = walk->vma;
vm_fault_t ret;

if (!vma)
goto err;

if (hmm_vma_walk->flags & HMM_FAULT_ALLOW_RETRY)
flags |= FAULT_FLAG_ALLOW_RETRY;
if (required_fault & NEED_WRITE_FAULT)
Expand All @@ -75,7 +72,7 @@ static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
if (ret & VM_FAULT_ERROR)
goto err;

return -EBUSY;
return 0;

err:
*pfn = range->values[HMM_PFN_ERROR];
Expand All @@ -96,18 +93,21 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long end,
}

/*
* hmm_vma_walk_hole_() - handle a range lacking valid pmd or pte(s)
* hmm_vma_fault() - Execute a page fault so the range matches required_fault
* @addr: range virtual start address (inclusive)
* @end: range virtual end address (exclusive)
* @required_fault: NEED_FAULT_* flags
* @required_fault: NEED_FAULT_* flags, can not be 0
* @walk: mm_walk structure
* Return: 0 on success, -EBUSY after page fault, or page fault error
* Return: -EBUSY after page fault, or page fault error. Does not return 0.
*
* This function will be called whenever pmd_none() or pte_none() returns true,
* or whenever there is no page directory covering the virtual address range.
* This function will be called whenever the pte/pmd/etc flags indicate there
* is no mapping and hmm_range_need_fault() or hmm_pte_need_fault() show the
* caller requested the page to be valid.
*
* After faulting is completed the walk has to be started again.
*/
static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
unsigned int required_fault, struct mm_walk *walk)
static int hmm_vma_fault(unsigned long addr, unsigned long end,
unsigned int required_fault, struct mm_walk *walk)
{
struct hmm_vma_walk *hmm_vma_walk = walk->private;
struct hmm_range *range = hmm_vma_walk->range;
Expand All @@ -117,23 +117,18 @@ static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
hmm_vma_walk->last = addr;
i = (addr - range->start) >> PAGE_SHIFT;

if ((required_fault & NEED_WRITE_FAULT) && walk->vma &&
if ((required_fault & NEED_WRITE_FAULT) &&
!(walk->vma->vm_flags & VM_WRITE))
return -EPERM;

for (; addr < end; addr += PAGE_SIZE, i++) {
pfns[i] = range->values[HMM_PFN_NONE];
if (required_fault) {
int ret;

ret = hmm_vma_do_fault(walk, addr, required_fault,
&pfns[i]);
if (ret != -EBUSY)
return ret;
}
}
int ret;

return required_fault ? -EBUSY : 0;
ret = hmm_vma_do_fault(walk, addr, required_fault, &pfns[i]);
if (ret)
return ret;
}
return -EBUSY;
}

static unsigned int hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
Expand Down Expand Up @@ -208,7 +203,7 @@ hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
return required_fault;
}

static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
static int hmm_vma_walk_hole(unsigned long start, unsigned long end,
__always_unused int depth, struct mm_walk *walk)
{
struct hmm_vma_walk *hmm_vma_walk = walk->private;
Expand All @@ -217,11 +212,18 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
unsigned long i, npages;
uint64_t *pfns;

i = (addr - range->start) >> PAGE_SHIFT;
npages = (end - addr) >> PAGE_SHIFT;
i = (start - range->start) >> PAGE_SHIFT;
npages = (end - start) >> PAGE_SHIFT;
pfns = &range->pfns[i];
required_fault = hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0);
return hmm_vma_walk_hole_(addr, end, required_fault, walk);
if (!walk->vma) {
if (required_fault)
return -EFAULT;
return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
}
if (required_fault)
return hmm_vma_fault(start, end, required_fault, walk);
return hmm_pfns_fill(start, end, range, HMM_PFN_NONE);
}

static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
Expand All @@ -247,19 +249,20 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
required_fault =
hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags);
if (required_fault)
return hmm_vma_walk_hole_(addr, end, required_fault, walk);
return hmm_vma_fault(addr, end, required_fault, walk);

pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
if (pmd_devmap(pmd)) {
hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
hmm_vma_walk->pgmap);
if (unlikely(!hmm_vma_walk->pgmap))
if (unlikely(!hmm_vma_walk->pgmap)) {
hmm_vma_walk->last = addr;
return -EBUSY;
}
}
pfns[i] = hmm_device_entry_from_pfn(range, pfn) | cpu_flags;
}
hmm_vma_walk->last = end;
return 0;
}

Expand Down Expand Up @@ -344,6 +347,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
hmm_vma_walk->pgmap);
if (unlikely(!hmm_vma_walk->pgmap)) {
pte_unmap(ptep);
hmm_vma_walk->last = addr;
return -EBUSY;
}
}
Expand All @@ -367,7 +371,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
fault:
pte_unmap(ptep);
/* Fault any virtual address we were asked to fault */
return hmm_vma_walk_hole_(addr, end, required_fault, walk);
return hmm_vma_fault(addr, end, required_fault, walk);
}

static int hmm_vma_walk_pmd(pmd_t *pmdp,
Expand Down Expand Up @@ -450,13 +454,10 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, pfns);
if (r) {
/* hmm_vma_handle_pte() did pte_unmap() */
hmm_vma_walk->last = addr;
return r;
}
}
pte_unmap(ptep - 1);

hmm_vma_walk->last = addr;
return 0;
}

Expand Down Expand Up @@ -507,19 +508,20 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
required_fault =
hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags);
if (required_fault)
return hmm_vma_walk_hole_(start, end, required_fault, walk);
return hmm_vma_fault(start, end, required_fault, walk);

pfn = pud_pfn(pud) + ((start & ~PUD_MASK) >> PAGE_SHIFT);
npages = (end - start) >> PAGE_SHIFT;

for (i = 0; i < npages; ++i, ++pfn) {
hmm_vma_walk->pgmap = get_dev_pagemap(pfn, hmm_vma_walk->pgmap);
if (unlikely(!hmm_vma_walk->pgmap))
if (unlikely(!hmm_vma_walk->pgmap)) {
hmm_vma_walk->last = end;
return -EBUSY;
}
pfns[i] = hmm_device_entry_from_pfn(range, pfn) | cpu_flags;
}

hmm_vma_walk->last = end;

/* Do not split the pud */
walk->action = ACTION_CONTINUE;
return 0;
Expand Down Expand Up @@ -552,15 +554,13 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags);
if (required_fault) {
spin_unlock(ptl);
return hmm_vma_walk_hole_(addr, end, required_fault, walk);
return hmm_vma_fault(addr, end, required_fault, walk);
}

pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT);
for (; addr < end; addr += PAGE_SIZE, i++, pfn++)
range->pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
cpu_flags;
hmm_vma_walk->last = end;

spin_unlock(ptl);
return 0;
}
Expand Down Expand Up @@ -597,7 +597,6 @@ static int hmm_vma_walk_test(unsigned long start, unsigned long end,
return -EFAULT;

hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
hmm_vma_walk->last = end;

/* Skip this vma and continue processing the next vma. */
return 1;
Expand Down Expand Up @@ -670,6 +669,13 @@ long hmm_range_fault(struct hmm_range *range, unsigned int flags)
put_dev_pagemap(hmm_vma_walk.pgmap);
hmm_vma_walk.pgmap = NULL;
}

/*
* When -EBUSY is returned the loop restarts with
* hmm_vma_walk.last set to an address that has not been stored
* in pfns. All entries < last in the pfn array are set to their
* output, and all >= are still at their input values.
*/
} while (ret == -EBUSY);

if (ret)
Expand Down

0 comments on commit 078e10c

Please sign in to comment.