| 
 | 1 | +package httpmw_test  | 
 | 2 | + | 
 | 3 | +import (  | 
 | 4 | +	"net/http"  | 
 | 5 | +	"net/http/httptest"  | 
 | 6 | +	"testing"  | 
 | 7 | + | 
 | 8 | +	"github.com/stretchr/testify/require"  | 
 | 9 | + | 
 | 10 | +	"github.com/coder/code-marketplace/api/httpmw"  | 
 | 11 | +)  | 
 | 12 | + | 
 | 13 | +func TestCors(t *testing.T) {  | 
 | 14 | +	t.Parallel()  | 
 | 15 | + | 
 | 16 | +	methods := []string{  | 
 | 17 | +		http.MethodOptions,  | 
 | 18 | +		http.MethodHead,  | 
 | 19 | +		http.MethodGet,  | 
 | 20 | +		http.MethodPost,  | 
 | 21 | +		http.MethodPut,  | 
 | 22 | +		http.MethodPatch,  | 
 | 23 | +		http.MethodDelete,  | 
 | 24 | +	}  | 
 | 25 | + | 
 | 26 | +	tests := []struct {  | 
 | 27 | +		name           string  | 
 | 28 | +		origin         string  | 
 | 29 | +		allowedOrigin  string  | 
 | 30 | +		headers        string  | 
 | 31 | +		allowedHeaders string  | 
 | 32 | +	}{  | 
 | 33 | +		{  | 
 | 34 | +			name:          "LocalHTTP",  | 
 | 35 | +			origin:        "http://localhost:3000",  | 
 | 36 | +			allowedOrigin: "*",  | 
 | 37 | +		},  | 
 | 38 | +		{  | 
 | 39 | +			name:          "LocalHTTPS",  | 
 | 40 | +			origin:        "https://localhost:3000",  | 
 | 41 | +			allowedOrigin: "*",  | 
 | 42 | +		},  | 
 | 43 | +		{  | 
 | 44 | +			name:          "HTTP",  | 
 | 45 | +			origin:        "http://code-server.domain.tld",  | 
 | 46 | +			allowedOrigin: "*",  | 
 | 47 | +		},  | 
 | 48 | +		{  | 
 | 49 | +			name:          "HTTPS",  | 
 | 50 | +			origin:        "https://code-server.domain.tld",  | 
 | 51 | +			allowedOrigin: "*",  | 
 | 52 | +		},  | 
 | 53 | +		{  | 
 | 54 | +			// VS Code appears to use this origin.  | 
 | 55 | +			name:          "VSCode",  | 
 | 56 | +			origin:        "vscode-file://vscode-app",  | 
 | 57 | +			allowedOrigin: "*",  | 
 | 58 | +		},  | 
 | 59 | +		{  | 
 | 60 | +			name:          "NoOrigin",  | 
 | 61 | +			allowedOrigin: "",  | 
 | 62 | +		},  | 
 | 63 | +		{  | 
 | 64 | +			name:           "Headers",  | 
 | 65 | +			origin:         "foobar",  | 
 | 66 | +			allowedOrigin:  "*",  | 
 | 67 | +			headers:        "X-TEST,X-TEST2",  | 
 | 68 | +			allowedHeaders: "X-Test, X-Test2",  | 
 | 69 | +		},  | 
 | 70 | +	}  | 
 | 71 | + | 
 | 72 | +	for _, test := range tests {  | 
 | 73 | +		test := test  | 
 | 74 | +		t.Run(test.name, func(t *testing.T) {  | 
 | 75 | +			t.Parallel()  | 
 | 76 | + | 
 | 77 | +			for _, method := range methods {  | 
 | 78 | +				method := method  | 
 | 79 | +				t.Run(method, func(t *testing.T) {  | 
 | 80 | +					t.Parallel()  | 
 | 81 | + | 
 | 82 | +					r := httptest.NewRequest(method, "http://dev.coder.com", nil)  | 
 | 83 | +					if test.origin != "" {  | 
 | 84 | +						r.Header.Set(httpmw.OriginHeader, test.origin)  | 
 | 85 | +					}  | 
 | 86 | + | 
 | 87 | +					// OPTIONS requests need to know what method will be requested, or  | 
 | 88 | +					// go-chi/cors will error.  Both request headers and methods should be  | 
 | 89 | +					// ignored for regular requests even if they are set, although that is  | 
 | 90 | +					// not tested here.  | 
 | 91 | +					if method == http.MethodOptions {  | 
 | 92 | +						r.Header.Set(httpmw.AccessControlRequestMethodHeader, http.MethodGet)  | 
 | 93 | +						if test.headers != "" {  | 
 | 94 | +							r.Header.Set(httpmw.AccessControlRequestHeadersHeader, test.headers)  | 
 | 95 | +						}  | 
 | 96 | +					}  | 
 | 97 | + | 
 | 98 | +					rw := httptest.NewRecorder()  | 
 | 99 | +					handler := httpmw.Cors()(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {  | 
 | 100 | +						rw.WriteHeader(http.StatusNoContent)  | 
 | 101 | +					}))  | 
 | 102 | +					handler.ServeHTTP(rw, r)  | 
 | 103 | + | 
 | 104 | +					// Should always set some kind of allowed origin, if allowed.  | 
 | 105 | +					require.Equal(t, test.allowedOrigin, rw.Header().Get(httpmw.AccessControlAllowOriginHeader))  | 
 | 106 | + | 
 | 107 | +					// OPTIONS should echo back the request method and headers and we  | 
 | 108 | +					// should never get to our handler as the middleware short-circuits  | 
 | 109 | +					// with a 200.  | 
 | 110 | +					if method == http.MethodOptions {  | 
 | 111 | +						require.Equal(t, http.MethodGet, rw.Header().Get(httpmw.AccessControlAllowMethodsHeader))  | 
 | 112 | +						require.Equal(t, test.allowedHeaders, rw.Header().Get(httpmw.AccessControlAllowHeadersHeader))  | 
 | 113 | +						require.Equal(t, http.StatusOK, rw.Code)  | 
 | 114 | +					} else {  | 
 | 115 | +						require.Equal(t, "", rw.Header().Get(httpmw.AccessControlAllowMethodsHeader))  | 
 | 116 | +						require.Equal(t, "", rw.Header().Get(httpmw.AccessControlAllowHeadersHeader))  | 
 | 117 | +						require.Equal(t, http.StatusNoContent, rw.Code)  | 
 | 118 | +					}  | 
 | 119 | +				})  | 
 | 120 | +			}  | 
 | 121 | +		})  | 
 | 122 | +	}  | 
 | 123 | +}  | 
0 commit comments