In [32]:
using DataFrames
using DataFramesMeta
using GLM
using CSV
using StatsFuns
using Statistics

In [44]:
# TODO
# figure out how to read in types
#data = vec(readdlm("coltypes.csv", ',', String))
#add type validations df::DataFrame
# add exceptions for unique ids, etc
# deal with missings
# replacement
# exact covariates
# iptw

In [34]:
"""
    make_dataset(df, _id_col_name, label_col_name, case, control, X)

Return a DataFrame containg only the identifier column, label column, and covariates

# Arguments
- `df::DataFrame`: DataFrame to be manipulated
- `_id_col_name:String`: Name of column containing _id's
- `label_col_name:String`: Name of column containing labels
- `case:String`: Label which represents case status
- `control:String`: Label which represents control status
"""
function make_dataset(df, _id_col_name, label_col_name, case, control, X)
    @subset!(df, $label_col_name  .== case .|| $label_col_name .== control)
    df = df[:, vcat(_id_col_name, label_col_name, X)]
    return df
end

make_dataset

In [35]:
"""
    fit_logit(df, label)

Return results from a logit model of `label` ~ rest of cols
"""
function fit_logit(df, _id_col_name, label_col_name, case)
    df = @transform(df, $label_col_name = ($label_col_name .== case))
    X = term.(names(df[:, Not([_id_col_name,label_col_name])]))
    mod = glm(term(label_col_name) ~ foldl(+, X), df, Binomial(), LogitLink())
    return mod
end

fit_logit

In [36]:
"""
    propensity_scores(df, _id, labels, model, _id_col_name="_id", label_colname="Label", ps_col_name="propensityScore")

Add propensity score column to DataFrame

# Arguments
- `df::DataFrame`: DataFrame to append column
- `_id:String`: Name of column containing _id's
- `label_col_name:String`: Name of column containing labels
- `case:String`: Label which represents case status
- `control:String`: Label which represents control status
"""
function propensity_scores(df, model, _id_col_name, label_col_name; ps_col_name="propensityScore")
    ps = DataFrame(ps_col_name => predict(model))
    return hcat(df[:, [_id_col_name,label_col_name]], ps)
end

propensity_scores

In [37]:
function greedy_match(df,n,exact,replacement,_id_col_name,label_col_name,ps_col_name,case,control;caliper="calc")
    if caliper == "calc"
        caliper = 0.2*std(logit.(df[:, ps_col_name]))
    end
    cases = @subset(df, $label_col_name .== case)
    ps_cases = collect(enumerate(cases[:, ps_col_name]))
    controls = @subset(df, $label_col_name .== control)
    ps_controls = collect(enumerate(controls[:, ps_col_name]))
    cases_to_drop = []
    if replacement
        controls_to_keep = []
    end
    for ps in ps_cases
        diffs = (abs.(last.(ps_controls) .- last(ps)))
        candidate_idx, candidate_diffs = (findall(x -> x <= caliper, diffs), diffs[diffs .<= caliper])
        if length(candidate_idx) >= n
            sorted_candidates = sort(collect(zip(candidate_idx, candidate_diffs)); by=last)
            matches = sorted_candidates[1:n]
            !replacement ? deleteat!(ps_controls, sort(first.(matches))) : append!(controls_to_keep, first.(matches))
        elseif length(candidate_idx) > 0 && !exact
            matches = candidate_idx
            !replacement ? deleteat!(ps_controls, sort(candidate_idx)) : append!(controls_to_keep, first.(matches))
        else
            append!(cases_to_drop, first(ps))
        end
    end
    if replacement
        return vcat(cases[Not(cases_to_drop), :], controls[controls_to_keep, :])
    else
        return vcat(cases[Not(cases_to_drop), :], controls[Not(first.(ps_controls)), :])
    end
end

greedy_match (generic function with 1 method)

In [38]:
function merge_propensity_scores(df, ps_df, on)
    return leftjoin(df, ps_df, on = on)
end

merge_propensity_scores (generic function with 1 method)

In [39]:
function add_matches(df, match_df, _id_col_name, label_col_name)
    to_add = df[findall(in(match_df[:, _id_col_name]), df[:, _id_col_name]), :]
    to_add[:, label_col_name] = string.(to_add[:, label_col_name]," Matched")
    return vcat(df, to_add)
end

add_matches (generic function with 1 method)

In [40]:
# types = [Float64, String, String, Float64, Float64,Float64, Float64, Float64]
# file = "example_data/crabs.txt"
# df = CSV.read(file, DataFrame; delim = '\t', header = true, types = types)
# ps_df = make_dataset(df, "id", "sp","B","O",["sex","FL","RW","BD"])
# mod = fit_logit(ps_df,"id","sp", "B")
# ps_df = propensity_scores(ps_df, mod, "id", "sp")
# match_df = greedy_match(ps_df, 1, true, false, "id", "sp", "propensityScore", "B", "O", caliper=1)
# df = merge_propensity_scores(df, ps_df, ["id","sp"])
# df = add_matches(df, match_df, "id", "sp")
# return

In [41]:
function main(file,types,
              _id_col_name,label_col_name,case,control,covariates,
              n,n_exact,replacement;
              ps_col_name="propensityScore",caliper="calc")
    df = CSV.read(file, DataFrame; delim = '\t', header = true, types = types)
    ps_df = make_dataset(df, _id_col_name, label_col_name, case, control, covariates)
    mod = fit_logit(ps_df, _id_col_name, label_col_name, case)
    ps_df = propensity_scores(ps_df, mod, _id_col_name, label_col_name, ps_col_name=ps_col_name)
    match_df = greedy_match(ps_df, n, n_exact, replacement, _id_col_name, label_col_name, ps_col_name, case, control, caliper=caliper)
    df = merge_propensity_scores(df, ps_df, [_id_col_name, label_col_name])
    return add_matches(df, match_df, _id_col_name, label_col_name)
end

main (generic function with 1 method)

In [42]:
types = [Float64, String, String, Float64, Float64,Float64, Float64, Float64]
file = "example_data/crabs.txt"
X = ["sex","FL","RW","BD"]
main(file,types,"id","sp","B","O",X,1,true,false)

Unnamed: 0_level_0,id,sp,sex,FL,RW,CL,CW,BD,propensityScore
Unnamed: 0_level_1,Float64,String,String,Float64,Float64,Float64,Float64,Float64,Float64?
1,1.0,B,M,8.1,6.7,16.1,19.0,7.0,0.533401
2,2.0,B,M,8.8,7.7,18.1,20.8,7.4,0.708732
3,3.0,B,M,9.2,7.8,19.0,22.4,7.7,0.604978
4,4.0,B,M,9.6,7.9,20.1,23.1,8.2,0.476071
5,5.0,B,M,9.8,8.0,20.3,23.0,8.2,0.452558
6,6.0,B,M,10.8,9.0,23.0,26.5,9.8,0.447934
7,7.0,B,M,11.1,9.9,23.8,27.1,9.8,0.738642
8,8.0,B,M,11.6,9.1,24.5,28.4,10.4,0.21077
9,9.0,B,M,11.8,9.6,24.2,27.8,9.7,0.385342
10,10.0,B,M,11.8,10.5,25.2,29.3,10.3,0.735677
