-
Notifications
You must be signed in to change notification settings - Fork 0
/
ngrams.go
137 lines (125 loc) · 2.78 KB
/
ngrams.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package main
import (
"database/sql"
"fmt"
"math/rand"
"unicode/utf8"
)
type ngramVariants struct {
weights map[string]int
weight int
}
type ngramStat struct {
ngrams map[string]*ngramVariants
}
func loadStat(db *sql.DB) (ngramStat, error) {
stat := ngramStat{
ngrams: make(map[string]*ngramVariants),
}
rows, err := db.Query(`SELECT ngram, nextNgram, weight FROM ngrams`)
if err != nil {
return stat, fmt.Errorf("querying db: %w", err)
}
defer rows.Close()
for rows.Next() {
var ngram, nextNgram string
var weight int
if err := rows.Scan(&ngram, &nextNgram, &weight); err != nil {
return stat, fmt.Errorf("scanning the result: %w", err)
}
if v, ok := stat.ngrams[ngram]; !ok {
stat.ngrams[ngram] = &ngramVariants{
weights: map[string]int{nextNgram: weight},
weight: weight,
}
} else {
v.weights[nextNgram] = weight
v.weight += weight
}
}
if err := rows.Err(); err != nil {
return stat, fmt.Errorf("scanning the result: %w", err)
}
return stat, nil
}
func (v ngramVariants) get() string {
n := rand.Intn(v.weight)
for k, v := range v.weights {
n -= v
if n < 0 {
return k
}
}
panic("invalid weights?")
}
func (n ngramStat) getRandom() string {
for k := range n.ngrams {
return k
}
return ""
}
func (s *ngramStat) add(ngram, nextNgram string) {
if v, ok := s.ngrams[ngram]; !ok {
s.ngrams[ngram] = &ngramVariants{
weights: map[string]int{nextNgram: 1},
weight: 1,
}
return
} else {
v.weight++
v.weights[nextNgram]++
}
}
func (s ngramStat) getNext(ngram string) string {
if options := s.ngrams[ngram]; options == nil {
return s.getRandom()
} else {
return options.get()
}
}
func (s ngramStat) generate(length int) string {
var (
current = s.getRandom()
result = current
next string
)
for i := 0; i < length; i++ {
next = s.getNext(current)
c, _ := utf8.DecodeLastRuneInString(next)
result += string(c)
current = next
}
return result
}
func (s ngramStat) save(db *sql.DB) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("Begin(): %w", err)
}
stmt, err := tx.Prepare("INSERT OR REPLACE INTO ngrams(ngram, nextNgram, weight) VALUES (?, ?, ?)")
if err != nil {
return fmt.Errorf("Prepare(): %w", err)
}
defer stmt.Close()
for ngram, variants := range s.ngrams {
for nextNgram, weight := range variants.weights {
if _, err := stmt.Exec(ngram, nextNgram, weight); err != nil {
return fmt.Errorf("inserting into db: %w", err)
}
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("Commit(): %w", err)
}
return nil
}
func dumpNgrams(stat ngramStat) {
for k, v := range stat.ngrams {
fmt.Printf("%s:\n", k)
for n, c := range v.weights {
fmt.Printf("\t%s -> %d\n", n, c)
}
fmt.Println()
}
fmt.Printf("%d ngrams collected\n", len(stat.ngrams))
}