Skip to content

Commit

Permalink
mm: userfaultfd: don't separate addr + len arguments
Browse files Browse the repository at this point in the history
We have a lot of functions which take an address + length pair,
currently passed as separate arguments. However, in our userspace API we
already have struct uffdio_range, which is exactly this pair, and this
is what we get from userspace when ioctls are called.

Instead of splitting the struct up into two separate arguments, just
plumb the struct through to the functions which use it (once we get to
the mfill_atomic_pte level, we're dealing with single (huge)pages, so we
don't need both parts).

Relatedly, for waking, just re-use this existing structure instead of
defining a new "struct uffdio_wake_range".

Signed-off-by: Axel Rasmussen <axelrasmussen@google.com>
  • Loading branch information
CmdrMoozy authored and intel-lab-lkp committed Mar 6, 2023
1 parent 0f740f4 commit cee642b
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 120 deletions.
107 changes: 42 additions & 65 deletions fs/userfaultfd.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ struct userfaultfd_wait_queue {
bool waken;
};

struct userfaultfd_wake_range {
unsigned long start;
unsigned long len;
};

/* internal indication that UFFD_API ioctl was successfully executed */
#define UFFD_FEATURE_INITIALIZED (1u << 31)

Expand All @@ -126,7 +121,7 @@ static void userfaultfd_set_vm_flags(struct vm_area_struct *vma,
static int userfaultfd_wake_function(wait_queue_entry_t *wq, unsigned mode,
int wake_flags, void *key)
{
struct userfaultfd_wake_range *range = key;
struct uffdio_range *range = key;
int ret;
struct userfaultfd_wait_queue *uwq;
unsigned long start, len;
Expand Down Expand Up @@ -881,7 +876,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
struct mm_struct *mm = ctx->mm;
struct vm_area_struct *vma, *prev;
/* len == 0 means wake all */
struct userfaultfd_wake_range range = { .len = 0, };
struct uffdio_range range = {0};
unsigned long new_flags;
VMA_ITERATOR(vmi, mm, 0);

Expand Down Expand Up @@ -1226,7 +1221,7 @@ static ssize_t userfaultfd_read(struct file *file, char __user *buf,
}

static void __wake_userfault(struct userfaultfd_ctx *ctx,
struct userfaultfd_wake_range *range)
struct uffdio_range *range)
{
spin_lock_irq(&ctx->fault_pending_wqh.lock);
/* wake all in the range and autoremove */
Expand All @@ -1239,7 +1234,7 @@ static void __wake_userfault(struct userfaultfd_ctx *ctx,
}

static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx,
struct userfaultfd_wake_range *range)
struct uffdio_range *range)
{
unsigned seq;
bool need_wakeup;
Expand Down Expand Up @@ -1270,21 +1265,21 @@ static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx,
}

static __always_inline int validate_range(struct mm_struct *mm,
__u64 start, __u64 len)
const struct uffdio_range *range)
{
__u64 task_size = mm->task_size;

if (start & ~PAGE_MASK)
if (range->start & ~PAGE_MASK)
return -EINVAL;
if (len & ~PAGE_MASK)
if (range->len & ~PAGE_MASK)
return -EINVAL;
if (!len)
if (!range->len)
return -EINVAL;
if (start < mmap_min_addr)
if (range->start < mmap_min_addr)
return -EINVAL;
if (start >= task_size)
if (range->start >= task_size)
return -EINVAL;
if (len > task_size - start)
if (range->len > task_size - range->start)
return -EINVAL;
return 0;
}
Expand Down Expand Up @@ -1331,8 +1326,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
vm_flags |= VM_UFFD_MINOR;
}

ret = validate_range(mm, uffdio_register.range.start,
uffdio_register.range.len);
ret = validate_range(mm, &uffdio_register.range);
if (ret)
goto out;

Expand Down Expand Up @@ -1538,11 +1532,11 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
goto out;

ret = validate_range(mm, uffdio_unregister.start,
uffdio_unregister.len);
ret = validate_range(mm, &uffdio_unregister);
if (ret)
goto out;

/* Get rid of start + end in favor of range *? */
start = uffdio_unregister.start;
end = start + uffdio_unregister.len;

Expand Down Expand Up @@ -1597,6 +1591,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
prev = vma_prev(&vmi);
ret = 0;
for_each_vma_range(vmi, vma, end) {
struct uffdio_range range;
cond_resched();

BUG_ON(!vma_can_userfault(vma, vma->vm_flags));
Expand All @@ -1614,22 +1609,21 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
start = vma->vm_start;
vma_end = min(end, vma->vm_end);

range.start = start;
range.len = vma_end - start;
if (userfaultfd_missing(vma)) {
/*
* Wake any concurrent pending userfault while
* we unregister, so they will not hang
* permanently and it avoids userland to call
* UFFDIO_WAKE explicitly.
*/
struct userfaultfd_wake_range range;
range.start = start;
range.len = vma_end - start;
wake_userfault(vma->vm_userfaultfd_ctx.ctx, &range);
}

/* Reset ptes for the whole vma range if wr-protected */
if (userfaultfd_wp(vma))
uffd_wp_range(vma, start, vma_end - start, false);
uffd_wp_range(vma, &range, false);

new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
Expand Down Expand Up @@ -1680,27 +1674,23 @@ static int userfaultfd_wake(struct userfaultfd_ctx *ctx,
{
int ret;
struct uffdio_range uffdio_wake;
struct userfaultfd_wake_range range;
const void __user *buf = (void __user *)arg;

ret = -EFAULT;
if (copy_from_user(&uffdio_wake, buf, sizeof(uffdio_wake)))
goto out;

ret = validate_range(ctx->mm, uffdio_wake.start, uffdio_wake.len);
ret = validate_range(ctx->mm, &uffdio_wake);
if (ret)
goto out;

range.start = uffdio_wake.start;
range.len = uffdio_wake.len;

/*
* len == 0 means wake all and we don't want to wake all here,
* so check it again to be sure.
*/
VM_BUG_ON(!range.len);
VM_BUG_ON(!uffdio_wake.len);

wake_userfault(ctx, &range);
wake_userfault(ctx, &uffdio_wake);
ret = 0;

out:
Expand All @@ -1713,7 +1703,7 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx,
__s64 ret;
struct uffdio_copy uffdio_copy;
struct uffdio_copy __user *user_uffdio_copy;
struct userfaultfd_wake_range range;
struct uffdio_range range;
int flags = 0;

user_uffdio_copy = (struct uffdio_copy __user *) arg;
Expand All @@ -1728,7 +1718,9 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx,
sizeof(uffdio_copy)-sizeof(__s64)))
goto out;

ret = validate_range(ctx->mm, uffdio_copy.dst, uffdio_copy.len);
range.start = uffdio_copy.dst;
range.len = uffdio_copy.len;
ret = validate_range(ctx->mm, &range);
if (ret)
goto out;
/*
Expand All @@ -1744,9 +1736,8 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx,
if (uffdio_copy.mode & UFFDIO_COPY_MODE_WP)
flags |= MFILL_ATOMIC_WP;
if (mmget_not_zero(ctx->mm)) {
ret = mfill_atomic_copy(ctx->mm, uffdio_copy.dst, uffdio_copy.src,
uffdio_copy.len, &ctx->mmap_changing,
flags);
ret = mfill_atomic_copy(ctx->mm, uffdio_copy.src, &range,
&ctx->mmap_changing, flags);
mmput(ctx->mm);
} else {
return -ESRCH;
Expand All @@ -1758,10 +1749,8 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx,
BUG_ON(!ret);
/* len == 0 would wake all */
range.len = ret;
if (!(uffdio_copy.mode & UFFDIO_COPY_MODE_DONTWAKE)) {
range.start = uffdio_copy.dst;
if (!(uffdio_copy.mode & UFFDIO_COPY_MODE_DONTWAKE))
wake_userfault(ctx, &range);
}
ret = range.len == uffdio_copy.len ? 0 : -EAGAIN;
out:
return ret;
Expand All @@ -1773,7 +1762,7 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx,
__s64 ret;
struct uffdio_zeropage uffdio_zeropage;
struct uffdio_zeropage __user *user_uffdio_zeropage;
struct userfaultfd_wake_range range;
struct uffdio_range range;

user_uffdio_zeropage = (struct uffdio_zeropage __user *) arg;

Expand All @@ -1787,17 +1776,16 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx,
sizeof(uffdio_zeropage)-sizeof(__s64)))
goto out;

ret = validate_range(ctx->mm, uffdio_zeropage.range.start,
uffdio_zeropage.range.len);
range = uffdio_zeropage.range;
ret = validate_range(ctx->mm, &range);
if (ret)
goto out;
ret = -EINVAL;
if (uffdio_zeropage.mode & ~UFFDIO_ZEROPAGE_MODE_DONTWAKE)
goto out;

if (mmget_not_zero(ctx->mm)) {
ret = mfill_atomic_zeropage(ctx->mm, uffdio_zeropage.range.start,
uffdio_zeropage.range.len,
ret = mfill_atomic_zeropage(ctx->mm, &uffdio_zeropage.range,
&ctx->mmap_changing);
mmput(ctx->mm);
} else {
Expand All @@ -1811,7 +1799,6 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx,
BUG_ON(!ret);
range.len = ret;
if (!(uffdio_zeropage.mode & UFFDIO_ZEROPAGE_MODE_DONTWAKE)) {
range.start = uffdio_zeropage.range.start;
wake_userfault(ctx, &range);
}
ret = range.len == uffdio_zeropage.range.len ? 0 : -EAGAIN;
Expand All @@ -1825,7 +1812,6 @@ static int userfaultfd_writeprotect(struct userfaultfd_ctx *ctx,
int ret;
struct uffdio_writeprotect uffdio_wp;
struct uffdio_writeprotect __user *user_uffdio_wp;
struct userfaultfd_wake_range range;
bool mode_wp, mode_dontwake;

if (atomic_read(&ctx->mmap_changing))
Expand All @@ -1837,8 +1823,7 @@ static int userfaultfd_writeprotect(struct userfaultfd_ctx *ctx,
sizeof(struct uffdio_writeprotect)))
return -EFAULT;

ret = validate_range(ctx->mm, uffdio_wp.range.start,
uffdio_wp.range.len);
ret = validate_range(ctx->mm, &uffdio_wp.range);
if (ret)
return ret;

Expand All @@ -1853,9 +1838,8 @@ static int userfaultfd_writeprotect(struct userfaultfd_ctx *ctx,
return -EINVAL;

if (mmget_not_zero(ctx->mm)) {
ret = mwriteprotect_range(ctx->mm, uffdio_wp.range.start,
uffdio_wp.range.len, mode_wp,
&ctx->mmap_changing);
ret = mwriteprotect_range(ctx->mm, &uffdio_wp.range,
mode_wp, &ctx->mmap_changing);
mmput(ctx->mm);
} else {
return -ESRCH;
Expand All @@ -1864,11 +1848,8 @@ static int userfaultfd_writeprotect(struct userfaultfd_ctx *ctx,
if (ret)
return ret;

if (!mode_wp && !mode_dontwake) {
range.start = uffdio_wp.range.start;
range.len = uffdio_wp.range.len;
wake_userfault(ctx, &range);
}
if (!mode_wp && !mode_dontwake)
wake_userfault(ctx, &uffdio_wp.range);
return ret;
}

Expand All @@ -1877,7 +1858,7 @@ static int userfaultfd_continue(struct userfaultfd_ctx *ctx, unsigned long arg)
__s64 ret;
struct uffdio_continue uffdio_continue;
struct uffdio_continue __user *user_uffdio_continue;
struct userfaultfd_wake_range range;
struct uffdio_range range;

user_uffdio_continue = (struct uffdio_continue __user *)arg;

Expand All @@ -1891,23 +1872,20 @@ static int userfaultfd_continue(struct userfaultfd_ctx *ctx, unsigned long arg)
sizeof(uffdio_continue) - (sizeof(__s64))))
goto out;

ret = validate_range(ctx->mm, uffdio_continue.range.start,
uffdio_continue.range.len);
range = uffdio_continue.range;
ret = validate_range(ctx->mm, &range);
if (ret)
goto out;

ret = -EINVAL;
/* double check for wraparound just in case. */
if (uffdio_continue.range.start + uffdio_continue.range.len <=
uffdio_continue.range.start) {
if (range.start + range.len <= range.start)
goto out;
}
if (uffdio_continue.mode & ~UFFDIO_CONTINUE_MODE_DONTWAKE)
goto out;

if (mmget_not_zero(ctx->mm)) {
ret = mfill_atomic_continue(ctx->mm, uffdio_continue.range.start,
uffdio_continue.range.len,
ret = mfill_atomic_continue(ctx->mm, &range,
&ctx->mmap_changing);
mmput(ctx->mm);
} else {
Expand All @@ -1923,7 +1901,6 @@ static int userfaultfd_continue(struct userfaultfd_ctx *ctx, unsigned long arg)
BUG_ON(!ret);
range.len = ret;
if (!(uffdio_continue.mode & UFFDIO_CONTINUE_MODE_DONTWAKE)) {
range.start = uffdio_continue.range.start;
wake_userfault(ctx, &range);
}
ret = range.len == uffdio_continue.range.len ? 0 : -EAGAIN;
Expand Down
17 changes: 9 additions & 8 deletions include/linux/userfaultfd_k.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,21 @@ extern int mfill_atomic_install_pte(pmd_t *dst_pmd,
unsigned long dst_addr, struct page *page,
bool newly_allocated, uffd_flags_t flags);

extern ssize_t mfill_atomic_copy(struct mm_struct *dst_mm, unsigned long dst_start,
unsigned long src_start, unsigned long len,
extern ssize_t mfill_atomic_copy(struct mm_struct *dst_mm, unsigned long src_start,
const struct uffdio_range *dst,
atomic_t *mmap_changing, uffd_flags_t flags);
extern ssize_t mfill_atomic_zeropage(struct mm_struct *dst_mm,
unsigned long dst_start,
unsigned long len,
const struct uffdio_range *dst,
atomic_t *mmap_changing);
extern ssize_t mfill_atomic_continue(struct mm_struct *dst_mm, unsigned long dst_start,
unsigned long len, atomic_t *mmap_changing);
extern ssize_t mfill_atomic_continue(struct mm_struct *dst_mm,
const struct uffdio_range *dst,
atomic_t *mmap_changing);
extern int mwriteprotect_range(struct mm_struct *dst_mm,
unsigned long start, unsigned long len,
const struct uffdio_range *range,
bool enable_wp, atomic_t *mmap_changing);
extern long uffd_wp_range(struct vm_area_struct *vma,
unsigned long start, unsigned long len, bool enable_wp);
const struct uffdio_range *range,
bool enable_wp);

/* mm helpers */
static inline bool is_mergeable_vm_userfaultfd_ctx(struct vm_area_struct *vma,
Expand Down
Loading

0 comments on commit cee642b

Please sign in to comment.