In [None]:
import findspark
findspark.init('/usr/lib/spark2')
import wmfdata.spark as wmfspark

import pyspark
import re
import pyspark.sql
from pyspark.sql import *
import pandas as pd
import glob
import matplotlib.pyplot as plt
import hashlib
import random
import os.path
import shutil
import json
from pyspark.sql.functions import *
from datetime import timedelta, date
import uuid
from pyspark.sql.types import *
pd.set_option('display.float_format', lambda x: '%.5f' % x)
import copy

In [None]:
random.seed(42)

## Link Prediction Task

In [None]:
langlist = ['en', 'ru', 'ja', 'de', 'fr', 'it', 'pl', 'fa']
path_types = ['real_nav', 'gen_clickstream_private', 'gen_clickstream_public', 'gen_graph']

root_dir = os.path.abspath(os.path.join(os.getcwd(),os.pardir))
PATH_IN = os.path.join(root_dir, 'data')
PATH_OUT = root_dir

In [None]:
def get_s_counts(fname):
    s_count = {}
    for lnum, line in enumerate(open(fname)):
        if lnum == 0:
            continue
        line = line.strip().split(',')
        src_id = int(line[0]); count = int(line[1])
        s_count[src_id] = count
    return s_count

In [None]:
def get_st_counts(fname):
    st_count = {}
    for lnum, line in enumerate(open(fname)):
        if lnum == 0:
            continue
        line = line.strip().split(',')
        src_id = int(line[0]); trgt_id = int(line[1]); count = int(line[2])
        st_count[(src_id,trgt_id)] = count
    return st_count

In [None]:
def get_new_links(fname, st_count):
    new_links = {}; positive_sources = {}; positive_targets = {}
    num_pos_links = 0; num_neg_links = 0
    for lnum, line in enumerate(open(fname)):
        if lnum == 0:
            continue
        line = line.strip().split(',')
        src_id = int(line[0]); trgt_id = int(line[1])
        if (src_id, trgt_id) in st_count:
            positive_sources[src_id] = True; positive_targets[trgt_id] = True
            new_links[(src_id, trgt_id)] = 1
            num_pos_links += 1

    for key in st_count:
        src_id = key[0]; trgt_id = key[1]
#         if (src_id in positive_sources and trgt_id not in positive_targets) or (src_id not in positive_sources and trgt_id in positive_targets):
        if src_id in positive_sources and trgt_id in positive_targets and (src_id,trgt_id) not in new_links:
            new_links[(src_id, trgt_id)] = 0
            num_neg_links += 1
    print(f'Found {num_pos_links} positive and {num_neg_links} negative links')
    return new_links, num_pos_links, num_neg_links

In [None]:
def get_pst_new_links(new_links, st_count, s_count, s_count_full):
    links_pst = {}
    for link in new_links:
        src_id = link[0]; trgt_id = link[1]; label = new_links[link]
        if link in st_count:
            pst = st_count[link]/s_count[src_id]
            pst_all = st_count[link]/s_count_full[src_id]
        else:
            pst = 0; pst_all = 0
        links_pst[link] = (pst_all, pst, label)
    return links_pst

In [None]:
def precision_at_k(labels):
    p_at_k = []; num_correct = 0
    for k in range(0, len(labels)):
        if labels[k] == 1:
            num_correct+=1
        p_at_k.append(num_correct/(k+1))
    return p_at_k

In [None]:
precision_results = {}
for lang in langlist:
    print(f'{lang}wiki')
    new_links_fname = os.path.join(PATH_IN, 'graphs', lang, f'{lang}wiki_new_links_2021-04.csv')
    stats_fout = open(os.path.join(PATH_OUT, 'downstream_tasks', 'link_prediction_results', f'{lang}wiki_link_stats.tsv'), "w")
    precision_fout = open(os.path.join(PATH_OUT, 'downstream_tasks', 'link_prediction_results', f'{lang}wiki_precision_at_k.tsv'), "w")
    
    stats_fout.write('#PositiveLinks\t#NegativeLinks\n')
    
    precision_results[lang] = {}
    for path_type in path_types:
        print(f'Paths: {path_type}')
        st_counts_fname = os.path.join(PATH_IN, 'link_prediction', lang, f'st_counts_atleast10paths_{path_type}.csv')
        s_counts_fname = os.path.join(PATH_IN, 'link_prediction', lang, f's_counts_atleast10paths_{path_type}.csv')
        s_counts_full_fname = os.path.join(PATH_IN, 'link_prediction', lang, f's_counts_{path_type}.csv')
        
        s_count = get_s_counts(s_counts_fname)
        s_count_full = get_s_counts(s_counts_full_fname)
        
        st_count = get_st_counts(st_counts_fname)
        
        if path_type == 'real_nav':
            new_links, num_pos_links, num_neg_links = get_new_links(new_links_fname, st_count)
            stats_fout.write(f'{num_pos_links}\t{num_neg_links}\n')
            precision_fout.write('PathType\t')
            for k in [1, 5, 10, 50, 100, 250, 500, 1000, 5000, 10000, 20000]:
                if k > len(new_links):
                    break
                precision_fout.write(f'P@{k}\t')
            precision_fout.write("\n")
        print(len(new_links))
        links_pst = get_pst_new_links(new_links, st_count, s_count, s_count_full)

        pred_labels_by_allpaths = [v[2] for k, v in sorted(links_pst.items(), key=lambda item: item[1][0], reverse=True)]
        p_at_k = precision_at_k(pred_labels_by_allpaths[0:50000])
        precision_results[lang][path_type] = p_at_k
        precision_fout.write(f'{path_type}\t')
        for k in [1, 5, 10, 50, 100, 250, 500, 1000, 5000, 10000, 20000]:
            if k > len(p_at_k):
                break
            precision_fout.write(f'{p_at_k[k-1]}\t')
            print(f'P@{k} = {p_at_k[k-1]}')
        precision_fout.write('\n')
        precision_fout.flush()
    precision_fout.close()
    stats_fout.close()

In [None]:
import seaborn as sns
sns.set_theme(style="white")
sns.color_palette('colorblind')
import matplotlib.pyplot as plt

In [None]:
def subplot_plot(ax, lang, char_code, results):
    labels = {'real_nav': 'Logs', 'gen_clickstream_private': 'Clickstream-Priv', 'gen_clickstream_public': 'Clickstream-Pub', 'gen_graph': 'Graph'}
    for path_type in path_types:
        ax.plot(results[path_type], label = labels[path_type], linewidth = 5)

    ax.set_xscale('log')
    if lang in ['en', 'ru', 'ja', 'de']:
        yval=-0.32
    else:
        yval=-0.45
    ax.set_title(f'{char_code} {lang.upper()}', y=yval, fontsize=24)
    ax.tick_params(axis='both', which='major', labelsize=20)

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(28,9), gridspec_kw={'hspace': 0.37, 'wspace': 0.1})
axis_map = {'en': (0,0), 'ja': (0,1), 'de': (0,2), 'ru': (0,3), 'fr': (1,0), 'it': (1,1), 'pl': (1,2), 'fa': (1,3)}
char_code = {'en': '(a)', 'ja': '(b)', 'de': '(c)', 'ru': '(d)', 'fr': '(e)', 'it': '(f)', 'pl': '(g)', 'fa': '(h)'}
for lang in langlist:
    print(f'{lang}wiki')
    idx, idy = axis_map[lang]
    subplot_plot(axes[idx,idy], lang, char_code[lang], precision_results[lang])

for ax in axes.flat:
    ax.set_xlabel('Rank k', fontsize=22)
    ax.set_ylabel('Precision@k', fontsize=22)
    
# Hide x labels and tick labels for top plots and y ticks for right plots.
for ax in axes.flat:
    lastrow = ax.is_last_row()
    firstcol = ax.is_first_col()
    if not lastrow:
        ax.set_xlabel("")
    if not firstcol:
        for label in ax.get_yticklabels(which="both"):
            label.set_visible(False)
        ax.get_yaxis().get_offset_text().set_visible(False)
        ax.set_ylabel("")

lines, labels = axes[1,3].get_legend_handles_labels()
fig.legend(lines, labels, fontsize=24, bbox_to_anchor=(0.5, 0.98), frameon=False, ncol=4, loc = 'upper center')
fig.savefig(os.path.join(PATH_OUT, 'downstream_tasks', 'link_prediction_results', f'link_pred_precision_plots.png'), dpi=300, bbox_inches = "tight")