diff --git a/cmd/root.go b/cmd/root.go index a6985e9..a0619d3 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -93,18 +93,51 @@ $ dep-tree check`, return root } +func inferLang(files []string) string { + score := struct { + js int + python int + rust int + }{} + top := struct { + lang string + v int + }{} + for _, file := range files { + switch { + case utils.EndsWith(file, js.Extensions): + score.js += 1 + if score.js > top.v { + top.v = score.js + top.lang = "js" + } + case utils.EndsWith(file, rust.Extensions): + score.rust += 1 + if score.rust > top.v { + top.v = score.rust + top.lang = "rust" + } + case utils.EndsWith(file, python.Extensions): + score.python += 1 + if score.python > top.v { + top.v = score.python + top.lang = "python" + } + } + } + return top.lang +} + func makeParserBuilder(files []string, cfg *config.Config) (language.NodeParserBuilder, error) { if len(files) == 0 { return nil, errors.New("at least one file must be provided") } - // TODO: There's smarter ways to check which language are we in than just reading the - // first encountered file extension. - switch { - case utils.EndsWith(files[0], js.Extensions): + switch inferLang(files) { + case "js": return language.ParserBuilder(js.MakeJsLanguage, &cfg.Js, cfg), nil - case utils.EndsWith(files[0], rust.Extensions): + case "rust": return language.ParserBuilder(rust.MakeRustLanguage, &cfg.Rust, cfg), nil - case utils.EndsWith(files[0], python.Extensions): + case "python": return language.ParserBuilder(python.MakePythonLanguage, &cfg.Python, cfg), nil default: return nil, fmt.Errorf("file \"%s\" not supported", files[0]) diff --git a/cmd/root_test.go b/cmd/root_test.go index 75b77ac..ee1e5d8 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -104,3 +104,39 @@ func TestRoot(t *testing.T) { }) } } + +func TestInferLang(t *testing.T) { + tests := []struct { + Name string + Files []string + Expected string + }{ + { + Name: "only 1 file", + Files: []string{"foo.js"}, + Expected: "js", + }, + { + Name: "majority of files", + Files: []string{"foo.js", "bar.js", "foo.rs", "foo.py"}, + Expected: "js", + }, + { + Name: "unrelated files", + Files: []string{"foo.js", "foo.pdf"}, + Expected: "js", + }, + { + Name: "no match", + Files: []string{"foo.pdf", "bar.docx"}, + Expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + a := require.New(t) + a.Equal(tt.Expected, inferLang(tt.Files)) + }) + } +}