diff --git a/internal/config/server.go b/internal/config/server.go index 635b941..b6f7f74 100644 --- a/internal/config/server.go +++ b/internal/config/server.go @@ -5,6 +5,7 @@ import ( "mime" "os" "os/exec" + "regexp" "strings" "github.com/google/shlex" @@ -142,16 +143,16 @@ func BuildExecCommand(message api.Payload, c *ServerConfig) (*exec.Cmd, error) { // replace it with the args passed by the event if a == "%args" { if message.Attachment.Content.Args != "" { - passedArgs, err := shlex.Split(message.Attachment.Content.Args) + passedArgs, err := GetPassedArgs(message.Attachment.Content.Args) if err != nil { - return nil, fmt.Errorf("Error parsing args %s: %v", message.Attachment.Content.Args, err) + return nil, fmt.Errorf("could not parse args: %v", err) } args = append(args, passedArgs...) } // if we have the special value of %source-mime-ext // replace it with the source mimetype extension } else if a == "%source-mime-ext" { - a, err := getMimeTypeExtension(message.Attachment.Content.SourceMimeType) + a, err := GetMimeTypeExtension(message.Attachment.Content.SourceMimeType) if err != nil { return nil, fmt.Errorf("unknown mime extension: %s", message.Attachment.Content.SourceMimeType) } @@ -160,7 +161,7 @@ func BuildExecCommand(message api.Payload, c *ServerConfig) (*exec.Cmd, error) { // if we have the special value of %destination-mime-ext // replace it with the source mimetype extension } else if a == "%destination-mime-ext" { - a, err := getMimeTypeExtension(message.Attachment.Content.DestinationMimeType) + a, err := GetMimeTypeExtension(message.Attachment.Content.DestinationMimeType) if err != nil { return nil, fmt.Errorf("unknown mime extension: %s", message.Attachment.Content.DestinationMimeType) } @@ -197,7 +198,7 @@ func BuildExecCommand(message api.Payload, c *ServerConfig) (*exec.Cmd, error) { return cmd, nil } -func getMimeTypeExtension(mimeType string) (string, error) { +func GetMimeTypeExtension(mimeType string) (string, error) { // since the std mimetype -> extension conversion returns a list // we need to override the default extension to use // it also is missing some mimetypes @@ -230,6 +231,7 @@ func getMimeTypeExtension(mimeType string) (string, error) { "audio/x-m4a": "m4a", "audio/x-realaudio": "ra", "audio/midi": "mid", + "audio/x-wav": "wav", } cleanMimeType := strings.TrimSpace(strings.ToLower(mimeType)) if ext, ok := mimeToExtension[cleanMimeType]; ok { @@ -243,3 +245,23 @@ func getMimeTypeExtension(mimeType string) (string, error) { return strings.TrimPrefix(extensions[len(extensions)-1], "."), nil } + +func GetPassedArgs(args string) ([]string, error) { + passedArgs, err := shlex.Split(args) + if err != nil { + return nil, fmt.Errorf("error splitting args %s: %v", args, err) + } + + // make sure args are OK + regex, err := regexp.Compile(`^[a-zA-Z0-9._\-:\/@ ]+$`) + if err != nil { + return nil, fmt.Errorf("failed to compile regex: %v", err) + } + for _, value := range passedArgs { + if !regex.MatchString(value) { + return nil, fmt.Errorf("invalid input for passed arg: %s", value) + } + } + + return passedArgs, nil +} diff --git a/internal/config/server_test.go b/internal/config/server_test.go new file mode 100644 index 0000000..7e29bfd --- /dev/null +++ b/internal/config/server_test.go @@ -0,0 +1,91 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBadCmdArgs(t *testing.T) { + payloads := []string{ + `"any;thing`, + `"any&thing`, + `"any|thing`, + `"any$thing`, + `"any\"thing`, + `"any\thing`, + `"any*thing`, + `"any?thing`, + `"any[thing`, + `"any]thing`, + `"any{thing`, + `"any}thing`, + `"any(thing`, + `"any)thing`, + `"anything`, + `"anything!`, + "\"any`thing\"", + } + for _, payload := range payloads { + _, err := GetPassedArgs(payload) + assert.Error(t, err) + } + +} + +func TestMimeTypes(t *testing.T) { + mimeTypes := map[string]string{ + "application/msword": "doc", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx", + "application/vnd.ms-excel": "xls", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx", + "application/vnd.ms-powerpoint": "ppt", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": "pptx", + + "image/jpeg": "jpg", + "image/jp2": "jp2", + "image/png": "png", + "image/gif": "gif", + "image/bmp": "bmp", + "image/svg+xml": "svg", + "image/tiff": "tiff", + "image/webp": "webp", + + "audio/mpeg": "mp3", + "audio/x-wav": "wav", + "audio/ogg": "ogg", + "audio/aac": "m4a", + "audio/webm": "webm", + "audio/flac": "flac", + "audio/midi": "mid", + "audio/x-m4a": "m4a", + "audio/x-realaudio": "ra", + + "video/mp4": "mp4", + "video/x-msvideo": "avi", + "video/x-ms-wmv": "wmv", + "video/mpeg": "mpg", + "video/webm": "webm", + "video/quicktime": "mov", + "application/vnd.apple.mpegurl": "m3u8", + "video/3gpp": "3gp", + "video/mp2t": "ts", + "video/x-flv": "flv", + "video/x-m4v": "m4v", + "video/x-mng": "mng", + "video/x-ms-asf": "asx", + "video/ogg": "ogg", + + "text/plain": "txt", + "text/html": "html", + "application/pdf": "pdf", + "text/csv": "csv", + } + + for mimeType, extension := range mimeTypes { + ext, err := GetMimeTypeExtension(mimeType) + assert.Equal(t, nil, err) + assert.Equal(t, extension, ext) + } +} diff --git a/main_test.go b/main_test.go index 20529d7..8cc0c8c 100644 --- a/main_test.go +++ b/main_test.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "io" "log/slog" "net" @@ -376,129 +375,6 @@ cmdByMimeType: } } -func TestMimeTypes(t *testing.T) { - mimeTypes := map[string]string{ - "application/msword": "doc", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx", - "application/vnd.ms-excel": "xls", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx", - "application/vnd.ms-powerpoint": "ppt", - "application/vnd.openxmlformats-officedocument.presentationml.presentation": "pptx", - - "image/jpeg": "jpg", - "image/jp2": "jp2", - "image/png": "png", - "image/gif": "gif", - "image/bmp": "bmp", - "image/svg+xml": "svg", - "image/tiff": "tiff", - "image/webp": "webp", - - "audio/mpeg": "mp3", - "audio/x-wav": "wav", - "audio/ogg": "ogg", - "audio/aac": "m4a", - "audio/webm": "webm", - "audio/flac": "flac", - "audio/midi": "mid", - "audio/x-m4a": "m4a", - "audio/x-realaudio": "ra", - - "video/mp4": "mp4", - "video/x-msvideo": "avi", - "video/x-ms-wmv": "wmv", - "video/mpeg": "mpg", - "video/webm": "webm", - "video/quicktime": "mov", - "application/vnd.apple.mpegurl": "m3u8", - "video/3gpp": "3gp", - "video/mp2t": "ts", - "video/x-flv": "flv", - "video/x-m4v": "m4v", - "video/x-mng": "mng", - "video/x-ms-asf": "asx", - "video/ogg": "ogg", - - "text/plain": "txt", - "text/html": "html", - "application/pdf": "pdf", - "text/csv": "csv", - } - test := Test{ - authHeader: "pass", - requestAuth: "pass", - expectedStatus: http.StatusOK, - expectedBody: "%s txt\n", - returnedBody: "", - expectMismatch: false, - destinationMimeType: "text/plain", - yml: ` -forwardAuth: false -allowedMimeTypes: - - "*" -cmdByMimeType: - default: - cmd: echo - args: - - "%source-mime-ext" - - "%destination-mime-ext" -`, - } - for mimeType, extension := range mimeTypes { - test.name = fmt.Sprintf("test %s to %s conversion", mimeType, extension) - test.mimetype = mimeType - test.expectedBody = fmt.Sprintf("%s txt\n", extension) - t.Run(test.name, func(t *testing.T) { - var err error - destinationServer := createMockDestinationServer(t, test.returnedBody) - defer destinationServer.Close() - - sourceServer := createMockSourceServer(t, test.mimetype, test.authHeader, destinationServer.URL) - defer sourceServer.Close() - - os.Setenv("SCYLLARIDAE_YML", test.yml) - // set the config based on test.yml - config, err = scyllaridae.ReadConfig("") - if err != nil { - t.Fatalf("Could not read YML: %v", err) - os.Exit(1) - } - - // Configure and start the main server - setupServer := httptest.NewServer(http.HandlerFunc(MessageHandler)) - defer setupServer.Close() - - // Send the mock message to the main server - req, err := http.NewRequest("GET", setupServer.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("X-Islandora-Args", destinationServer.URL) - // set the mimetype to send to the destination server in the Accept header - req.Header.Set("Accept", test.destinationMimeType) - req.Header.Set("Authorization", test.requestAuth) - req.Header.Set("Apix-Ldp-Resource", sourceServer.URL) - - // Capture the response - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - assert.Equal(t, test.expectedStatus, resp.StatusCode) - if !test.expectMismatch { - // if we're setesting up the destination server as the cURL target - // make sure it returned - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("Unable to read source uri resp body: %v", err) - } - assert.Equal(t, test.expectedBody, string(body)) - } - }) - } -} - func createMockDestinationServer(t *testing.T, content string) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte(content)); err != nil {