In [217]:
import pandas as pd

from collections import UserDict, UserList, Counter
from tqdm import tqdm
from cached_property import cached_property
from itertools import islice, chain
from functools import lru_cache
from torch.utils.data import random_split

from news_vec.utils import read_json_gz_lines
from news_vec import logger

In [40]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import altair as alt
import seaborn as sns

mpl.style.use('seaborn-muted')
sns.set(style="whitegrid")

%matplotlib inline

In [213]:
class Corpus:

    def __init__(self, headline_root, skim=None):
        """Read headline df.
        """
        logger.info('Reading headlines.')
        
        lines = islice(read_json_gz_lines(headline_root), skim)
        self.df = pd.DataFrame(list(tqdm(lines)))

    def __repr__(self):

        pattern = '{cls_name}<{hl_count} headlines>'

        return pattern.format(
            cls_name=self.__class__.__name__,
            hl_count=len(self.df),
        )
    
    @cached_property
    def min_db_count(self):
        return self.df.groupby(['domain', 'ts_bucket']).size().min()
    
    def sample_all_vs_all(self):
        return self.df.groupby(['domain', 'ts_bucket']).apply(lambda x: x.sample(self.min_db_count))
    
    @lru_cache(None)
    def filter_ab(self, d1, d2):
        return self.df[self.df.domain.isin([d1, d2])].groupby(['domain', 'ts_bucket'])
    
    def sample_ab(self, d1, d2):
        return self.filter_ab(d1, d2).apply(lambda x: x.sample(self.min_db_count))
    
    @lru_cache(None)
    def filter_ab_ts(self, d1, d2, bucket):
        return self.df[self.df.domain.isin([d1, d2])&(self.df.ts_bucket==bucket)].groupby('domain')
    
    def sample_ab_ts(self, d1, d2, bucket):
        return self.filter_ab_ts(d1, d2, bucket).apply(lambda x: x.sample(self.min_db_count))

In [242]:
class HeadlineDataset(UserList):
    
    @classmethod
    def from_df(cls, df, label_col='domain', **kwargs):
        pairs = [(d, d[label_col]) for d in df.to_dict('records')]
        return cls(pairs, **kwargs)

    def __init__(self, pairs, test_frac=0.1):
        """Set train/val/test splits.
        """
        test_size = round(len(pairs) * test_frac)
        train_size = len(pairs) - (test_size * 2)

        sizes = (train_size, test_size, test_size)
        self.train, self.val, self.test = random_split(pairs, sizes)

    def __repr__(self):

        pattern = '{cls_name}<{train_size}/{val_size}/{test_size}>'

        return pattern.format(
            cls_name=self.__class__.__name__,
            train_size=len(self.train),
            val_size=len(self.val),
            test_size=len(self.test),
        )
    
    def __iter__(self):
        return chain(self.train, self.val, self.test)
    
    def token_counts(self):
        """Collect all token -> count.
        """
        logger.info('Gathering token counts.')

        counts = Counter()
        for hl, _ in tqdm(self):
            counts.update(hl['clf_tokens'])

        return counts

    def label_counts(self):
        """Label -> count.
        """
        logger.info('Gathering label counts.')

        counts = Counter()
        for _, label in tqdm(self):
            counts[label] += 1

        return counts

    def labels(self):
        counts = self.label_counts()
        return [label for label, _ in counts.most_common()]

In [243]:
c = Corpus('../data/clf-headlines.json/', 100000)

2018-12-27 12:26:15,532 | INFO : Reading headlines.
100000it [00:01, 98935.26it/s]


In [244]:
c.min_db_count

198

In [248]:
ds = HeadlineDataset.from_df(c.sample_ab('nytimes.com', 'apnews.com'))

In [249]:
ds.label_counts()

2018-12-27 12:26:34,935 | INFO : Gathering label counts.
3960it [00:00, 234084.19it/s]


Counter({'nytimes.com': 1980, 'apnews.com': 1980})

In [250]:
%time c.sample_ab_ts('nytimes.com', 'apnews.com', 0).head(10)

CPU times: user 7.22 ms, sys: 1.3 ms, total: 8.53 ms
Wall time: 7.39 ms


Unnamed: 0_level_0,Unnamed: 1_level_0,article_id,clf_tokens,domain,impressions,tokens,ts_bucket
domain,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
apnews.com,46481,996432433100,"[car, bombing, near, syria, town, captured, fr...",apnews.com,34652,"[Car, bombing, near, Syria, town, captured, fr...",0
apnews.com,21459,893353200152,"[#, die, amid, apparent, winter, tornadoes, ot...",apnews.com,11103,"[18, die, amid, apparent, winter, tornadoes, ,...",0
apnews.com,25125,180388664250,"[black, lawmakers, dismayed, by, trump, s, inv...",apnews.com,22107,"[Black, lawmakers, dismayed, by, Trump, 's, in...",0
apnews.com,33665,816043792265,"[note, in, recycler, s, trash, helps, cops, id...",apnews.com,151638,"[Note, in, recycler, 's, trash, helps, cops, I...",0
apnews.com,2249,618475316784,"[worker, pinned, under, large, pile, of, steel...",apnews.com,9967369,"[Worker, pinned, under, large, pile, of, steel...",0
apnews.com,98076,17179889639,"[washington, state, sues, trump, over, immigra...",apnews.com,165219,"[Washington, state, sues, Trump, over, immigra...",0
apnews.com,95150,919123035430,"[russia, says, it, starts, syrian, drawdown, w...",apnews.com,34509,"[Russia, says, it, starts, Syrian, drawdown, w...",0
apnews.com,64631,919123003883,"[pence, fought, against, releasing, records, a...",apnews.com,10391,"[Pence, fought, against, releasing, records, a...",0
apnews.com,72964,1503238576337,"[harvard, honors, rihanna, s, philanthropy]",apnews.com,1057945,"[Bright, like, a, diamond, :, Harvard, honors,...",0
apnews.com,1737,489626283535,"[oldest, aardvark, in, captivity, in, us, dies...",apnews.com,15030,"[Philly, Zoo, :, Oldest, aardvark, in, captivi...",0


In [251]:
%time c.sample_ab('nytimes.com', 'apnews.com').head(10)

CPU times: user 31.8 ms, sys: 1.78 ms, total: 33.6 ms
Wall time: 31.9 ms


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,article_id,clf_tokens,domain,impressions,tokens,ts_bucket
domain,ts_bucket,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
apnews.com,0,84216,1262720389194,"[signs, from, women, s, marches, being, saved]",apnews.com,11354,"[Sign, of, the, times, :, Signs, from, women, ...",0
apnews.com,0,27969,953482785636,"[stolen, vehicle, s, driver, in, labor, kids, ...",apnews.com,17383,"[Police, :, Stolen, vehicle, 's, driver, in, l...",0
apnews.com,0,90903,1451698962738,"[nasa, displays, apollo, capsule, hatch, #, ye...",apnews.com,12162,"[NASA, displays, Apollo, capsule, hatch, 50, y...",0
apnews.com,0,32121,386547072225,"[#, year, old, gorilla, #st, to, be, born, in,...",apnews.com,13039,"[60, year, old, gorilla, ,, 1st, to, be, born,...",0
apnews.com,0,97605,1606317796549,"[ford, invests, $, #b, in, robotics, startup, ...",apnews.com,26576,"[Ford, invests, $, 1B, in, robotics, startup, ...",0
apnews.com,0,82888,876173341987,"[lebron, james, named, ap, male, athlete, of, ...",apnews.com,109701,"[The, king, of, 2016, :, LeBron, James, named,...",0
apnews.com,0,47119,1176821067769,"[plane, carrying, #, people, hits, australian,...",apnews.com,31846,"[Plane, carrying, 5, people, hits, Australian,...",0
apnews.com,0,78551,1348619747401,"[before, lion, the, story, behind, an, unlikel...",apnews.com,35913,"[Before, Lion, ,, the, story, behind, an, unli...",0
apnews.com,0,41619,1357209692577,"[fbi, deletes, details, about, hacking, effort...",apnews.com,48984,"[FBI, deletes, details, about, hacking, effort...",0
apnews.com,0,37376,146028913919,"[reynolds, fisher, laid, to, rest, at, hollywo...",apnews.com,123655,"[Reynolds, ,, Fisher, laid, to, rest, at, Holl...",0
