@@ -550,49 +550,61 @@ static int __io_uring_register(struct io_ring_ctx *ctx, unsigned opcode,
550550 return ret ;
551551}
552552
553- SYSCALL_DEFINE4 (io_uring_register , unsigned int , fd , unsigned int , opcode ,
554- void __user * , arg , unsigned int , nr_args )
553+ /*
554+ * Given an 'fd' value, return the ctx associated with if. If 'registered' is
555+ * true, then the registered index is used. Otherwise, the normal fd table.
556+ * Caller must call fput() on the returned file, unless it's an ERR_PTR.
557+ */
558+ struct file * io_uring_register_get_file (int fd , bool registered )
555559{
556- struct io_ring_ctx * ctx ;
557- long ret = - EBADF ;
558560 struct file * file ;
559- bool use_registered_ring ;
560561
561- use_registered_ring = !!(opcode & IORING_REGISTER_USE_REGISTERED_RING );
562- opcode &= ~IORING_REGISTER_USE_REGISTERED_RING ;
563-
564- if (opcode >= IORING_REGISTER_LAST )
565- return - EINVAL ;
566-
567- if (use_registered_ring ) {
562+ if (registered ) {
568563 /*
569564 * Ring fd has been registered via IORING_REGISTER_RING_FDS, we
570565 * need only dereference our task private array to find it.
571566 */
572567 struct io_uring_task * tctx = current -> io_uring ;
573568
574569 if (unlikely (!tctx || fd >= IO_RINGFD_REG_MAX ))
575- return - EINVAL ;
570+ return ERR_PTR ( - EINVAL ) ;
576571 fd = array_index_nospec (fd , IO_RINGFD_REG_MAX );
577572 file = tctx -> registered_rings [fd ];
578- if (unlikely (!file ))
579- return - EBADF ;
580573 } else {
581574 file = fget (fd );
582- if (unlikely (!file ))
583- return - EBADF ;
584- ret = - EOPNOTSUPP ;
585- if (!io_is_uring_fops (file ))
586- goto out_fput ;
587575 }
588576
577+ if (unlikely (!file ))
578+ return ERR_PTR (- EBADF );
579+ if (io_is_uring_fops (file ))
580+ return file ;
581+ fput (file );
582+ return ERR_PTR (- EOPNOTSUPP );
583+ }
584+
585+ SYSCALL_DEFINE4 (io_uring_register , unsigned int , fd , unsigned int , opcode ,
586+ void __user * , arg , unsigned int , nr_args )
587+ {
588+ struct io_ring_ctx * ctx ;
589+ long ret = - EBADF ;
590+ struct file * file ;
591+ bool use_registered_ring ;
592+
593+ use_registered_ring = !!(opcode & IORING_REGISTER_USE_REGISTERED_RING );
594+ opcode &= ~IORING_REGISTER_USE_REGISTERED_RING ;
595+
596+ if (opcode >= IORING_REGISTER_LAST )
597+ return - EINVAL ;
598+
599+ file = io_uring_register_get_file (fd , use_registered_ring );
600+ if (IS_ERR (file ))
601+ return PTR_ERR (file );
589602 ctx = file -> private_data ;
590603
591604 mutex_lock (& ctx -> uring_lock );
592605 ret = __io_uring_register (ctx , opcode , arg , nr_args );
593606 mutex_unlock (& ctx -> uring_lock );
594607 trace_io_uring_register (ctx , opcode , ctx -> nr_user_files , ctx -> nr_user_bufs , ret );
595- out_fput :
596608 if (!use_registered_ring )
597609 fput (file );
598610 return ret ;
0 commit comments