In [None]:
import copy
import csv
import datetime
import math
import random

from matplotlib import pyplot as plt
import numpy as np
from numpy.linalg import norm
import pandas as pd
import os,sys
import json
import bz2
import re
import glob
import time
import h5py

from tokenizer import _tokenize
import pickle
import utils
import nltk
from nltk.corpus import stopwords
from sklearn.decomposition import PCA

import mwxml
import mwparserfromhell as mwph
import fasttext
import fasttext.util
forbidden_link_prefixes = ['category', 'image', 'file'] ## for english
from mwtext.wikitext_preprocessor import WikitextPreprocessor

In [None]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.current_device())
torch.cuda.set_device('cuda:0')

## Parameters

In [None]:
root_dir = os.path.abspath(os.path.join(os.getcwd(),os.pardir))
PATH_IN = os.path.join(root_dir, 'data', 'xml_dumps')
PATH_OUT = os.path.join(root_dir, 'data', 'article_embeddings')
langlist = ['en', 'ru', 'ja', 'de', 'fr', 'it', 'pl', 'fa']
snapshot = '20210401'
N_articles_max = -1 ## maxmimum number of articles to parse (put -1 for all)
N_dim = 300 ## number of dimensions for word-vectors (default 300, can be reduced)
N_cores_max = 20 ## maximum number of cores to use for parallel parsing

## Loading the Models

In [None]:
vectors_fasttext = {}
print('loading model')

for lang in ['en', 'ru', 'ja', 'de', 'fr', 'it', 'pl', 'fa']:
    vectors_fasttext[lang] = utils.loadWordVectors(os.path.join(root_dir, 'data', 'pretrained_embeddings', f'cc.{lang}.300.vec'))

print("Embeddings Loaded")

In [None]:
for lang in ['en', 'ru', 'ja', 'de', 'fr', 'it', 'pl', 'fa']:
    print(len(vectors_fasttext[lang]))

In [None]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('xlm-r-distilroberta-base-paraphrase-v1')

print('ready')

In [None]:
def get_sentence_embedding_xlm(sentences, emb_model):
    embedding_array = []
    embeddings = emb_model.encode(sentences)
    for embedding in embeddings:
        embedding_array.append((embedding/norm(embedding)).tolist())
    embedding_array = np.array(embedding_array, dtype='float64')
    return embedding_array

## Getting xml dump chunks downloaded from Internet Archive

In [None]:
def get_xml_chunks_local(lang, wiki, snapshot):
    paths = []
    dump_fn = os.path.join(PATH_IN, f'{lang}_{snapshot}', f'{wiki}-{snapshot}-pages-articles.xml.bz2')
    for infile in glob.glob('{0}/{1}_{3}/{2}-{3}-pages-articles*.xml*.bz2'.format(PATH_IN, lang, wiki,snapshot) ):
        if infile == dump_fn:
            continue
        if 'multistream' in infile:
            continue
        paths += [infile]
    if len(paths) == 0:
        paths+=[dump_fn]

    return paths

In [None]:
def get_description_embeddings(text_tokens_processed, text_tokens, embeddingType, emb_model):
    if embeddingType == 'xlm':
        embedding_array = get_sentence_embedding_xlm(text_tokens, emb_model)
    else:
        embedding_array = utils.get_sentence_embedding(text_tokens_processed, emb_model)
    return embedding_array

In [None]:
forbidden_link_prefixes = ['category', 'image', 'file'] ## for english
wtpp = WikitextPreprocessor(forbidden_link_prefixes)

def page_to_vector(dump, path):
    for page in dump:
        # talk pages for existing articles
        if page.namespace == 0 and page.redirect is None:
           ## go to most recent revision
            for rev in page: pass 
            ## get wikitext of last revision
            wikitext = rev.text
            ## get only first section
            first_section = re.search('={2,}.+?={2,}', wikitext)
            if first_section:
                wikitext = wikitext[:first_section.span()[0]]
            # concatenate paragpahs as one list of tokens (text)
            text = wtpp.process(wikitext)
            text_tokens = []; text_tokens_processed = []
            for paragraph in text:
                paragraph_text  = " ".join(paragraph)
                paragraph_text_without_sw = [word for word in _tokenize(paragraph_text) if not word in all_stopwords]
                text_tokens.append(paragraph_text)
                text_tokens_processed += paragraph_text_without_sw

            dict_page = {
                'page_id':page.id, 
                'rev_id':rev.id, 
                'page_title':page.title, 
                'text_tokens': text_tokens,
                'text_tokens_processed': text_tokens_processed
            }
            yield dict_page

In [None]:
def save_article_descriptions(files, wiki, snapshot):
    N_articles = 0
    N_articles_kept = 0
    threads = min([N_cores_max, len(files)])
    print(f'Threads = {threads}')

    print(f'processing dump {wiki}-{snapshot}')
    t1 = time.time()

    fout = bz2.open(os.path.join(PATH_OUT, f'article-descriptions_{wiki}-{snapshot}.jsonl.bz2'), 'wt')

    for dict_page in mwxml.map(page_to_vector, files, threads=threads):    
        fout.write(json.dumps(dict_page) + '\n')
        N_articles_kept += 1
        if N_articles_kept%100000==0:
            print('... processed %s articles in %.2f'%(N_articles_kept,time.time()-t1))
        if N_articles_kept==N_articles_max:
            break
    t2 = time.time()
    print('done in %s seconds'%( t2-t1))

    fout.close()
    return N_articles_kept

In [None]:
def get_article_embeddings(wiki, snapshot, embeddingType, emb_model):
    N_articles_read = 0
    ferr = open(os.path.join(PATH_OUT, f'{wiki}-{snapshot}_article_description_{embeddingType}_embeddings.error'), "w")

    entity_embeddings = {}
    num_missing_articles = 0

    print(f'Computing {embeddingType} embeddings for {wiki}-{snapshot}')
    t1 = time.time()

    xlm_embeddings_for_pca = []
    with bz2.open(os.path.join(PATH_OUT, f'article-descriptions_{wiki}-{snapshot}.jsonl.bz2'), 'rt') as fin:
        for line in fin:
            article_obj = json.loads(line)
            page_id = article_obj['page_id']
            rev_id = article_obj['rev_id']
            page_title = article_obj['page_title']
            text_tokens_processed = article_obj['text_tokens_processed']
            text_tokens = article_obj['text_tokens']

            embedding_array = get_description_embeddings(text_tokens_processed, text_tokens, embeddingType, emb_model)
            if embedding_array.size == 0: # None of the description words had an embedding
                ferr.write(f'No embeddings in {embeddingType} for {page_id}, {rev_id}, {page_title}, with text: {text_tokens} \n')
                num_missing_articles += 1
                continue
            sent_embedding = np.average(embedding_array, axis=0)
            sent_embedding /= norm(sent_embedding)
            entity_embeddings[page_id] = sent_embedding
            if embeddingType == 'xlm' and random.uniform(0,1)<0.16:
                xlm_embeddings_for_pca.append(sent_embedding.tolist())

            N_articles_read += 1
                if N_articles_read%100000==0:
                print('... processed %s articles in %.2f'%(N_articles_read,time.time()-t1))
            if N_articles_read==N_articles_max:
                break

    t2 = time.time()
    print('done in %s seconds'%( t2-t1))

    try:
        ferr.close()
    except:
        print('Error file already closed!')
    
    return num_missing_articles, entity_embeddings, xlm_embeddings_for_pca

In [None]:
langlist = ['en', 'ru', 'ja', 'de', 'fr', 'it', 'pl', 'fa']
N_articles_max = -1

for lang in langlist:
    wiki = f'{lang}wiki'
    if lang == 'en':
        all_stopwords = stopwords.words('english')
        all_stopwords.remove('not')
    elif lang == 'de':
        all_stopwords = stopwords.words('german')
    elif lang == 'fr':
        all_stopwords = stopwords.words('french')
    elif lang == 'ru':
        all_stopwords = stopwords.words('russian')
    elif lang == 'it':
        all_stopwords = stopwords.words('italian')
    else:
        all_stopwords = []

    paths = get_xml_chunks_local(lang, wiki, snapshot)
    print(wiki, len(paths))
    num_articles = save_article_descriptions(paths, wiki, snapshot)
    print(f'There are {num_articles} articles in {wiki}-{snapshot}')

In [None]:
for embeddingType in ['fasttext']:
    for lang in langlist:
        wiki = f'{lang}wiki'

        if embeddingType == 'fasttext':
            model_dict = vectors_fasttext[lang]
        else:
            model_dict = model
        num_missing_articles, entity_embeddings, xlm_embeddings_for_pca = get_article_embeddings(wiki, snapshot, embeddingType, model_dict)
        print(f'There are {num_missing_articles} articles in {wiki} with missing embeddings for {embeddingType}')
        pickle.dump(entity_embeddings, open(os.path.join(PATH_OUT, f'article-description-embeddings_{wiki}-{snapshot}-{embeddingType}.pickle'), "wb"))
        print(len(xlm_embeddings_for_pca))