In [12]:
using Pkg
Pkg.activate("../../.")
using TestEnv
TestEnv.activate()
using Revise
using Vizagrams
using DataFrames
using Random
using Distributions
using Colors
using StructArrays
using LinearAlgebra

[32m[1m  Activating[22m[39m project at `~/Documents/GitHub/Vizagrams.jl`
[33m[1m│ [22m[39mIt is recommended to `Pkg.resolve()` or consider `Pkg.update()` if necessary.
[33m[1m└ [22m[39m[90m@ Pkg.API ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Pkg/src/API.jl:1807[39m
[33m[1m└ [22m[39m[90m@ TestEnv ~/.julia/packages/TestEnv/shkbW/src/julia-1.9/activate_set.jl:63[39m


In [13]:
struct SmoothWire <: Mark
    ptstart
    ptend
    hstart
    hend
end
SmoothWire(;ptstart=[0,0.5],ptend=[1,1.5],hstart=1,hend=1) = SmoothWire(ptstart,ptend,hstart,hend)
function Vizagrams.ζ(w::SmoothWire)::TMark
    (;ptstart, ptend, hstart, hend) = w
    stop = ptstart + [0,hstart/2]
    sbot = ptstart + [0,-hstart/2]
    etop = ptend + [0,hend/2]
    ebot = ptend + [0,-hend/2]
    midx = (ptend[1] + ptstart[1])/2
    q = CBezierPolygon(
        [sbot,stop,etop,ebot],
        [sbot,stop,[midx, stop[2]],[midx,etop[2]],etop,ebot,[midx,ebot[2]],[midx,sbot[2]]]
    )
    return dmlift(q)
end

In [14]:

# Sample from the distribution
# random_value = countries[rand(dist)]
# Random.seed!(4)
function random_student()
    countries = ["CHINA","USA","AUSTRALIA","INDIA","OTHER","NOTHING"]
    grad = ["UNDERGRADUATE","GRADUATE","POST-GRADUATE","NOTHING"]
    w = [5, 3, 1,1,1]  # Define your weights here
    w = w/sum(w)
    
    # Create a weighted distribution
    dist = Distributions.Categorical(w)
    under = countries[rand(dist)]
    index = findfirst(x -> x == under, countries)
    w = [4, 20, 5,9,2,5]  # Define your weights here
    w = w/sum(w)
    w[index] = 1.0
    w = w/sum(w)
    dist = Distributions.Categorical(w)
    grad = countries[rand(dist)]
    if grad == "NOTHING"
        return under, grad, "NOTHING"
    end
    index = findfirst(x -> x == grad, countries)
    w = [4, 8, 5,9,2,10]  # Define your weights here
    w = w/sum(w)
    distgrad = Distributions.Categorical(w)
    postgrad = countries[rand(dist)]
    return under, grad, postgrad
end

Random.seed!(4)
df = DataFrame(map(x->random_student(),1:1000));
df = rename(df,[:under,:grad,:pos]);

df[!,:count] = map(x->1, eachrow(df));

In [15]:
gdf = combine(groupby(df,[:under,:grad]),:count=>sum)
gdf[!,:vertex] = map(row->"under_"*row[:under]*"_" * row[:grad], eachrow(gdf))
gdf = rename(gdf,:under=>:country,:count_sum=>:w)
gdf[!,:degree] = map(row->"under",eachrow(gdf))

vert = copy(gdf[!,[:vertex,:country,:degree,:w]])

gdf = combine(groupby(df,[:grad,:pos]),:count=>sum)
gdf[!,:vertex] = map(row->"grad_"*row[:grad]*"_" * row[:pos], eachrow(gdf))
gdf = rename(gdf,:grad=>:country,:count_sum=>:w)
gdf[!,:degree] = map(row->"grad",eachrow(gdf))
gdf = gdf[gdf[!,:country] .!= "NOTHING",:]

vert = vcat(vert,copy(gdf[!,[:vertex,:country,:degree,:w]]))

gdf = gdf[gdf[!,:pos] .!= "NOTHING",:]
gdf[!,:degree] = map(row->"pos",eachrow(gdf))
gdf = rename(gdf,:country=>:grad)
gdf[!,:vertex] = map(row->"pos_"*row[:grad]*"_" * row[:pos], eachrow(gdf))
gdf = rename(gdf,:pos=>:country)

vert = vcat(vert,copy(gdf[!,[:vertex,:country,:degree,:w]]))

vert = sort(vert,:w,rev=true);

In [16]:
dv = combine(groupby(vert,[:country,:degree]),:w=>sum,renamecols=false)

dv[!,:v] = map(row->row[:degree]*"_"*row[:country],eachrow(dv));

In [17]:
plt = Plot(
    title="MyPlot",
    figsize=(800,340)./2,
    data=dv,
    encodings=(
        x=(field=:degree,datatype=:n,guide=(tickvalues=["under","grad","pos"],)),
        y=(field=:w,datatype=:q,guide=(lim = (0,1500),)),
        color=(field=:country,datatype=:n),
        country=(field=:country,datatype=:n,scale=IdScale()),
        degree=(field=:degree,datatype=:n,scale=IdScale()),
    ),
    graphic =
            ∑(i=:x,op=+,
                ∑(i=:color,op=op=(x,y)->x↑(T(0,10),y),orderby=:country, descend=true,
                    ∑(row -> begin
                        S(:fill=>row[:color],:id=>row[:degree]*"_"*row[:country])*
                    Bar(h=row[:y],c=[row[:x],0], w =35)
                    end
                ))
        )
)
drawsvg(plt,height=500)

In [18]:
gplt = Vizagrams.flatten(ζ(plt)._2);
barpos = StructArray(map(x->(v=x.s.d[:id],xbar=x.geom.c[1],
            ybar=x.geom.c[2], hbar=x.geom.h, wbar =x.geom.w, color=x.s.d[:fill]),gplt))
sdata = DataFrame(barpos);

In [19]:
dv = leftjoin(dv,sdata,on=:v);

In [20]:
edges = DataFrame(:src=>[],:tgt=>[],:weight=>[])

for cols in [[:under,:grad],[:grad,:pos]]
    gdf = combine(groupby(df,cols),:count=>sum);
    gdf[!,:under_grad] = map(row->row[cols[1]]*"-"*row[cols[2]], eachrow(gdf))
    gdf[!,:src] = map(row->string(cols[1])*"_"*row[cols[1]], eachrow(gdf))
    gdf[!,:tgt] = map(row->string(cols[2])*"_"*row[cols[2]], eachrow(gdf))
    gdf[!,:weight] = gdf[!,:count_sum]
    gdf = gdf[!,[:src,:tgt,:weight]]
    edges = vcat(edges, gdf)
end
edges = leftjoin(edges,dv[!,[:xbar,:ybar,:hbar,:wbar,:v,:w,:country,:degree]],on=:src=>:v)
edges = rename(edges, :xbar=>:x_src,:ybar=>:y_src,:country=>:src_country,:degree=>:src_degree,:hbar=>:hbar_src)
edges = leftjoin(edges,dv[!,[:xbar,:ybar,:v,:country,:degree,:hbar]],on=:tgt=>:v)
edges = rename(edges, :xbar=>:x_tgt,:ybar=>:y_tgt,:country=>:tgt_country,:degree=>:tgt_degree,:hbar=>:hbar_tgt);
edges = dropmissing(edges);
edges[!,:h] = map(row->row[:hbar_src]*row[:weight]/row[:w], eachrow(edges));
temp = combine(groupby(edges,[:src_country,:src_degree]),:h=>cumsum,:src=>identity,:tgt=>identity)
temp = rename(temp, :src_identity=>:src,:tgt_identity=>:tgt,:h_cumsum=>:hc)
temp = temp[!,[:src,:tgt,:hc]]
edges = leftjoin(edges, temp, on=[:src,:tgt]);

gp = DataFrame()
for g in groupby(edges,[:src_country,:src_degree])
    g = sort(g,:hbar_src,rev=true)
    g[1,:y_src] = g[1,:y_src]+(g[1,:hbar_src]-g[1,:hc])/2
    # g[1,:y_tgt] = g[1,:y_tgt]+(g[1,:hbar_tgt]-g[1,:hc])/2
    for i in 2:size(g)[1]
        g[i,:y_src] = g[i,:y_src] + g[i,:hbar_src]/2 - g[i,:h]/2 - g[i-1,:hc]
        # g[i,:y_tgt] = g[i,:y_tgt] + g[i,:hbar_tgt]/2 - g[i,:h]/2 - g[i-1,:hc]
    end
    gp = vcat(gp,g)
end
edges = copy(gp)


temp = combine(groupby(sort(copy(edges),:hbar_src,rev=true),[:tgt_country,:tgt_degree]),:h=>cumsum,:src=>identity,:tgt=>identity)
temp = rename(temp, :src_identity=>:src,:tgt_identity=>:tgt,:h_cumsum=>:hc_tgt)
temp = temp[!,[:src,:tgt,:hc_tgt]]
edges = leftjoin(edges, temp, on=[:src,:tgt]);

In [21]:
gp = DataFrame()
for g in groupby(sort(copy(edges),:hbar_src,rev=true),[:tgt_country,:tgt_degree])
    g[1,:y_tgt] = g[1,:y_tgt] + (g[1,:hbar_tgt] - g[1,:h])/2
    for i in 2:size(g)[1]
        g[i,:y_tgt] = g[i,:y_tgt] + g[i,:hbar_tgt]/2 - g[i,:h]/2- g[i-1,:hc_tgt]
    end
    gp = vcat(gp,g)
end
gp = leftjoin(gp,dv[!,[:v,:color]],on=:src=>:v);

In [22]:
sankey = ∑(op=+) do row
    S(:strokeWidth=>0,:fill=>:grey,:fillOpacity=>0.2)SmoothWire(
        ptstart = [row[:x_src]+row[:wbar]/2,row[:y_src]],
        ptend = [row[:x_tgt]-row[:wbar]/2,row[:y_tgt]],
        hstart = row[:h],
        hend = row[:h],)
end

ed = StructArray(NamedTuple.(eachrow(gp)));
drawsvg(sankey(ed)+plt.graphic(scaledata(plt)),height=500)