In [None]:
from notebook_prelude import *
import experiments

In [None]:
data = collections.defaultdict(lambda: [])
cmap_cache_files = dataset_helper.get_all_cached_graph_datasets(graph_type=TYPE_CONCEPT_MAP)
for file in helper.log_progress(cmap_cache_files):
    dataset = filename_utils.get_dataset_from_filename(file)
    X, Y = dataset_helper.get_dataset_cached(file)
    X = graph_helper.get_graphs_only(X)
    
    all_labels = set(graph_helper.get_all_node_labels(X))
    data['dataset'] += [dataset] * len(all_labels)
    data['labels'] += [str(x) for x in all_labels]

df = pd.DataFrame(data)
df['num_words'] = df['labels'].str.split().apply(len)
df = df.set_index(['dataset', 'labels'])

In [None]:
num_datasets = len(df.reset_index().dataset.unique())
fig, axes = plt.subplots(ncols=2, nrows=int(np.ceil(num_datasets / 2)), sharex=False)

for ax, (dataset, df_) in zip(axes.flatten()[:num_datasets], df.groupby('dataset')):
    df_.reset_index().set_index('labels').num_words.plot(kind='hist', bins=50, ax=ax, title=dataset)
    ax.grid('off')
    if df_.num_words.max() < 20:
        labels = list(range(1, df_.num_words.max()))
        ax.set_xticks(labels)
        ax.set_xticklabels(labels)

fig.tight_layout()

In [None]:
df.groupby('dataset').num_words.describe()

## Single word distribution

In [None]:
import nltk
from nltk.corpus import stopwords
stopwords = set(stopwords.words('english')) | set([',', 'one', 'two'])

In [None]:
all_labels_splitted = collections.defaultdict(lambda: [])
for idx, df_ in df.reset_index().iterrows():
    dataset = df_.dataset
    if df_.num_words > 1:
        all_labels_splitted[dataset] += df_['labels'].split()

In [None]:
data = collections.defaultdict(lambda: [])
for dataset, single_words in all_labels_splitted.items():
    c = collections.Counter(single_words)
    data['dataset'] += [dataset] * len(c.keys())
    keys, vals = zip(*c.items())
    data['label'] += keys
    data['occurrences'] += vals

df_single_word_count = pd.DataFrame(data).sort_values('occurrences', ascending=False)
df_single_word_count_no_stopwords = df_single_word_count[df_single_word_count['label'].apply(lambda x: x not in stopwords)]

In [None]:
num_datasets = len(df_single_word_count_no_stopwords.dataset.unique())
fig, axes = plt.subplots(ncols=2, nrows=int(np.ceil(num_datasets / 2)))
ax = df_single_word_count_no_stopwords.hist(column='occurrences', bins=120, by='dataset', log=True, ax = axes.flatten()[:num_datasets])
for x in axes.flatten():
    x.set_yscale('log')
    x.grid(True,which="both",ls="-")
fig.tight_layout()
#ax.set_xlabel('word occurrences')
#df_single_word_count[df_single_word_count.index.map(lambda x: x not in stopwords)]

In [None]:
from preprocessing import preprocessing

df_ = df.reset_index()
df_[(df_.dataset=='ng20') & (df_.num_words > 10)]

df_['label_clean'] = df_['labels'].apply(preprocessing.preprocess)

In [None]:
df_[(df_.dataset=='ng20') & (df_.num_words > 10)]

## Splitting labels into new nodes

In [None]:
from transformers.graph_multi_word_label_splitter import GraphMultiWordLabelSplitter

dataset = 'review_polarity'
graph_type = TYPE_CONCEPT_MAP
#graph_type = TYPE_COOCCURRENCE
cmap_cache_file = dataset_helper.get_all_cached_graph_datasets(dataset_name=dataset, graph_type=graph_type)[0]
X, Y = dataset_helper.get_dataset_cached(cmap_cache_file)
X = graph_helper.get_graphs_only(X)

In [None]:
trans = GraphMultiWordLabelSplitter(add_self_links=False, copy=True)
X_ = trans.transform(X)

In [None]:
figsize = (20, 10)
candidates = [idx for idx, graph in enumerate(X) if len(graph.nodes()) < 10]
idx = np.random.choice(candidates)
print(idx)
graph = X[idx]
graph_ = X_[idx]

fig, axes = plt.subplots(ncols=2, figsize=figsize)
for ax, g, title in zip(axes, [graph, graph_], ['Before', 'After']):
    pos = nx.layout.shell_layout(g)
    nx.draw_networkx(g, pos=pos, node_size=3, ax = ax)
    edges = [(source, node, data['name']) for source, node, data in g.edges(data=True)]
    edges_ = {(source, node): label for source, node, label in edges}
    nx.draw_networkx_edge_labels(g, pos=pos, edge_labels=edges_, ax = ax)
    ax.grid(False)
    ax.set_xticks([])
    ax.set_title(title)
fig.suptitle('Splitting multi-word node labels')
fig.tight_layout()
fig.subplots_adjust(top=0.92)