In [56]:
import random
import pandas as pd
import numpy as np

from itertools import combinations
from scipy import stats
from collections import Counter
from tqdm import tqdm

from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import metrics
from sklearn.model_selection import train_test_split

from news_vec.corpus import HeadlineDataset
from news_vec.encoder import read_preds

In [8]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import altair as alt
import seaborn as sns

mpl.style.use('seaborn-muted')
sns.set(style="whitegrid")

In [64]:
def random_subseq(L, size):
    i = random.randint(0, len(L)-size)
    return L[i:i+size]

In [9]:
ds = HeadlineDataset.load('../data/ava.p')

In [10]:
ds

HeadlineDataset<225696/28212/28212>

In [47]:
df = pd.DataFrame([r for r, _ in ds])

In [48]:
len(df)

282120

In [49]:
df = df[df.clf_tokens.apply(lambda ts: len(ts) >= 5)]

In [50]:
len(df)

276495

In [65]:
df['clf_tokens_5'] = df.clf_tokens.apply(lambda ts: random_subseq(ts, 5))

In [68]:
train_df, test_df = train_test_split(df)

In [80]:
def train_model(tokens_key='clf_tokens', binary=False):
    
    X_train, y_train = train_df[tokens_key], train_df.domain
    X_test, y_test = test_df[tokens_key], test_df.domain
    
    tv = TfidfVectorizer(
        analyzer='word',
        tokenizer=lambda x: x,
        preprocessor=lambda x: x,
        ngram_range=(1,3),
        token_pattern=None,
        binary=binary,
    )
    
    X_train = tv.fit_transform(X_train)
    X_test = tv.transform(X_test)
    
    clf = LinearSVC()
    fit = clf.fit(X_train, y_train)
    
    y_test_pred = fit.predict(X_test)
    acc = metrics.accuracy_score(y_test, y_test_pred)
    return acc

In [81]:
train_model('clf_tokens', False)

0.3670360511544471

In [82]:
train_model('clf_tokens', True)

0.3667467160465251

In [83]:
train_model('clf_tokens_5', False)

0.2545136276835831

In [84]:
train_model('clf_tokens_5', True)

0.25431109310803773