# Train Model

In [13]:
import sys
sys.path.append("..")
%reload_ext autoreload
%autoreload 1

In [14]:
import pandas as pd
from sklearn_pandas import cross_val_score, DataFrameMapper
from pandas import Categorical
from sklearn.ensemble import RandomForestClassifier

In [15]:
from ares.models import featurize

In [16]:

input_file = "../data/enwiki.labeling_revisions.w_features.nettrom_30k.csv.gz"
revisions = load_revisions(input_file)

# Create a pipeline

In [17]:
import pandas as pd
import numpy as np
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.preprocessing import FunctionTransformer, PolynomialFeatures

In [18]:
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegressionCV
from sklearn.base import BaseEstimator, TransformerMixin


sqrt_cols = ['words',
             'headings',
             'sub_headings',
             'images',
             'categories',
             'wikilinks',

             'who_templates',
             'main_templates',
             'cite_templates',
             # infobox as a binary
             'citation_needed',
             'other_templates',

             'ref',
             'smartlists',
             # coordinates
             'coordinates']

binarized_cols = ['coordinates', 'infoboxes']

mapper = DataFrameMapper([
    (sqrt_cols, FunctionTransformer(func=np.sqrt)),
    (binarized_cols, FunctionTransformer(func=lambda x: x.astype(bool)))
])

clf = LogisticRegressionCV(multi_class='multinomial', 
                           random_state=1234, 
                           penalty='l2',
                           fit_intercept=True, n_jobs=-1)

#clf = RandomForestClassifier(n_estimators=200, max_depth=None,
#                            random_state=1234, bootstrap=True,
#                            n_jobs = -1, oob_score=True)

In [19]:
pipe = Pipeline([
    ('mapper', mapper),
    ('poly', PolynomialFeatures(degree=2)),
    ('PCA', PCA(whiten=True)),
    ('clf', clf)
])

In [20]:
fitted = pipe.fit(X=revisions.copy(), y=revisions['wp10'])

In [21]:
import dill


In [22]:
import mwapi
session = mwapi.Session("https://en.wikipedia.org/", user_agent="<jeffrey.arnold@gmail.com>")

def get_current_page(session: mwapi.Session, title: str) -> dict:
    params = {'action': 'query', 'titles': title, 'prop': "revisions",
              'rvprop': 'ids|content', "rvslots": "main"}
    r = session.get(**params)
    page = list(r['query']['pages'].values())[0]
    rev = page['revisions'][0]
    rev['content'] = rev['slots']['main']['*']
    del rev['slots']
    return rev
    

- Add paragraph of text
- Add an external link (ignore character contribution)
- Add an internal link (ignore character contribution)
- Add image
- Add headings
- Add subheadings

In [23]:
class WikiPage:
    
    def __init__(self, content):
        self.data = pd.DataFrame.from_records([featurize(content)])
        
    def add_count(self, variable, value):
        """Add value to count number to ensure that it is not greater than """
        df = self.data.copy()
        df[variable] = df[variable] + value
        df.loc[df[variable] < 0] = 0
        return df
    
    def set_value(self, variable, value):
        """Set all values of a column to the same value"""
        df = self.data.copy()
        df[variable] = value
        return df    
   

In [24]:
page = WikiPage(get_current_page(session, "Ahvaz_military_parade_attack"))

In [25]:
pipe.predict(page.data)

array(['Stub'], dtype=object)

In [26]:
count_variables = (("words", 50),
             ("headings", 1),
             ("sub_headings", 1),
             ("images", 1),
             ("categories", 1),
             ("wikilinks", 1),
             ("cite_templates", 1),
             ("citation_needed", -1),
             ("who_templates", -1),
             ("smartlists", 1),
             ("ref", 1),
             ("coordinates", 1))

binary_variables = (("coordinates", 1), ("infobox", 1))

for x in count_variables:
    pipe.predict(page.add_count(*x))
for x in binary_variables:
    pipe.predict(page.set_value(*x))

In [27]:
page.data.columns

Index(['categories', 'citation_needed', 'cite_templates', 'coordinates',
       'external_links', 'headings', 'images', 'infoboxes', 'main_templates',
       'other_templates', 'ref', 'smartlists', 'sub_headings', 'text',
       'who_templates', 'wikilinks', 'words'],
      dtype='object')