In [1]:
using LogProbs
using StatsFuns.RFunctions: gammarand
using SpecialFunctions: logbeta
import Base: +, -, *, /, zero, one, <, ==



#############################
### Dirichlet Multinomial ###
#############################

function categorical_sample(tokens, weights)
    T = eltype(weights)
    x = rand(T) * sum(weights)
    cum_weights = zero(T)
    for (t, w) in zip(tokens, weights)
        cum_weights += w
        if cum_weights > x
            return t
        end
    end
end

categorical_sample(d::Dict) = categorical_sample(keys(d), values(d))
categorical_sample(v::Vector) = categorical_sample(1:length(v), v)

abstract type Distribution{T} end

mutable struct DirCat{T, C} <: Distribution{T}
    counts :: Dict{T, C}
end

DirCat(support, priors) = DirCat(Dict(x => p for (x,p) in zip(support, priors)))
support(dc::DirCat) = keys(dc.counts)

function sample(dc::DirCat)
    weights = [gammarand(c, 1) for c in values(dc.counts)]
    categorical_sample(keys(dc.counts), weights)
end

function logscore(dc::DirCat, obs)
    LogProb(logbeta(sum(values(dc.counts)), 1) - logbeta(dc.counts[obs], 1); islog=true)
end

function add_obs!(dc::DirCat, obs)
    dc.counts[obs] += 1
end

function rm_obs!(dc::DirCat, obs)
    dc.counts[obs] -= 1
end


################################
### Conditional Distribution ###
################################

struct SimpleCond{C, D, S} # context, distribution, support
    dists   :: Dict{C, D}
    support :: S
    SimpleCond(dists::Dict{C, D}, support::S) where {C, D, S} =
        new{C, D, S}(dists, unique(support))
end

function SimpleCond(dists::AbstractDict)
    SimpleCond(
        dists,
        vcat([collect(support(dist)) for dist in values(dists)]...)
    )
end

sample(sc::SimpleCond, context, args...) = sample(sc.dists[context], args...)
logscore(sc::SimpleCond, obs, context) = logscore(sc.dists[context], obs)
rm_obs!(sc::SimpleCond, obs, context) = rm_obs!(sc.dists[context], obs)

score_type(::SimpleCond) = LogProb

function add_obs!(cond::SimpleCond{C,D,S}, obs, context) where {C,D,S}
    if !haskey(cond.dists, context)
        cond.dists[context] = D(cond.support)
    end
    add_obs!(cond.dists[context], obs)
end

###############
### CFRules ###
###############

mutable struct RunningCounter
    n :: Int
end

RunningCounter() = RunningCounter(0)
count!(c::RunningCounter) = c.n += 1

rule_counter = RunningCounter()

struct CFRules{LHS, RHS} # left hand side and right hand side of the rule
    mappings ::Dict{LHS, Vector{RHS}}
    name :: Symbol
end

==(r1::CFRules, r2::CFRules) = r1.name == r2.name
hash(r::CFRules, h::UInt) = hash(hash(CFRules, hash(r.name)), h)

Base.show(io::IO, r::CFRules) = print(io, "CFRules($(r.name))")

CFRules(pairs::Pair...) =
    CFRules(Dict(pairs...), Symbol("rules", count!(rule_counter)))
CFRules(g::Base.Generator) =
    CFRules(Dict(g), Symbol("rules", count!(rule_counter)))
CFRules(f::Function, lhss, name) =
    CFRules(Dict(lhs => f(lhs) for lhs in lhss), name)
CFRules(f::Function, lhss) =
    CFRules(Dict(lhs => f(lhs) for lhs in lhss), Symbol("rules", count!(rule_counter)))

lhss(r::CFRules) = keys(r.mappings) # aka domain
isapplicable(r::CFRules, lhs) = haskey(r.mappings, lhs)
(r::CFRules)(lhs) = r.mappings[lhs]

###############
### CFState ###
###############


mutable struct CompletionAutomaton{Cat,Comp} # category, completion
    transitions :: Vector{Dict{Cat, Int}}
    completions :: Vector{Vector{Comp}}
end

CompletionAutomaton(Cat::Type, Comp::Type) =
    CompletionAutomaton([Dict{Cat, Int}()], [Vector{Comp}()])

number_of_states(ca::CompletionAutomaton) = length(ca.transitions)
isfinal(ca::CompletionAutomaton, s) = isempty(ca.transitions[s])
is_possible_transition(ca::CompletionAutomaton, s, c) = haskey(ca.transitions[s], c)
transition(ca::CompletionAutomaton, s, c) = ca.transitions[s][c]
completions(ca::CompletionAutomaton, s) = ca.completions[s]

#Not sure what is going on here
function add_completion!(ca::CompletionAutomaton{Cat,Comp}, comp, categories) where {Cat,Comp}
    s = 1
    for c in categories
        if is_possible_transition(ca, s, c)
            s = transition(ca, s, c)
        else
            push!(ca.transitions, Dict{Cat,Int}())
            push!(ca.completions, Vector{Comp}())
            s = ca.transitions[s][c] = number_of_states(ca)
        end
    end
    push!(ca.completions[s], comp)
end

function add_rule!(ca::CompletionAutomaton, r::CFRules)
    for lhs in lhss(r)
        add_completion!(ca, (lhs, r), r(lhs))
    end
end

#################
### CFGrammar ###
#################

struct CFGrammar{C, T, Cond, F}
    comp_automtn  :: CompletionAutomaton{C, Tuple{C, CFRules{C, C}}}
    startsymbols  :: Vector{C}
    terminal_dict :: Dict{T, Vector{Tuple{C, CFRules{C, T}}}}
    cond          :: Cond # conditional scoring
    dependent_components::F
end

function CFGrammar(
        category_rules::Vector{CFRules{C, C}},
        terminal_rules::Vector{CFRules{C, T}},
        startsymbols  ::Vector{C},
        dependent_components=identity::Function
        ) where {C, T}
    comp_automtn = CompletionAutomaton(C, Tuple{C, CFRules{C, C}})
    for r in category_rules
        add_rule!(comp_automtn, r)
    end

    terminal_dict = Dict{T, Vector{Tuple{C, CFRules{C, T}}}}()
    for r in terminal_rules
        for lhs in lhss(r)
        t = r(lhs)[1]
            if haskey(terminal_dict, t)
                push!(terminal_dict[t], (lhs, r))
            else
                terminal_dict[t] = [(lhs, r)]
            end
        end
    end

    applicable_rules = Dict{C, Vector{CFRules}}()
    for r in CFRules[category_rules; terminal_rules]
        for c in lhss(r)
            if haskey(applicable_rules, c)
                push!(applicable_rules[c], r)
            else
                applicable_rules[c] = CFRules[r]
            end
        end
    end

    cond = SimpleCond(
        Dict(
            dependent_components(c) => let rules = applicable_rules[c]
                n = length(rules)
                k = count(isa.(rules, CFRules{C, T})) # number terminal rules
                DirCat(rules, [fill(1.0, n-k); fill(1/k, k)])
            end
            for c in keys(applicable_rules)
        )
    )

    CFGrammar(comp_automtn, startsymbols, terminal_dict, cond, dependent_components)
end

dependent_components(g::CFGrammar, c) = g.dependent_components(c)

startstate(g::CFGrammar) = 1
startsymbols(g::CFGrammar) = g.startsymbols

isfinal(g::CFGrammar, s) = isfinal(g.comp_automtn, s)
is_possible_transition(g::CFGrammar, s, c) = is_possible_transition(g.comp_automtn, s, c)
transition(g::CFGrammar, s, c) = transition(g.comp_automtn, s, c)

completions(g::CFGrammar, s::Int) =
    ((c, r, score(g, c, r)) for (c, r) in completions(g.comp_automtn, s))
completions(g::CFGrammar, t) =
    ((c, r, score(g, c, r)) for (c, r) in g.terminal_dict[t])

score(g::CFGrammar, c, r) = logscore(g.cond, r, dependent_components(g, c))

@inline function types(grammar::CFGrammar{C, T, Cond}) where {C, T, Cond}
    C, T, CFRules{C, C}, CFRules{C, T}, Int, LogProb
end



types (generic function with 1 method)

In [2]:
using DataStructures: PriorityQueue
import Base: length, insert!, isempty, convert, getindex, promote_rule, range
import DataStructures: enqueue!, dequeue!, peek

#sing .ParserTypes: Edge, Constituent, EdgeKey, ConstituentKey, IntervalRange, CyclicRange
#using .PCFGGrammar: CFGrammar, is_possible_transition, completions, transition, isfinal
###################
### Completions ###
###################

struct TerminalCompletion{T,TR,S}
    terminal :: T
    rule     :: TR
    score    :: S
end
terminal(comp::TerminalCompletion) = comp.terminal
rule(comp::TerminalCompletion) = comp.rule
score(comp::TerminalCompletion) = comp.score

struct EdgeCompletion{E,CR,S}
    edge   :: E
    rule   :: CR
    score  :: S
    inloop :: Bool
end
edge(comp::EdgeCompletion) = comp.edge
rule(comp::EdgeCompletion) = comp.rule
score(comp::EdgeCompletion) = comp.score
inloop(comp::EdgeCompletion) = comp.inloop
EdgeCompletion(edge, rule, score) = EdgeCompletion(edge, rule, score, false)
@inline function ==(c1::EdgeCompletion, c2::EdgeCompletion)
    c1.edge == c2.edge && c1.rule == c2.rule
end


#################
### Traversal ###
#################

struct Traversal{E,CO,S}
    edge   :: Union{E,Nothing}
    cont   :: CO
    score  :: S
    inloop :: Bool
end
Traversal(edge, cont, score) = Traversal(edge, cont, score, false)
Traversal(edge, cont) = Traversal(edge, cont, score(edge) * score(cont), false)
hasedge(trav::Traversal) = !isnothing(trav.edge)
#edge(trav::Traversal) = get(trav.edge)
edge(trav::Traversal) = trav.edge
cont(trav::Traversal) = trav.cont
score(trav::Traversal) = trav.score
inloop(trav::Traversal) = trav.inloop

@inline function ==(t1::Traversal, t2::Traversal)
    if hasedge(t1)
        if hasedge(t2)
            t1.edge == t2.edge && t1.cont == t2.cont
        else
            false
        end
    else
        if hasedge(t2)
            false
        else
            t1.cont == t2.cont
        end
    end
end


##############
### ModInt ###
##############

struct ModInt{n} <: Number
  val::Int
  ModInt{n}(val) where {n} = new(mod(val,n))
end

show(io::IO, a::ModInt{n}) where n = print(io, "$(a.val) mod $n")

+(a::ModInt{n}, b::ModInt{n}) where n = ModInt{n}(a.val + b.val)
-(a::ModInt{n}) where n = - a.val
-(a::ModInt{n}, b::ModInt{n}) where n = ModInt{n}(a.val - b.val)
*(a::ModInt{n}, b::ModInt{n}) where n = ModInt{n}(a.val * b.val)
/(a::ModInt{n}, b::ModInt{n}) where n = a * invmod(b, n)

<(a::ModInt{n}, b::ModInt{n}) where n = a.val < b.val

one(a::ModInt{n}) where n = ModInt{n}(1)
zero(a::ModInt{n}) where n = ModInt{n}(0)

convert(::Type{ModInt{n}}, x::Int) where n = ModInt{n}(x)
convert(::Type{Int}, x::ModInt) = x.val

getindex(t::Union{Tuple, Array}, i::ModInt) = getindex(t, i.val + 1)

promote_rule(::Type{ModInt{n}}, ::Type{Int}) where n = ModInt{n}


#############
### Range ###
#############

abstract type ItemRange end

ItemRange(s::Int, e::Int, n::Int, cyclic::Bool) =
    cyclic ? CyclicRange(s, e, n) : IntervalRange(s, e)

start(r::ItemRange) = r.start
_end(r::ItemRange)  = r._end

struct IntervalRange <: ItemRange
    start :: Int
    _end  :: Int
end

length(r::IntervalRange) = _end(r) - start(r)
concatenable(r1::IntervalRange, r2::IntervalRange) = _end(r1) == start(r2)

function *(r1::IntervalRange, r2::IntervalRange)
    @assert concatenable(r1, r2)
    IntervalRange(start(r1), _end(r2))
end

struct CyclicRange{n} <: ItemRange
    start  :: ModInt{n}
    _end   :: ModInt{n}
    length :: Int
end
CyclicRange(s::ModInt, e::ModInt) = CyclicRange(s, e, Int(e-s))
CyclicRange(s::Int, e::Int, n::Int) = CyclicRange(ModInt{n}(s), ModInt{n}(e))

length(r::CyclicRange) = r.length

@inline function concatenable(r1::CyclicRange{n}, r2::CyclicRange{n}) where n
    _end(r1) == start(r2) && length(r1) + length(r2) <= n
end

function *(r1::CyclicRange, r2::CyclicRange)
    @assert concatenable(r1, r2)
    CyclicRange(start(r1), _end(r2), length(r1) + length(r2))
end

###############
### ItemKey ###
###############

abstract type ItemKey{R} end

range(k::ItemKey)  =          k.range
start(k::ItemKey)  =  start(range(k))
_end(k::ItemKey)   =   _end(range(k))
length(k::ItemKey) = length(range(k))

struct EdgeKey{R,ST} <: ItemKey{R}
    range :: R
    state :: ST
end
state(k::EdgeKey) = k.state

struct ConstituentKey{R,C} <: ItemKey{R}
    range    :: R
    category :: C
end
category(k::ConstituentKey) = k.category

############
### Item ###
############

abstract type Item end

Item(key, trav::Traversal) = Edge(key, trav)
Item(key, comp::EdgeCompletion) = Constituent(key, comp)

key(item::Item) = item.key
range(item::Item) = range(key(item))
start(item::Item) = start(range(item))
_end(item::Item) = _end(range(item))
length(item::Item) = length(range(item))
isfinished(item::Item) = !(isnothing(item.score))
lastpopscore(item::Item) = item.lastpopscore
insidepopnumber(item::Item) = item.insidepopnumber

############
### Edge ###
############

mutable struct Edge{R,ST,S,CO} <: Item
    key             :: EdgeKey{R,ST}
    score           :: Union{S,Nothing}
    lastpopscore    :: S
    insidepopnumber :: Int
    traversals      :: Vector{Traversal{Edge{R,ST,S,CO},CO,S}}
end

@inline function Edge(key, trav::Traversal{E,CO,S}) where {E,CO,S}
    Edge(key, nothing, zero(S), 0, [trav])
end

state(edge::Edge) = state(key(edge))
traversals(edge::Edge) = edge.traversals

@inline function score(edge::Edge)
    if isfinished(edge)
        get(edge.score)
    else
        sum(score(trav) for trav in edge.traversals)
    end
end

function add!(edge::Edge, trav)
    found = false
    for (i, t) in enumerate(edge.traversals)
        if t==trav
            edge.traversals[i] = trav
            found = true
            break
        end
    end
    if !found
        push!(edge.traversals, trav)
    end
    nothing
end



add! (generic function with 1 method)

In [3]:
###################
### Constituent ###
###################

mutable struct Constituent{R,C,T,CR,TR,ST,S} <: Item
    key                 :: ConstituentKey{R,C}
    score               :: Union{S,Nothing}
    lastpopscore        :: S
    insidepopnumber     :: Int
    terminal_completion :: Union{TerminalCompletion{T,TR,S}, Nothing}
    completions         :: Vector{EdgeCompletion{Edge{R,ST,S,Constituent{R,C,T,CR,TR,ST,S}},CR,S}}
end

@inline function Constituent(
        key::ConstituentKey{R,C}, comp::TerminalCompletion, grammar
    ) where {R,C}
    C_, T, CR, TR, ST, S = types(grammar)
    Constituent(
        key, nothing, zero(S), 0, comp,
        Vector{EdgeCompletion{Edge{R,ST,S,Constituent{R,C,T,CR,TR,ST,S}},CR,S}}()
    )
end

@inline function Constituent(
        key,
        comp :: EdgeCompletion{Edge{R,ST,S,Constituent{R,C,T,CR,TR,ST,S}},CR,S}
    ) where {R,C,T,CR,TR,ST,S}
    Constituent(
        key, nothing, zero(S), 0,
        nothing,
        [comp]
    )
end

category(cont::Constituent) = category(key(cont))
completions(cont::Constituent) = cont.completions

hasterminal(cont::Constituent) = !(isnothing(cont.terminal_completion))
terminal_completion(cont::Constituent) = cont.terminal_completion
terminal(cont::Constituent) = terminal(cont.terminal_completion)
#terminal_completion(cont::Constituent) = get(cont.terminal_completion)
#terminal(cont::Constituent) = get(cont.terminal_completion).terminal

@inline function score(cont::Constituent)
    if isfinished(cont)
        get(cont.score)
    else
        if hasterminal(cont)
            if isempty(completions(cont))
                score(terminal_completion(cont))
            else
                +(
                    score(terminal_completion(cont)),
                    sum(score(comp) for comp in completions(cont))
                )
            end
        else
            sum(score(comp) for comp in completions(cont))
        end
    end
end

function add!(cont::Constituent, comp)
    found = false
    for (i, c) in enumerate(cont.completions)
        if c==comp
            cont.completions[i] = comp
            found = true
            break
        end
    end
    if !found
        push!(cont.completions, comp)
    end
    nothing
end



add! (generic function with 2 methods)

In [4]:
@inline function Traversal(cont::Constituent{R,C,T,CR,TR,ST,S}) where {R,C,T,CR,TR,ST,S}
    Traversal{Edge{R,ST,S,Constituent{R,C,T,CR,TR,ST,S}}, Constituent{R,C,T,CR,TR,ST,S}, S}(nothing, cont, score(cont), false)
end


Traversal

In [5]:

#####################
### ParserLogbook ###
#####################

struct ParserLogbook{R,C,T,CR,TR,ST,S}
    edges :: Dict{EdgeKey{R,ST}, Edge{R,ST,S,Constituent{R,C,T,CR,TR,ST,S}}}
    conts :: Dict{ConstituentKey{R,C}, Constituent{R,C,T,CR,TR,ST,S}}
end

@inline function ParserLogbook(grammar, n::Int, cyclic::Bool)
    R = cyclic ? CyclicRange{n} : IntervalRange
    C,T,CR,TR,ST,S = types(grammar)
    ParserLogbook(
        Dict{EdgeKey{R,ST}, Edge{R,ST,S,Constituent{R,C,T,CR,TR,ST,S}}}(),
        Dict{ConstituentKey{R,C}, Constituent{R,C,T,CR,TR,ST,S}}()
    )
end
discover!(logbook::ParserLogbook, edge::Edge) =
    logbook.edges[key(edge)] = edge
discover!(logbook::ParserLogbook, cont::Constituent) =
    logbook.conts[key(cont)] = cont
isdiscovered(logbook, key::EdgeKey) = haskey(logbook.edges, key)
isdiscovered(logbook, key::ConstituentKey) = haskey(logbook.conts, key)
getitem(logbook, key::EdgeKey) = logbook.edges[key]
getitem(logbook, key::ConstituentKey) = logbook.conts[key]

##################
### ParseChart ###
##################

struct ChartCell{R,C,T,CR,TR,ST,S}
    edges :: Dict{ST, Vector{Edge{R,ST,S,Constituent{R,C,T,CR,TR,ST,S}}}}
    conts :: Dict{C, Vector{Constituent{R,C,T,CR,TR,ST,S}}}
end
@inline function ChartCell(grammar, n::Int, cyclic::Bool)
    R = cyclic ? CyclicRange{n} : IntervalRange
    C,T,CR,TR,ST,S = types(grammar)
    ChartCell(
        Dict{ST, Vector{Edge{R,ST,S,Constituent{R,C,T,CR,TR,ST,S}}}}(),
        Dict{C, Vector{Constituent{R,C,T,CR,TR,ST,S}}}()
    )
end

struct ParseChart{R,C,T,CR,TR,ST,S}
    cells :: Vector{ChartCell{R,C,T,CR,TR,ST,S}}
end
# vector indices begin with 1
# item   indices begin with 0

edges(chart::ParseChart, edge::Edge)        = chart.cells[ _end(edge)+1].edges
edges(chart::ParseChart, cont::Constituent) = chart.cells[start(cont)+1].edges
conts(chart::ParseChart, edge::Edge)        = chart.cells[ _end(edge)+1].conts
conts(chart::ParseChart, cont::Constituent) = chart.cells[start(cont)+1].conts

@inline function push_or_init!(d::Dict, k, v)
    if haskey(d, k)
        push!(d[k], v)
    else
        d[k] = [v]
    end
end
insert!(chart::ParseChart, edge::Edge) =
    push_or_init!(edges(chart, edge), state(edge), edge)
insert!(chart::ParseChart, cont::Constituent) =
    push_or_init!(conts(chart, cont), category(cont), cont)

####################
### InsideAgenda ###
####################

struct InsideAgenda{R,C,T,CR,TR,ST,S}
    edge_queue :: PriorityQueue{Edge{R,ST,S,Constituent{R,C,T,CR,TR,ST,S}}, Int, Base.Order.ForwardOrdering}
    cont_queue :: PriorityQueue{Constituent{R,C,T,CR,TR,ST,S}, Int, Base.Order.ForwardOrdering}
end
function InsideAgenda(grammar, n::Int, cyclic)
    R = cyclic ? CyclicRange{n} : IntervalRange
    C,T,CR,TR,ST,S = types(grammar)
    InsideAgenda(
        PriorityQueue{Edge{R,ST,S,Constituent{R,C,T,CR,TR,ST,S}}, Int}(),
        PriorityQueue{Constituent{R,C,T,CR,TR,ST,S}, Int}()
    )
end
@inline function enqueue!(agenda::InsideAgenda, edge::Edge, just_used)
    agenda.edge_queue[edge] = priority(edge, just_used)
end
@inline function enqueue!(agenda::InsideAgenda, cont::Constituent, just_used)
    agenda.cont_queue[cont] = priority(cont, just_used)
end
@inline function next_is_edge(agenda::InsideAgenda)
    isempty(agenda.cont_queue) || !isempty(agenda.edge_queue) && peek(agenda.edge_queue)[2] < peek(agenda.cont_queue)[2]
end
dequeue_edge!(agenda::InsideAgenda)    = dequeue!(agenda.edge_queue)
dequeue_cont!(agenda::InsideAgenda)    = dequeue!(agenda.cont_queue)
isempty(agenda::InsideAgenda)          = isempty(agenda.edge_queue) && isempty(agenda.cont_queue)
priority(edge::Edge, just_used)        = 4 * length(edge) - 2*!(just_used) - 1
priority(cont::Constituent, just_used) = 4 * length(cont) - 2*!(just_used)




priority (generic function with 2 methods)

In [6]:
######################
### Parser Methods ###
######################

function create_or_update!(key, trav_or_comp, agenda, logbook)
    if isdiscovered(logbook, key)
        item = getitem(logbook, key)
        add!(item, trav_or_comp)
    else
        item = Item(key, trav_or_comp)
        discover!(logbook, item)
    end
    enqueue!(agenda, item, false)
    nothing
end

@noinline function initialize(terminals, grammar, epsilon, cyclic)
    n       = length(terminals)
    chart   = ParseChart([ChartCell(grammar, n, cyclic) for i in 0:(cyclic ? n-1 : n)])
    agenda  = InsideAgenda(grammar, n, cyclic)
    logbook = ParserLogbook(grammar, n, cyclic)

    for (i, terminal) in enumerate(terminals)
        for (category, rule, score) in completions(grammar, terminal)
            cont = Constituent(
                ConstituentKey(ItemRange(i-1, i, n, cyclic), category),
                TerminalCompletion(terminal, rule, score),
                grammar
            )
            discover!(logbook, cont)
            enqueue!(agenda, cont, false)
        end
        if !ismissing(epsilon)
            for (category, rule, score) in completions(grammar, epsilon)
                cont = Constituent(
                    ConstituentKey(ItemRange(i-1, i-1, n, cyclic), category),
                    TerminalCompletion(epsilon, rule, score),
                    grammar
                )
                discover!(logbook, cont)
                enqueue!(agenda, cont, false)
            end
        end
    end
    if !ismissing(epsilon) && !cyclic
        for (category, rule, score) in completions(grammar, epsilon)
            cont = Constituent(
                ConstituentKey(ItemRange(n, n, n, cyclic), category),
                TerminalCompletion(epsilon, rule, score),
                grammar
            )
            discover!(logbook, cont)
            enqueue!(agenda, cont, false)
        end
    end
    chart, agenda, logbook
end

@inline function do_fundamental_rule!(
        edge::Edge, chart, agenda, logbook, grammar, cyclic
    )
    for category in keys(conts(chart, edge))
        if is_possible_transition(grammar, state(edge), category)
            for cont in conts(chart, edge)[category]
                if !cyclic || concatenable(range(edge), range(cont))
                    trav      = Traversal(edge, cont)
                    new_state = transition(grammar, state(edge), category)
                    key       = EdgeKey(range(edge) * range(cont), new_state)
                    create_or_update!(key, trav, agenda, logbook)
                end
            end
        end
    end
    nothing
end

@inline function do_fundamental_rule!(
        cont::Constituent, chart, agenda, logbook, grammar, cyclic
    )
    for state in keys(edges(chart, cont))
        if is_possible_transition(grammar, state, category(cont))
            for edge in edges(chart, cont)[state]
                if !cyclic || concatenable(range(edge), range(cont))
                    trav      = Traversal(edge, cont)
                    new_state = transition(grammar, state, category(cont))
                    key       = EdgeKey(range(edge) * range(cont), new_state)
                    create_or_update!(key, trav, agenda, logbook)
                end
            end
        end
    end
    nothing
end

@inline function introduce_edge!(cont, agenda, logbook, grammar)
    if is_possible_transition(grammar, startstate(grammar), category(cont))
        state = transition(grammar, startstate(grammar), category(cont))
        key   = EdgeKey(range(cont), state)
        create_or_update!(key, Traversal(cont), agenda, logbook)
    end
    nothing
end

@noinline function complete_edge!(edge, agenda, logbook::ParserLogbook, grammar)
    C, T, CR, TR, ST, S = types(grammar) #added
    for (category::C, rule::CR, s::S) in completions(grammar, state(edge))
        key  = ConstituentKey(range(edge), category)
        comp = EdgeCompletion(edge, rule, score(edge) * s)
        create_or_update!(key, comp, agenda, logbook)
    end
    nothing
end

@noinline function process_edge!(
        edge, chart, agenda, logbook, grammar, max_pop_num, cyclic
    )
    s = score(edge)
    edge.insidepopnumber += 1
    if s ≈ lastpopscore(edge) || insidepopnumber(edge) == max_pop_num
        if !isfinal(grammar, state(edge))
            insert!(chart, edge)
        end
        edge.score = nothing
        #edge.score = Nullable(s) # finish the edge
        do_fundamental_rule!(edge, chart, agenda, logbook, grammar, cyclic)
    else
        complete_edge!(edge, agenda, logbook, grammar)
        edge.lastpopscore = s
        enqueue!(agenda, edge, true)
    end
    nothing
end

@noinline function process_cont!(
        cont, chart, agenda, logbook, grammar, max_pop_num, cyclic
    )
    s = score(cont)
    cont.insidepopnumber += 1
    if s ≈ lastpopscore(cont) || insidepopnumber(cont) == max_pop_num
        insert!(chart, cont)
        cont.score = nothing
        #cont.score = Nullable(s) # finish the constituent
        do_fundamental_rule!(cont, chart, agenda, logbook, grammar, cyclic)
    else
        introduce_edge!(cont, agenda, logbook, grammar)
        cont.lastpopscore = s
        enqueue!(agenda, cont, true)
    end
    nothing
end

@noinline function loop!(chart, agenda, args...)
    while !isempty(agenda)
        if next_is_edge(agenda)
            process_edge!(dequeue_edge!(agenda), chart, agenda, args...)
        else
            process_cont!(dequeue_cont!(agenda), chart, agenda, args...)
        end
    end
end


loop! (generic function with 1 method)

In [7]:
ascend = CFRules(1:9) do i
    [i, i+1]
end
double = CFRules(1:10) do i
    [i, i]
end
terminate = CFRules(1:10) do i
    [string(i)]
end

CFRules(rules3)

In [8]:
grammar = CFGrammar([ascend, double], [terminate], [1])

CFGrammar{Int64,String,SimpleCond{Int64,DirCat{CFRules{Int64,RHS} where RHS,Float64},Array{CFRules{Int64,RHS} where RHS,1}},typeof(identity)}(CompletionAutomaton{Int64,Tuple{Int64,CFRules{Int64,Int64}}}(Dict{Int64,Int64}[Dict(7 => 2,4 => 4,9 => 6,10 => 23,2 => 8,3 => 10,5 => 12,8 => 14,6 => 16,1 => 18…), Dict(7 => 20,8 => 3), Dict(), Dict(4 => 21,5 => 5), Dict(), Dict(9 => 22,10 => 7), Dict(), Dict(2 => 25,3 => 9), Dict(), Dict(4 => 11,3 => 26)  …  Dict(), Dict(), Dict(10 => 24), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict()], Array{Tuple{Int64,CFRules{Int64,Int64}},1}[[], [], [(7, CFRules(rules1))], [], [(4, CFRules(rules1))], [], [(9, CFRules(rules1))], [], [(2, CFRules(rules1))], []  …  [(4, CFRules(rules2))], [(9, CFRules(rules2))], [], [(10, CFRules(rules2))], [(2, CFRules(rules2))], [(3, CFRules(rules2))], [(5, CFRules(rules2))], [(8, CFRules(rules2))], [(6, CFRules(rules2))], [(1, CFRules(rules2))]]), [1], Dict("8" => [(8, CFRules(rules3))],"4" => [(4, CFRules(rules3))]

In [9]:
ascend.mappings

Dict{Int64,Array{Int64,1}} with 9 entries:
  7 => [7, 8]
  4 => [4, 5]
  9 => [9, 10]
  2 => [2, 3]
  3 => [3, 4]
  5 => [5, 6]
  8 => [8, 9]
  6 => [6, 7]
  1 => [1, 2]

In [10]:
double.mappings

Dict{Int64,Array{Int64,1}} with 10 entries:
  7  => [7, 7]
  4  => [4, 4]
  9  => [9, 9]
  10 => [10, 10]
  2  => [2, 2]
  3  => [3, 3]
  5  => [5, 5]
  8  => [8, 8]
  6  => [6, 6]
  1  => [1, 1]

In [11]:
terminate.mappings

Dict{Int64,Array{String,1}} with 10 entries:
  7  => ["7"]
  4  => ["4"]
  9  => ["9"]
  10 => ["10"]
  2  => ["2"]
  3  => ["3"]
  5  => ["5"]
  8  => ["8"]
  6  => ["6"]
  1  => ["1"]

In [12]:
#prob(p::LogProb) = p

###################
### ParseForest ###
###################

struct ParseForest{R,C,T,CR,TR,ST,S}
    heads     :: Vector{Constituent{R,C,T,CR,TR,ST,S}}
    terminals :: Vector{T}
end


In [13]:
function ParseForest(chart::ParseChart, terminals, grammar, cyclic)
    if cyclic
        ParseForest(
            vcat(
                [
                    vcat(
                        map(collect(keys(cell.conts))) do category
                            filter(cell.conts[category]) do cont
                                length(cont) == length(terminals) &&
                                    category in startsymbols(grammar)
                            end
                        end...
                    )
                    for cell in chart.cells
                ]...
            )
            ,
            terminals
        )
    else
        ParseForest(
            vcat(
                map(collect(keys(chart.cells[1].conts))) do category
                    filter(chart.cells[1].conts[category]) do cont
                        length(cont) == length(terminals) &&
                            category in startsymbols(grammar)
                    end
                end...
            ),
            terminals
        )
    end
end


ParseForest

In [14]:
heads(forest::ParseForest) = forest.heads
iscomplete(forest::ParseForest) = !isempty(forest.heads)
score(forest::ParseForest) = sum(score(h) for h in forest.heads)

score (generic function with 7 methods)

In [15]:
@noinline function run_chartparser(
        terminals, grammar; epsilon=missing, max_pop_num=4, cyclic=false
    )
    C, T, CR, TR, ST, S = types(grammar)
    chart, agenda, logbook = initialize(
        T.(terminals), grammar, epsilon, cyclic)
    loop!(chart, agenda, logbook, grammar, max_pop_num, cyclic)
    ParseForest(chart, T.(terminals), grammar, cyclic)
end

run_chartparser (generic function with 1 method)

In [18]:
forest = run_chartparser(["1","1", "1"], grammar)

ParseForest{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}(Constituent{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}[Constituent{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}(ConstituentKey{IntervalRange,Int64}(IntervalRange(0, 3), 1), nothing, LogProb(0.008230452674897129), 2, nothing, EdgeCompletion{Edge{IntervalRange,Int64,LogProb,Constituent{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}},CFRules{Int64,Int64},LogProb}[EdgeCompletion{Edge{IntervalRange,Int64,LogProb,Constituent{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}},CFRules{Int64,Int64},LogProb}(Edge{IntervalRange,Int64,LogProb,Constituent{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}}(EdgeKey{IntervalRange,Int64}(IntervalRange(0, 3), 30), nothing, LogProb(0.024691358024691374), 2, Traversal{Edge{

In [19]:
typeof(forest.heads[1])



Constituent{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}

In [20]:
#################
### ParseTree ###
#################

struct ParseTree{C,T,CR,TR,ST,S,R}
    head  :: Constituent{C,T,CR,TR,ST,S}
    rules :: Vector{R}
end

In [21]:


#probscore = prob ∘ score

function sample_tree(forest::ParseForest)
    @assert iscomplete(forest)
    head = categorical_sample(heads(forest), score.(heads(forest)))
    ParseTree(head, sample_tree(head))
end

function sample_conts(edge::Edge)
    trav = categorical_sample(traversals(edge), score.(traversals(edge)))
    if hasedge(trav)
        [sample_conts(trav.edge); cont(trav)]
    else
        [cont(trav)]
    end
end

function sample_tree(cont::Constituent{R,C,T,CR,TR,ST,S}) where {R,C,T,CR,TR,ST,S}
    if hasterminal(cont)
        if isempty(completions(cont))
            comp = terminal_completion(cont)
            [(category(cont), rule(comp))]
        else
            comp = categorical_sample(
                [terminal_completion(cont); completions(cont)],
                [score(terminal_completion(cont)); score.(completions(cont))]
            )
            if comp isa TerminalCompletion
                [(category(cont), rule(comp))]
            else
                conts = sample_conts(edge(comp))
                [(category(cont), rule(comp)); vcat([sample_tree(cont) for cont in conts]...)]
            end
        end
    else
        comp = categorical_sample(completions(cont), score.(completions(cont)))
        conts = sample_conts(edge(comp))
        [(category(cont), rule(comp)); vcat([sample_tree(cont) for cont in conts]...)]
    end
end

best_choice(tokens, weights) = tokens[findmax(weights)[2]]

function best_tree(forest::ParseForest)
    @assert iscomplete(forest)
    head = best_choice(heads(forest), score.(heads(forest)))
    ParseTree(head, best_tree(head))
end

function best_conts(edge::Edge)
    trav = best_choice(traversals(edge), score.(traversals(edge)))
    if hasedge(trav)
        [best_conts(get(trav.edge)); trav.cont]
    else
        [cont(trav)]
    end
end

function best_tree(cont::Constituent{R,C,T,CR,TR,ST,S}) where {R,C,T,CR,TR,ST,S}
    if hasterminal(cont)
        if isempty(completions(cont))
            comp = terminal_completion(cont)
            [(category(cont), rule(comp))]
        else
            comp = best_choice(
                [terminal_completion(cont); completions(cont)],
                [probscore(terminal_completion(cont)); score.(completions(cont))]
            )
            if comp isa TerminalCompletion
                [(category(cont), rule(comp))]
            else
                conts = best_conts(edge(comp))
                [(category(cont), rule(comp)); vcat([best_tree(cont) for cont in conts]...)]
            end
        end
    else
        comp = best_choice(completions(cont), score.(completions(cont)))
        conts = best_conts(edge(comp))
        [(category(cont), rule(comp)); vcat([best_tree(cont) for cont in conts]...)]
    end
end



best_tree (generic function with 2 methods)

In [26]:
t = sample_tree(forest)

ParseTree{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,Tuple{Int64,CFRules{Int64,RHS} where RHS}}(Constituent{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}(ConstituentKey{IntervalRange,Int64}(IntervalRange(0, 3), 1), nothing, LogProb(0.008230452674897129), 2, nothing, EdgeCompletion{Edge{IntervalRange,Int64,LogProb,Constituent{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}},CFRules{Int64,Int64},LogProb}[EdgeCompletion{Edge{IntervalRange,Int64,LogProb,Constituent{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}},CFRules{Int64,Int64},LogProb}(Edge{IntervalRange,Int64,LogProb,Constituent{IntervalRange,Int64,String,CFRules{Int64,Int64},CFRules{Int64,String},Int64,LogProb}}(EdgeKey{IntervalRange,Int64}(IntervalRange(0, 3), 30), nothing, LogProb(0.024691358024691374), 2, Traversal{Edge{IntervalRange,Int64,LogProb,Constituent{IntervalRange,Int64,Strin

In [22]:
include("../generalized-chart-parser/Trees.jl")
using .Trees: Tree, EmptyTree, TreeNode,
       isterminal, insert_child!,
       tree

In [23]:
using LightGraphs: Graph, add_vertex!, add_edge!, nv

import TikzGraphs: plot

In [27]:
t.rules

5-element Array{Tuple{Int64,CFRules{Int64,RHS} where RHS},1}:
 (1, CFRules(rules2))
 (1, CFRules(rules2))
 (1, CFRules(rules3))
 (1, CFRules(rules3))
 (1, CFRules(rules3))

In [25]:
t.head

UndefVarError: UndefVarError: t not defined

In [30]:
root = TreeNode(category(t.head), Any)

TreeNode{Any}(1, EmptyTree{Any}(), TreeNode{Any}[])

In [28]:
root.data[1]

1

In [35]:
function foo!(nodes, rules)
    if isempty(rules)
        nothing
    else
        c, rule = rules[1]
        node = nodes[1]
        category = node.data
        if isapplicable(rule, category)
            for c in rule(category)
                insert_child!(node, c)
            end
            foo!([node.children; nodes[2:end]], rules[2:end])
        else
            foo!(nodes[2:end], rules)
        end
    end
end

foo! (generic function with 1 method)

In [36]:
root = TreeNode(category(t.head), Any)
foo!([root], t.rules)

In [37]:
root

TreeNode{Any}(1, EmptyTree{Any}(), TreeNode{Any}[TreeNode{Any}(1, TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[TreeNode{Any}(1, TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[TreeNode{Any}("1", TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[])]), TreeNode{Any}(1, TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[TreeNode{Any}("1", TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[])])]), TreeNode{Any}(1, TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[TreeNode{Any}("1", TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[])])])

In [33]:
function convert(t::ParseTree{R,C,T,CR,TR,ST,S}) where {R,C,T,CR,TR,ST,S}
    function foo!(nodes, rules)
        if isempty(rules)
            nothing
        else
            c, rule = rules[1]
            node = nodes[1]
            category = node.data
            if isapplicable(rule, category)
                for c in rule(category)
                    insert_child!(node, c)
                end
                foo!([node.children; nodes[2:end]], rules[2:end])
            else
                foo!(nodes[2:end], rules)
            end
        end
    end
    root = TreeNode(category(t.head), Any)
    foo!([root], t.rules)
    root
end

function show(io::IO, t::ParseTree)
    print(io, map(first, convert(t), uniform_type=false))
end

show (generic function with 2 methods)

In [38]:
t_tree = convert(t)

TreeNode{Any}(1, EmptyTree{Any}(), TreeNode{Any}[TreeNode{Any}(1, TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[TreeNode{Any}(1, TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[TreeNode{Any}("1", TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[])]), TreeNode{Any}(1, TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[TreeNode{Any}("1", TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[])])]), TreeNode{Any}(1, TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[TreeNode{Any}("1", TreeNode{Any}(#= circular reference @-3 =#), TreeNode{Any}[])])])

In [45]:
latex_replacements = Dict(
    "^" => "\\^",
    "#" => "\\#")

function latex_replace(str)
    s = str
    for k in enumerate(keys(latex_replacements))
        s = replace(s, latex_replacements[k])
    end
    s
end

latex_replace (generic function with 1 method)

In [46]:


function tree_to_graph_and_labels(tree)
    graph = Graph(1)
    labels = [string('$', latex_replace(string(tree.data)), '$')]
    for (i, c) in enumerate(children(tree))
        add_vertex!(graph)
        add_edge!(graph, 1, nv(graph))
        tree_to_graph_and_labels!(c, graph, labels)
    end
    graph, labels
end

function tree_to_graph_and_labels!(tree, graph, labels)
    n = nv(graph)
    push!(labels, string('$', latex_replace(string(tree.data)), '$'))
    for (i, c) in enumerate(children(tree))
        add_vertex!(graph)
        add_edge!(graph, n, nv(graph))
        tree_to_graph_and_labels!(c, graph, labels)
    end
    graph, labels
end

plot(tree::Tree) = plot(tree_to_graph_and_labels(tree)...)
plot(tree::ParseTree) = plot(map(first, convert(tree), uniform_type=false))


plot (generic function with 6 methods)

In [47]:
graph, labels = tree_to_graph_and_labels(t_tree)

MethodError: MethodError: objects of type Dict{String,String} are not callable