In [11]:

import (
	"encoding/json"
	"errors"
	"io/ioutil"
	"strings"
	"unicode/utf8"
)

const GPT2_EOT int32 = 50256

type Tokenizer struct {
	Encoder     map[string]int32 `json:"encoder"`
	BpeRanks    map[string]int   `json:"bpe_ranks"`
	SpecialTokens map[string]int32 `json:"special_tokens"`
	decoder     map[int32]string
}

func NewTokenizer(filename string) (*Tokenizer, error) {
	data, err := ioutil.ReadFile(filename)
	if err != nil {
		return nil, err
	}

	var t Tokenizer
	err = json.Unmarshal(data, &t)
	if err != nil {
		return nil, err
	}

	t.decoder = make(map[int32]string)
	for token, id := range t.Encoder {
		t.decoder[id] = token
	}

	return &t, nil
}

func (t *Tokenizer) Encode(text string) ([]int32, error) {
	if t.Encoder == nil {
		return nil, errors.New("tokenizer not initialized")
	}

	var tokens []int32
	for len(text) > 0 {
		i := len(text)
		for i > 0 && !utf8.ValidString(text[:i]) {
			i--
		}
		if i == 0 {
			return nil, errors.New("invalid utf-8 string")
		}
		token := text[:i]
		text = text[i:]

		if id, ok := t.Encoder[token]; ok {
			tokens = append(tokens, id)
		} else {
			bpeToken := t.bpe(token)
			for _, bpeSubToken := range strings.Split(bpeToken, " ") {
				if id, ok := t.Encoder[bpeSubToken]; ok {
					tokens = append(tokens, id)
				} else {
					return nil, errors.New("unknown token: " + bpeSubToken)
				}
			}
		}
	}

	return tokens, nil
}

func (t *Tokenizer) Decode(tokens []int32) (string, error) {
	if t.decoder == nil {
		return "", errors.New("tokenizer not initialized")
	}

	var text strings.Builder
	for _, token := range tokens {
		if token == GPT2_EOT {
			continue
		}
		if str, ok := t.decoder[token]; ok {
			text.WriteString(str)
		} else {
			return "", errors.New("unknown token ID")
		}
	}

	return text.String(), nil
}

func (t *Tokenizer) bpe(token string) string {
	pairs := getPairs(token)
	if len(pairs) == 0 {
		return token
	}

	for {
		minPair := ""
		minRank := int(^uint(0) >> 1) // Max int

		for _, pair := range pairs {
			if rank, ok := t.BpeRanks[pair]; ok {
				if rank < minRank {
					minPair = pair
					minRank = rank
				}
			}
		}

		if minPair == "" {
			break
		}

		parts := strings.Split(minPair, ",")
		if len(parts) != 2 {
			break
		}
		first, second := parts[0], parts[1]
		newToken := strings.ReplaceAll(token, first+second, first+"\u0000"+second)
		token = strings.ReplaceAll(newToken, "\u0000", "")

		if !strings.Contains(token, " ") {
			break
		}

		pairs = getPairs(token)
	}

	return token
}

func getPairs(word string) []string {
	pairs := []string{}
	chars := strings.Split(word, "")
	for i := 0; i < len(chars)-1; i++ {
		pairs = append(pairs, chars[i]+","+chars[i+1])
	}
	return pairs
}

In [19]:
func main(){
tokenizer, err := NewTokenizer("./tokenizer.json")
if err != nil {
    panic(err)
}
gonbui.RequestInput("Tokenize some text: ", false)
// reader := bufio.NewReader(os.Stdin)
// _, err := reader.ReadString('\n')
// if err != nil {
//     panic(err)
// }
if err != nil { panic(err) }
encoded, err := tokenizer.Encode("hello there")
fmt.Println("encoded: ", encoded)
decoded, err := tokenizer.Decode(encoded)
fmt.Println("decoded: ", decoded)
}

encoded:  []
decoded:  


Tokenize some text:  hwllo
