# Notebook for summarizing Nudibranch phylogenetic factor analysis

# Part 1: Data pre-processing
## 1.1 - Import data from CSV file

In [None]:
using DataFrames, CSV, Statistics, BEASTDataPrep, BeastUtils.MatrixUtils, Clustering
import Random
Random.seed!(666) # number of the BEAST

cd(@__DIR__)

raw_data_path = "Nelson Bay 2cm raw_manually curated_26.7.22.csv"

raw_data = CSV.read(raw_data_path, DataFrame, missingstring="NA");

Replace species names.
There are some mistmatches between the original data file and the final phylogeny.

In [None]:
replacements = [
    "Theracera pennigera" => "Thecacera pennigera",
    "Chromodoris cf striatella QLD" => "Chromodoris cf striatella",
    "Tenelia sibogae" => "Tenellia sibogae",
    "Dendrodoris gunnamatta" => "Dendrodoris krusensterni",
    "Tayuva lilacina" => "Discodoris sp",
    "Mariona sp" => "Marionia sp",
    "Pleurobranchus peronii" => "Pleurobranchus peroni"
]

for r in replacements
    replace!(raw_data.Species, r)
end

CSV.write("corrected.csv", raw_data)

## 1.2 - Subset to QCPA Data

In [None]:
col_types = eltype.(eachcol(raw_data))


numeric_cols = findall(x -> x <: Union{Float64, Missing}, col_types)


n_taxa, n_cols = size(raw_data)

nonnumeric_cols = setdiff(1:n_cols, numeric_cols)
names(raw_data)[nonnumeric_cols] # non-numeric values in the data

## 1.3 - Data Transformations
We need to figure out appropriate transformations of the data to "normalize" it.

In [None]:
n_numeric = length(numeric_cols)
numeric_names = names(raw_data)[numeric_cols]

metadata = DataFrame(trait = numeric_names, mean = zeros(n_numeric), sd = zeros(n_numeric),
    maximum = zeros(n_numeric), minimum = zeros(n_numeric),
    perc_missing = zeros(n_numeric))


for i = 1:n_numeric
    col = numeric_cols[i]
    col_data = raw_data[:, col]
    present_inds = (!ismissing).(col_data)
    complete_data = col_data[present_inds]
    metadata.mean[i] = mean(complete_data)
    metadata.sd[i] = std(complete_data)
    metadata.maximum[i] = maximum(complete_data)
    metadata.minimum[i] = minimum(complete_data)
    metadata.perc_missing[i] = 1.0 - sum(present_inds) / n_taxa
end

In [None]:
transform_svg = "transform_check.svg"


pvals = plot_transformed_data(raw_data[!, numeric_cols], overwrite=true, svg_path = transform_svg);

In [None]:
open(transform_svg) do f
   display("image/svg+xml", read(f, String))
end

In [None]:
function find_cat(x::String)
    return split(x, '.')[1]
end

categories = unique(find_cat.(pvals.trait))
n_cats = length(categories)
cat_dfs = Vector{DataFrame}(undef, n_cats)
for i = 1:n_cats
    cat = categories[i]
    rows = findall(x -> startswith(x, cat), pvals.trait)
    cat_dfs[i] = pvals[rows, :]
    display(cat_dfs[i])
end

Based on above, it looks like `Col` and `Lum` should be log-transformed, `GabRat` should be logit-transformed, and `BSA` should be un-transformed.

CAA and VCA are ambiguous. As CAA has a high number of true `0` values that would need to be adjusted for log-transformation, we do not transform CAA. Some of the log-transformed VCA quantities seem to fit normal data well, so we opt to log-transform them.


In [None]:
log_values = ["Col", "Lum", "VCA"]
logit_values = ["GabRat"]
other_values = ["BSA", "CAA"]

function logit(x::Float64)
    return log(x / (1 - x))
end

function logit(x::Missing)
    return missing
end

transformed_data = DataFrame()
col_names = names(raw_data)

for i = 1:size(raw_data, 2)
    col = raw_data[:, i]
    col_name = col_names[i]
    col_start = split(col_name, '.')[1]
    if eltype(col) <: Real
        if col_start in log_values
            transformed_col = log.(col)
        elseif col_start in logit_values
            transformed_col = logit.(col)
        elseif col_start in other_values
            transformed_col = copy(col)
        else
            error("unknown transformation")
        end
    else
        transformed_col = copy(col)
    end
    
    transformed_data[!, col_name] = transformed_col
end

## 1.4 - Colinearity

Many QCPA quanties are highly correlated and can be removed from the analysis.

In [None]:
function check_colinear(df::DataFrame; threshold::Float64 = 0.9)
    types = eltype.(eachcol(df))
    @show length(types)
    numeric = findall(x -> x <: Union{Real, Missing}, types)
    sub_df = df[:, numeric]
    
    cols = names(sub_df)
    n, p = size(sub_df)
    @show n, p
    C = MatrixUtils.missing_cor(sub_df)
    
    #find colinearity structure
    dists = ones(p, p) - abs.(C)
    clust = hclust(dists, linkage=:complete)
    assignments = cutree(clust, h = 1.0 - threshold)

    n_clusters = maximum(assignments)
    clusters = [Int[] for i = 1:n_clusters]
    for i = 1:p
        push!(clusters[assignments[i]], i)
    end
    
    
    
    
    cor_df = DataFrame(trait = cols)
    for i = 1:p
        cor_df[!, cols[i]] = C[:, i]
    end
    
    
    
    return cor_df, clusters
end

trans_cor_df, trans_clusters = check_colinear(transformed_data)
trans_clusters
traits = trans_cor_df.trait

ind = 0
for clust in trans_clusters
    ind += 1
    println("cluster $ind: $(traits[clust])")
end

There is high co-linearity between most traits that are split by vertical and horizontal patterns. For now, I'm going to just use the base value for those traits. If there are two traits in different categories that are highly correlated, I take both.

In [None]:
cluster_representatives = Dict{Int, Union{String,Vector{String}}}(
    1 => "Col.mean",
    2 => "Col.sd",
    3 => "Col.CoV",
    4 => ["Col.skew", "Col.kurtosis"],
    5 => "Lum.mean",
    6 => "Lum.sd",
    7 => "Lum.CoV",
    8 => "Lum.skew",
    9 => "Lum.kurtosis",
    10 => "GabRat",
    11 => ["CAA.Sc", "CAA.Hc"],
    12 => "CAA.Jc",
    13 => "CAA.St",
    14 => ["CAA.Jt", "CAA.Scpl"],
    15 => "CAA.Qc",
    16 => "CAA.Ht",
    17 => "CAA.Qt",
    18 => "CAA.Qcpl",
    18 => "CAA.C",
    20 => "CAA.Qc.Hrz",
    21 => "CAA.Qt.Hrz",
    22 => "CAA.PT",
    23 => "CAA.Asp",
    24 => "VCA.ML",
    25 => "VCA.sL",
    26 => "VCA.CVL",
    27 => ["VCA.MDmax", "VCA.MSsat"],
    28 => ["VCA.sDmax", "VCA.sSsat"],
    29 => ["VCA.CVDmax", "VCA.CVSsat"],
    30 => "VCA.MSL",
    31 => "VCA.sSL",
    32 => "VCA.CVSL",
    33 => ["VCA.MS", "VCA.sS"],
    34 => "VCA.CVS",
    35 => ["BSA.BML", "BSA.BMSL"],
    36 => ["BSA.BsL", "BSA.BsSL"],
    37 => ["BSA.BCVL", "BSA.BCVSL"],
    38 => ["BSA.BMDmax", "BSA.BMSsat"],
    39 => ["BSA.BsDmax", "BSA.BsSsat"],
    40 => "BSA.BCVDmax",
    41 => "BSA.BCVSsat",
    42 => "BSA.BMS",
    43 => "BSA.BsS",
    44 => "BSA.BCVS"
)

keep_traits = String[]
for v in values(cluster_representatives)
    keep_traits = [keep_traits; v]
end

for trait in keep_traits
    @assert trait in traits
end

## 1.5 - Create a new data frame to be used in PFA

In [None]:
trimmed_data = [DataFrame(taxon = raw_data.Individual) transformed_data[!, keep_traits]]
full_data = [DataFrame(taxon = raw_data.Individual) transformed_data[!, numeric_cols]];

## Conforming the tree and trait data

The first step is to replace the species names with individual names.

In [None]:
original_newick_path = "final_tree_13_3_23.tre"
newick = read(original_newick_path, String)


all_species = unique(raw_data.Species)


for species in all_species
    species_ = join(split(species), '_')
#     @show species_
    if occursin(species_ * ":", newick)
        inds = findall(isequal(species), raw_data.Species)
        individuals = raw_data.Individual[inds]
        newick_block = "(" * join([x * ":0.0" for x in individuals], ',') * ")"
        newick = replace(newick, species_ * ":" => newick_block * ":")
    else
        @warn "Species $species not in tree"
    end
end

Make sure that the tree and trait data all have the same taxa.

In [None]:
conformed_trimmed_data, conformed_newick = conform_tree_and_data(trimmed_data, newick)
conformed_colinear_data, _ = conform_tree_and_data(full_data, newick)

colinear_csv = "data_12_06.csv"
noColinear_csv = "data_noColinear_12_06.csv"

CSV.write(noColinear_csv, conformed_trimmed_data)
CSV.write(colinear_csv, conformed_colinear_data);

trimmed_nwk = "newick_12_06.nwk"

write(trimmed_nwk, conformed_newick)

# Part 2: Phylogenetic Factor Analysis (PFA)

## 2.1 - Run PFA

In [None]:
using BEASTXMLConstructor.NewBEASTXMLConstructor
using BEASTDataPrep
using CSV
using DataFrames
using BeastUtils.RunBeast
using LinearAlgebra

day_df = raw_data[!, ["Individual", "Species", "daytime"]]


species = unique(day_df.Species)
n_species = length(species)
is_day = zeros(Int, n_species)

@assert sort(unique(day_df.daytime)) == ["D", "N"] # make sure there aren't any other codes I'm not looking for
for i = 1:n_species
    daytime = @view day_df.daytime[day_df.Species .== species[i]]
#     @show species[i]
#     @show daytime
    day_count = count(isequal("D"), daytime)
    day_count == 0 ? is_day[i] = 0 : is_day[i] = 1
end

assignments = Dict(species[i] => is_day[i] for i = 1:n_species)

assignments["Discodoris sp"] = 0 # the only "day" species were found hiding under rocks

day_df.species_day = [assignments[x] for x in day_df.Species]


trimmed_nwk = "newick_12_06.nwk"
activity_csv = "activity.csv"


df = DataFrame(taxon = day_df.Individual, diel = day_df.species_day)
df, _ = conform_tree_and_data(df, trimmed_nwk)
CSV.write(activity_csv, df);   


In [None]:
function replace_space(s::String)
    return join(split(s), '_')
end

function unreplace_space(s::String)
    return join(split(s, '_'), ' ')
end

function merge_by_species(df::DataFrame, species_dict::Dict{String, String}; take_average::Bool = true)
    taxa = df.taxon
    data = Matrix(df[!, 2:end])
    
    species = unique(values(species_dict))
    n_species = length(species)
    
    species_inds = [findall(x -> species_dict[x] == s, taxa) for s in species]
    
    found_species = findall(x -> length(x) > 0, species_inds)
    
    species = species[found_species]
    species_inds = species_inds[found_species]
    n_species = length(species)
    
    nms = names(df)
    new_df = DataFrame(taxon = replace_space.(species))
    for i = 2:length(nms)
        new_df[!, nms[i]] = zeros(n_species)
    end
    
    for i = 1:n_species
        s_data = @view data[species_inds[i], :]
        new_data = take_average ? vec(missing_mean(s_data)) : s_data[1, :]
        new_df[i, 2:end] .= new_data
    end
    
    return new_df
end

function merge_by_species(input::String, output::String, species_dict::Dict{String, String}; kwargs...)
    df = CSV.read(input, DataFrame)
    df2 = merge_by_species(df, species_dict; kwargs...)
    CSV.write(output, df2)
end

function merge_by_species(input::String, species_dict::Dict{String, String}; kwargs...)
    bn, ext = splitext(input)
    output = bn * "_bySpecies." * ext
    merge_by_species(input, output, species_dict; kwargs...)
    return output
end
    


species_dict = Dict(raw_data.Individual[i] => raw_data.Species[i] for i = 1:size(raw_data, 1));

In [None]:
function joint_xml(xml_path::String, color_csv::String, day_csv::String, newick_path::String; factors::Int = 2, chain_length::Int = 10000)
    color_data = parse_traitdata(color_csv, trait_name = "color")
    day_data = parse_traitdata(day_csv, trait_name = "activity", discrete_traits = [1])
    
    color_model = FactorModel(color_data, factors, standardize = true)
    
    p_data = size(day_data, 2)
    day_model = RepeatedMeasuresModel(day_data, Diagonal(fill(10.0, p_data)), standardize = true)
    
    newick = read(newick_path, String)
    
    joint_model = JointTraitModel([color_model, day_model], newick)
    options = MCMCOptions(chain_length = chain_length)
    xml = save_xml(joint_model, xml_path, loadings_operator = "gibbs", mcmc_options = options, blombergs_k = true)
end

function file_name(bn::String, opt::String, k::Int)
    return "$(bn)_$(opt)_$(k)factors"
end
    

merge_species = false

this_opt = 1:2

bn = "nudibranch_mean_k_all"
opts = ["full", "sub"][this_opt]
csvs = ["data_12_06.csv", "data_noColinear_12_06.csv"][this_opt]
activity_csv = "activity.csv"
trimmed_nwk = "newick_12_06.nwk"


if merge_species
    global trimmed_nwk
    csvs = [merge_by_species(c, species_dict) for c in csvs]
    _,  new_newick = conform_tree_and_data(CSV.read(csvs[1], DataFrame), original_newick_path)
    trimmed_nwk = splitext(trimmed_nwk)[1] * "_bySpecies.nwk"
    write(trimmed_nwk, new_newick)

    activity_csv = merge_by_species(activity_csv, species_dict, take_average = false)
end
n_opts = length(opts)
ks = [1, 2, 3, 4, 5, 6];

k_opts = [(k, opt) for k in ks, opt in 1:n_opts]
fns = [file_name(bn, opt, k) for k in ks, opt in opts]
xml_paths = fns .* ".xml"
log_paths = fns .* ".log"
rotated_paths = fns .* "_processed.log"




# color_csv = "data_noColinear_12_06.csv"
# color_csv = "data_12_06.csv"
# bn = "color_and_activity_full"
# xml_path = "$bn.xml"



In [None]:
lk = ReentrantLock()
Threads.@threads for (k, i) in k_opts
    xml_path = xml_paths[k, i]

    lock(lk) do 
        println("starting $xml_path ...")
    end

    joint_xml(xml_path, csvs[i], activity_csv, trimmed_nwk, factors = k, chain_length = 10_000)
    out_path = splitext(xml_path)[1] * ".out"
    if !(isfile(out_path) && occursin("Operator analysis", read(out_path, String)))
        run_beast(xml_path, capture_output = true, beast_jar = "beast.jar")
    end

    println("... finished $xml_path")
end

## 2.2 - Post-processing

In [None]:
using PhylogeneticFactorAnalysis.BEASTPostProcessing
using DataFrames
using CSV



function data_stats(path::String)
    df = CSV.read(path, DataFrame)
    n, p = size(df)
    p -= 1
    return n, p
end

n, p_activity = data_stats(activity_csv)

for i in 1:n_opts
    n, p_color = data_stats(csvs[i])

    Threads.@threads for k in ks
        fn = file_name(bn, opts[i], k)
        log_path = log_paths[k, i]
        rotated_path = rotated_paths[k, i]

        post_process(log_path, rotated_path, 
                     ["color" => (k, p_color), "activity" => (p_activity, p_activity)],
                     n,
                     optimization_inds=[[k + 1], Int[]])
    end
end


## 2.3 - Determine the number of factors

In [None]:
using RCall
using BeastUtils.Logs


function collect_first_loadings(log_path::String)
    df = import_log(log_path, burnin = 0.1)
    load_df = df[!, findall(x -> startswith(x, "color.L.1"), names(df))]
    L1 = vec(mean(Matrix(load_df), dims=1))
    return L1
end



pair_dfs = [DataFrame() for _ = 1:n_opts]

for (k, i) in k_opts
    rotated_path = rotated_paths[k, i]
    L1 = collect_first_loadings(rotated_path)
    pair_dfs[i][!, "k=$k"] = L1
end


svg_paths = ["pairs_$(opt).svg" for opt in opts]
@rput pair_dfs
@rput svg_paths

R"""

for (i in 1:length(pair_dfs)) {
    svg(svg_paths[i])
    pairs(pair_dfs[i])
    dev.off()
}
"""
for i = 1:n_opts
    println(opts[i])
    open(svg_paths[i]) do f
       display("image/svg+xml", read(f, String))
    end
end

As can be seen from the plots above, the loadings associated with the first factor stabilize after k = 4 factors.
As such, we settle on a model with 4 factors.

## 2.4 - Check Model Equivalence

We double check that the full model with all traits and the smaller model where colinear traits have been removed indeed return equivalent results.

In [None]:
using Gadfly


labels = Vector{String}[]

for path in csvs
    l = names(CSV.read(path, DataFrame))[2:end]
    push!(labels, l)
end

in_all = intersect(labels...)
inds = [[findfirst(x -> x == y, labs) for y in in_all] for labs in labels]

paired = [DataFrame() for _ in ks]

for k in ks
    for i = 1:n_opts
        rotated_path = rotated_paths[k, i]
        li = collect_first_loadings(rotated_path)[inds[i]]
        df = paired[k]
        df[!, opts[i]] = li
    end
end
        
ps = [plot(paired[k], x = :sub, y = :full, Geom.point, Guide.title("k = $k")) for k in ks]
p = vstack(ps)
img_path = "loadings_correlations.svg"
img = SVG(img_path, 4inch, 20inch)
draw(img, p)
open(img_path) do f
   display("image/svg+xml", read(f, String))
end

As can be seen from the plots above, the loadings associated with the first factor are essentially the same between the models.

# Part 3: Visualizing Results

In [None]:
using BeastUtils.Logs
using PhylogeneticFactorAnalysis
using RCall

function extract_species(s::AbstractString)
    return split(s, '_')[2]
end

function clean_taxa(taxa::AbstractArray{<:AbstractString}, dict::Dict{String, String})
    species = [dict[taxon] for taxon in taxa]
    u_species = unique(species)
    firsts = [findfirst(isequal(s), species) for s in u_species]
    keep_taxa = taxa[firsts]
    df = DataFrame(original = keep_taxa, new = species[firsts])
    return df
end




In [None]:
correlation_paths = fns .* "_correlation.pdf"

cd(@__DIR__)


new_species_dict = Dict(raw_data.Individual[i] => raw_data.Species[i] for i in 1:size(raw_data, 1))
if merge_species
    new_species_dict = Dict(replace_space(raw_data.Species[i]) => raw_data.Species[i] for i in 1:size(raw_data, 1))
end


trait_names = ["color", "activity"]
for i = 1:n_opts
    color = CSV.read(csvs[i], DataFrame)
    trait_labels = names(color)[2:end]
    all_taxa = color.taxon
    new_taxa = clean_taxa(all_taxa, new_species_dict)
    for k in ks
        rotated_path = rotated_paths[k, i]
        correlation_path = correlation_paths[k, i]

        bl = LazyLog(rotated_path)

        plot_correlation(bl, correlation_path, [["color.$i" for i = 1:k]; "activity"])
        plot_loadings(bl, trait_names, [k, p_activity], file_base = fns[k, i], factor_partitions = [1], original_labels = [trait_labels])

        PhylogeneticFactorAnalysis.prep_factors(bl, "tmp_factors.csv",
                fac_header = "color.activity.joint.",
                k = sum(k + p_activity))



        PhylogeneticFactorAnalysis.prep_r_factors(
            "color",
            "tmp_factors.csv",
            trimmed_nwk,
            "",
            layout = "rectangular",
            fac_names = ["color.$i" for i = 1:k],
            tip_labels = true,
            line_width = 1,
            include_only = new_taxa.original,
            relabel = new_taxa
        )

        factors = collect(1:k)
        @rput factors
        plot_path = fns[k, i]
        @rput plot_path

        R"""
        source(R_PLOT_SCRIPT)

        fac_names <- optional_arguments[[1]]
        include_only <- optional_arguments[[2]]
        relabel <- optional_arguments[[3]]


        plot_factor_tree(plot_path, tree_path, stats_path, class_path=class_array[[1]],
                         fac_names = fac_names, layout = layout,
                         tip_labels = tip_labels, line_width = line_width,
                         include_only = include_only, relabel = relabel,
                         extra_offset = 0.2,
                         labels_offset = 0.02,
                         width=12,
                         factors = factors)
        """
    end
end

rm("tmp_factors.csv")

### Pretty plot

In [None]:
opt = "sub"
opti = findfirst(x -> x == opt, opts)
color_csv = csvs[opti]
# trimmed_nwk = "newick_12_06.nwk"
# activity_csv = "activity.csv"

labels_csv = "color_labels.csv"
color_names = names(CSV.read(color_csv, DataFrame))[2:end]

k = 4
processed_log = rotated_paths[k, opti]
tmp_csv = "tmp.csv"
PhylogeneticFactorAnalysis.prep_loadings(processed_log, 
    tmp_csv, 
    k = k, n_traits = k + 1, L_header = "color.L.", fac_header = "color.activity.joint",
    original_labels = color_names)

cat_dict = Dict("Col" => "LEIA",
                "Lum" => "LEIA",
                "GabRat" => "",
                "CAA" => "CAA",
                "VCA" => "VCA",
                "BSA" => "BSA")

type_dict = Dict("mean" => "mean",
                 "sd" => "sd",
                 "CoV" => "CoV",
                 "skew" => "skew",
                 "kurtosis" => "kurtosis",
                 "Sc" => "Sc",
                 "Jc" => "Jc",
                 "St" => "St",
                 "Jt" => "Jt",
                 "Hc" => "Hc",
                 "Qc" => "Qc",
                 "Qt" => "Qt",
                 "Ht" => "Ht",
                 "Scpl" => "Scpl",
                 "Qcpl" => "Qcpl",
                 "C" => "C",
                 "PT" => "PT",
                 "Asp" => "Asp",
                 "ML" => "ML",
                 "sL" => "sL",
                 "CVL" => "CVL",
                 "MDmax" => "MDmax",
                 "sDmax" => "sDmax",
                 "CVDmax" => "CVDmax",
                 "MSsat" => "MSsat",
                 "sSsat" => "sSsat",
                 "CVSsat" => "CVSsat",
                 "MSL" => "MSL",
                 "sSL" => "sSL",
                 "CVSL" => "CVSL",
                 "MS" => "MS",
                 "sS" => "sS",
                 "CVS" => "CVS",
                 "BML" => "BML",
                 "BsL" => "BsL",
                 "BCVL" => "BCVL",
                 "BMDmax" => "BMDmax",
                 "BsDmax" => "BsDmax",
                 "BCVDmax" => "BCVDmax",
                 "BMSsat" => "BMSsat",
                 "BsSsat" => "BsSsat",
                 "BCVSsat" => "BCVSsat",
                 "BMSL" => "BMSL",
                 "BsSL" => "BsSL",
                 "BCVSL" => "BCVSL",
                 "BMS" => "BMS",
                 "BsS" => "BsS",
                 "BCVS" => "BCVS")
subtype_dict = Dict("Vrt" => "vertical",
                    "vrt" => "vertical",
                    "hrz" => "horizontal",
                    "Hrz" => "horizontal")



In [None]:

p = length(color_names)
split_names = [split(x, '.') for x in color_names]
cats = fill("", p)
types = deepcopy(cats)

for i = 1:p
    sn = split(color_names[i], '.')
    cats[i] = cat_dict[sn[1]]
    if length(sn) == 1
        types[i] = sn[1]
    elseif length(sn) == 2
        if cats[i] == "LEIA"
            types[i] *= sn[1] * "."
        end
        types[i] *= type_dict[sn[2]]
    elseif length(sn) == 3
        types[i] *= type_dict[sn[2]] * " (" * subtype_dict[sn[3]] * ")"
    elseif length(sn) == 1
#         types[i] = cats[i]
    else        
        error("not implemented")
    end
end

df = DataFrame(trait = color_names, pretty = types, cat = cats)
CSV.write(labels_csv, df)

r_script = PhylogeneticFactorAnalysis.R_PLOT_SCRIPT
@rput r_script
@rput tmp_csv
@rput labels_csv
R"""
source(r_script)
plot_loadings(tmp_csv, "nudibranch_loadings.pdf", labels_path = labels_csv, factors=c(1), width_scale=1.5)
"""
rm(tmp_csv)
# rm(labels_csv)

In [None]:
using PhylogeneticFactorAnalysis
using RCall
using CSV

color = CSV.read(csvs[opti], DataFrame)
trait_labels = names(color)[2:end]
all_taxa = color.taxon
new_taxa = clean_taxa(all_taxa, new_species_dict)
CSV.write("taxon_dict.csv", new_taxa)
tmp_csv = "tmp_factors.csv"
activity_df = CSV.read(activity_csv, DataFrame)
activity_dict = Dict(0 => "night only", 1 => "day")
labels_df = DataFrame(taxon = activity_df.taxon, activity = [activity_dict[x] for x in activity_df.diel])
labels_csv = "activity_labels.csv"
CSV.write(labels_csv, labels_df)

PhylogeneticFactorAnalysis.prep_factors(rotated_paths[k, opti], tmp_csv,
        fac_header = "color.activity.joint.",
        k = sum(k + p_activity))

PhylogeneticFactorAnalysis.prep_r_factors(
    "color",
    "tmp_factors.csv",
    trimmed_nwk,
    labels_csv,
    layout = "rectangular",
    fac_names = ["color.$i" for i = 1:k],
    tip_labels = true,
    line_width = 1,
    include_only = new_taxa.original,
    relabel = new_taxa
)



In [None]:
factors = collect(1:k)
@rput factors
plot_path = fns[k, opti] * "_final"
@rput plot_path

R"""
source(R_PLOT_SCRIPT)

fac_names <- optional_arguments[[1]]
include_only <- optional_arguments[[2]]
relabel <- optional_arguments[[3]]

print(tree_path)
print(relabel)


plot_factor_tree(plot_path, tree_path, stats_path, class_path=class_array[[1]],
                 fac_names = c("factor 1"), layout = layout,
                 tip_labels = tip_labels, line_width = line_width,
                 include_only = include_only, relabel = relabel,
                 extra_offset = 1,
                 labels_offset = 0.02,
                 width=7,
                 factors = c(1),
                 combined=FALSE)
"""

In [None]:
species_activity_csv = merge_by_species(activity_csv, species_dict, take_average = false)
species_factor_csv = merge_by_species("tmp_factors.csv", species_dict, take_average = true)

# Part 4: Predictive Performance

In [None]:
using GLM

facs = CSV.read(species_factor_csv, DataFrame)
acs = CSV.read(species_activity_csv, DataFrame)

logit_df = innerjoin(facs, acs, on = :taxon)
fit = glm(@formula(diel ~ 1 + f1), logit_df, Bernoulli(), LogitLink())
ps = predict(fit)
thresholds = 0.5:0.01:1.0
n = length(thresholds)
probs = zeros(n)
for i = 1:n
    threshold = thresholds[i]
    diel_p = ps .> threshold
    m = nrow(logit_df)
    probs[i] = sum(diel_p .== logit_df.diel) / m
end

@show findmax(probs)
best_threshold = thresholds[findmax(probs)[2]]
best_assigned = ps .> best_threshold


plot(x = predict(fit), y = logit_df.diel)


In [None]:
sp = sortperm(logit_df.f1)
logit_df = logit_df[sp, :]
n = nrow(logit_df)
Ms = zeros(Int, n, 2, 2)
tc_rates = zeros(n)
for i = 1:n
    sb = sum(logit_df.diel[1:i])
    sa = sum(logit_df.diel[(i + 1):n])
    Ms[i, 1, 1] = i - sb
    Ms[i, 1, 2] = sb
    Ms[i, 2, 1] = n - i - sa
    Ms[i, 2, 2] = sa
    tc_rates[i] = (sa + i - sb) / n
end

best_rate = maximum(tc_rates)
max_inds = findall(isequal(best_rate), tc_rates)
ind = max_inds[1]
Ms[ind, :, :]

The results above indicate that we can accurately predict activity patterns in 89% of species using only factor 1.

In [None]:
threshold = 0.5 * (logit_df.f1[ind] + logit_df.f1[ind + 1])
logit_df.class = logit_df.f1 .< threshold
@rput threshold

logit_df.pretty_diel = [Dict(1 => "day", 0 => "night only")[i] for i in logit_df.diel]

logit_df.class_name = [x ? "night only" : "day" for x in logit_df.class]
logit_df.correct_prediction = [x == y for (x, y) in zip(logit_df.class_name, logit_df.pretty_diel)]
print(logit_df)




@rput logit_df

m = logit_df.f1[1] - 1e-4
x = logit_df.f1[end] + 1e-4
n = 50
thresholds = range(x, m, length = n)
f1 = logit_df.f1
diel = logit_df.diel
class_pos = [sum(f1 .> t) for t in thresholds]
true_pos = [sum((f1 .> t) .&& (diel .== 1) ) for t in thresholds]
false_pos = class_pos .- true_pos
tpr = true_pos ./ sum(diel)
fpr = false_pos ./ (nrow(logit_df) - sum(diel))

roc_df = DataFrame(fpr = fpr, tpr = tpr)

aurocs = [0.5 * (roc_df.tpr[i] + roc_df.tpr[i - 1]) * (roc_df.fpr[i] - roc_df.fpr[i - 1]) for i = 2:nrow(roc_df)]
auroc = sum(aurocs)
@show auroc

@rput auroc
@rput roc_df
R"""
library(ggplot2)
ggplot(logit_df, aes(y = f1, x = pretty_diel, color = class, shape = correct_prediction)) +
    geom_point() +
    geom_hline(yintercept = threshold) +
    xlab("True Activity Pattern") + 
    ylab("factor 1") +
    theme_bw() +
    scale_color_discrete(labels = c("TRUE" = "night only", "FALSE" = "day"), 
                         name = "Predicted Activity\nPattern") +
    scale_shape_manual(labels = c("TRUE" = "correct", "FALSE" = "incorrect"),
                         name = "Prediction Accuracy",
                         values = c(4, 19))
ggsave("separation.svg", height = 4, width = 4)

ggplot(roc_df, aes(x = fpr, y = tpr)) +
    geom_path() +
    geom_segment(x = 0, y = 0, xend = 1, yend = 1, linetype="dashed") +
    xlab("false positive rate") +
    ylab("true positive rate") +
    annotate("text", x = 0.75, y = 0.325, label = paste0("AUROC = ", round(auroc, digits = 2))) +
    theme_bw()

ggsave("roc.svg", height=4, width = 4)
"""



<img src="roc.svg">

We achieve an area under the ROC curve (AUROC) of 0.94.
For reference, a completely random classifier has ROC = 0.5, and a perfect classifier has ROC = 1. 

In [None]:
using HypothesisTests

f_day = logit_df.f1[logit_df.diel .== 1]
f_night = logit_df.f1[logit_df.diel .== 0]

test = UnequalVarianceTTest(f_day, f_night)

The resutls above indicate a highly significant relationship between factor 1 and diel activity (p < 1e-5).