Skip to content

Commit

Permalink
feat: cobra, shell completions, perf and ui improvements (#112)
Browse files Browse the repository at this point in the history
* refactor: use cobra for flags

- Use Cobra
- Clean up
- Refactor

* fix: db path

* feat: completions

* chore: cleanup

* chore: cleanup

* fix: tests

* test: db.Completions

* fix: improvements

* fix: config

* test: fix broken test

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fix: env, gitignore

* fix: no need to wrap in an error

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fix: mods with no args

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fix: --list should print list to stdout (#113)

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fix: improve error handling (#114)

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fix: styles & stderr

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fix: version

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fix: remove unused examples

* fix: hide completion and help commands

* fix: error handling

- Fix modsError
- Error when continuing without a prompt

* fix: mark flags as mutually exclusive

* fix: create cache directory

* fix: improve completions

* fix: improve completions

* fix: use errors.As

* fix: improve styles

* fix: improve config creation and its error handling

* fix: mkdir cache

* fix: ensure error's reasons ends with .

* perf: indices

* test: fix broken test, add no prompt continue test

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* chore: fmt queries

---------

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>
Co-authored-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>
  • Loading branch information
aymanbagabas and caarlos0 committed Sep 1, 2023
1 parent dec042d commit 17a19d4
Show file tree
Hide file tree
Showing 10 changed files with 507 additions and 307 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mods
.envrc

completions/
dist/
125 changes: 48 additions & 77 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (

"github.com/adrg/xdg"
"github.com/caarlos0/env/v9"
"github.com/charmbracelet/lipgloss"
"github.com/muesli/termenv"
"github.com/spf13/cobra"
flag "github.com/spf13/pflag"
"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -139,23 +139,28 @@ func (f flagParseError) Flag() string {
}
}

func newConfig() (Config, error) {
func ensureConfig() (Config, error) {
var c Config
sp, err := xdg.ConfigFile(filepath.Join("mods", "mods.yml"))
if err != nil {
return c, fmt.Errorf("can't find settings path: %s", err)
return c, modsError{err, "Could not find settings path."}
}
c.SettingsPath = sp
err = writeConfigFile(sp)
if err != nil {

dir := filepath.Dir(sp)
if err := os.MkdirAll(dir, 0o700); err != nil { //nolint:gomnd
return c, modsError{err, "Could not create cache directory."}
}

if err := writeConfigFile(sp); err != nil {
return c, err
}
content, err := os.ReadFile(sp)
if err != nil {
return c, fmt.Errorf("can't read settings file: %s", err)
return c, modsError{err, "Could not read settings file."}
}
if err := yaml.Unmarshal(content, &c); err != nil {
return c, fmt.Errorf("%s: %w", sp, err)
return c, modsError{err, "Could not parse settings file."}
}
ms := make(map[string]Model)
for _, api := range c.APIs {
Expand All @@ -179,50 +184,16 @@ func newConfig() (Config, error) {

_ = os.Setenv("__MODS_GLAMOUR", fmt.Sprintf("%v", isOutputTTY()))
if err := env.ParseWithOptions(&c, env.Options{Prefix: "MODS_"}); err != nil {
return c, fmt.Errorf("could not parse environment into config: %s", err)
}

flag.StringVarP(&c.Model, "model", "m", c.Model, help["model"])
flag.StringVarP(&c.API, "api", "a", c.API, help["api"])
flag.StringVarP(&c.HTTPProxy, "http-proxy", "x", c.HTTPProxy, help["http-proxy"])
flag.BoolVarP(&c.Format, "format", "f", c.Format, help["format"])
flag.BoolVarP(&c.Glamour, "glamour", "g", c.Glamour, help["glamour"])
flag.IntVarP(&c.IncludePrompt, "prompt", "P", c.IncludePrompt, help["prompt"])
flag.BoolVarP(&c.IncludePromptArgs, "prompt-args", "p", c.IncludePromptArgs, help["prompt-args"])
flag.BoolVarP(&c.Quiet, "quiet", "q", c.Quiet, help["quiet"])
flag.BoolVar(&c.Settings, "settings", false, help["settings"])
flag.BoolVarP(&c.ShowHelp, "help", "h", false, help["help"])
flag.BoolVarP(&c.Version, "version", "v", false, help["version"])
flag.StringVarP(&c.Continue, "continue", "c", "", help["continue"])
flag.BoolVarP(&c.ContinueLast, "continue-last", "C", false, help["continue-last"])
flag.BoolVarP(&c.List, "list", "l", c.List, help["list"])
flag.IntVar(&c.MaxRetries, "max-retries", c.MaxRetries, help["max-retries"])
flag.BoolVar(&c.NoLimit, "no-limit", c.NoLimit, help["no-limit"])
flag.IntVar(&c.MaxTokens, "max-tokens", c.MaxTokens, help["max-tokens"])
flag.Float32Var(&c.Temperature, "temp", c.Temperature, help["temp"])
flag.Float32Var(&c.TopP, "topp", c.TopP, help["topp"])
flag.UintVar(&c.Fanciness, "fanciness", c.Fanciness, help["fanciness"])
flag.StringVar(&c.StatusText, "status-text", c.StatusText, help["status-text"])
flag.BoolVar(&c.ResetSettings, "reset-settings", c.ResetSettings, help["reset-settings"])
flag.StringVarP(&c.Title, "title", "t", c.Title, help["title"])
flag.StringVar(&c.Delete, "delete", c.Delete, help["delete"])
flag.StringVarP(&c.Show, "show", "s", c.Show, help["show"])
flag.BoolVar(&c.NoCache, "no-cache", c.NoCache, help["no-cache"])
flag.Lookup("prompt").NoOptDefVal = "-1"
flag.Usage = usage
flag.CommandLine.SortFlags = false
flag.CommandLine.Init("", flag.ContinueOnError)
if err := flag.CommandLine.Parse(os.Args[1:]); err != nil {
return c, flagParseError{err}
return c, modsError{err, "Could not parse environment into settings file."}
}

if c.Format && c.FormatText == "" {
c.FormatText = "Format the response as markdown without enclosing backticks."
}
if c.CachePath == "" {
c.CachePath = filepath.Join(xdg.DataHome, "mods", "conversations")
}
c.Prefix = strings.Join(flag.Args(), " ")

if err := os.MkdirAll(c.CachePath, 0o700); err != nil { //nolint:gomnd
return c, modsError{err, "Could not create cache directory."}
}

return c, nil
}
Expand All @@ -240,25 +211,18 @@ func writeConfigFile(path string) error {
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
return createConfigFile(path)
} else if err != nil {
return fmt.Errorf("could not stat path '%s': %w", path, err)
return modsError{err, "Could not stat path."}
}
return nil
}

func createConfigFile(path string) error {
var c Config
tmpl, err := template.New("config").Parse(strings.TrimSpace(configTemplate))
if err != nil {
return fmt.Errorf("could not parse config template: %w", err)
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o700); err != nil { //nolint:gomnd
return fmt.Errorf("could not create directory '%s': %w", dir, err)
}
tmpl := template.Must(template.New("config").Parse(configTemplate))

var c Config
f, err := os.Create(path)
if err != nil {
return fmt.Errorf("could not create file '%s': %w", path, err)
return modsError{err, "Could not create configuration file."}
}
defer func() { _ = f.Close() }()

Expand All @@ -270,48 +234,55 @@ func createConfigFile(path string) error {
Help: help,
}
if err := tmpl.Execute(f, m); err != nil {
return fmt.Errorf("could not render template: %w", err)
return modsError{err, "Could not render template."}
}
return nil
}

func usage() {
r := lipgloss.DefaultRenderer()
s := makeStyles(r)
func useLine() string {
appName := filepath.Base(os.Args[0])

if r.ColorProfile() == termenv.TrueColor {
appName = makeGradientText(s.AppName, appName)
if stdoutRenderer.ColorProfile() == termenv.TrueColor {
appName = makeGradientText(stdoutStyles.AppName, appName)
}

return fmt.Sprintf(
"%s %s",
appName,
stdoutStyles.CliArgs.Render("[OPTIONS] [PREFIX TERM]"),
)
}

func usageFunc(cmd *cobra.Command) error {
fmt.Printf("GPT on the command line. Built for pipelines.\n\n")
fmt.Printf(
"Usage:\n %s %s\n\n",
appName,
s.CliArgs.Render("[OPTIONS] [PREFIX TERM]"),
"Usage:\n %s\n\n",
useLine(),
)
fmt.Println("Options:")
flag.VisitAll(func(f *flag.Flag) {
cmd.Flags().VisitAll(func(f *flag.Flag) {
if f.Shorthand == "" {
fmt.Printf(
" %-42s %s\n",
s.Flag.Render("--"+f.Name),
s.FlagDesc.Render(f.Usage),
" %-44s %s\n",
stdoutStyles.Flag.Render("--"+f.Name),
stdoutStyles.FlagDesc.Render(f.Usage),
)
} else {
fmt.Printf(
" %s%s %-38s %s\n",
s.Flag.Render("-"+f.Shorthand),
s.FlagComma,
s.Flag.Render("--"+f.Name),
s.FlagDesc.Render(f.Usage),
" %s%s %-40s %s\n",
stdoutStyles.Flag.Render("-"+f.Shorthand),
stdoutStyles.FlagComma,
stdoutStyles.Flag.Render("--"+f.Name),
stdoutStyles.FlagDesc.Render(f.Usage),
)
}
})
desc, example := randomExample()
fmt.Printf(
"\nExample:\n %s\n %s\n",
s.Comment.Render("# "+desc),
cheapHighlighting(s, example),
stdoutStyles.Comment.Render("# "+desc),
cheapHighlighting(stdoutStyles, example),
)

return nil
}
65 changes: 57 additions & 8 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,19 @@ func openDB(ds string) (*convoDB, error) {
updated_at datetime not null default(strftime('%Y-%m-%d %H:%M:%f', 'now')),
check(id <> ''),
check(title <> '')
);
)
`); err != nil {
return nil, fmt.Errorf("could not migrate db: %w", err)
}
if _, err := db.Exec(`
create index if not exists idx_conv_id
on conversations(id)
`); err != nil {
return nil, fmt.Errorf("could not migrate db: %w", err)
}
if _, err := db.Exec(`
create index if not exists idx_conv_title
on conversations(title)
`); err != nil {
return nil, fmt.Errorf("could not migrate db: %w", err)
}
Expand Down Expand Up @@ -110,7 +122,12 @@ func (c *convoDB) Delete(id string) error {

func (c *convoDB) FindHEAD() (*Conversation, error) {
var convo Conversation
if err := c.db.Get(&convo, "select * from conversations order by updated_at desc limit 1"); err != nil {
if err := c.db.Get(&convo, `
select *
from conversations
order by updated_at desc
limit 1
`); err != nil {
return nil, fmt.Errorf("FindHead: %w", err)
}
return &convo, nil
Expand All @@ -131,14 +148,45 @@ func (c *convoDB) findByIDOrTitle(result *[]Conversation, in string) error {
if err := c.db.Select(result, c.db.Rebind(`
select *
from conversations
where id like ?
where id glob ?
or title = ?
`), in+"%", in); err != nil {
`), in+"*", in); err != nil {
return fmt.Errorf("findByIDOrTitle: %w", err)
}
return nil
}

func (c *convoDB) Completions(in string) ([]string, error) {
var result []string
if err := c.db.Select(&result, c.db.Rebind(`
select printf(
'%s%c%s',
case
when length(?) < ? then
substr(id, 1, ?)
else
id
end,
char(9),
title
)
from conversations where id glob ?
union
select
printf(
"%s%c%s",
title,
char(9),
substr(id, 1, ?)
)
from conversations
where title glob ?
`), in, sha1short, sha1short, in+"*", sha1short, in+"*"); err != nil {
return result, fmt.Errorf("Completions: %w", err)
}
return result, nil
}

func (c *convoDB) Find(in string) (*Conversation, error) {
var conversations []Conversation
var err error
Expand All @@ -163,10 +211,11 @@ func (c *convoDB) Find(in string) (*Conversation, error) {

func (c *convoDB) List() ([]Conversation, error) {
var convos []Conversation
if err := c.db.Select(
&convos,
"select * from conversations order by updated_at desc",
); err != nil {
if err := c.db.Select(&convos, `
select *
from conversations
order by updated_at desc
`); err != nil {
return convos, fmt.Errorf("List: %w", err)
}
return convos, nil
Expand Down
25 changes: 25 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -143,4 +144,28 @@ func TestConvoDB(t *testing.T) {
require.NoError(t, err)
require.Empty(t, list)
})

t.Run("completions", func(t *testing.T) {
db := testDB(t)

const testid1 = "fc5012d8c67073ea0a46a3c05488a0e1d87df74b"
const title1 = "some title"
const testid2 = "6c33f71694bf41a18c844a96d1f62f153e5f6f44"
const title2 = "football teams"
require.NoError(t, db.Save(testid1, title1))
require.NoError(t, db.Save(testid2, title2))

results, err := db.Completions("f")
require.NoError(t, err)
require.Equal(t, []string{
fmt.Sprintf("%s\t%s", testid1[:sha1short], title1),
fmt.Sprintf("%s\t%s", title2, testid2[:sha1short]),
}, results)

results, err = db.Completions(testid1[:8])
require.NoError(t, err)
require.Equal(t, []string{
fmt.Sprintf("%s\t%s", testid1, title1),
}, results)
})
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ require (
github.com/mattn/go-isatty v0.0.19
github.com/muesli/termenv v0.15.2
github.com/sashabaranov/go-openai v1.14.2
github.com/spf13/cobra v1.6.1
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.4
gopkg.in/yaml.v3 v3.0.1
Expand All @@ -32,6 +33,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/css v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.0.1 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ github.com/hashicorp/serf v0.9.8/go.mod h1:TXZNMjZQijwlDvp+r0b63xZ45H7JmCmgg4gpT
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7Pgzkat/bFNc=
github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
Expand Down Expand Up @@ -635,6 +636,7 @@ github.com/spf13/afero v1.9.2/go.mod h1:iUV7ddyEEZPO5gA3zD4fJt6iStLlL+Lg4m2cihcD
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU=
github.com/spf13/cobra v1.0.0/go.mod h1:/6GTrnGXV9HjY+aR4k0oJ5tcvakLuG6EuKReYlHNrgE=
github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA=
github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY=
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo=
Expand Down
Loading

0 comments on commit 17a19d4

Please sign in to comment.