-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.go
104 lines (91 loc) · 2.79 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// This example shows how to use every public method of the HostedModels object
// using a text-generation model (GPT-2). You should create a Hosted Model on RunwayML
// before running this example. You can train your own text-generation model or use
// one of the ones publicly available on the platform.
// See https://learn.runwayml.com/#/how-to/hosted-models for details.
//
// Usage ./build/bin/text-generation
// --prompt string
// An optional prompt to use when querying the model. (default "Four score and seven years ago")
// --token string
// The hosted model token. Required if model is private.
// --url string
// A text-generation (GPT-2) hosted model url (e.g. https://my-text-model.hosted-models.runwayml.cloud/v1)
package main
import (
"encoding/json"
"fmt"
"math/rand"
"os"
"time"
runway "github.com/brannondorsey/go-runway"
flag "github.com/spf13/pflag"
)
func main() {
args := parseArgs()
model, err := runway.NewHostedModel(args.Url, args.Token)
if err != nil {
panic(err)
}
fmt.Println("[INFO] Waiting for model to wake up...")
pollIntervalMillis := 1000 // check if the model is awake every second
err = model.WaitUntilAwake(pollIntervalMillis)
if err != nil {
panic(err)
}
// You can also check if the model is awake via the model.IsAwake() method. This is
// unnecessary when combined with WaitUntilAwake() above and used here simply for
// demonstration.
awake, err := model.IsAwake()
if err != nil {
panic(err)
}
if awake {
fmt.Println("[INFO] Model is awake")
}
fmt.Println("[INFO] Calling Model.Info()...")
info, err := model.Info()
if err != nil {
panic(err)
} else {
pretty, _ := json.MarshalIndent(info, "", " ")
fmt.Printf("[INFO] Received response from model.info(): \n%v\n\n", string(pretty))
}
// Query
input := runway.JSONObject{
"prompt": args.Prompt,
"seed": rand.Intn(1000),
"max_characters": 512,
}
fmt.Println("[INFO] Calling Model.Query()...")
output, err := model.Query(input)
if err != nil {
panic(err)
} else {
pretty, _ := json.MarshalIndent(output, "", " ")
fmt.Printf("[INFO] Received response from model.query():\n%v\n\n", string(pretty))
}
}
type Args struct {
Url string
Token string
Prompt string
}
func parseArgs() Args {
url := flag.StringP("url", "u", "", "A text-generation (GPT-2) hosted model url (e.g. https://my-text-model.hosted-models.runwayml.cloud/v1)")
token := flag.StringP("token", "t", "", "The hosted model token. Required if model is private.")
prompt := flag.StringP("prompt", "p", "Four score and seven years ago", "An optional prompt to use when querying the model.")
flag.Parse()
if *url == "" {
flag.Usage()
os.Exit(1)
}
return Args{
Url: *url,
Token: *token,
Prompt: *prompt,
}
}
func init() {
rand.Seed(time.Now().UnixNano())
}