Skip to content

Commit

Permalink
refactor: SamplerSelector has no access to logits ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
gsuuon committed Sep 12, 2023
1 parent e6aeb97 commit f755f93
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/sample.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export type CpuNDArray = {
/**
* Select relevant tokens -- called once for each token generation
*/
type SamplerSelector = (cpuLogits: CpuNDArray, tokens: number[], completion: string) => number[]
type SamplerSelector = (tokens: number[], completion: string) => number[]
// NOTE that since we return relevant tokens as an array, duplicates will have their bias applied twice (no change with gates)
// I'm considering this a feature for now as a way to weigh relevant tokens, but consider changing number[] to a Set.

Expand Down Expand Up @@ -74,7 +74,7 @@ export const buildBias = (model: Model): Bias => {
return (cpuLogits, tokens, completion) => {
const logits = cpuLogits.data

const relevantTokens = selector(cpuLogits, tokens, completion)
const relevantTokens = selector(tokens, completion)

if (relevantTokens.length > 0) {
console.debug('penalize', {
Expand Down Expand Up @@ -109,7 +109,7 @@ export const buildBias = (model: Model): Bias => {
return (cpuLogits, tokens, completion) => {
const logits = cpuLogits.data

const relevantTokens = selector(cpuLogits, tokens, completion)
const relevantTokens = selector(tokens, completion)

if (relevantTokens.length > 0) {
const modified = logits.toArray().map(adjust(relevantTokens))
Expand Down Expand Up @@ -195,7 +195,7 @@ export const oneOf: (items: string[]) => CreateSamplerTemplate = items => model
}
)

return (_logits, tokens, _completions) => {
return (tokens, _completions) => {
const filtered = encoded.filter(item => arrayStartsWith(tokens, item) && item.length > tokens.length)
const nextRelevantTokens = filtered.map(x => x[tokens.length])

Expand Down Expand Up @@ -253,7 +253,7 @@ export const consistsOf = (chars: string[]) => (model: Model) => (priorCompletio

const encoded = Array.from(new Set(encodedTokens))

return (_, tokens, completion) => {
return (tokens, completion) => {
console.debug('consistsOf', {
tokens: [...tokens],
tokensChars: model.tokenizer.decode(new Int32Array(tokens)),
Expand Down

0 comments on commit f755f93

Please sign in to comment.