Skip to content

Commit

Permalink
nvproxy: Expand HasRMCtrlFD idea to frontend ioctl and control commands.
Browse files Browse the repository at this point in the history
This is helpful for handling parameter types that have one field for frontend
FD that needs to be translated (and are simple apart from that). Avoids
repetitive code.

Rename HasRMCtrlFD->HasFrontendFD so it can have a broader meaning.
Implement generic handlers for frontend ioctl and control commands.

Updates #10413.

PiperOrigin-RevId: 633238248
  • Loading branch information
ayushr2 authored and gvisor-bot committed May 14, 2024
1 parent ea2e93a commit 0cb437d
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 100 deletions.
10 changes: 10 additions & 0 deletions pkg/abi/nvgpu/ctrl.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ type NV0000_CTRL_OS_UNIX_EXPORT_OBJECT_TO_FD_PARAMS struct {
Flags uint32
}

// GetFrontendFD implements HasFrontendFD.GetFrontendFD.
func (p *NV0000_CTRL_OS_UNIX_EXPORT_OBJECT_TO_FD_PARAMS) GetFrontendFD() int32 {
return p.FD
}

// SetFrontendFD implements HasFrontendFD.SetFrontendFD.
func (p *NV0000_CTRL_OS_UNIX_EXPORT_OBJECT_TO_FD_PARAMS) SetFrontendFD(fd int32) {
p.FD = fd
}

// +marshal
type NV0000_CTRL_SYSTEM_GET_BUILD_VERSION_PARAMS struct {
SizeOfStrings uint32
Expand Down
28 changes: 28 additions & 0 deletions pkg/abi/nvgpu/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ type IoctlAllocOSEvent struct {
Status uint32
}

// GetFrontendFD implements HasFrontendFD.GetFrontendFD.
func (p *IoctlAllocOSEvent) GetFrontendFD() int32 {
return int32(p.FD)
}

// SetFrontendFD implements HasFrontendFD.SetFrontendFD.
func (p *IoctlAllocOSEvent) SetFrontendFD(fd int32) {
p.FD = uint32(fd)
}

// IoctlFreeOSEvent is nv_ioctl_free_os_event_t, the parameter type for
// NV_ESC_FREE_OS_EVENT.
//
Expand All @@ -85,6 +95,16 @@ type IoctlFreeOSEvent struct {
Status uint32
}

// GetFrontendFD implements HasFrontendFD.GetFrontendFD.
func (p *IoctlFreeOSEvent) GetFrontendFD() int32 {
return int32(p.FD)
}

// SetFrontendFD implements HasFrontendFD.SetFrontendFD.
func (p *IoctlFreeOSEvent) SetFrontendFD(fd int32) {
p.FD = uint32(fd)
}

// RMAPIVersion is nv_rm_api_version_t, the parameter type for
// NV_ESC_CHECK_VERSION_STR.
//
Expand Down Expand Up @@ -405,6 +425,14 @@ func (n *NVOS64Parameters) FromOS64(other NVOS64Parameters) { *n = other }
// ToOS64 implements RmAllocParamType.ToOS64.
func (n *NVOS64Parameters) ToOS64() NVOS64Parameters { return *n }

// HasFrontendFD is a type constraint for parameter structs containing a
// frontend FD field. This is necessary because, as of this writing (Go 1.20),
// there is no way to enable field access using a Go type constraint.
type HasFrontendFD interface {
GetFrontendFD() int32
SetFrontendFD(int32)
}

// Frontend ioctl parameter struct sizes.
var (
SizeofIoctlRegisterFD = uint32((*IoctlRegisterFD)(nil).SizeBytes())
Expand Down
38 changes: 20 additions & 18 deletions pkg/abi/nvgpu/uvm.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@

package nvgpu

// HasRMCtrlFD is a type constraint for UVM parameter structs containing a
// RMCtrlFD field. This is necessary because, as of this writing (Go 1.20),
// there is no way to enable field access using a Go type constraint.
type HasRMCtrlFD interface {
GetRMCtrlFD() int32
SetRMCtrlFD(int32)
}

// UVM ioctl commands.
const (
// From kernel-open/nvidia-uvm/uvm_linux_ioctl.h:
Expand Down Expand Up @@ -86,11 +78,13 @@ type UVM_REGISTER_GPU_VASPACE_PARAMS struct {
RMStatus uint32
}

func (p *UVM_REGISTER_GPU_VASPACE_PARAMS) GetRMCtrlFD() int32 {
// GetFrontendFD implements HasFrontendFD.GetFrontendFD.
func (p *UVM_REGISTER_GPU_VASPACE_PARAMS) GetFrontendFD() int32 {
return p.RMCtrlFD
}

func (p *UVM_REGISTER_GPU_VASPACE_PARAMS) SetRMCtrlFD(fd int32) {
// SetFrontendFD implements HasFrontendFD.SetFrontendFD.
func (p *UVM_REGISTER_GPU_VASPACE_PARAMS) SetFrontendFD(fd int32) {
p.RMCtrlFD = fd
}

Expand All @@ -113,11 +107,13 @@ type UVM_REGISTER_CHANNEL_PARAMS struct {
Pad0 [4]byte
}

func (p *UVM_REGISTER_CHANNEL_PARAMS) GetRMCtrlFD() int32 {
// GetFrontendFD implements HasFrontendFD.GetFrontendFD.
func (p *UVM_REGISTER_CHANNEL_PARAMS) GetFrontendFD() int32 {
return p.RMCtrlFD
}

func (p *UVM_REGISTER_CHANNEL_PARAMS) SetRMCtrlFD(fd int32) {
// SetFrontendFD implements HasFrontendFD.SetFrontendFD.
func (p *UVM_REGISTER_CHANNEL_PARAMS) SetFrontendFD(fd int32) {
p.RMCtrlFD = fd
}

Expand All @@ -142,11 +138,13 @@ type UVM_MAP_EXTERNAL_ALLOCATION_PARAMS struct {
RMStatus uint32
}

func (p *UVM_MAP_EXTERNAL_ALLOCATION_PARAMS) GetRMCtrlFD() int32 {
// GetFrontendFD implements HasFrontendFD.GetFrontendFD.
func (p *UVM_MAP_EXTERNAL_ALLOCATION_PARAMS) GetFrontendFD() int32 {
return p.RMCtrlFD
}

func (p *UVM_MAP_EXTERNAL_ALLOCATION_PARAMS) SetRMCtrlFD(fd int32) {
// SetFrontendFD implements HasFrontendFD.SetFrontendFD.
func (p *UVM_MAP_EXTERNAL_ALLOCATION_PARAMS) SetFrontendFD(fd int32) {
p.RMCtrlFD = fd
}

Expand All @@ -163,11 +161,13 @@ type UVM_MAP_EXTERNAL_ALLOCATION_PARAMS_V550 struct {
RMStatus uint32
}

func (p *UVM_MAP_EXTERNAL_ALLOCATION_PARAMS_V550) GetRMCtrlFD() int32 {
// GetFrontendFD implements HasFrontendFD.GetFrontendFD.
func (p *UVM_MAP_EXTERNAL_ALLOCATION_PARAMS_V550) GetFrontendFD() int32 {
return p.RMCtrlFD
}

func (p *UVM_MAP_EXTERNAL_ALLOCATION_PARAMS_V550) SetRMCtrlFD(fd int32) {
// SetFrontendFD implements HasFrontendFD.SetFrontendFD.
func (p *UVM_MAP_EXTERNAL_ALLOCATION_PARAMS_V550) SetFrontendFD(fd int32) {
p.RMCtrlFD = fd
}

Expand All @@ -191,11 +191,13 @@ type UVM_REGISTER_GPU_PARAMS struct {
RMStatus uint32
}

func (p *UVM_REGISTER_GPU_PARAMS) GetRMCtrlFD() int32 {
// GetFrontendFD implements HasFrontendFD.GetFrontendFD.
func (p *UVM_REGISTER_GPU_PARAMS) GetFrontendFD() int32 {
return p.RMCtrlFD
}

func (p *UVM_REGISTER_GPU_PARAMS) SetRMCtrlFD(fd int32) {
// SetFrontendFD implements HasFrontendFD.SetFrontendFD.
func (p *UVM_REGISTER_GPU_PARAMS) SetFrontendFD(fd int32) {
p.RMCtrlFD = fd
}

Expand Down
81 changes: 24 additions & 57 deletions pkg/sentry/devices/nvproxy/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,48 +268,17 @@ func frontendRegisterFD(fi *frontendIoctlState) (uintptr, error) {
return frontendIoctlInvoke(fi, &ioctlParams)
}

func rmAllocOSEvent(fi *frontendIoctlState) (uintptr, error) {
var ioctlParams nvgpu.IoctlAllocOSEvent
if fi.ioctlParamsSize != nvgpu.SizeofIoctlAllocOSEvent {
func frontendIoctHasFD[Params any, PtrParams hasFrontendFDPtr[Params]](fi *frontendIoctlState) (uintptr, error) {
var ioctlParams Params
if int(fi.ioctlParamsSize) != (PtrParams)(&ioctlParams).SizeBytes() {
return 0, linuxerr.EINVAL
}
if _, err := ioctlParams.CopyIn(fi.t, fi.ioctlParamsAddr); err != nil {
if _, err := (PtrParams)(&ioctlParams).CopyIn(fi.t, fi.ioctlParamsAddr); err != nil {
return 0, err
}
eventFileGeneric, _ := fi.t.FDTable().Get(int32(ioctlParams.FD))
if eventFileGeneric == nil {
return 0, linuxerr.EINVAL
}
defer eventFileGeneric.DecRef(fi.ctx)
eventFile, ok := eventFileGeneric.Impl().(*frontendFD)
if !ok {
return 0, linuxerr.EINVAL
}

origFD := ioctlParams.FD
ioctlParams.FD = uint32(eventFile.hostFD)
n, err := frontendIoctlInvoke(fi, &ioctlParams)
ioctlParams.FD = origFD
if err != nil {
return n, err
}

if _, err := ioctlParams.CopyOut(fi.t, fi.ioctlParamsAddr); err != nil {
return n, err
}

return n, nil
}

func rmFreeOSEvent(fi *frontendIoctlState) (uintptr, error) {
var ioctlParams nvgpu.IoctlFreeOSEvent
if fi.ioctlParamsSize != nvgpu.SizeofIoctlFreeOSEvent {
return 0, linuxerr.EINVAL
}
if _, err := ioctlParams.CopyIn(fi.t, fi.ioctlParamsAddr); err != nil {
return 0, err
}
eventFileGeneric, _ := fi.t.FDTable().Get(int32(ioctlParams.FD))
origFD := (PtrParams)(&ioctlParams).GetFrontendFD()
eventFileGeneric, _ := fi.t.FDTable().Get(origFD)
if eventFileGeneric == nil {
return 0, linuxerr.EINVAL
}
Expand All @@ -319,18 +288,15 @@ func rmFreeOSEvent(fi *frontendIoctlState) (uintptr, error) {
return 0, linuxerr.EINVAL
}

origFD := ioctlParams.FD
ioctlParams.FD = uint32(eventFile.hostFD)
(PtrParams)(&ioctlParams).SetFrontendFD(eventFile.hostFD)
n, err := frontendIoctlInvoke(fi, &ioctlParams)
ioctlParams.FD = origFD
(PtrParams)(&ioctlParams).SetFrontendFD(origFD)
if err != nil {
return n, err
}

if _, err := ioctlParams.CopyOut(fi.t, fi.ioctlParamsAddr); err != nil {
if _, err := (PtrParams)(&ioctlParams).CopyOut(fi.t, fi.ioctlParamsAddr); err != nil {
return n, err
}

return n, nil
}

Expand Down Expand Up @@ -561,15 +527,17 @@ func ctrlCmdFailWithStatus(fi *frontendIoctlState, ioctlParams *nvgpu.NVOS54Para
return err
}

func ctrlExportObjectToFD(fi *frontendIoctlState, ioctlParams *nvgpu.NVOS54Parameters) (uintptr, error) {
var ctrlParams nvgpu.NV0000_CTRL_OS_UNIX_EXPORT_OBJECT_TO_FD_PARAMS
if ctrlParams.SizeBytes() != int(ioctlParams.ParamsSize) {
func ctrlHasFrontendFD[Params any, PtrParams hasFrontendFDPtr[Params]](fi *frontendIoctlState, ioctlParams *nvgpu.NVOS54Parameters) (uintptr, error) {
var ctrlParams Params
if (PtrParams)(&ctrlParams).SizeBytes() != int(ioctlParams.ParamsSize) {
return 0, linuxerr.EINVAL
}
if _, err := ctrlParams.CopyIn(fi.t, addrFromP64(ioctlParams.Params)); err != nil {
if _, err := (PtrParams)(&ctrlParams).CopyIn(fi.t, addrFromP64(ioctlParams.Params)); err != nil {
return 0, err
}
ctlFileGeneric, _ := fi.t.FDTable().Get(ctrlParams.FD)

origFD := (PtrParams)(&ctrlParams).GetFrontendFD()
ctlFileGeneric, _ := fi.t.FDTable().Get(origFD)
if ctlFileGeneric == nil {
return 0, linuxerr.EINVAL
}
Expand All @@ -579,14 +547,13 @@ func ctrlExportObjectToFD(fi *frontendIoctlState, ioctlParams *nvgpu.NVOS54Param
return 0, linuxerr.EINVAL
}

origFD := ctrlParams.FD
ctrlParams.FD = ctlFile.hostFD
(PtrParams)(&ctrlParams).SetFrontendFD(ctlFile.hostFD)
n, err := rmControlInvoke(fi, ioctlParams, &ctrlParams)
ctrlParams.FD = origFD
(PtrParams)(&ctrlParams).SetFrontendFD(origFD)
if err != nil {
return n, err
}
if _, err := ctrlParams.CopyOut(fi.t, addrFromP64(ioctlParams.Params)); err != nil {
if _, err := (PtrParams)(&ctrlParams).CopyOut(fi.t, addrFromP64(ioctlParams.Params)); err != nil {
return n, err
}
return n, nil
Expand Down Expand Up @@ -767,8 +734,8 @@ func rmAlloc(fi *frontendIoctlState) (uintptr, error) {
//
// Unlike frontendIoctlSimple and rmControlSimple, rmAllocSimple requires the
// parameter type since the parameter's size is otherwise unknown.
func rmAllocSimple[Params any, PParams marshalPtr[Params]](fi *frontendIoctlState, ioctlParams *nvgpu.NVOS64Parameters, isNVOS64 bool) (uintptr, error) {
return rmAllocSimpleParams[Params, PParams](fi, ioctlParams, isNVOS64, addSimpleObjDepParentLocked)
func rmAllocSimple[Params any, PtrParams marshalPtr[Params]](fi *frontendIoctlState, ioctlParams *nvgpu.NVOS64Parameters, isNVOS64 bool) (uintptr, error) {
return rmAllocSimpleParams[Params, PtrParams](fi, ioctlParams, isNVOS64, addSimpleObjDepParentLocked)
}

// addSimpleObjDepParentLocked implements rmAllocInvoke.addObjLocked for
Expand All @@ -777,20 +744,20 @@ func addSimpleObjDepParentLocked[Params any](fi *frontendIoctlState, ioctlParams
fi.fd.dev.nvp.objAdd(fi.ctx, ioctlParams.HRoot, ioctlParams.HObjectNew, ioctlParams.HClass, newRmAllocObject(fi.fd, ioctlParams, rightsRequested, allocParams), ioctlParams.HObjectParent)
}

func rmAllocSimpleParams[Params any, PParams marshalPtr[Params]](fi *frontendIoctlState, ioctlParams *nvgpu.NVOS64Parameters, isNVOS64 bool, objAddLocked func(fi *frontendIoctlState, ioctlParams *nvgpu.NVOS64Parameters, rightsRequested nvgpu.RS_ACCESS_MASK, allocParams *Params)) (uintptr, error) {
func rmAllocSimpleParams[Params any, PtrParams marshalPtr[Params]](fi *frontendIoctlState, ioctlParams *nvgpu.NVOS64Parameters, isNVOS64 bool, objAddLocked func(fi *frontendIoctlState, ioctlParams *nvgpu.NVOS64Parameters, rightsRequested nvgpu.RS_ACCESS_MASK, allocParams *Params)) (uintptr, error) {
if ioctlParams.PAllocParms == 0 {
return rmAllocInvoke[Params](fi, ioctlParams, nil, isNVOS64, objAddLocked)
}

var allocParams Params
if _, err := (PParams)(&allocParams).CopyIn(fi.t, addrFromP64(ioctlParams.PAllocParms)); err != nil {
if _, err := (PtrParams)(&allocParams).CopyIn(fi.t, addrFromP64(ioctlParams.PAllocParms)); err != nil {
return 0, err
}
n, err := rmAllocInvoke(fi, ioctlParams, &allocParams, isNVOS64, objAddLocked)
if err != nil {
return n, err
}
if _, err := (PParams)(&allocParams).CopyOut(fi.t, addrFromP64(ioctlParams.PAllocParms)); err != nil {
if _, err := (PtrParams)(&allocParams).CopyOut(fi.t, addrFromP64(ioctlParams.PAllocParms)); err != nil {
return n, err
}
return n, nil
Expand Down
5 changes: 5 additions & 0 deletions pkg/sentry/devices/nvproxy/nvproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,8 @@ type marshalPtr[T any] interface {
func addrFromP64(p nvgpu.P64) hostarch.Addr {
return hostarch.Addr(uintptr(uint64(p)))
}

type hasFrontendFDPtr[T any] interface {
marshalPtr[T]
nvgpu.HasFrontendFD
}
27 changes: 10 additions & 17 deletions pkg/sentry/devices/nvproxy/uvm.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
Expand Down Expand Up @@ -165,16 +164,16 @@ func uvmIoctlNoParams(ui *uvmIoctlState) (uintptr, error) {
return uvmIoctlInvoke[byte](ui, nil)
}

func uvmIoctlSimple[Params any, PParams marshalPtr[Params]](ui *uvmIoctlState) (uintptr, error) {
func uvmIoctlSimple[Params any, PtrParams marshalPtr[Params]](ui *uvmIoctlState) (uintptr, error) {
var ioctlParams Params
if _, err := (PParams)(&ioctlParams).CopyIn(ui.t, ui.ioctlParamsAddr); err != nil {
if _, err := (PtrParams)(&ioctlParams).CopyIn(ui.t, ui.ioctlParamsAddr); err != nil {
return 0, err
}
n, err := uvmIoctlInvoke(ui, &ioctlParams)
if err != nil {
return n, err
}
if _, err := (PParams)(&ioctlParams).CopyOut(ui.t, ui.ioctlParamsAddr); err != nil {
if _, err := (PtrParams)(&ioctlParams).CopyOut(ui.t, ui.ioctlParamsAddr); err != nil {
return n, err
}
return n, nil
Expand Down Expand Up @@ -241,25 +240,19 @@ func uvmMMInitialize(ui *uvmIoctlState) (uintptr, error) {
return n, nil
}

type hasRMCtrlFDPtr[T any] interface {
*T
marshal.Marshallable
nvgpu.HasRMCtrlFD
}

func uvmIoctlHasRMCtrlFD[Params any, PParams hasRMCtrlFDPtr[Params]](ui *uvmIoctlState) (uintptr, error) {
func uvmIoctlHasFrontendFD[Params any, PtrParams hasFrontendFDPtr[Params]](ui *uvmIoctlState) (uintptr, error) {
var ioctlParams Params
if _, err := (PParams)(&ioctlParams).CopyIn(ui.t, ui.ioctlParamsAddr); err != nil {
if _, err := (PtrParams)(&ioctlParams).CopyIn(ui.t, ui.ioctlParamsAddr); err != nil {
return 0, err
}

rmCtrlFD := (PParams)(&ioctlParams).GetRMCtrlFD()
rmCtrlFD := (PtrParams)(&ioctlParams).GetFrontendFD()
if rmCtrlFD < 0 {
n, err := uvmIoctlInvoke(ui, &ioctlParams)
if err != nil {
return n, err
}
if _, err := (PParams)(&ioctlParams).CopyOut(ui.t, ui.ioctlParamsAddr); err != nil {
if _, err := (PtrParams)(&ioctlParams).CopyOut(ui.t, ui.ioctlParamsAddr); err != nil {
return n, err
}
return n, nil
Expand All @@ -276,15 +269,15 @@ func uvmIoctlHasRMCtrlFD[Params any, PParams hasRMCtrlFDPtr[Params]](ui *uvmIoct
}

sentryIoctlParams := ioctlParams
(PParams)(&sentryIoctlParams).SetRMCtrlFD(ctlFile.hostFD)
(PtrParams)(&sentryIoctlParams).SetFrontendFD(ctlFile.hostFD)
n, err := uvmIoctlInvoke(ui, &sentryIoctlParams)
if err != nil {
return n, err
}

outIoctlParams := sentryIoctlParams
(PParams)(&outIoctlParams).SetRMCtrlFD(rmCtrlFD)
if _, err := (PParams)(&outIoctlParams).CopyOut(ui.t, ui.ioctlParamsAddr); err != nil {
(PtrParams)(&outIoctlParams).SetFrontendFD(rmCtrlFD)
if _, err := (PtrParams)(&outIoctlParams).CopyOut(ui.t, ui.ioctlParamsAddr); err != nil {
return n, err
}

Expand Down

0 comments on commit 0cb437d

Please sign in to comment.