In [1]:
%load_ext autoreload
%env CUDA_VISIBLE_DEVICES=0
import sys
sys.path.append('..')

env: CUDA_VISIBLE_DEVICES=0


In [2]:
%autoreload
from graph_augmented_pt.utils.tensorboard_utils import *

from graphlet_atlas import *
from synthetic_datasets import *
from synthetic_runner import *
from simplicial_manfiolds import *

import matplotlib, pandas as pd
pd.options.mode.chained_assignment = None 
%matplotlib inline

from collections import Counter, defaultdict
import copy, itertools, json, logging, math, os, pickle, scipy, shutil, time, numpy as np
from tqdm.notebook import tqdm
from scipy.stats import spearmanr

from pathlib import Path
from IPython.display import Image 

In [3]:
def nCr(n,r):
    f = math.factorial
    return int(f(n) / f(r) / f(n-r))

# Generating Synthetic Data

In [139]:
RAW_FILEPATH = '/crimea/graph_augmented_pt/synthetic_datasets/wikisent2.txt'
PKL_FILEPATH = '/crimea/graph_augmented_pt/synthetic_datasets/wikisent2_feat.pkl'
SIMPLEX_ORDER_FILEPATH = '/crimea/graph_augmented_pt/synthetic_datasets/topics_to_try.pkl'
TOPIC_CLIQUE_FILEPATH = '/crimea/graph_augmented_pt/synthetic_datasets/topic_clique.json'

assert os.path.isfile(RAW_FILEPATH)
assert os.path.isfile(PKL_FILEPATH)

# Source: https://www.kaggle.com/mikeortman/wikipedia-sentences
with open(RAW_FILEPATH, mode='r') as f: text_data = f.readlines()

with open(PKL_FILEPATH, mode='rb') as f:
    X, LDA, topics, first_topic, sents_by_topic, topic_correlations = pickle.load(f)
n_sents_by_topic = {t: len(sents) for t, sents in sents_by_topic.items()}

with open(SIMPLEX_ORDER_FILEPATH, mode='rb') as f:
    topics_to_try = pickle.load(f)
    
with open(TOPIC_CLIQUE_FILEPATH, mode='r') as f:
    topic_clique_hint = json.loads(f.read())

Exception ignored in: <function tqdm.__del__ at 0x7f4dafc060d0>
Traceback (most recent call last):
  File "/crimea/conda_envs/graph_augmented_pt_2/lib/python3.8/site-packages/tqdm/std.py", line 1121, in __del__
    def __del__(self):
KeyboardInterrupt: 


##### A third different Approach

In [132]:
def expand(tset, valid_simplices, base_ts=np.arange(100)):
    tset = list(tset)

    ts = []
    for t1 in set(base_ts) - set(tset):
        can_include = True
        for t2, t3 in itertools.combinations(tset, 2):
            if not frozenset((t1, t2, t3)) in valid_simplices:
                can_include = False
                break
                
        if can_include: ts.append(t1)
        
    return set(ts)
        
def maximally_expand(tset, valid_simplices, base_ts=np.arange(100), depth=0, memoization_dict=None):
    if memoization_dict is None: memoization_dict = {}
        
    if frozenset(tset) in memoization_dict: return memoization_dict[frozenset(tset)]
        
    ts = expand(tset, valid_simplices, base_ts=base_ts)
    if len(ts) == 0: return tset
    
    max_opt = tset
    ts_rng = ts if (depth > 2 or len(ts) < 10) else tqdm(ts, leave=False, desc="Expanding")
    for t in ts_rng:
        query_tset = frozenset([t, *tset])
        new_tset = maximally_expand(
            query_tset, valid_simplices, depth=depth+1, memoization_dict=memoization_dict
        )
        memoization_dict[query_tset] = new_tset
        
        if len(new_tset) > len(max_opt): max_opt = new_tset
            
    return max_opt

In [133]:
fake_valid_simplices = set(
    frozenset(s) for s in list(itertools.combinations(np.arange(6), 3)) + [(0, 1, 7), (7, 8, 9), (4, 8, 9)]
)

fake_valid_simplices_list = list(fake_valid_simplices)
random.shuffle(fake_valid_simplices_list)

o = []
for s in fake_valid_simplices_list:
    m = maximally_expand(s, fake_valid_simplices)
    if len(m) > len(o): o = m
        
print(o)

frozenset({0, 1, 2, 3, 4, 5})


In [135]:
topics_cp               = copy.deepcopy(topics)
min_sents_per_simplex   = 25
topics_thresh           = 2/3

N = len(topics_cp)

top_3_st = time.time()
first_topics  = np.argmax(topics_cp, axis=1)
topics_cp[np.arange(N), first_topics] = 0
second_topics = np.argmax(topics_cp, axis=1)
topics_cp[np.arange(N), second_topics] = 0
third_topics = np.argmax(topics_cp, axis=1)
top_3_end = time.time()

reindex_st = time.time()
top_3 = np.vstack((first_topics, second_topics, third_topics)).T
top_3_probs = topics[np.arange(N), [first_topics, second_topics, third_topics]].T

obs_probability_mass = top_3_probs.sum(axis=1)
valid_sents_mask = (obs_probability_mass > topics_thresh)
valid_sents_idx,  = np.where(valid_sents_mask)

topics_cp   = topics_cp[valid_sents_mask]
top_3       = top_3[valid_sents_mask]
top_3_probs = top_3_probs[valid_sents_mask]
reindex_end = time.time()

print(
    f"Dropping {len(topics) - len(top_3)} sentences as they lack sufficient probability mass in their top-3.\n"
    f"It took {(reindex_end - reindex_st)/60:.1f} minutes to do that check & drop."
)

cnt_assignments_st = time.time()
all_observed_topic_simplices = Counter(frozenset(t) for t in top_3)
cnt_assignments_end = time.time()

print(
    f"Observe {len(all_observed_topic_simplices)} simplices (of {nCr(100, 3)} total possible) "
    f"in total across {len(topics_cp)} sentences.\n"
    f"It took {(top_3_end - top_3_st)/60:.1f} minutes to get the top 3 topics / sent and "
    f"{(cnt_assignments_end - cnt_assignments_st)/60:.1f} minutes to get the counts."
)

first_filtering_st      = time.time()
valid_simplices         = set(k for k, v in all_observed_topic_simplices.items() if v >= min_sents_per_simplex)
valid_simplex_checker   = lambda np_arr: np.array([frozenset(row) in valid_simplices for row in np_arr])
sufficiently_dense_mask = valid_simplex_checker(top_3)

topics_cp               = topics_cp[sufficiently_dense_mask]
top_3                   = top_3[sufficiently_dense_mask]
top_3_probs             = top_3_probs[sufficiently_dense_mask]
valid_sents_idx         = valid_sents_idx[sufficiently_dense_mask]
first_filtering_end     = time.time()

print(
    f"After filtering out insufficiently dense simplices, we have {len(valid_simplices)}/{len(topics_cp)} "
    f"simplices / sentences, respectively. This process took "
    f"{(first_filtering_end - first_filtering_st)/60:.1f} minutes"
)

global_selection = []
containing_maximal_cliques = {}

Dropping 2178743 sentences as they lack sufficient probability mass in their top-3.
It took 0.0 minutes to do that check & drop.
Observe 158547 simplices (of 161700 total possible) in total across 5693082 sentences.
It took 0.0 minutes to get the top 3 topics / sent and 0.2 minutes to get the counts.
After filtering out insufficiently dense simplices, we have 46965/4645898 simplices / sentences, respectively. This process took 0.2 minutes


In [None]:
to_complete_subgraph_st = time.time() 
valid_simplices_list = list(valid_simplices)
random.shuffle(valid_simplices_list)

valid_simplices_rng = tqdm(
    np.arange(len(valid_simplices_list)), desc="Complete Subgraph: 3 so far"
)
for tset_idx in valid_simplices_rng:
    tset = valid_simplices_list[tset_idx]
    
    max_expansion = maximally_expand(
        tset, valid_simplices, memoization_dict=containing_maximal_cliques, 
    )
    
    if len(max_expansion) <= len(global_selection): continue
        
    global_selection = max_expansion
    valid_simplices_rng.set_description(f"Complete Subgraph: {len(global_selection)} so far")
        
    # Given we found a new optimum, we want to take advantage of that.
    
    fresh_optima = True
    while fresh_optima:
        fresh_optima = False
        for subclique_size in range(len(global_selection)-1, 2, -1):
            for tset in itertools.combinations(global_selection, subclique_size):
                m = maximally_expand(tset, valid_simplices, memoization_dict=containing_maximal_cliques)
                if len(m) > len(global_selection):
                    global_selection = m
                    fresh_optima = True
        
valid_topics = global_selection
valid_simplices = {t_set for t_set in valid_simplices if t_set.issubset(valid_topics)}
assert len(valid_simplices) == nCr(len(valid_topics), 3)
    
simplex_valid_mask = np.array([frozenset(row) in valid_simplices for row in top_3])

topics_cp                = topics_cp[simplex_valid_mask]
top_3                    = top_3[simplex_valid_mask]
top_3_probs              = top_3_probs[simplex_valid_mask]
valid_sents_idx          = valid_sents_idx[simplex_valid_mask]
to_complete_subgraph_end = time.time()

print(
    "After filtering out simplices that are not universally compatible, we have "
    f"{len(valid_simplices)}/{len(topics_cp)} "
    f"simplices / sentences, respectively. This process took "
    f"{(to_complete_subgraph_end - to_complete_subgraph_st)/60:.1f} minutes"
)


normalization_st = time.time()
normalized = np.divide(top_3_probs, top_3_probs.sum(axis=1)[:, np.newaxis])
entropy    = -(normalized * np.log(normalized)).sum(axis=1)
entropy_per_simplex = defaultdict(list)
for i, (e, ts) in enumerate(zip(entropy, top_3)): entropy_per_simplex[frozenset(ts)].append(e)
    
agg_entropy_per_simplex = {
    k: (np.min(es), np.max(es), np.histogram(es)) for k, es in entropy_per_simplex.items()
}
normalization_end = time.time()

print(f"Normalizing & computing entropy took {(normalization_end - normalization_st)/60:.1f} minutes")

In [140]:
with open(TOPIC_CLIQUE_FILEPATH, mode='w') as f:
    f.write(
        json.dumps([int(t) for t in global_selection]),
    )

In [150]:
with open('/crimea/graph_augmented_pt/synthetic_datasets/topic_containing_maximal_cliques.pkl', mode='wb') as f:
    pickle.dump(containing_maximal_cliques, f)