In [120]:
import torch
import ujson
import gzip
import string
import random
import sys

import numpy as np

from glob import glob
from tqdm import tqdm
from collections import Counter, defaultdict
from itertools import chain, islice

from torch import nn, optim
from torch.nn import functional as F
from torch.nn.utils import rnn
from torch.utils.data import random_split
from torch.utils.data import DataLoader

from headline_parser import parse_headline

In [2]:
def read_json_gz_lines(root):
    """Read JSON corpus.

    Yields: dict
    """
    for path in glob('%s/*.gz' % root):
        with gzip.open(path) as fh:
            for line in fh:
                yield ujson.loads(line)

In [3]:
def group_by_sizes(L, sizes):
    """Given a flat list and a list of sizes that sum to the length of the
    list, group the list into sublists with corresponding sizes.

    Args:
        L (list)
        sizes (list<int>)

    Returns: list<list>
    """
    parts = []

    total = 0
    for s in sizes:
        parts.append(L[total:total+s])
        total += s

    return parts

In [115]:
def print_replace(msg):
    sys.stdout.write(f'\r{msg}')
    sys.stdout.flush()

In [4]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [129]:
class Corpus:
    
    @classmethod
    def from_dump(cls, root, skim=None):
        
        rows_iter = islice(read_json_gz_lines(root), skim)
        
        # Label -> [line1, line2, ...]
        # TODO: Parallelize.
        groups = defaultdict(list)
        for row in tqdm(rows_iter):
            doc = parse_headline(row['title'])
            spans = [s.text for s in doc._.spans]
            for span in spans:
                if span:
                    groups[row['domain']].append(span)
                
        return cls(groups)
    
    def __init__(self, groups, test_frac=0.1):
        self.groups = groups
        self.test_frac = test_frac
        self.set_splits()
        
    def labels(self):
        return list(self.groups)
        
    def min_label_count(self):
        return min([len(v) for v in self.groups.values()])
    
    def set_splits(self):
        
        min_count = self.min_label_count()
        
        pairs = list(chain(*[
            [(line, label) for line in random.sample(lines, min_count)]
            for label, lines in self.groups.items()
        ]))
        
        test_size = round(len(pairs) * self.test_frac)
        train_size = len(pairs) - (test_size * 2)
        sizes = (train_size, test_size, test_size)
        
        self.train, self.val, self.test = random_split(pairs, sizes)

In [130]:
class CharEmbedding(nn.Embedding):

    def __init__(self, embed_dim=15):
        """Set vocab, map s->i.
        """
        self.vocab = (
            string.ascii_letters +
            string.digits +
            string.punctuation
        )

        # <UNK> -> 1
        self._ctoi = {s: i+1 for i, s in enumerate(self.vocab)}

        super().__init__(len(self.vocab)+1, embed_dim)
        
    @property
    def out_dim(self):
        return self.weight.shape[1]

    def ctoi(self, c):
        return self._ctoi.get(c, 0)

    def chars_to_idxs(self, chars):
        """Map characters to embedding indexes.
        """
        idxs = [self.ctoi(c) for c in chars]

        return torch.LongTensor(idxs).to(DEVICE)

    def forward(self, chars):
        """Batch-embed chars.

        Args:
            tokens (list<str>)
        """
        idxs = [self.ctoi(c) for c in chars]
        x = torch.LongTensor(idxs).to(DEVICE)
        
        return super().forward(x)

In [131]:
class SpanEncoder(nn.Module):

    def __init__(self, input_size, hidden_size=100, num_layers=1):
        """Initialize LSTM.
        """
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
        )

        self.out_dim = self.lstm.hidden_size * 2

    def forward(self, xs):
        """Sort, pack, encode, reorder.

        Args:
            xs (list<Tensor>): Variable-length embedding tensors.

        Returns:
            x (Tensor): F/B hidden tops.
        """
        sizes = [len(x) for x in xs]

        # Indexes to sort descending.
        sort_idxs = np.argsort(sizes)[::-1]

        # Indexes to restore original order.
        unsort_idxs = torch.from_numpy(np.argsort(sort_idxs)).to(DEVICE)

        # Sort by size descending.
        xs = [xs[i] for i in sort_idxs]

        # Pad + pack, LSTM.
        x = rnn.pack_sequence(xs)
        _, (hn, _) = self.lstm(x)

        # Cat forward + backward hidden layers.
        x = torch.cat([hn[0,:,:], hn[1,:,:]], dim=1)
        x = x[unsort_idxs]

        return x

In [132]:
class Classifier(nn.Module):

    def __init__(self, labels, hidden_dim=100):
        """Initialize encoders + clf.
        """
        super().__init__()

        self.labels = labels
        self.ltoi = {label: i for i, label in enumerate(labels)}

        self.embed_chars = CharEmbedding()
        self.encode_spans = SpanEncoder(self.embed_chars.out_dim)

        self.dropout = nn.Dropout()

        self.predict = nn.Sequential(
            nn.Linear(self.encode_spans.out_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, len(labels)),
            nn.LogSoftmax(1),
        )

    def forward(self, spans):
        """Predict outlet.
        """
        sizes = [len(s) for s in spans]

        # Embed chars, regroup by line.
        x = self.embed_chars(list(chain(*spans)))
        xs = group_by_sizes(x, sizes)

        # Embed spans.
        x = self.encode_spans(xs)
        x = self.dropout(x)

        return self.predict(x)

    def collate_batch(self, batch):
        """Labels -> indexes.
        """
        lines, labels = list(zip(*batch))

        yt_idx = [self.ltoi[label] for label in labels]
        yt = torch.LongTensor(yt_idx).to(DEVICE)

        return lines, yt

In [133]:
class ProgressDataLoader(DataLoader):

    def __iter__(self):
        """Track # generated pairs.
        """
        self.n = 0

        for x, y in super().__iter__():
            self.n += len(x)
            print_replace(f'{self.n}/{len(self.dataset)}\r')
            yield x, y

In [162]:
corpus = Corpus.from_dump('data/cleaning-titles.json/', 10000)

10000it [00:04, 2354.99it/s]


In [163]:
clf = Classifier(c.labels())

In [164]:
optimizer = optim.Adam(clf.parameters(), lr=1e-4)
loss_func = nn.NLLLoss()

In [184]:
def train(split):

    loader = DataLoader(
        split,
        collate_fn=clf.collate_batch,
        batch_size=50,
    )

    losses = []
    for spans, yt in tqdm(loader):

        clf.train()
        optimizer.zero_grad()

        yp = clf(spans)

        loss = loss_func(yp, yt)
        loss.backward()

        optimizer.step()

        losses.append(loss.item())
        
    return losses

In [178]:
def predict(split):
    
    clf.eval()

    loader = DataLoader(
        split,
        collate_fn=clf.collate_batch,
        batch_size=50,
    )
    
    yt, yp = [], []
    for lines, yti in loader:
        yp += clf(lines).tolist()
        yt += yti.tolist()
        
    yt = torch.LongTensor(yt)
    yp = torch.FloatTensor(yp)
        
    return yt, yp

In [181]:
def evaluate(split):
    yt, yp = predict(split)
    return loss_func(yp, yt)

In [186]:
for _ in range(10):
    train(corpus.train)
    print(evaluate(corpus.val))

100%|██████████| 53/53 [00:08<00:00,  6.64it/s]
  2%|▏         | 1/53 [00:00<00:10,  5.20it/s]

tensor(2.3036)


100%|██████████| 53/53 [00:08<00:00,  6.75it/s]
  2%|▏         | 1/53 [00:00<00:08,  5.94it/s]

tensor(2.2970)


100%|██████████| 53/53 [00:08<00:00,  6.68it/s]
  2%|▏         | 1/53 [00:00<00:08,  6.04it/s]

tensor(2.2686)


100%|██████████| 53/53 [00:08<00:00,  6.66it/s]
  2%|▏         | 1/53 [00:00<00:08,  5.95it/s]

tensor(2.2712)


100%|██████████| 53/53 [00:08<00:00,  6.63it/s]
  2%|▏         | 1/53 [00:00<00:08,  5.99it/s]

tensor(2.2896)


100%|██████████| 53/53 [00:08<00:00,  6.49it/s]
  2%|▏         | 1/53 [00:00<00:08,  5.90it/s]

tensor(2.2774)


100%|██████████| 53/53 [00:08<00:00,  6.47it/s]
  2%|▏         | 1/53 [00:00<00:08,  5.99it/s]

tensor(2.2643)


100%|██████████| 53/53 [00:08<00:00,  6.44it/s]
  2%|▏         | 1/53 [00:00<00:09,  5.75it/s]

tensor(2.2695)


100%|██████████| 53/53 [00:08<00:00,  6.59it/s]
  2%|▏         | 1/53 [00:00<00:09,  5.69it/s]

tensor(2.2809)


100%|██████████| 53/53 [00:08<00:00,  6.58it/s]


tensor(2.2402)


In [185]:
evaluate(corpus.val)

tensor(2.2858)

In [196]:
clf(['opinion']).exp().argmax()

tensor(6)

In [197]:
clf.labels[6]

'cnn.com'