From e7af9f04586a8d4ec9d765180abd6dba284dc275 Mon Sep 17 00:00:00 2001 From: Jake Scott Date: Tue, 27 Aug 2024 09:36:03 -0400 Subject: [PATCH 1/3] changes --- cred.go | 2 +- names.go | 40 +++++++++++++++++++++++++++++++++++++++- provider.go | 5 ++--- seccontext.go | 16 ++++++++++++---- 4 files changed, 54 insertions(+), 9 deletions(-) diff --git a/cred.go b/cred.go index d915972..0f11a7e 100644 --- a/cred.go +++ b/cred.go @@ -36,7 +36,7 @@ type Credential struct { id C.gss_cred_id_t } -func (library) AcquireCredential(name g.GssName, mechs []g.GssMech, usage g.CredUsage, lifetime time.Duration) (g.Credential, error) { +func (provider) AcquireCredential(name g.GssName, mechs []g.GssMech, usage g.CredUsage, lifetime time.Duration) (g.Credential, error) { // turn the mechs into an array of OIDs gssOidSet := gssOidSetFromOids(mechsToOids(mechs)) gssOidSet.Pin() diff --git a/names.go b/names.go index 8c0bf7e..556c42e 100644 --- a/names.go +++ b/names.go @@ -40,7 +40,7 @@ func nameFromGssInternal(name C.gss_name_t) GssName { return GssName{name} } -func (library) ImportName(name string, nameType g.GssNameType) (g.GssName, error) { +func (provider) ImportName(name string, nameType g.GssNameType) (g.GssName, error) { nameOid := nameType.Oid() var minor C.OM_uint32 var cGssName C.gss_name_t @@ -55,6 +55,44 @@ func (library) ImportName(name string, nameType g.GssNameType) (g.GssName, error }, nil } +func (provider) InquireNamesForMech(m g.GssMech) ([]g.GssNameType, error) { + cMechOid := oid2Coid(m.Oid()) + + var minor C.OM_uint32 + var cNameTypes C.gss_OID_set // cNameTypes.elements allocated by GSSAPI; released by *1 + major := C.gss_inquire_names_for_mech(&minor, cMechOid, &cNameTypes) + + if major != 0 { + return nil, makeStatus(major, minor) + } + + defer C.gss_release_oid_set(&minor, &cNameTypes) + + nameTypeOids := oidsFromGssOidSet(cNameTypes) + ret := make([]g.GssNameType, 0, len(nameTypeOids)) + + seen := make(map[string]bool) + + for _, oid := range nameTypeOids { + nt, err := g.NameFromOid(oid) + switch { + default: + ntStr := nt.String() + if _, ok := seen[ntStr]; !ok { + ret = append(ret, nt) + seen[nt.String()] = true + } + case errors.Is(err, g.ErrBadNameType): + // warn + continue + case err != nil: + return nil, err + } + } + + return ret, nil +} + func (n *GssName) Compare(other g.GssName) (bool, error) { // other must be our type, not one from a different GSSAPI impl // .. but this method needs to implement gsscommon.GssName.Compare() diff --git a/provider.go b/provider.go index 8b06b34..bbee854 100644 --- a/provider.go +++ b/provider.go @@ -9,19 +9,18 @@ import ( // #cgo LDFLAGS: -lgssapi_krb5 import "C" - const LIBID = "GSSAPI-C" func init() { g.RegisterProvider(LIBID, New) } -type library struct { +type provider struct { name string } func New() g.Provider { - return &library{ + return &provider{ name: LIBID, } } diff --git a/seccontext.go b/seccontext.go index 583e9a8..88e2512 100644 --- a/seccontext.go +++ b/seccontext.go @@ -19,6 +19,7 @@ type SecContext struct { continueNeeded bool isInitiator bool targetName *GssName + mech g.GssMech } func oid2Coid(oid g.Oid) C.gss_OID { @@ -32,7 +33,7 @@ func oid2Coid(oid g.Oid) C.gss_OID { } } -func (library) InitSecContext(name g.GssName, opts ...g.InitSecContextOption) (g.SecContext, []byte, error) { +func (provider) InitSecContext(name g.GssName, opts ...g.InitSecContextOption) (g.SecContext, []byte, error) { o := g.InitSecContextOptions{} for _, opt := range opts { opt(&o) @@ -86,10 +87,11 @@ func (library) InitSecContext(name g.GssName, opts ...g.InitSecContextOption) (g continueNeeded: major == C.GSS_S_CONTINUE_NEEDED, isInitiator: true, targetName: lName, + mech: o.Mech, }, outToken, nil } -func (library) AcceptSecContext(cred g.Credential, inputToken []byte) (g.SecContext, []byte, error) { +func (provider) AcceptSecContext(cred g.Credential, inputToken []byte) (g.SecContext, []byte, error) { // get the C cred ID and name var cGssCred C.gss_cred_id_t = C.GSS_C_NO_CREDENTIAL if cred != nil { @@ -124,7 +126,7 @@ func (library) AcceptSecContext(cred g.Credential, inputToken []byte) (g.SecCont }, outToken, nil } -func (library) ImportSecContext(token []byte) (g.SecContext, error) { +func (provider) ImportSecContext(token []byte) (g.SecContext, error) { var minor C.OM_uint32 var cGssCtxId C.gss_ctx_id_t @@ -147,8 +149,14 @@ func (c *SecContext) Continue(inputToken []byte) ([]byte, error) { cInputToken, pinner := bytesToCBuffer(inputToken) defer pinner.Unpin() + mech := g.Oid{} + if c.mech != nil { + mech = c.mech.Oid() + } + cMechOid := oid2Coid(mech) + if c.isInitiator { - major = C.gss_init_sec_context(&minor, C.GSS_C_NO_CREDENTIAL, &c.id, c.targetName.name, nil, 0, 0, nil, &cInputToken, nil, &cOutToken, nil, nil) + major = C.gss_init_sec_context(&minor, C.GSS_C_NO_CREDENTIAL, &c.id, c.targetName.name, cMechOid, 0, 0, nil, &cInputToken, nil, &cOutToken, nil, nil) } else { major = C.gss_accept_sec_context(&minor, &c.id, C.GSS_C_NO_CREDENTIAL, &cInputToken, nil, nil, nil, &cOutToken, nil, nil, nil) } From 955ba90d45114308ea3de03cf4df86146408f31f Mon Sep 17 00:00:00 2001 From: Jake Scott Date: Wed, 28 Aug 2024 15:41:35 -0400 Subject: [PATCH 2/3] Implement the QoP interface changes --- cred.go | 26 +++---------------- go.mod | 8 ++---- go.sum | 2 ++ helpers.go | 26 ++++++++----------- names.go | 37 ++++++++------------------ provider_test.go | 13 ++++++++++ seccontext.go | 65 ++++++++++++++++++++++++++-------------------- seccontext_test.go | 18 +++++++++---- status.go | 13 +++------- 9 files changed, 96 insertions(+), 112 deletions(-) create mode 100644 provider_test.go diff --git a/cred.go b/cred.go index 0f11a7e..8eea63d 100644 --- a/cred.go +++ b/cred.go @@ -3,24 +3,6 @@ package gssapi /* #include -gss_OID_desc GoStringToGssOID(_GoString_ s); - -OM_uint32 inquire_cred_by_mech (OM_uint32 *minor, const gss_cred_id_t cred_handle, _GoString_ mechOid, - gss_name_t *output_name, OM_uint32 *init_life, OM_uint32 *accept_life, gss_cred_usage_t *usage) { - gss_OID_desc oid = GoStringToGssOID(mechOid); - - return gss_inquire_cred_by_mech(minor, cred_handle, &oid, output_name, init_life, accept_life, usage); -} - -OM_uint32 add_cred(OM_uint32 *minor, const gss_cred_id_t cred_handle, const gss_name_t name, _GoString_ mechOid, - gss_cred_usage_t usage, OM_uint32 initiator_lifetime, OM_uint32 acceptor_lifetime, - gss_OID_set *actual_mechs, OM_uint32 *initiator_rec, OM_uint32 *acceptor_rec) { - gss_OID_desc oid = GoStringToGssOID(mechOid); - - return gss_add_cred(minor, cred_handle, name, &oid, usage, initiator_lifetime, acceptor_lifetime, NULL, - actual_mechs, initiator_rec, acceptor_rec ); -} - */ import "C" @@ -142,13 +124,13 @@ func (c *Credential) Inquire() (*g.CredInfo, error) { } func (c *Credential) InquireByMech(mech g.GssMech) (*g.CredInfo, error) { - mechOid := mech.Oid() + cMechOid := oid2Coid(mech.Oid()) var minor C.OM_uint32 var cGssName C.gss_name_t // cGssName allocated by GSSAPI; releaseed by *1 var cTimeRecInit, cTimeRecAcc C.OM_uint32 var cCredUsage C.gss_cred_usage_t - major := C.inquire_cred_by_mech(&minor, c.id, string(mechOid), &cGssName, &cTimeRecInit, &cTimeRecAcc, &cCredUsage) + major := C.gss_inquire_cred_by_mech(&minor, c.id, cMechOid, &cGssName, &cTimeRecInit, &cTimeRecAcc, &cCredUsage) if major != 0 { return nil, makeMechStatus(major, minor, mech) @@ -190,7 +172,7 @@ func (c *Credential) InquireByMech(mech g.GssMech) (*g.CredInfo, error) { } func (c *Credential) Add(name g.GssName, mech g.GssMech, usage g.CredUsage, initiatorLifetime time.Duration, acceptorLifetime time.Duration) error { - mechOid := mech.Oid() + cMechOid := oid2Coid(mech.Oid()) var cGssName C.gss_name_t if name != nil { @@ -205,7 +187,7 @@ func (c *Credential) Add(name g.GssName, mech g.GssMech, usage g.CredUsage, init var minor C.OM_uint32 var cTimeRecInit, cTimeRecAcc C.OM_uint32 var cActualMechs C.gss_OID_set // cActualMechs.elements allocated by GSSAPI; released by *1 - major := C.add_cred(&minor, c.id, cGssName, string(mechOid), C.int(usage), C.OM_uint32(initiatorLifetime.Seconds()), C.OM_uint32(acceptorLifetime.Seconds()), &cActualMechs, &cTimeRecInit, &cTimeRecAcc) + major := C.gss_add_cred(&minor, c.id, cGssName, cMechOid, C.int(usage), C.OM_uint32(initiatorLifetime.Seconds()), C.OM_uint32(acceptorLifetime.Seconds()), nil, &cActualMechs, &cTimeRecInit, &cTimeRecAcc) if major != 0 { return makeMechStatus(major, minor, mech) } diff --git a/go.mod b/go.mod index 9eb6bbe..7361177 100644 --- a/go.mod +++ b/go.mod @@ -2,15 +2,11 @@ module github.com/golang-auth/go-gssapi-c go 1.18 -replace github.com/golang-auth/go-gssapi/v3 => ../go-gssapi/v3 - -require ( - github.com/golang-auth/go-gssapi/v3 v3.0.0-00010101000000-000000000000 - github.com/stretchr/testify v1.9.0 -) +require github.com/stretchr/testify v1.9.0 require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-auth/go-gssapi/v3 v3.0.0-alpha // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 60ce688..fea3d57 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-auth/go-gssapi/v3 v3.0.0-alpha h1:vzRDfKlo9OFYaxUOYSrdZ1JiC0TSTWbLtO1y2mNq7Vg= +github.com/golang-auth/go-gssapi/v3 v3.0.0-alpha/go.mod h1:xNotWZQDADAqcBR4A7AKn+p4tSxQE4m6KA06J41U0cY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/helpers.go b/helpers.go index ba1535c..9323a6b 100644 --- a/helpers.go +++ b/helpers.go @@ -2,21 +2,6 @@ package gssapi /* #include - -gss_OID_desc GoStringToGssOID(_GoString_ s) { - size_t l = _GoStringLen(s); - void *elms = (void*)_GoStringPtr(s); - gss_OID_desc oid = {l, elms}; - return oid; -} - -gss_buffer_desc GoStringToGssBuffer(_GoString_ s) { - size_t l = _GoStringLen(s); - void *value = (void*)_GoStringPtr(s); - gss_buffer_desc buf = {l, value}; - return buf; -} - */ import "C" @@ -100,3 +85,14 @@ func bytesToCBuffer(b []byte) (C.gss_buffer_desc, runtime.Pinner) { return ret, pinner } + +func oid2Coid(oid g.Oid) C.gss_OID { + if len(oid) > 0 { + return &C.gss_OID_desc{ + length: C.OM_uint32(len(oid)), + elements: unsafe.Pointer(&oid[0]), + } + } else { + return C.GSS_C_NO_OID + } +} diff --git a/names.go b/names.go index 556c42e..81fbff1 100644 --- a/names.go +++ b/names.go @@ -2,26 +2,6 @@ package gssapi /* #include - -gss_OID_desc GoStringToGssOID(_GoString_ s); -gss_buffer_desc GoStringToGssBuffer(_GoString_ s); - -// _GoString_ is really a convenient []byte here.. -OM_uint32 import_name(_GoString_ name, _GoString_ nameOid, OM_uint32 *minor, gss_name_t *output_name) { - gss_buffer_desc nameBuf = GoStringToGssBuffer(name); - gss_OID_desc oid = GoStringToGssOID(nameOid); - gss_OID pOid = oid.length > 0 ? &oid : GSS_C_NO_OID; - - return gss_import_name(minor, &nameBuf, pOid, output_name); -} - -OM_uint32 canonicalize_name(const gss_name_t name, _GoString_ mechOid, OM_uint32 *minor, gss_name_t *output_name) { - gss_OID_desc oid = GoStringToGssOID(mechOid); - - return gss_canonicalize_name(minor, name, &oid, output_name); -} - - */ import "C" @@ -41,10 +21,14 @@ func nameFromGssInternal(name C.gss_name_t) GssName { } func (provider) ImportName(name string, nameType g.GssNameType) (g.GssName, error) { - nameOid := nameType.Oid() + cNameOid := oid2Coid(nameType.Oid()) + + cNameBuf, pinner := bytesToCBuffer([]byte(name)) + defer pinner.Unpin() + var minor C.OM_uint32 var cGssName C.gss_name_t - major := C.import_name(name, string(nameOid), &minor, &cGssName) + major := C.gss_import_name(&minor, &cNameBuf, cNameOid, &cGssName) if major != 0 { return nil, makeStatus(major, minor) @@ -55,8 +39,8 @@ func (provider) ImportName(name string, nameType g.GssNameType) (g.GssName, erro }, nil } -func (provider) InquireNamesForMech(m g.GssMech) ([]g.GssNameType, error) { - cMechOid := oid2Coid(m.Oid()) +func (provider) InquireNamesForMech(mech g.GssMech) ([]g.GssNameType, error) { + cMechOid := oid2Coid(mech.Oid()) var minor C.OM_uint32 var cNameTypes C.gss_OID_set // cNameTypes.elements allocated by GSSAPI; released by *1 @@ -179,10 +163,11 @@ func (n *GssName) InquireMechs() ([]g.GssMech, error) { } func (n *GssName) Canonicalize(mech g.GssMech) (g.GssName, error) { - mechOid := mech.Oid() + cMechOid := oid2Coid(mech.Oid()) + var minor C.OM_uint32 var cOutName C.gss_name_t - major := C.canonicalize_name(n.name, string(mechOid), &minor, &cOutName) + major := C.gss_canonicalize_name(&minor, n.name, cMechOid, &cOutName) if major != 0 { return nil, makeMechStatus(major, minor, mech) } diff --git a/provider_test.go b/provider_test.go new file mode 100644 index 0000000..48355b8 --- /dev/null +++ b/provider_test.go @@ -0,0 +1,13 @@ +package gssapi + +import ( + "testing" + + g "github.com/golang-auth/go-gssapi/v3" + "github.com/stretchr/testify/assert" +) + +func TestProvider(t *testing.T) { + p := g.NewProvider("GSSAPI-C") + assert.IsType(t, &provider{}, p) +} diff --git a/seccontext.go b/seccontext.go index 88e2512..f551a0a 100644 --- a/seccontext.go +++ b/seccontext.go @@ -9,7 +9,6 @@ import ( "fmt" "math" "time" - "unsafe" g "github.com/golang-auth/go-gssapi/v3" ) @@ -22,17 +21,6 @@ type SecContext struct { mech g.GssMech } -func oid2Coid(oid g.Oid) C.gss_OID { - if len(oid) > 0 { - return &C.gss_OID_desc{ - length: C.OM_uint32(len(oid)), - elements: unsafe.Pointer(&oid[0]), - } - } else { - return C.GSS_C_NO_OID - } -} - func (provider) InitSecContext(name g.GssName, opts ...g.InitSecContextOption) (g.SecContext, []byte, error) { o := g.InitSecContextOptions{} for _, opt := range opts { @@ -82,11 +70,16 @@ func (provider) InitSecContext(name g.GssName, opts ...g.InitSecContextOption) ( outToken := C.GoBytes(cOutToken.value, C.int(cOutToken.length)) + savedName, err := lName.Duplicate() + if err != nil { + return nil, nil, makeMechStatus(major, minor, o.Mech) + } + return &SecContext{ id: cGssCtxId, continueNeeded: major == C.GSS_S_CONTINUE_NEEDED, isInitiator: true, - targetName: lName, + targetName: savedName.(*GssName), mech: o.Mech, }, outToken, nil } @@ -178,6 +171,11 @@ func (c *SecContext) ContinueNeeded() bool { } func (c *SecContext) Delete() ([]byte, error) { + if c.targetName != nil { + c.targetName.Release() + c.targetName = nil + } + if c.id == nil { return nil, nil } @@ -294,7 +292,7 @@ func (c *SecContext) Inquire() (*g.SecContextInfo, error) { }, nil } -func (c *SecContext) WrapSizeLimit(confRequired bool, maxWrapSize uint) (uint, error) { +func (c *SecContext) WrapSizeLimit(confRequired bool, maxWrapSize uint, qop g.QoP) (uint, error) { var minor C.OM_uint32 var cConfReq C.int var cMaxInputSize C.OM_uint32 @@ -303,8 +301,11 @@ func (c *SecContext) WrapSizeLimit(confRequired bool, maxWrapSize uint) (uint, e if maxWrapSize > math.MaxUint32 { return 0, ErrTooLarge } + if qop > math.MaxUint32 { + return 0, g.ErrBadQop + } - major := C.gss_wrap_size_limit(&minor, c.id, cConfReq, C.GSS_C_QOP_DEFAULT, C.OM_uint32(maxWrapSize), &cMaxInputSize) + major := C.gss_wrap_size_limit(&minor, c.id, cConfReq, C.gss_qop_t(qop), C.OM_uint32(maxWrapSize), &cMaxInputSize) if major != 0 { return 0, makeStatus(major, minor) } @@ -331,11 +332,14 @@ func (c *SecContext) Export() ([]byte, error) { return outToken, nil } -func (c *SecContext) Wrap(msgIn []byte, confReq bool) ([]byte, bool, error) { +func (c *SecContext) Wrap(msgIn []byte, confReq bool, qop g.QoP) ([]byte, bool, error) { // the C bindings support a 32 bit message size.. if len(msgIn) > math.MaxUint32 { return nil, false, ErrTooLarge } + if qop > math.MaxUint32 { + return nil, false, g.ErrBadQop + } cInputMessage, pinner := bytesToCBuffer(msgIn) defer pinner.Unpin() @@ -347,7 +351,7 @@ func (c *SecContext) Wrap(msgIn []byte, confReq bool) ([]byte, bool, error) { cConfReq = 1 } - major := C.gss_wrap(&minor, c.id, cConfReq, C.GSS_C_QOP_DEFAULT, &cInputMessage, &cConfState, &cOutputMessage) + major := C.gss_wrap(&minor, c.id, cConfReq, C.gss_qop_t(qop), &cInputMessage, &cConfState, &cOutputMessage) if major != 0 { return nil, false, makeStatus(major, minor) } @@ -358,10 +362,10 @@ func (c *SecContext) Wrap(msgIn []byte, confReq bool) ([]byte, bool, error) { return msgOut, cConfState != 0, nil } -func (c *SecContext) Unwrap(msgIn []byte) ([]byte, bool, error) { +func (c *SecContext) Unwrap(msgIn []byte) ([]byte, bool, g.QoP, error) { // the C bindings support a 32 bit message size.. if len(msgIn) > math.MaxUint32 { - return nil, false, ErrTooLarge + return nil, false, 0, ErrTooLarge } cInputMessage, pinner := bytesToCBuffer(msgIn) @@ -370,30 +374,34 @@ func (c *SecContext) Unwrap(msgIn []byte) ([]byte, bool, error) { var minor C.OM_uint32 var cConfState C.int var cOutputMessage C.gss_buffer_desc // allocated by GSSAPI; released by *1 + var cQoP C.gss_qop_t - major := C.gss_unwrap(&minor, c.id, &cInputMessage, &cOutputMessage, &cConfState, nil) + major := C.gss_unwrap(&minor, c.id, &cInputMessage, &cOutputMessage, &cConfState, &cQoP) if major != 0 { - return nil, false, makeStatus(major, minor) + return nil, false, 0, makeStatus(major, minor) } defer C.gss_release_buffer(&minor, &cOutputMessage) // *1 Release GSSAPI allocated buffer msgOut := C.GoBytes(cOutputMessage.value, C.int(cOutputMessage.length)) - return msgOut, cConfState != 0, nil + return msgOut, cConfState != 0, g.QoP(cQoP), nil } -func (c *SecContext) GetMIC(msg []byte) ([]byte, error) { +func (c *SecContext) GetMIC(msg []byte, qop g.QoP) ([]byte, error) { // the C bindings support a 32 bit message size.. if len(msg) > math.MaxUint32 { return nil, ErrTooLarge } + if qop > math.MaxUint32 { + return nil, g.ErrBadQop + } cMessage, pinner := bytesToCBuffer(msg) defer pinner.Unpin() var minor C.OM_uint32 var cMsgToken C.gss_buffer_desc // allocated by GSSAPI; released by *1 - major := C.gss_get_mic(&minor, c.id, C.GSS_C_QOP_DEFAULT, &cMessage, &cMsgToken) + major := C.gss_get_mic(&minor, c.id, C.gss_qop_t(qop), &cMessage, &cMsgToken) if major != 0 { return nil, makeStatus(major, minor) } @@ -404,10 +412,10 @@ func (c *SecContext) GetMIC(msg []byte) ([]byte, error) { return token, nil } -func (c *SecContext) VerifyMIC(msg, token []byte) error { +func (c *SecContext) VerifyMIC(msg, token []byte) (g.QoP, error) { // the C bindings support a 32 bit message size.. if len(msg) > math.MaxUint32 { - return ErrTooLarge + return 0, ErrTooLarge } cMessage, pinnerMsg := bytesToCBuffer(msg) @@ -416,6 +424,7 @@ func (c *SecContext) VerifyMIC(msg, token []byte) error { defer pinnerToken.Unpin() var minor C.OM_uint32 - major := C.gss_verify_mic(&minor, c.id, &cMessage, &cToken, nil) - return makeStatus(major, minor) + var cQoP C.gss_qop_t + major := C.gss_verify_mic(&minor, c.id, &cMessage, &cToken, &cQoP) + return g.QoP(cQoP), makeStatus(major, minor) } diff --git a/seccontext_test.go b/seccontext_test.go index 0690227..a547e17 100644 --- a/seccontext_test.go +++ b/seccontext_test.go @@ -41,6 +41,14 @@ func (ta *testAssets) Free() { os.Remove(ta.ccfile) } +// will prevent compilation if SecContext{} doesn't implement the interface +func TestSecContextInterface(t *testing.T) { + s := SecContext{} + var gsc g.SecContext = &s + + _ = gsc +} + func TestInitSecContext(t *testing.T) { assert := assert.New(t) @@ -203,7 +211,7 @@ func TestContextWrapSizeLimit(t *testing.T) { // the max unwrapped token size would always be less that the max // wrapped token size - tokSize, err := secCtxInitiator.WrapSizeLimit(true, 100) + tokSize, err := secCtxInitiator.WrapSizeLimit(true, 100, 0) assert.NoError(err) assert.Less(tokSize, uint(1000)) } @@ -270,20 +278,20 @@ func TestSecContextEstablishment(t *testing.T) { assert.False(secCtxAcceptor.ContinueNeeded()) msg := []byte("Hello GSSAPI") - wrapped, hasConf, err := secCtxInitiator.Wrap(msg, true) + wrapped, hasConf, err := secCtxInitiator.Wrap(msg, true, 0) assert.NoError(err) assert.True(hasConf) assert.NotEmpty(wrapped) - unwrapped, hasConf, err := secCtxAcceptor.Unwrap(wrapped) + unwrapped, hasConf, _, err := secCtxAcceptor.Unwrap(wrapped) assert.NoError(err) assert.True(hasConf) assert.Equal(msg, unwrapped) - mic, err := secCtxInitiator.GetMIC(msg) + mic, err := secCtxInitiator.GetMIC(msg, 0) assert.NoError(err) assert.NotEmpty(mic) - err = secCtxAcceptor.VerifyMIC(msg, mic) + _, err = secCtxAcceptor.VerifyMIC(msg, mic) assert.NoError(err) } diff --git a/status.go b/status.go index e1573c3..ac22cd1 100644 --- a/status.go +++ b/status.go @@ -9,15 +9,6 @@ import ( /* #include - -gss_OID_desc GoStringToGssOID(_GoString_ s); - -OM_uint32 display_status(OM_uint32 status, int status_type, _GoString_ mechOid, OM_uint32 *minor, OM_uint32 *msgCtx, gss_buffer_desc *status_string) { - gss_OID_desc oid = GoStringToGssOID(mechOid); - gss_OID poid = (oid.length == 0) ? NULL : &oid; - - return gss_display_status(minor, status, status_type, poid, msgCtx, status_string); -} */ import "C" @@ -138,13 +129,15 @@ func gssMinorErrors(minor C.OM_uint32, mech g.GssMech) []error { if mech != nil { mechOid = mech.Oid() } + + cMechOid := oid2Coid(mechOid) var lMinor, msgCtx C.OM_uint32 var statusString C.gss_buffer_desc ret := []error{} for { - major := C.display_status(minor, 2, string(mechOid), &lMinor, &msgCtx, &statusString) + major := C.gss_display_status(&lMinor, minor, 2, cMechOid, &msgCtx, &statusString) if major != 0 { // specifically do not call makeStatus here - we might end up in a loop.. ret = append(ret, fmt.Errorf("got GSS error %d/%d while finding string for minor code %d", major, lMinor, minor)) From b5afad6a698675380e64f4a7f9e8c8dc4ae09616 Mon Sep 17 00:00:00 2001 From: Jake Scott Date: Thu, 5 Sep 2024 16:47:01 -0400 Subject: [PATCH 3/3] New interface --- go.mod | 2 + seccontext.go | 113 ++++++++++++++++---------- seccontext_test.go | 196 ++++++++++++++++++++++++++++----------------- 3 files changed, 194 insertions(+), 117 deletions(-) diff --git a/go.mod b/go.mod index 7361177..ed4c45d 100644 --- a/go.mod +++ b/go.mod @@ -10,3 +10,5 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/golang-auth/go-gssapi/v3 => ../go-gssapi/v3 diff --git a/seccontext.go b/seccontext.go index f551a0a..2f79c45 100644 --- a/seccontext.go +++ b/seccontext.go @@ -19,81 +19,99 @@ type SecContext struct { isInitiator bool targetName *GssName mech g.GssMech + + initOptions *g.InitSecContextOptions } -func (provider) InitSecContext(name g.GssName, opts ...g.InitSecContextOption) (g.SecContext, []byte, error) { +// InitSecContext() is just a constructor for the context -- it does not perform any GSSAPI calls +func (provider) InitSecContext(name g.GssName, opts ...g.InitSecContextOption) (g.SecContext, error) { o := g.InitSecContextOptions{} for _, opt := range opts { opt(&o) } + var nameImpl *GssName // impl not interface + if name != nil { + var ok bool + nameImpl, ok = name.(*GssName) // name must be *our* impl + if !ok { + return nil, fmt.Errorf("bad name type %T, %w", name, g.ErrBadName) + } + } + + savedName, err := nameImpl.Duplicate() + if err != nil { + return nil, fmt.Errorf("%w duplicating name: %w", g.ErrFailure, err) + } + + return &SecContext{ + isInitiator: true, + continueNeeded: true, + targetName: savedName.(*GssName), + mech: o.Mech, + initOptions: &o, + }, nil +} + +func (provider) AcceptSecContext(cred g.Credential) (g.SecContext, error) { + return &SecContext{ + isInitiator: false, + initOptions: &g.InitSecContextOptions{ + Credential: cred, + }, + }, nil +} + +// initSecContext() performs the GSSAPI context initialization using paramers supplied to InitSecContext() +func (c *SecContext) initSecContext() ([]byte, error) { mech := g.Oid{} - if o.Mech != nil { - mech = o.Mech.Oid() + if c.initOptions.Mech != nil { + mech = c.initOptions.Mech.Oid() } cMechOid := oid2Coid(mech) // get the C cred ID and name var cGssCred C.gss_cred_id_t = C.GSS_C_NO_CREDENTIAL - if o.Credential != nil { - lCred, ok := o.Credential.(*Credential) + if c.initOptions.Credential != nil { + credImpl, ok := c.initOptions.Credential.(*Credential) // must be *our* impl if !ok { - return nil, nil, fmt.Errorf("bad credential type %T, %w", lCred, g.ErrDefectiveCredential) + return nil, fmt.Errorf("bad credential type %T, %w", credImpl, g.ErrDefectiveCredential) } - cGssCred = lCred.id + cGssCred = credImpl.id } - var cGssName C.gss_name_t - var lName *GssName - if name != nil { - var ok bool - lName, ok = name.(*GssName) - if !ok { - return nil, nil, fmt.Errorf("bad name type %T, %w", name, g.ErrBadName) - } - - cGssName = lName.name - } + var cGssName C.gss_name_t = c.targetName.name var minor C.OM_uint32 var cGssCtxId C.gss_ctx_id_t var cOutToken C.gss_buffer_desc // cOutToken.value allocated by GSSAPI; released by *1 - major := C.gss_init_sec_context(&minor, cGssCred, &cGssCtxId, cGssName, cMechOid, C.OM_uint32(o.Flags), C.OM_uint32(o.Lifetime.Seconds()), nil, nil, nil, &cOutToken, nil, nil) + major := C.gss_init_sec_context(&minor, cGssCred, &cGssCtxId, cGssName, cMechOid, C.OM_uint32(c.initOptions.Flags), C.OM_uint32(c.initOptions.Lifetime.Seconds()), nil, nil, nil, &cOutToken, nil, nil) if major != 0 && major != C.GSS_S_CONTINUE_NEEDED { - return nil, nil, makeMechStatus(major, minor, o.Mech) + return nil, makeMechStatus(major, minor, c.initOptions.Mech) } // *1 release GSSAPI allocated buffer defer C.gss_release_buffer(&minor, &cOutToken) outToken := C.GoBytes(cOutToken.value, C.int(cOutToken.length)) + c.continueNeeded = major == C.GSS_S_CONTINUE_NEEDED + c.id = cGssCtxId - savedName, err := lName.Duplicate() - if err != nil { - return nil, nil, makeMechStatus(major, minor, o.Mech) - } - - return &SecContext{ - id: cGssCtxId, - continueNeeded: major == C.GSS_S_CONTINUE_NEEDED, - isInitiator: true, - targetName: savedName.(*GssName), - mech: o.Mech, - }, outToken, nil + return outToken, nil } -func (provider) AcceptSecContext(cred g.Credential, inputToken []byte) (g.SecContext, []byte, error) { +func (c *SecContext) acceptSecContext(inputToken []byte) ([]byte, error) { // get the C cred ID and name var cGssCred C.gss_cred_id_t = C.GSS_C_NO_CREDENTIAL - if cred != nil { - lCred, ok := cred.(*Credential) + if c.initOptions.Credential != nil { + credImpl, ok := c.initOptions.Credential.(*Credential) // must be *our* impl if !ok { - return nil, nil, fmt.Errorf("bad credential type %T, %w", lCred, g.ErrDefectiveCredential) + return nil, fmt.Errorf("bad credential type %T, %w", credImpl, g.ErrDefectiveCredential) } - cGssCred = lCred.id + cGssCred = credImpl.id } var minor C.OM_uint32 @@ -105,18 +123,15 @@ func (provider) AcceptSecContext(cred g.Credential, inputToken []byte) (g.SecCon major := C.gss_accept_sec_context(&minor, &cGssCtxId, cGssCred, &cInputToken, nil, nil, nil, &cOutToken, nil, nil, nil) if major != 0 && major != C.GSS_S_CONTINUE_NEEDED { - return nil, nil, makeStatus(major, minor) + return nil, makeStatus(major, minor) } // *1 release GSSAPI allocated buffer defer C.gss_release_buffer(&minor, &cOutToken) outToken := C.GoBytes(cOutToken.value, C.int(cOutToken.length)) - return &SecContext{ - id: cGssCtxId, - continueNeeded: major == C.GSS_S_CONTINUE_NEEDED, - isInitiator: false, - }, outToken, nil + c.id = cGssCtxId + return outToken, nil } func (provider) ImportSecContext(token []byte) (g.SecContext, error) { @@ -134,9 +149,21 @@ func (provider) ImportSecContext(token []byte) (g.SecContext, error) { return &SecContext{ id: cGssCtxId, }, nil + } func (c *SecContext) Continue(inputToken []byte) ([]byte, error) { + // if the context is not yet initialized then do that.. + if c.id == nil { + if c.isInitiator { + return c.initSecContext() + } else { + return c.acceptSecContext(inputToken) + } + } + + // otherwise continue establishing the context.. + // var major, minor C.OM_uint32 var cOutToken C.gss_buffer_desc // cOutToken.value allocated by GSSAPI; released by *1 cInputToken, pinner := bytesToCBuffer(inputToken) diff --git a/seccontext_test.go b/seccontext_test.go index a547e17..4ea47e0 100644 --- a/seccontext_test.go +++ b/seccontext_test.go @@ -1,6 +1,7 @@ package gssapi import ( + "errors" "os" "testing" @@ -35,6 +36,32 @@ func mkTestAssets() *testAssets { return ta } +type testAssetType int + +const ( + testKeytabRack testAssetType = 1 << iota + testKeytabRuin + testCredCache +) + +func (ta *testAssets) useAsset(at testAssetType) { + switch { + default: + os.Unsetenv("KRB5_KTNAME") + case at&testKeytabRack > 0: + os.Setenv("KRB5_KTNAME", ta.ktfileRack) + case at&testKeytabRuin > 0: + os.Setenv("KRB5_KTNAME", ta.ktfileRuin) + } + + switch { + default: + os.Unsetenv("KRB5CCNAME") + case at&testCredCache > 0: + os.Setenv("KRB5CCNAME", "FILE:"+ta.ccfile) + } +} + func (ta *testAssets) Free() { ta.saveVars.Restore() os.Remove(ta.ktfileRack) @@ -49,13 +76,61 @@ func TestSecContextInterface(t *testing.T) { _ = gsc } -func TestInitSecContext(t *testing.T) { - assert := assert.New(t) +func initContextOne(provider g.Provider, name g.GssName, opts ...g.InitSecContextOption) (g.SecContext, []byte, error) { + secCtx, err := provider.InitSecContext(name, opts...) + if err != nil { + return nil, nil, err + } + + if secCtx == nil { + return nil, nil, errors.New("nil sec ctx") + } + + outTok, err := secCtx.Continue(nil) + if err == nil && len(outTok) == 0 { + err = errors.New("Empty first token") + } + + ctx := secCtx.(*SecContext) + if err == nil && ctx.id == nil { + return nil, nil, errors.New("unexpected nil context") + } + + return secCtx, outTok, err +} + +func acceptContextOne(provider g.Provider, cred g.Credential, inTok []byte) (g.SecContext, []byte, error) { + secCtx, err := provider.AcceptSecContext(cred) + if err != nil { + return nil, nil, err + } + + if secCtx == nil { + return nil, nil, errors.New("nil sec ctx") + } + + outTok, err := secCtx.Continue(inTok) + + ctx := secCtx.(*SecContext) + if err == nil && ctx.id == nil { + return nil, nil, errors.New("unexpected nil context") + } + + return secCtx, outTok, err +} + +var ta *testAssets + +func TestMain(m *testing.M) { + ta = mkTestAssets() + defer ta.Free() - ta := mkTestAssets() + m.Run() +} - os.Setenv("KRB5_KTNAME", ta.ktfileRack) - os.Setenv("KRB5CCNAME", "FILE:"+ta.ccfile) +func TestInitSecContext(t *testing.T) { + assert := assert.New(t) + ta.useAsset(testCredCache) // InitSecContext with this name should work because the cred-cache has a ticket // for rack/foo.golang-auth.io@GOLANG-AUTH.IO @@ -63,17 +138,17 @@ func TestInitSecContext(t *testing.T) { assert.NoError(err) // no continue should be needed when we don't request mutual auth - secCtx, outTok, err := ta.lib.InitSecContext(name) + secCtx, outTok, err := initContextOne(ta.lib, name) assert.NoError(err) - assert.NotEmpty(outTok) assert.NotNil(secCtx) + assert.NotEmpty(outTok) assert.False(secCtx.ContinueNeeded()) // .. but should be needed if we do request mutual auth - secCtx, outTok, err = ta.lib.InitSecContext(name, g.WithInitiatorFlags(g.ContextFlagMutual)) + secCtx, outTok, err = initContextOne(ta.lib, name, g.WithInitiatorFlags(g.ContextFlagMutual)) assert.NoError(err) - assert.NotEmpty(outTok) assert.NotNil(secCtx) + assert.NotEmpty(outTok) assert.True(secCtx.ContinueNeeded()) // This one should not work because the CC doesn't have a ticket for ruin/bar.golang-auth.io@GOLANG-AUTH.IO @@ -81,34 +156,29 @@ func TestInitSecContext(t *testing.T) { name, err = ta.lib.ImportName("ruin@bar.golang-auth.io", g.GSS_NT_HOSTBASED_SERVICE) assert.NoError(err) - _, _, err = ta.lib.InitSecContext(name) + _, _, err = initContextOne(ta.lib, name) assert.Error(err) - assert.Contains(err.Error(), "Cannot find KDC") + if err != nil { + assert.Contains(err.Error(), "Cannot find KDC") + } } func TestAcceptSecContext(t *testing.T) { assert := assert.New(t) - - ta := mkTestAssets() - - os.Setenv("KRB5_KTNAME", ta.ktfileRack) - os.Setenv("KRB5CCNAME", "FILE:"+ta.ccfile) + ta.useAsset(testCredCache | testKeytabRack) // InitSecContext with this name should work because the cred-cache has a ticket // for rack/foo.golang-auth.io@GOLANG-AUTH.IO name, err := ta.lib.ImportName("rack@foo.golang-auth.io", g.GSS_NT_HOSTBASED_SERVICE) assert.NoError(err) - secCtxInitiator, initiatorTok, err := ta.lib.InitSecContext(name) + _, initiatorTok, err := initContextOne(ta.lib, name) assert.NoError(err) - assert.NotEmpty(initiatorTok) - assert.NotNil(secCtxInitiator) - assert.False(secCtxInitiator.ContinueNeeded()) // the initiator token should be accepted by AcceptSecContext because we have a keytab // for the service princ. The output token should be empty because the initiator // didn't request mutual auth - secCtxAcceptor, acceptorTok, err := ta.lib.AcceptSecContext(nil, initiatorTok) + secCtxAcceptor, acceptorTok, err := acceptContextOne(ta.lib, nil, initiatorTok) assert.NoError(err) assert.Empty(acceptorTok) assert.NotNil(secCtxAcceptor) @@ -116,13 +186,10 @@ func TestAcceptSecContext(t *testing.T) { // if we're doing mutual auth we should get an output token from the acceptor but it // should not need another one back from the initiator - secCtxInitiator, initiatorTok, err = ta.lib.InitSecContext(name, g.WithInitiatorFlags(g.ContextFlagMutual)) + _, initiatorTok, err = initContextOne(ta.lib, name, g.WithInitiatorFlags(g.ContextFlagMutual)) assert.NoError(err) - assert.NotEmpty(initiatorTok) - assert.NotNil(secCtxInitiator) - assert.True(secCtxInitiator.ContinueNeeded()) - secCtxAcceptor, acceptorTok, err = ta.lib.AcceptSecContext(nil, initiatorTok) + secCtxAcceptor, acceptorTok, err = acceptContextOne(ta.lib, nil, initiatorTok) assert.NoError(err) assert.NotEmpty(acceptorTok) assert.NotNil(secCtxAcceptor) @@ -131,20 +198,14 @@ func TestAcceptSecContext(t *testing.T) { func TestDeleteSecContext(t *testing.T) { assert := assert.New(t) - - ta := mkTestAssets() - - os.Setenv("KRB5_KTNAME", ta.ktfileRack) - os.Setenv("KRB5CCNAME", "FILE:"+ta.ccfile) + ta.useAsset(testCredCache) // This should work because the cred-cache has a ticket for rack/foo.golang-auth.io@GOLANG-AUTH.IO name, err := ta.lib.ImportName("rack@foo.golang-auth.io", g.GSS_NT_HOSTBASED_SERVICE) assert.NoError(err) - secCtx, outTok, err := ta.lib.InitSecContext(name) + secCtx, _, err := initContextOne(ta.lib, name) assert.NoError(err) - assert.NotEmpty(outTok) - assert.NotNil(secCtx) // deleting a live or a deleted context should not return errors _, err = secCtx.Delete() @@ -158,29 +219,19 @@ func TestDeleteSecContext(t *testing.T) { func TestContextExpiresAt(t *testing.T) { assert := assert.New(t) - - ta := mkTestAssets() - - os.Setenv("KRB5_KTNAME", ta.ktfileRack) - os.Setenv("KRB5CCNAME", "FILE:"+ta.ccfile) + ta.useAsset(testCredCache | testKeytabRack) // This should work because the cred-cache has a ticket for rack/foo.golang-auth.io@GOLANG-AUTH.IO name, err := ta.lib.ImportName("rack@foo.golang-auth.io", g.GSS_NT_HOSTBASED_SERVICE) assert.NoError(err) - secCtxInitiator, initiatorTok, err := ta.lib.InitSecContext(name) + secCtxInitiator, initiatorTok, err := initContextOne(ta.lib, name) assert.NoError(err) - assert.NotEmpty(initiatorTok) - assert.NotNil(secCtxInitiator) - assert.False(secCtxInitiator.ContinueNeeded()) - secCtxAcceptor, acceptorTok, err := ta.lib.AcceptSecContext(nil, initiatorTok) + secCtxAcceptor, _, err := acceptContextOne(ta.lib, nil, initiatorTok) assert.NoError(err) - assert.Empty(acceptorTok) - assert.NotNil(secCtxAcceptor) - assert.False(secCtxAcceptor.ContinueNeeded()) - // both the initiator and the acceptor should know about the expiry time + // both the initiator and the acceptor should know about the expiry time of the kerberos creds tm, err := secCtxInitiator.ExpiresAt() assert.NoError(err) assert.Equal(2051, tm.Year()) @@ -192,10 +243,7 @@ func TestContextExpiresAt(t *testing.T) { func TestContextWrapSizeLimit(t *testing.T) { assert := assert.New(t) - - ta := mkTestAssets() - - os.Setenv("KRB5CCNAME", "FILE:"+ta.ccfile) + ta.useAsset(testCredCache) // This should work because the cred-cache has a ticket for rack/foo.golang-auth.io@GOLANG-AUTH.IO name, err := ta.lib.ImportName("rack@foo.golang-auth.io", g.GSS_NT_HOSTBASED_SERVICE) @@ -203,11 +251,8 @@ func TestContextWrapSizeLimit(t *testing.T) { o := g.WithInitiatorFlags(g.ContextFlagInteg | g.ContextFlagConf) - secCtxInitiator, initiatorTok, err := ta.lib.InitSecContext(name, o) + secCtxInitiator, _, err := initContextOne(ta.lib, name, o) assert.NoError(err) - assert.NotEmpty(initiatorTok) - assert.NotNil(secCtxInitiator) - assert.False(secCtxInitiator.ContinueNeeded()) // the max unwrapped token size would always be less that the max // wrapped token size @@ -218,17 +263,14 @@ func TestContextWrapSizeLimit(t *testing.T) { func TestExportImportSecContext(t *testing.T) { assert := assert.New(t) + ta.useAsset(testCredCache) - ta := mkTestAssets() - os.Setenv("KRB5CCNAME", "FILE:"+ta.ccfile) // This should work because the cred-cache has a ticket for rack/foo.golang-auth.io@GOLANG-AUTH.IO name, err := ta.lib.ImportName("rack@foo.golang-auth.io", g.GSS_NT_HOSTBASED_SERVICE) assert.NoError(err) - secCtx, initiatorTok, err := ta.lib.InitSecContext(name) + + secCtx, _, err := initContextOne(ta.lib, name) assert.NoError(err) - assert.NotEmpty(initiatorTok) - assert.NotNil(secCtx) - assert.False(secCtx.ContinueNeeded()) _, err = secCtx.Inquire() // should work the first time assert.NoError(err) @@ -250,31 +292,37 @@ func TestExportImportSecContext(t *testing.T) { func TestSecContextEstablishment(t *testing.T) { assert := assert.New(t) - - ta := mkTestAssets() - - os.Setenv("KRB5_KTNAME", ta.ktfileRack) - os.Setenv("KRB5CCNAME", "FILE:"+ta.ccfile) + ta.useAsset(testCredCache | testKeytabRack) name, err := ta.lib.ImportName("rack@foo.golang-auth.io", g.GSS_NT_HOSTBASED_SERVICE) assert.NoError(err) - secCtxInitiator, initiatorTok, err := ta.lib.InitSecContext(name, g.WithInitiatorFlags(g.ContextFlagMutual)) + secCtxInitiator, err := ta.lib.InitSecContext(name, g.WithInitiatorFlags(g.ContextFlagMutual)) assert.NoError(err) - secCtxAcceptor, acceptorTok, err := ta.lib.AcceptSecContext(nil, initiatorTok) + secCtxAcceptor, err := ta.lib.AcceptSecContext(nil) assert.NoError(err) + var initiatorTok, acceptorTok []byte for secCtxInitiator.ContinueNeeded() { - initiatorTok, err = secCtxInitiator.Continue(acceptorTok) - assert.NoError(err) + acceptorTok, err = secCtxInitiator.Continue(initiatorTok) + if err != nil { + break + } - if len(initiatorTok) > 0 { - acceptorTok, err = secCtxAcceptor.Continue(initiatorTok) - assert.NoError(err) + if len(acceptorTok) > 0 { + initiatorTok, err = secCtxAcceptor.Continue(acceptorTok) + if err != nil { + break + } } } + assert.NoError(err) + if err != nil { + return + } + assert.False(secCtxAcceptor.ContinueNeeded()) msg := []byte("Hello GSSAPI")