Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 111 additions & 85 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
)

const (
headerSetCookie = "Set-Cookie"
msgInternalServerError = "Internal Server Error"
)

type Config struct {
APIBaseUrl string `json:"apiBaseUrl"`
UserSessionCookieName string `json:"userSessionCookieName"`
Expand Down Expand Up @@ -76,81 +82,71 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}

func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
cookies := p.extractCookies(req)

queryValues := req.URL.Query()
if p.handleSessionExchange(rw, req, queryValues) { // handled via redirect or error
return
}

if sessionRequestValue := queryValues.Get(p.resourceSessionRequestParam); sessionRequestValue != "" {
body := ExchangeSessionBody{
RequestToken: &sessionRequestValue,
RequestHost: &req.Host,
RequestIP: &req.RemoteAddr,
}

jsonData, err := json.Marshal(body)
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}

verifyURL := fmt.Sprintf("%s/badger/exchange-session", p.apiBaseUrl)
resp, err := http.Post(verifyURL, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
cookies := p.extractCookies(req)
originalRequestURL := p.buildOriginalRequestURL(req, queryValues)
p.verifySession(rw, req, originalRequestURL, cookies, queryValues)
}

var result ExchangeSessionResponse
err = json.NewDecoder(resp.Body).Decode(&result)
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
// handleSessionExchange attempts to exchange a request token for a session cookie.
// Returns true if the request was fully handled (redirect or error), false to continue.
func (p *Badger) handleSessionExchange(rw http.ResponseWriter, req *http.Request, queryValues url.Values) bool {
sessionRequestValue := queryValues.Get(p.resourceSessionRequestParam)
if sessionRequestValue == "" {
return false
}

if result.Data.Cookie != nil && *result.Data.Cookie != "" {
rw.Header().Add("Set-Cookie", *result.Data.Cookie)
body := ExchangeSessionBody{
RequestToken: &sessionRequestValue,
RequestHost: &req.Host,
RequestIP: &req.RemoteAddr,
}

queryValues.Del(p.resourceSessionRequestParam)
cleanedQuery := queryValues.Encode()
originalRequestURL := fmt.Sprintf("%s://%s%s", p.getScheme(req), req.Host, req.URL.Path)
if cleanedQuery != "" {
originalRequestURL = fmt.Sprintf("%s?%s", originalRequestURL, cleanedQuery)
}
jsonData, err := json.Marshal(body)
if err != nil {
internalServerError(rw)
return true
}

if result.Data.ResponseHeaders != nil {
for key, value := range result.Data.ResponseHeaders {
rw.Header().Add(key, value)
}
}
url := fmt.Sprintf("%s/badger/exchange-session", p.apiBaseUrl)
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
internalServerError(rw)
return true
}
defer resp.Body.Close()

fmt.Println("Got exchange token, redirecting to", originalRequestURL)
http.Redirect(rw, req, originalRequestURL, http.StatusFound)
return
}
var result ExchangeSessionResponse
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
internalServerError(rw)
return true
}

cleanedQuery := queryValues.Encode()
originalRequestURL := fmt.Sprintf("%s://%s%s", p.getScheme(req), req.Host, req.URL.Path)
if cleanedQuery != "" {
originalRequestURL = fmt.Sprintf("%s?%s", originalRequestURL, cleanedQuery)
if result.Data.Cookie == nil || *result.Data.Cookie == "" { // continue to normal verification
return false
}

verifyURL := fmt.Sprintf("%s/badger/verify-session", p.apiBaseUrl)
rw.Header().Add(headerSetCookie, *result.Data.Cookie)
queryValues.Del(p.resourceSessionRequestParam)

headers := make(map[string]string)
for name, values := range req.Header {
if len(values) > 0 {
headers[name] = values[0] // Send only the first value for simplicity
originalRequestURL := p.buildOriginalRequestURL(req, queryValues)
if result.Data.ResponseHeaders != nil {
for k, v := range result.Data.ResponseHeaders {
rw.Header().Add(k, v)
}
}

queryParams := make(map[string]string)
for key, values := range queryValues {
if len(values) > 0 {
queryParams[key] = values[0]
}
}
fmt.Println("Got exchange token, redirecting to", originalRequestURL)
http.Redirect(rw, req, originalRequestURL, http.StatusFound)
return true
}

func (p *Badger) verifySession(rw http.ResponseWriter, req *http.Request, originalRequestURL string, cookies map[string]string, queryValues url.Values) {
verifyURL := fmt.Sprintf("%s/badger/verify-session", p.apiBaseUrl)

cookieData := VerifyBody{
Sessions: cookies,
Expand All @@ -161,42 +157,41 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
RequestMethod: &req.Method,
TLS: req.TLS != nil,
RequestIP: &req.RemoteAddr,
Headers: headers,
Query: queryParams,
Headers: p.extractHeaders(req),
Query: p.extractQueryParams(queryValues),
}

jsonData, err := json.Marshal(cookieData)
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusInternalServerError) // TODO: redirect to error page
internalServerError(rw) // TODO: redirect to error page
return
}

resp, err := http.Post(verifyURL, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
internalServerError(rw)
return
}
defer resp.Body.Close()

for _, setCookie := range resp.Header["Set-Cookie"] {
rw.Header().Add("Set-Cookie", setCookie)
for _, setCookie := range resp.Header[headerSetCookie] {
rw.Header().Add(headerSetCookie, setCookie)
}

if resp.StatusCode != http.StatusOK {
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
internalServerError(rw)
return
}

var result VerifyResponse
err = json.NewDecoder(resp.Body).Decode(&result)
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
internalServerError(rw)
return
}

if result.Data.ResponseHeaders != nil {
for key, value := range result.Data.ResponseHeaders {
rw.Header().Add(key, value)
for k, v := range result.Data.ResponseHeaders {
rw.Header().Add(k, v)
}
}

Expand All @@ -206,26 +201,57 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}

if result.Data.Valid {
if !result.Data.Valid {
http.Error(rw, "Unauthorized", http.StatusUnauthorized)
return
}

if result.Data.Username != nil {
req.Header.Add("Remote-User", *result.Data.Username)
}
// Attach identity headers
if result.Data.Username != nil {
req.Header.Add("Remote-User", *result.Data.Username)
}
if result.Data.Email != nil {
req.Header.Add("Remote-Email", *result.Data.Email)
}
if result.Data.Name != nil {
req.Header.Add("Remote-Name", *result.Data.Name)
}

if result.Data.Email != nil {
req.Header.Add("Remote-Email", *result.Data.Email)
}
fmt.Println("Badger: Valid session")
p.next.ServeHTTP(rw, req)
}

if result.Data.Name != nil {
req.Header.Add("Remote-Name", *result.Data.Name)
func (p *Badger) buildOriginalRequestURL(req *http.Request, queryValues url.Values) string {
cleanedQuery := queryValues.Encode()
base := fmt.Sprintf("%s://%s%s", p.getScheme(req), req.Host, req.URL.Path)
if cleanedQuery == "" {
return base
}
return fmt.Sprintf("%s?%s", base, cleanedQuery)
}

func (p *Badger) extractHeaders(req *http.Request) map[string]string {
result := make(map[string]string)
for name, values := range req.Header {
if len(values) > 0 {
result[name] = values[0]
}
}
return result
}

fmt.Println("Badger: Valid session")
p.next.ServeHTTP(rw, req)
return
func (p *Badger) extractQueryParams(values url.Values) map[string]string {
result := make(map[string]string)
for k, v := range values {
if len(v) > 0 {
result[k] = v[0]
}
}
return result
}

http.Error(rw, "Unauthorized", http.StatusUnauthorized)
func internalServerError(rw http.ResponseWriter) {
http.Error(rw, msgInternalServerError, http.StatusInternalServerError)
}

func (p *Badger) extractCookies(req *http.Request) map[string]string {
Expand Down