In [1]:
using Model
using Dictionaries, SplitApplyCombine
using Distributions
using LogExpFunctions: softmax
using FillArrays
using LinearAlgebra
using TransformVariables
using StructArrays
using CairoMakie

In [2]:
decisions = loaddata("../../../data/processed/json_augmented");

In [3]:
struct MixedMembershipCategoricalModel <: Model.AbstractDecisionModel
    ys::Vector{Int}
    js::Vector{Vector{Int}}
    J::Int
    labels::Vector{String}
end

function MixedMembershipCategoricalModel(decisions::Vector{Decision})
    outcomes = ["annulled", "claim dismissed", "partially annulled"]
    ys = [findfirst(==(label(outcome(d))), outcomes) for d in decisions]    
    js = map(d -> id.(judges(d)), decisions) 
    J = length(unique(reduce(vcat, js)))
    MixedMembershipCategoricalModel(ys, js, J, outcomes)
end;

In [4]:
function loglikelihood(θ; problem)
    (; ys, js) = problem
    (; β) = θ

    sum(eachindex(ys)) do i
        jsᵢ = js[i]
        ηᵢ = sum(@views β[:,j] for j in jsᵢ)
        logpdf(Categorical(softmax(vcat(0.0, ηᵢ))), ys[i])
    end
end

function logprior(θ)
    (; β) = θ
    k = size(β, 1)
    d = MvNormal(Zeros(k), I)
    sum(eachcol(β)) do p
        logpdf(d, p)
    end
end

logprior (generic function with 1 method)

In [5]:
function (problem::MixedMembershipCategoricalModel)(θ)
    loglikelihood(θ; problem) + logprior(θ)
end

In [6]:
function Model.transformation(problem::MixedMembershipCategoricalModel)
    as((β=as(Array, asℝ, (2, problem.J)),))
end

In [7]:
problem = MixedMembershipCategoricalModel(decisions)

MixedMembershipCategoricalModel([3, 3, 3, 2, 3, 1, 2, 1, 3, 3  …  2, 2, 1, 3, 3, 1, 3, 3, 1, 2], [[54, 184, 61, 110, 200], [54, 6, 36, 76, 200], [168, 184, 151, 6, 61], [168, 184, 61, 110, 151], [169, 6, 36, 76, 54], [168, 151, 6, 36, 76], [169, 6, 36, 76, 54], [169, 184, 61, 54, 110], [100, 84, 9, 34, 80], [51, 139, 37, 173, 110]  …  [33, 123, 188, 4, 115], [33, 188, 4, 115, 180], [33, 188, 166, 3, 29], [33, 123, 188, 4, 115], [140, 49, 190, 167, 158], [140, 179, 158, 30, 199], [139, 7, 158, 30, 68], [140, 139, 7, 30, 68], [140, 158, 35, 190, 167], [140, 139, 35, 190, 167]], 200, ["annulled", "claim dismissed", "partially annulled"])

In [10]:
post = Model.sample(problem, 500, 4)