In [1]:
using Distributed
addprocs(11)
nprocs()

12

In [25]:
@everywhere using AppleAccelerate

In [3]:
@everywhere using LinearAlgebra, Polyhedra, CDDLib
@everywhere using JuMP, Ipopt, Clarabel
@everywhere using Random
@everywhere using PhyloNetworks

In [4]:
@everywhere function trop_normalize(x)
    return x .- first(x)
end

In [5]:
@everywhere function poly_dist(vec1, vec2, alphas)
    """
    Calculate the polyhedral distance between two vectors.
    The rows of `alphas` are the facet normals scaled to α⋅x = 1.
    """
    differences = alphas * (vec1 - vec2)
    return maximum(differences)
end

In [6]:
@everywhere function sum_of_poly_dist(ref, sample, alphas; power=1)
    """
    Calculate the sum of polyhedral distances between `ref` and the points in `sample`.
    The rows of `alphas` are the facet normals scaled to α⋅x = 1.
    `power` gives the exponent of the distance before taking the sum.
    """
    return sum([poly_dist(pt, ref, alphas)^power for pt in sample])
end

In [7]:
@everywhere function poly_frechet_model(sample, alphas; power=2)
    dim = length(sample[1])
    
    # Choose model depending on power
    if power == 1
        error("FW computation not yet implemeneted")
    elseif power == 2
        model = Model(Clarabel.Optimizer)
    else
        model = Model(Ipopt.Optimizer)
    end
    
    # suppress printing
    set_silent(model)

    @variable(model, x[1:dim])
    @variable(model, t[1:length(sample)])

    @objective(model, Min, sum(t))

    for (p_idx, p) in enumerate(sample)
        expressions = alphas * (x - p)
        
        for expr in expressions
            @constraint(model, t[p_idx] >= (expr)^power)
        end
    end

    return model, x
end

In [8]:
@everywhere function poly_frechet(sample, alphas; power=2)
    """
    Find one polyhedral Fréchet mean of a given sample.
    The rows of `alphas` are the facet normals scaled to α⋅x = 1.
    `power` gives the exponent of the distance before taking the sum.
    """
    model, x = poly_frechet_model(sample, alphas, power=power)
    
    @debug "\nOptimising..."
    
    optimize!(model)
    minimiser = value.(x)
    
    return minimiser
end

In [9]:
@everywhere function poly_frechet_set(::Type{T}, sample, alphas; power=2, rep=:polyhedron, tol=1e-3) where {T<:Polyhedra.Library}
    """
    Find the set of polyhedral Fréchet means of a given sample.
    The rows of `alphas` are the facet normals scaled to α⋅x = 1.
    `power` gives the exponent of the distance before taking the sum.
    `rep` is either "vrep" or "hrep" -- returning either vertices or halfspaces.
    """
    dim = length(sample[1])
    num_facets = size(alphas)[1]
    
    # Compute one Fréchet mean
    one_mean = poly_frechet(sample, alphas, power=power)
    
    @debug "Frechet mean found: $(one_mean)"
    
    # Rationalise all coordinates
    rat_alphas = rationalize.(alphas, tol=tol)
    rat_sample = [rationalize.(pt, tol=tol) for pt in sample]
    rat_mean = rationalize.(one_mean, tol=tol)
    
    distances = [poly_dist(rat_mean, pt, alphas) for pt in rat_sample]
    
    Amat = vcat(rat_alphas, -rat_alphas)
    bval1 = Rational{Int64}[]
    bval2 = Rational{Int64}[]
    
    # Discard redundant halfspacesprintln("Setting up constraints: $(round(progress/total * 100, digits=3))%   \r")
    evals = rat_alphas * hcat([rat_mean - pt for pt in rat_sample]...)
    for k in 1:length(rat_sample)
        evals[:,k] .-= distances[k]
    end
    
    progress = 0
    total = num_facets
    
    for k = 1:num_facets
        
        greatest_nonpos = argmax([x > 0 ? -Inf : x for x in evals[k,:]])
        push!(bval1, dot(rat_alphas[k,:], rat_sample[greatest_nonpos]) + 
            distances[greatest_nonpos])
        
        progress += 1
        @debug "Removing redundant half-spaces: $(round(progress/total * 100, digits=3))%   \r"    
    end
    
    @debug "\nFinding defining facets..."
    
    poly = polyhedron(hrep(rat_alphas, bval1), T(:exact))
    removehredundancy!(poly)
    
    if rep == :hrep
        @debug "Finding facets..."
        return hrep(poly)
    elseif rep == :vrep
        @debug "Finding vertices..."
        return vrep(poly)
    else
        @debug "Defaulting to polyhedron."
        return poly
    end
end

@everywhere poly_frechet_set(sample, alphas; power=2, rep=:polyhedron, tol=1e-3) = poly_frechet_set(CDDLib.Library, sample, alphas, power=power, rep=rep, tol=tol)

In [12]:
@everywhere function trop_facets(n::Int64)
    """
    Find the relevant facet normals in n dimensions for a tropical ball
    """
    result = zeros(Rational{Int64}, n * (n - 1), n)
    k = 1
    for i = 1:n
        for j = 1:n
            if i != j
                result[k, i] = 1//1
                result[k, j] = -1//1
                k += 1
            end
        end
    end
    return result
end

In [10]:
@everywhere function trop_dist(vec1, vec2)
    """
    Calculate the tropical distance between two vectors.
    """
    return maximum(vec1 - vec2) - minimum(vec1 - vec2)
end

In [11]:
@everywhere function sum_of_trop_dist(ref, sample; power=1)
    """
    Calculate the sum of tropical distances between `ref` and the points in `sample`.
    """
    return sum([trop_dist(pt, ref)^power for pt in sample])
end

In [14]:
@everywhere function trop_frechet_set(sample; power=2, rep=:polyhedron, tol=1e-3)
    """
    Find one polyhedral Fréchet mean of a given sample.
    The rows of `alphas` are the facet normals scaled to α⋅x = 1.
    `power` gives the exponent of the distance before taking the sum.
    """
    dim = length(sample[1])
    alphas = trop_facets(dim)
    return poly_frechet_set(sample, alphas, power=power, rep=rep, tol=tol)
end

In [15]:
@everywhere function trop_frechet(sample; power=2)
    """
    Find one polyhedral Fréchet mean of a given sample.
    The rows of `alphas` are the facet normals scaled to α⋅x = 1.
    `power` gives the exponent of the distance before taking the sum.
    """
    dim = length(sample[1])
    alphas = trop_facets(dim)
    return poly_frechet(sample, alphas, power=power)
end

In [14]:
using JSON3
using DataFrames

# Function to read the JSON file and convert to a list of matrices
function read_and_convert_json(file_path::String)
    # Read the JSON file
    json_data = JSON3.read(file_path)
    
    # Extract elements from the nested arrays
    elements = [x[1] for x in json_data]
    elements = rationalize.(10000 * elements, tol=1e-2)
    
    # Convert elements into matrices
    num_elements = length(elements)
    matrices = []
    
    for i in 1:64:num_elements
        # Get the next 64 elements
        matrix_elements = elements[i:min(i+63, num_elements)]
        
        # Convert to an 8x8 matrix if there are 64 elements, otherwise create a smaller matrix
        matrix_size = length(matrix_elements)
        sqrt_size = Int(sqrt(matrix_size))
        push!(matrices, reshape(matrix_elements, sqrt_size, sqrt_size))
    end
    
    return matrices
end

# Read and convert the JSON file
file_path = "all_matrices.json"
matrices = read_and_convert_json(file_path)
taxa = ["Tg", "Et", "Cp", "Ta", "Bb", "Tt", "Pv", "Pf"]

8-element Vector{String}:
 "Tg"
 "Et"
 "Cp"
 "Ta"
 "Bb"
 "Tt"
 "Pv"
 "Pf"

In [15]:
"""
Take a matrix of pairwise distances between taxa and returns the cophenetic vector.
"""
function cophenetic_from_distance(pairwise)
    n = size(pairwise, 1)
    coph = [pairwise[i, j] for i in 1:n-1 for j in i+1:n]
    return coph
end

cophenetic_from_distance

In [16]:
"""
Check if a distance matrix defines a phylogenetic tree
"""
function is_phylogenetic_tree(D)
    n = size(D, 1)
    
    # Check if the matrix is symmetric and non-negative
    for i in 1:n
        for j in i:n
            if D[i, j] != D[j, i] || D[i, j] < 0
                return false
            end
        end
    end

    # Check the four-point condition
    for i in 1:n-3
        for j in i+1:n-2
            for k in j+1:n-1
                for l in k+1:n
                    # Calculate distances
                    D_ij_kl = D[i, j] + D[k, l]
                    D_ik_jl = D[i, k] + D[j, l]
                    D_il_jk = D[i, l] + D[j, k]
                    
                    # Check the four-point condition
                    if !(D_ij_kl >= D_ik_jl && D_ij_kl >= D_il_jk) &&
                       !(D_ik_jl >= D_ij_kl && D_ik_jl >= D_il_jk) &&
                       !(D_il_jk >= D_ij_kl && D_il_jk >= D_ik_jl)
                        return false
                    end
                end
            end
        end
    end
    
    return true
end

is_phylogenetic_tree

In [17]:
"""
Check if a distance matrix defines an ultrametric tree
"""
function is_ultrametric_tree(D)
    n = size(D, 1)
    
    # Check if the matrix is symmetric and non-negative
    for i in 1:n
        for j in i:n
            if D[i, j] != D[j, i] || D[i, j] < 0
                return false
            end
        end
    end

    # Check the ultrametric condition
    for i in 1:n-2
        for j in i+1:n-1
            for k in j+1:n
                # Calculate distances
                Dij = D[i, j]
                Dik = D[i, k]
                Djk = D[j, k]
                
                # Check if the largest distance is attained at least twice
                if !(Dij <= max(Dik, Djk) && Dik <= max(Dij, Djk) && Djk <= max(Dij, Dik))
                    return false
                end
            end
        end
    end
    
    return true

end

is_ultrametric_tree

In [18]:
coph_vecs = [cophenetic_from_distance(mat) for mat in matrices]

268-element Vector{Vector{Rational{Int64}}}:
 [3784, 6626, 9906, 6521, 11778, 8750, 7661, 7601, 10881, 7496  …  4901, 15579, 12551, 11462, 12194, 9167, 8078, 14217, 13128, 1089]
 [3485, 8484, 9427, 8865, 25257, 9257, 10300, 8814, 9756, 9195  …  3645, 25010, 9010, 10054, 24449, 8449, 9492, 20033, 21076, 2163]
 [3570, 5592, 6890, 4591, 4849, 5136, 4872, 3905, 6062, 3763  …  6609, 6867, 7154, 6890, 2383, 4367, 4103, 4625, 4361, 274]
 [1696, 4297, 5665, 6408, 4431, 4218, 4248, 4473, 5841, 6584  …  9039, 7062, 6849, 6880, 5542, 6984, 7014, 5006, 5037, 557]
 [1864, 8173, 6779, 10056, 10843, 9010, 9319, 8433, 7039, 10315  …  11759, 12546, 10713, 11022, 15116, 13283, 13592, 12332, 12642, 2722]
 [3801, 11567, 5222, 5725, 5642, 8167, 7548, 12836, 6491, 6994  …  6129, 6046, 8571, 7952, 601, 7410, 6791, 7327, 6708, 4443]
 [751, 29706, 6877, 4376, 5857, 9999, 8966, 29899, 7070, 4569  …  6079, 7560, 11702, 10669, 3359, 8305, 7272, 9787, 8754, 5104]
 [4080, 8200, 9699, 9924, 11942, 9302, 9463, 7689, 

In [19]:
@time phylo_frech = trop_frechet_set(coph_vecs)

130.072391 seconds (403.06 M allocations: 9.640 GiB, 1.47% gc time, 3.68% compilation time: 17% of which was recompilation)


Polyhedron CDDLib.Polyhedron{Rational{BigInt}}:
4-element iterator of HyperPlane{Rational{BigInt}, Vector{Rational{BigInt}}}:
 HyperPlane(Rational{BigInt}[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0], 0//1)
 HyperPlane(Rational{BigInt}[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], 0//1)
 HyperPlane(Rational{BigInt}[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0], 0//1)
 HyperPlane(Rational{BigInt}[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], 0//1),
294-element iterator of HalfSpace{Rational{BigInt}, Vector{Rational{BigInt}}}:
 HalfSpace(Rational{BigInt}[0, 1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1192//1)
 HalfSpace(Rational{BigInt}[0, 1, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 866//1)
 HalfSpace(Rational{BigInt}[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 

In [20]:
chunk_size = 20
nchunks = ceil(Int, length(coph_vecs) / chunk_size)

chunks = map(1:nchunks) do i 
    start = (i - 1) * chunk_size + 1
    stop = min(i * chunk_size, length(coph_vecs))
    
    return coph_vecs[start:stop]
end

14-element Vector{Vector{Vector{Rational{Int64}}}}:
 [[3784, 6626, 9906, 6521, 11778, 8750, 7661, 7601, 10881, 7496  …  4901, 15579, 12551, 11462, 12194, 9167, 8078, 14217, 13128, 1089], [3485, 8484, 9427, 8865, 25257, 9257, 10300, 8814, 9756, 9195  …  3645, 25010, 9010, 10054, 24449, 8449, 9492, 20033, 21076, 2163], [3570, 5592, 6890, 4591, 4849, 5136, 4872, 3905, 6062, 3763  …  6609, 6867, 7154, 6890, 2383, 4367, 4103, 4625, 4361, 274], [1696, 4297, 5665, 6408, 4431, 4218, 4248, 4473, 5841, 6584  …  9039, 7062, 6849, 6880, 5542, 6984, 7014, 5006, 5037, 557], [1864, 8173, 6779, 10056, 10843, 9010, 9319, 8433, 7039, 10315  …  11759, 12546, 10713, 11022, 15116, 13283, 13592, 12332, 12642, 2722], [3801, 11567, 5222, 5725, 5642, 8167, 7548, 12836, 6491, 6994  …  6129, 6046, 8571, 7952, 601, 7410, 6791, 7327, 6708, 4443], [751, 29706, 6877, 4376, 5857, 9999, 8966, 29899, 7070, 4569  …  6079, 7560, 11702, 10669, 3359, 8305, 7272, 9787, 8754, 5104], [4080, 8200, 9699, 9924, 11942, 9302, 9463

In [21]:
length.(chunks)

14-element Vector{Int64}:
 20
 20
 20
 20
 20
 20
 20
 20
 20
 20
 20
 20
 20
  8

In [23]:
for i in 1:2
    @time trop_frechet_set(chunks[i])
end

 19.072286 seconds (60.57 M allocations: 1.195 GiB, 0.30% gc time)
 72.119848 seconds (177.85 M allocations: 2.067 GiB, 0.08% gc time)


In [None]:
pmap(trop_frechet_set, chunks)