In [1]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

import pickle

from random import sample, seed
from tqdm import tqdm
from utils.propagators.directed import PropagateDirected
from utils.metrics import accuracy

seed(42)

In [2]:
# Load Graph.
G = nx.read_gexf("../../state_files/PyPi Network V4.gexf")
# Keep only giant component.
gc_nodes = sorted(nx.connected_components(G.to_undirected()), key = lambda x: len(x), reverse=True)[0]
not_gc_nodes = set(G.nodes()) - gc_nodes
G.remove_nodes_from(not_gc_nodes)

# Load labels.
with open("../../state_files/PyPi Dataframe V4.pickle", "rb") as f:
    df = pickle.load(f)
labels = df["Topic"].dropna().to_dict()
# Remove 'Other' topic.
for n, l in labels.items():
    if 'Other/Nonlisted Topic' in l:
        l.remove('Other/Nonlisted Topic')
# Filter nodes with no topics.
labels = dict(filter(lambda x: len(x[1]) > 0, labels.items()))
# Keep nodes from giant component.
labels = dict(filter(lambda x: x[0] in gc_nodes, labels.items()))

# Split train and test set.
test_size = 0.1

test_nodes = sample(list(labels.keys()), int(test_size*len(labels.keys())))
train_nodes = list(filter(lambda x: x not in test_nodes, labels.keys()))

train_labels = dict(filter(lambda x: x[0] in train_nodes, labels.items()))
test_labels = dict(filter(lambda x: x[0] in test_nodes, labels.items()))

In [None]:
pl = PropagateDirected(G, train_labels, method = "global")
final_labels = pl.propagate_all()
acc = accuracy(test_labels, final_labels)

100%|█████████████████████████████████████████████████████████████████████████████| 54335/54335 [25:15<00:00, 35.86it/s]
  1%|█▏                                                                            | 1426/96960 [00:32<31:57, 49.83it/s]

In [16]:
from collections import Counter

label_collector = []
for n, l in train_labels.items():
    label_collector += l

# Pesos globales -> El peso para elegir cada tópico va a ser peso local / peso global.
# Esto debería calcularlo el Propagador en su clase Padre.
train_fracs = {t: v/len(label_collector) for t, v in Counter(label_collector).items()}
train_fracs = dict(sorted(train_fracs.items(), key = lambda x: x[1], reverse=True))

In [17]:
def balanced_accuracy(test_labels: dict, pred_labels: dict) -> float:
    """
    
    """
    
    # Get topics and percentage of each in test set.
    label_collector = []
    for n, l in test_labels.items():
        label_collector += l

    topics = set(label_collector)
    test_fracs = {t: v/len(label_collector) for t, v in Counter(label_collector).items()}
    test_fracs = dict(sorted(test_fracs.items(), key = lambda x: x[1], reverse=True))
    
    # Init dicts.
    times_seen = {t: 0 for t in topics}
    matches = {t: 0 for t in topics}
    
    # Find nodes that were predicted.
    found_test_labels = {n: l for n, l in test_labels.items() if n in pred_labels.keys()}

    for n, l in found_test_labels.items():
        if not isinstance(l, list):
            raise Exception("No debería de haber etiquetas de testeo que no sean listas.")          

        # Revisamos que la la etiqueta encontrada esté en la lista de las de testeo.
        if isinstance(pred_labels[n], str):
            for _l in l:
                times_seen[_l] += 1
            if pred_labels[n] in l: 
                matches[pred_labels[n]] += 1

        # En caso de que sea una lista la encontrada, nos fijamos que haya una coincidencia por lo menos.
        elif isinstance(pred_labels[n], list):
            for _l in l:
                times_seen[_l] += 1
            if pred_labels[n] in l: 
                matches[pred_labels[n]] += 1

        #No debería de haber etiquetas propagadas que no sean str o list.
        else:
            raise Exception("No debería de haber etiquetas finales que no sean lista o str.")
    
    accuracy_per_topic = {t: matches[t]/times_seen[t] for t in topics}
    bal_acc = sum(accuracy_per_topic.values())/len(accuracy_per_topic.values())
    return bal_acc, times_seen
    
acc_x_t, ts = balanced_accuracy(test_labels, final_labels)

In [18]:
print("Balanced Accuracy: Modelo de propagación aleatoria.")
print(f"{acc_x_t*100:.0f}%")

print("Balanced Accuracy: Modelo de tirar un dado de 23 caras")
label_collector = []
for n, l in test_labels.items():
    label_collector += l

topics = set(label_collector)
matches = {t: int(ts[t]/len(ts)) for t in topics}
accuracy_per_topic = {t: matches[t]/ts[t] for t in topics}
print(f"{sum(accuracy_per_topic.values())/len(accuracy_per_topic.values())*100:.0f}%")

print("Balanced Accuracy: Modelo de etiquetar TODO con la etiqueta mas común")
print(f'{1 * (ts["Software Development"]/sum(ts.values())) / len(ts)*100:.0f}%')

Balanced Accuracy: Modelo de propagación aleatoria.
13%
Balanced Accuracy: Modelo de tirar un dado de 23 caras
3%
Balanced Accuracy: Modelo de etiquetar TODO con la etiqueta mas común
2%
