Skip to content

Commit

Permalink
Restrict characters that can be passed to X-Islandora-Args (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall committed May 19, 2024
1 parent 895957d commit e4107eb
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 129 deletions.
32 changes: 27 additions & 5 deletions internal/config/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"mime"
"os"
"os/exec"
"regexp"
"strings"

"github.com/google/shlex"
Expand Down Expand Up @@ -142,16 +143,16 @@ func BuildExecCommand(message api.Payload, c *ServerConfig) (*exec.Cmd, error) {
// replace it with the args passed by the event
if a == "%args" {
if message.Attachment.Content.Args != "" {
passedArgs, err := shlex.Split(message.Attachment.Content.Args)
passedArgs, err := GetPassedArgs(message.Attachment.Content.Args)
if err != nil {
return nil, fmt.Errorf("Error parsing args %s: %v", message.Attachment.Content.Args, err)
return nil, fmt.Errorf("could not parse args: %v", err)
}
args = append(args, passedArgs...)
}
// if we have the special value of %source-mime-ext
// replace it with the source mimetype extension
} else if a == "%source-mime-ext" {
a, err := getMimeTypeExtension(message.Attachment.Content.SourceMimeType)
a, err := GetMimeTypeExtension(message.Attachment.Content.SourceMimeType)
if err != nil {
return nil, fmt.Errorf("unknown mime extension: %s", message.Attachment.Content.SourceMimeType)
}
Expand All @@ -160,7 +161,7 @@ func BuildExecCommand(message api.Payload, c *ServerConfig) (*exec.Cmd, error) {
// if we have the special value of %destination-mime-ext
// replace it with the source mimetype extension
} else if a == "%destination-mime-ext" {
a, err := getMimeTypeExtension(message.Attachment.Content.DestinationMimeType)
a, err := GetMimeTypeExtension(message.Attachment.Content.DestinationMimeType)
if err != nil {
return nil, fmt.Errorf("unknown mime extension: %s", message.Attachment.Content.DestinationMimeType)
}
Expand Down Expand Up @@ -197,7 +198,7 @@ func BuildExecCommand(message api.Payload, c *ServerConfig) (*exec.Cmd, error) {
return cmd, nil
}

func getMimeTypeExtension(mimeType string) (string, error) {
func GetMimeTypeExtension(mimeType string) (string, error) {
// since the std mimetype -> extension conversion returns a list
// we need to override the default extension to use
// it also is missing some mimetypes
Expand Down Expand Up @@ -230,6 +231,7 @@ func getMimeTypeExtension(mimeType string) (string, error) {
"audio/x-m4a": "m4a",
"audio/x-realaudio": "ra",
"audio/midi": "mid",
"audio/x-wav": "wav",
}
cleanMimeType := strings.TrimSpace(strings.ToLower(mimeType))
if ext, ok := mimeToExtension[cleanMimeType]; ok {
Expand All @@ -243,3 +245,23 @@ func getMimeTypeExtension(mimeType string) (string, error) {

return strings.TrimPrefix(extensions[len(extensions)-1], "."), nil
}

func GetPassedArgs(args string) ([]string, error) {
passedArgs, err := shlex.Split(args)
if err != nil {
return nil, fmt.Errorf("error splitting args %s: %v", args, err)
}

// make sure args are OK
regex, err := regexp.Compile(`^[a-zA-Z0-9._\-:\/@ ]+$`)
if err != nil {
return nil, fmt.Errorf("failed to compile regex: %v", err)
}
for _, value := range passedArgs {
if !regex.MatchString(value) {
return nil, fmt.Errorf("invalid input for passed arg: %s", value)
}
}

return passedArgs, nil
}
91 changes: 91 additions & 0 deletions internal/config/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package config

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestBadCmdArgs(t *testing.T) {
payloads := []string{
`"any;thing`,
`"any&thing`,
`"any|thing`,
`"any$thing`,
`"any\"thing`,
`"any\thing`,
`"any*thing`,
`"any?thing`,
`"any[thing`,
`"any]thing`,
`"any{thing`,
`"any}thing`,
`"any(thing`,
`"any)thing`,
`"any<thing`,
`"any>thing`,
`"anything!`,
"\"any`thing\"",
}
for _, payload := range payloads {
_, err := GetPassedArgs(payload)
assert.Error(t, err)
}

}

func TestMimeTypes(t *testing.T) {
mimeTypes := map[string]string{
"application/msword": "doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx",
"application/vnd.ms-excel": "xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
"application/vnd.ms-powerpoint": "ppt",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": "pptx",

"image/jpeg": "jpg",
"image/jp2": "jp2",
"image/png": "png",
"image/gif": "gif",
"image/bmp": "bmp",
"image/svg+xml": "svg",
"image/tiff": "tiff",
"image/webp": "webp",

"audio/mpeg": "mp3",
"audio/x-wav": "wav",
"audio/ogg": "ogg",
"audio/aac": "m4a",
"audio/webm": "webm",
"audio/flac": "flac",
"audio/midi": "mid",
"audio/x-m4a": "m4a",
"audio/x-realaudio": "ra",

"video/mp4": "mp4",
"video/x-msvideo": "avi",
"video/x-ms-wmv": "wmv",
"video/mpeg": "mpg",
"video/webm": "webm",
"video/quicktime": "mov",
"application/vnd.apple.mpegurl": "m3u8",
"video/3gpp": "3gp",
"video/mp2t": "ts",
"video/x-flv": "flv",
"video/x-m4v": "m4v",
"video/x-mng": "mng",
"video/x-ms-asf": "asx",
"video/ogg": "ogg",

"text/plain": "txt",
"text/html": "html",
"application/pdf": "pdf",
"text/csv": "csv",
}

for mimeType, extension := range mimeTypes {
ext, err := GetMimeTypeExtension(mimeType)
assert.Equal(t, nil, err)
assert.Equal(t, extension, ext)
}
}
124 changes: 0 additions & 124 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"fmt"
"io"
"log/slog"
"net"
Expand Down Expand Up @@ -376,129 +375,6 @@ cmdByMimeType:
}
}

func TestMimeTypes(t *testing.T) {
mimeTypes := map[string]string{
"application/msword": "doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx",
"application/vnd.ms-excel": "xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
"application/vnd.ms-powerpoint": "ppt",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": "pptx",

"image/jpeg": "jpg",
"image/jp2": "jp2",
"image/png": "png",
"image/gif": "gif",
"image/bmp": "bmp",
"image/svg+xml": "svg",
"image/tiff": "tiff",
"image/webp": "webp",

"audio/mpeg": "mp3",
"audio/x-wav": "wav",
"audio/ogg": "ogg",
"audio/aac": "m4a",
"audio/webm": "webm",
"audio/flac": "flac",
"audio/midi": "mid",
"audio/x-m4a": "m4a",
"audio/x-realaudio": "ra",

"video/mp4": "mp4",
"video/x-msvideo": "avi",
"video/x-ms-wmv": "wmv",
"video/mpeg": "mpg",
"video/webm": "webm",
"video/quicktime": "mov",
"application/vnd.apple.mpegurl": "m3u8",
"video/3gpp": "3gp",
"video/mp2t": "ts",
"video/x-flv": "flv",
"video/x-m4v": "m4v",
"video/x-mng": "mng",
"video/x-ms-asf": "asx",
"video/ogg": "ogg",

"text/plain": "txt",
"text/html": "html",
"application/pdf": "pdf",
"text/csv": "csv",
}
test := Test{
authHeader: "pass",
requestAuth: "pass",
expectedStatus: http.StatusOK,
expectedBody: "%s txt\n",
returnedBody: "",
expectMismatch: false,
destinationMimeType: "text/plain",
yml: `
forwardAuth: false
allowedMimeTypes:
- "*"
cmdByMimeType:
default:
cmd: echo
args:
- "%source-mime-ext"
- "%destination-mime-ext"
`,
}
for mimeType, extension := range mimeTypes {
test.name = fmt.Sprintf("test %s to %s conversion", mimeType, extension)
test.mimetype = mimeType
test.expectedBody = fmt.Sprintf("%s txt\n", extension)
t.Run(test.name, func(t *testing.T) {
var err error
destinationServer := createMockDestinationServer(t, test.returnedBody)
defer destinationServer.Close()

sourceServer := createMockSourceServer(t, test.mimetype, test.authHeader, destinationServer.URL)
defer sourceServer.Close()

os.Setenv("SCYLLARIDAE_YML", test.yml)
// set the config based on test.yml
config, err = scyllaridae.ReadConfig("")
if err != nil {
t.Fatalf("Could not read YML: %v", err)
os.Exit(1)
}

// Configure and start the main server
setupServer := httptest.NewServer(http.HandlerFunc(MessageHandler))
defer setupServer.Close()

// Send the mock message to the main server
req, err := http.NewRequest("GET", setupServer.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("X-Islandora-Args", destinationServer.URL)
// set the mimetype to send to the destination server in the Accept header
req.Header.Set("Accept", test.destinationMimeType)
req.Header.Set("Authorization", test.requestAuth)
req.Header.Set("Apix-Ldp-Resource", sourceServer.URL)

// Capture the response
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
assert.Equal(t, test.expectedStatus, resp.StatusCode)
if !test.expectMismatch {
// if we're setesting up the destination server as the cURL target
// make sure it returned
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Unable to read source uri resp body: %v", err)
}
assert.Equal(t, test.expectedBody, string(body))
}
})
}
}

func createMockDestinationServer(t *testing.T, content string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write([]byte(content)); err != nil {
Expand Down

0 comments on commit e4107eb

Please sign in to comment.