## My First Cycle Rep
Getting an early cycle rep!!!

The parameters are mostly arbitrary, and this isn't final in any way, but it's cool.

#### Preliminaries
The network threshold is 10x what Jingyi suggests in her paper.

In [1]:
# load some packages
import Gavin.utils.make_network as mn
import plotly.graph_objects as go
from time import time
import networkx as nx
import pandas as pd
import oatpy as oat
import numpy as np

# config
DATA_PATH = 'datasets/concept_network/'
CONCEPT_FILE = 'articles_category_for_2l_abstracts_concepts_processed_v1_EX_102.csv.gz' # Applied Mathematics
article_concept_df = mn.filter_article_concept_file(
        DATA_PATH+CONCEPT_FILE,
        relevance_cutoff=0.7,
        min_article_freq=0.001, # 0.006%
        max_article_freq=0.002, # 0.05%
        normalize_year=True,
        year_min=1920
    ) # use a filtered data file to make the samples

#### Problem Setup
Take the concept-article dataframe and turn it into a network, then a distance matrix.

In [2]:
G = mn.gen_concept_network(article_concept_df) # make the graph
adj = nx.adjacency_matrix(G, weight='norm_year') # adjacency matrix
node_births = np.array(list(nx.get_node_attributes(G, 'norm_year').values())) # node orgin years, these break the cycle reps (idk why)
adj.setdiag(node_births)
# adj.setdiag(0)

#### Homology Calculation
We setup and calculate homology, then do some basic visualizations.

In [3]:
start = time()

# setup the problem
factored = oat.rust.FactoredBoundaryMatrixVr( # two functions that do this, idk what the other one is
        dissimilarity_matrix=adj,
        homology_dimension_max=2
    )

# solve homology
homology = factored.homology( # solve homology
        return_cycle_representatives=True, # These need to be true to be able to make a barcode, makes the problem take ~30% longer (1:30ish)
        return_bounding_chains=True
    )

f'Homology calculation took {time() - start} secs'

'Homology calculation took 7.060551881790161 secs'

In [4]:
# persistance diagram
fig = oat.plot.pd(homology)
fig.update_layout(
        width=600, 
        height=500,
        margin=dict(l=20, r=20, t=20, b=20)
    )
fig.show()

In [5]:
# Barcode diagram
fig = oat.plot.barcode(homology)
fig.update_layout(
        width=1000, 
        height=500,
        margin=dict(l=20, r=20, t=20, b=20)
    )
fig.show()

#### Cycle Rep
Find a cycle rep!

Interesting ones:
| Index | Reason |
|---|---|
| 4623 | Quick 3D Cycle |

In [6]:
homology[homology['dimension'] == 1]

Unnamed: 0_level_0,dimension,birth,death,birth simplex,death simplex,cycle representative,cycle nnz,bounding chain,bounding nnz
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
128,1,1.000000,inf,"[56, 79]",,simplex filtration coefficient 0 [56...,8,,
129,1,1.000000,inf,"[26, 40]",,simplex filtration coefficient 0 [26...,12,,
130,1,0.980198,inf,"[40, 166]",,simplex filtration coefficient 0 [4...,17,,
131,1,0.970297,1.000000,"[22, 63]","[13, 23, 63]",simplex filtration coefficient 0 [22...,15,simplex filtration coefficient 0...,156.0
132,1,0.960396,inf,"[54, 63]",,simplex filtration coefficient 0 [54...,19,,
...,...,...,...,...,...,...,...,...,...
459,1,0.475248,0.495050,"[122, 158]","[10, 122, 158]","simplex filtration coefficient 0 [122,...",5,simplex filtration coefficient 0 ...,3.0
460,1,0.475248,0.544554,"[15, 158]","[50, 77, 159]","simplex filtration coefficient 0 [15, 1...",6,simplex filtration coefficient 0 ...,6.0
461,1,0.475248,0.742574,"[15, 77]","[15, 50, 116]","simplex filtration coefficient 0 [15, ...",10,simplex filtration coefficient 0 ...,38.0
462,1,0.455446,0.584158,"[112, 152]","[27, 73, 132]","simplex filtration coefficient 0 [112,...",8,simplex filtration coefficient 0 ...,6.0


In [7]:
## Representative 2D Hole
# index of cycle to optimize in homology dataframe
i = 461 # i think this is a cool one

# optimization problem
start = time()
optimal = factored.optimize_cycle(
        birth_simplex=homology['birth simplex'][i], 
        problem_type='preserve PH basis'
    )
print(f'Optimizaiton took {time() - start} secs')

dirty_optimal = optimal.loc['optimal cycle', 'chain'] # round all the dumb and bad coefficicents
clean_optimal = dirty_optimal[round(dirty_optimal['coefficient'].astype(float)) != 0]
print(f'Removing {len(dirty_optimal)-len(clean_optimal)}/{len(dirty_optimal)} degenerate simplexes, {len(clean_optimal)} simplicies left')

# rep_cycle_nodes = list(np.array(G.nodes)[pd.to_numeric(clean_optimal['simplex'].explode().drop_duplicates()).tolist()])

unordered_edges = clean_optimal['simplex'].tolist()
rep_cycle_nodes = [unordered_edges.pop()[1]]
while len(unordered_edges) > 0:
    n = rep_cycle_nodes[-1]
    for i, e in enumerate(unordered_edges):
        if n in e:
            break
    if e[-1] == n:
        e.reverse()
    rep_cycle_nodes.append(e[1])
    unordered_edges.remove(e)
rep_cycle_nodes = list(np.array(G.nodes)[rep_cycle_nodes])

rep_cycle_nodes

Optimizaiton took 0.009147882461547852 secs
Finished construcing L1 optimization program.

Removing 0/9 degenerate simplexes, 9 simplicies left
Constraint matrix has 42 nonzero entries.
Passing program to solver.

Done solving.
MINILP solution: Solution { direction: Minimize, num_vars: 47, num_constraints: 70, objective: 3.920792079207921 }


['linear partial differential equation',
 'difference equations',
 'value problem',
 'calculus of variation',
 'continuous time system',
 'linear stochastic system',
 'fokker planck equation',
 'optimal control system',
 'class of problem']

In [8]:
start = time()
optimal = factored.optimize_bounding_chain(
        birth_simplex=homology['birth simplex'][i], 
    )
print(f'Optimizaiton took {time() - start} secs')

dirty_optimal = optimal.loc['optimal bounding chain', 'chain']
clean_optimal = dirty_optimal[round(dirty_optimal['coefficient'].astype(float)) != 0]
print(f'Removing {len(dirty_optimal)-len(clean_optimal)}/{len(dirty_optimal)} degenerate simplexes, {len(clean_optimal)} simplicies left')

bounding_chain_nodes = list(np.array(G.nodes)[pd.to_numeric(clean_optimal['simplex'].explode().drop_duplicates()).tolist()])
bounding_chain_nodes


Finished construcing L1 optimization program.
Constraint matrix has 31123 nonzero entries.
Passing program to solver.
Optimizaiton took 8.059895992279053 secs
Removing 427/430 degenerate simplexes, 3 simplicies left

Done solving.
MINILP solution: Solution { direction: Minimize, num_vars: 6868, num_constraints: 7050, objective: 2.7227722772277008 }
max difference in boundaries: Some(Ratio { numer: 134169738898737676, denom: 3042564363923662609 })


['deep neural network',
 'recurrent neural network',
 'variational inequalities',
 'variational principle']

In [None]:
np.random.seed(10)
MIN_YEAR = 1920
MAX_YEAR = 2021
denorm_year = lambda ny: int(ny * (MAX_YEAR-MIN_YEAR) + MIN_YEAR)

diff_nodes = list(set(bounding_chain_nodes) - set(rep_cycle_nodes)) # nodes not in cycle that help close it
cycle_G = G.subgraph(set(bounding_chain_nodes+rep_cycle_nodes)) # graph of all nodes in the cycle

# birth = denorm_year(homology.loc[i, 'birth']) # animate between birth and death
birth = min(nx.get_node_attributes(cycle_G, 'year').values())-1
death = denorm_year(homology.loc[i, 'death'])

theta = np.linspace(0, 2*np.pi, len(rep_cycle_nodes)+1)[:-1]
x = np.cos(theta)
y = np.sin(theta)

pos = dict()
for i, n in enumerate(rep_cycle_nodes):
    pos[n] = np.array([x[i], y[i]])
for n in diff_nodes:
    pos[n] = np.random.normal(0, 0.1, size=2)
if len(diff_nodes) > 0:
    pos = nx.spring_layout(cycle_G, k=1, pos=pos, fixed=rep_cycle_nodes)

def viz_graph(G, yr):
    # edge locations
    e_x = [] # edge x
    e_y = [] # edge y
    for e in G.edges:
        if cycle_G.edges[e]['year'] <= yr:
            u, v = e # edge goes from u to v
            u_x, u_y = pos[u] # u position
            v_x, v_y = pos[v] # v position
            e_x += [u_x, v_x, None]
            e_y += [u_y, v_y, None]

    edge_trace = go.Scatter(
            x=e_x, y=e_y,
            hoverinfo='none',
            mode='lines',
            line=dict(width=5, color='#888')
        )

    # node locations
    n_x = [] # node x
    n_y = [] # node y
    n_t = [] # node test
    for n in G.nodes:
        if cycle_G.nodes[n]['year'] <= yr:
            x, y = pos[n]
            n_x.append(x)
            n_y.append(y)
            n_t.append(n.title())
        
    node_trace = go.Scatter(
            x=n_x, y=n_y,
            hoverinfo='none',
            mode='markers+text',
            text=n_t,
            marker=dict(
                    size=25,
                    line_width=2
                )
        )
    
    return edge_trace, node_trace

# Add traces
frames = []
for yr in range(birth, death+1):
    edge_trace, node_trace = viz_graph(cycle_G, yr) # viz objects

    # add as frame
    frames.append(go.Frame(
            data=[edge_trace, node_trace],
            name=yr
        ))
    if yr == birth:
        edge_0_trace, node_0_trace = edge_trace, node_trace

# create figure
fig = go.Figure(data=[edge_0_trace, node_0_trace], frames=frames)

## the rest is coped from the plotly documentation example on mri volume slices
def frame_args(duration):
    return {
            'frame': {'duration': duration},
            'mode': 'immediate',
            'fromcurrent': True,
            'transition': {'duration': 0, 'easing': 'linear'},
        }
fig.update_layout(
        showlegend=False,
        width=500, 
        height=550,
        margin=dict(l=20, r=20, t=20, b=20),
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-1.2,1.2]),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-1.2,1.2]),
        updatemenus = [dict(
                buttons=[
                        dict(
                                args=[None, frame_args(200)],
                                label='&#9654;', # play symbol
                                method='animate'
                            ),
                        dict(
                                args=[None, frame_args(0)],
                                label='&#9724;', # play symbol
                                method='animate'
                            )
                    ],
                direction='left',
                pad=dict(l=0, r=0, t=10, b=10),
                type='buttons',
                x=0.1,
                y=0
            )],
        sliders=[
                dict(
                        pad=dict(l=15, r=0, t=10, b=10),
                        len=0.9,
                        x=0.1,
                        y=0,
                        steps=[dict(
                                args=[[f.name], frame_args(0)],
                                label=f.name,
                                method='animate'
                            ) for k, f in enumerate(fig.frames)],
                    )
            ]
)
fig.show()

: 