In [22]:
import networkx as nx
import networkit as nk

from panricci import RicciFlow
# from panricci.distributions.variation_graph import DistributionNodes
from panricci.distributions.sequence_graph import DistributionNodes
from panricci.utils import GFALoader

___
## Apply Ricci-Flow to a Variation Graph

In [23]:
# load graph
gfa_loader = GFALoader(undirected=False)
G = gfa_loader("../data/test1.gfa")

# compute distribution of nodes
distribution = DistributionNodes(G, alpha=0.5)

# Initialize ricci-flow
ricci_flow = RicciFlow(G, distribution, dirsave_graphs="../output/test3/ricci-flow")
G_ricci = ricci_flow.run(iterations=5, save_last=False, save_intermediate_graphs=True, name="test3")

RicciFlow: 100%|██████████| 5/5 [00:00<00:00, 756.96it/s]


___

## Results

**Checkpoints can be loaded**

In [None]:
G_chkpt = nx.read_edgelist(
                        "../output/test3/ricci-flow/test3-ricciflow-5.edgelist",    # path checkpoint
                        nodetype=int, 
                        create_using=nx.DiGraph
                        )

In [None]:
G_chkpt.nodes(), G_chkpt.edges(), G_chkpt.edges[(1,2)]

### Using networkit

In [None]:
G_ricci_nk = nk.nxadapter.nx2nk(G_ricci,  weightAttr="weight")

In [None]:
G_ricci_nk.totalEdgeWeight()

In [None]:
nk.overview(G_ricci_nk)

In [None]:
import pandas as pd
import seaborn as sns
df=pd.DataFrame(list(G.edges(data=True)), columns=["node1","node2","feats"])

In [None]:
df[["curvature","weights"]]=df["feats"].apply(lambda row: pd.Series(row))

In [None]:
sns.histplot(data=df, x="curvature", stat="count", bins=100, discrete=False)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
%matplotlib notebook

In [None]:
def animate(i):
    G=nx.read_edgelist(
    f"/home/avila/panricci/output/DQA1-3117-nodes-alpha/nodes-a5-ricciflow-{i}.edgelist",
    nodetype=int, 
    create_using=nx.DiGraph
    )
    df=pd.DataFrame(list(G.edges(data=True)), columns=["node1","node2","feats"])
    graph = sns.histplot(data=df, x="curvature", stat="count", bins=100, discrete=False)

In [None]:
fig = plt.figure()
fig.suptitle('Histogram of curvatures per epoch', fontsize=14) 

anim = animation.FuncAnimation(fig, animate, frames=20,interval=700,repeat=True)
anim.save("../ricciflow-curvatures.mp4")
# converting to an html5 video 
video = anim.to_html5_video() 
  
# embedding for the video 
html = display.HTML(video) 
  
# draw the animation 
display.display(html) 
plt.close() 