forked from matthewhartstonge/storage
/
request_pkce_request_session.go
144 lines (123 loc) · 3.68 KB
/
request_pkce_request_session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package mongo
import (
// Standard Library Imports
"context"
// External Imports
"github.com/ory/fosite"
"github.com/sirupsen/logrus"
// Internal Imports
"github.com/matthewhartstonge/storage"
)
// CreatePKCERequestSession implements fosite.PKCERequestStorage.
func (r *RequestManager) CreatePKCERequestSession(ctx context.Context, signature string, request fosite.Requester) error {
// Initialize contextual method logger
log := logger.WithFields(logrus.Fields{
"package": "mongo",
"collection": storage.EntityPKCESessions,
"method": "CreatePKCERequestSession",
})
// Copy a new DB session if none specified
mgoSession, ok := ContextToMgoSession(ctx)
if !ok {
mgoSession = r.DB.Session.Copy()
ctx = MgoSessionToContext(ctx, mgoSession)
defer mgoSession.Close()
}
// Trace how long the Mongo operation takes to complete.
span, ctx := traceMongoCall(ctx, dbTrace{
Manager: "RequestManager",
Method: "CreatePKCERequestSession",
})
defer span.Finish()
// Store session request
_, err := r.Create(ctx, storage.EntityPKCESessions, toMongo(signature, request))
if err != nil {
if err == storage.ErrResourceExists {
log.WithError(err).Debug(logConflict)
return err
}
// Log to StdOut
log.WithError(err).Error(logError)
return err
}
return nil
}
// GetPKCERequestSession implements fosite.PKCERequestStorage.
func (r *RequestManager) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) {
// Initialize contextual method logger
log := logger.WithFields(logrus.Fields{
"package": "mongo",
"collection": storage.EntityPKCESessions,
"method": "GetPKCERequestSession",
})
// Copy a new DB session if none specified
mgoSession, ok := ContextToMgoSession(ctx)
if !ok {
mgoSession = r.DB.Session.Copy()
ctx = MgoSessionToContext(ctx, mgoSession)
defer mgoSession.Close()
}
// Trace how long the Mongo operation takes to complete.
span, ctx := traceMongoCall(ctx, dbTrace{
Manager: "RequestManager",
Method: "GetPKCERequestSession",
})
defer span.Finish()
// Get the stored request
req, err := r.GetBySignature(ctx, storage.EntityPKCESessions, signature)
if err != nil {
if err == fosite.ErrNotFound {
log.WithError(err).Debug(logNotFound)
return nil, err
}
// Log to StdOut
log.WithError(err).Error(logError)
return nil, err
}
// Transform to a fosite.Request
request, err := req.ToRequest(ctx, session, r.Clients)
if err != nil {
if err == fosite.ErrNotFound {
log.WithError(err).Debug(logNotFound)
return nil, err
}
// Log to StdOut
log.WithError(err).Error(logError)
return nil, err
}
return request, nil
}
// DeletePKCERequestSession implements fosite.PKCERequestStorage.
func (r *RequestManager) DeletePKCERequestSession(ctx context.Context, signature string) error {
// Initialize contextual method logger
log := logger.WithFields(logrus.Fields{
"package": "mongo",
"collection": storage.EntityPKCESessions,
"method": "DeletePKCERequestSession",
})
// Copy a new DB session if none specified
mgoSession, ok := ContextToMgoSession(ctx)
if !ok {
mgoSession = r.DB.Session.Copy()
ctx = MgoSessionToContext(ctx, mgoSession)
defer mgoSession.Close()
}
// Trace how long the Mongo operation takes to complete.
span, ctx := traceMongoCall(ctx, dbTrace{
Manager: "RequestManager",
Method: "DeletePKCERequestSession",
})
defer span.Finish()
// Remove session request
err := r.DeleteBySignature(ctx, storage.EntityPKCESessions, signature)
if err != nil {
if err == fosite.ErrNotFound {
log.WithError(err).Debug(logNotFound)
return err
}
// Log to StdOut
log.WithError(err).Error(logError)
return err
}
return nil
}