In [1]:
import string
import gzip
import ujson
import math

import numpy as np

from tqdm import tqdm
from collections import Counter
from glob import glob
from boltons.iterutils import chunked_iter
from itertools import islice, chain

import torch
from torch import nn, optim
from torchtext.vocab import Vocab
from torch.nn.utils import rnn
from torch.nn import functional as F
from sklearn.model_selection import train_test_split

from news_vec.cuda import itype, ftype
from news_vec import logger
from news_vec.utils import group_by_sizes

In [2]:
def read_json_lines(root, lower=True):
    """Read JSON corpus.

    Yields: Line
    """
    for path in glob('%s/*.gz' % root):
        with gzip.open(path) as fh:
            for line in fh:

                data = ujson.loads(line)

                title = data.pop('title')
                label = data.pop('label')

                yield Line(title, label, data, lower)

In [3]:
class Line:

    def __init__(self, text, label=None, metadata=None, lower=True):
        self.text = text.lower() if lower else text
        self.label = label
        self.metadata = metadata or {}

    def __repr__(self):

        pattern = '{cls_name}<{char_count} chars -> {label}>'

        return pattern.format(
            cls_name=self.__class__.__name__,
            char_count=len(self.text),
            label=self.label,
        )

In [4]:
class Corpus:

    def __init__(self, root, skim=None, lower=True):
        """Read lines.
        """
        logger.info('Parsing line corpus.')

        lines_iter = islice(read_json_lines(root, lower), skim)

        self.lines = list(tqdm(lines_iter))

    def __repr__(self):

        pattern = '{cls_name}<{line_count} lines>'

        return pattern.format(
            cls_name=self.__class__.__name__,
            line_count=len(self),
        )

    def __len__(self):
        return len(self.lines)

    def __iter__(self):
        return iter(self.lines)

    def char_counts(self):
        """Collect all char -> count.
        """
        logger.info('Gathering char counts.')

        counts = Counter()
        for line in tqdm(self):
            counts.update(list(line.text))

        return counts

    def label_counts(self):
        """Label -> count.
        """
        logger.info('Gathering label counts.')

        counts = Counter()
        for line in tqdm(self):
            counts[line.label] += 1

        return counts

    def labels(self):
        counts = self.label_counts()
        return [label for label, _ in counts.most_common()]

In [5]:
class CharEmbedding(nn.Embedding):

    def __init__(self, vocab, embed_dim=15):
        """Set vocab, map s->i.
        """
        self.vocab = vocab
        super().__init__(len(self.vocab), embed_dim)

    def chars_to_idxs(self, chars):
        """Map characters to embedding indexes.
        """
        idxs = [self.vocab.stoi[c] for c in chars]

        return torch.LongTensor(idxs).type(itype)

    def forward(self, texts):
        """Batch-embed token chars.

        Args:
            texts (list<str>)
        """
        sizes = [len(t) for t in texts]
        chars = list(chain(*texts))
        
        # Map chars -> indexes.
        x = torch.cat([self.chars_to_idxs(t) for t in chars])

        # Embed.
        x = super().forward(x)
        
        return group_by_sizes(x, sizes)

In [6]:
class LineEncoder(nn.Module):

    def __init__(self, input_size, hidden_size=1024, num_layers=2):
        """Initialize LSTM.
        """
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
        )

        self.dropout = nn.Dropout()

    @property
    def out_dim(self):
        return self.lstm.hidden_size * 2

    def forward(self, x):
        """Sort, pack, encode, reorder.

        Args:
            x (list<Tensor>): Variable-length embedding tensors.
        """
        sizes = list(map(len, x))

        # Pad + LSTM.
        x = rnn.pad_sequence(x, batch_first=True)
        x, _ = self.lstm(x)
        x = self.dropout(x)

        # Unpad.
        return [s[:size] for s, size in zip(x, sizes)]

In [7]:
class CharLM(nn.Module):

    def __init__(self, char_counts, lstm_dim=200, embed_dim=100):
        """Initialize encoders + clf.
        """
        super().__init__()

        self.vocab = Vocab(char_counts)

        self.embed_chars = CharEmbedding(self.vocab)

        self.encode_f = LineEncoder(self.embed_chars.embedding_dim, lstm_dim)
        self.encode_b = LineEncoder(self.embed_chars.embedding_dim, lstm_dim)

        self.merge = nn.Linear(lstm_dim*2, embed_dim)

        self.predict = nn.Sequential(
            nn.Linear(embed_dim, len(self.vocab)),
            nn.LogSoftmax(1),
        )

    def batch_iter(self, lines_iter, size=50):
        """Generate batches of line -> targets.
        """
        for lines in chunked_iter(lines_iter, size):
            
            yt_idx = [self.vocab.stoi[c] for line in lines for c in line.text]
            yt = torch.LongTensor(yt_idx).type(itype)

            yield lines, yt

    def encode(self, lines):
        """Embed lines.

        Args:
            lines (list<str>)
        """
        # Add start/end spaces.
        texts = [f' {line.text} ' for line in lines]
        
        x = self.embed_chars(texts)

        # Forward LSTM.
        xf = self.encode_f(x)

        # Backward LSTM.
        x_rev = [xi.flip(0) for xi in x]
        xb = self.encode_b(x_rev)
        xb = [xi.flip(0) for xi in xb]

        # Cat [forward n-1, backward n+1] states for each token.
        x = [
            torch.cat([xfi[:-2], xbi[2:]], dim=1)
            for xfi, xbi in zip(xf, xb)
        ]

        x = torch.cat(x, dim=0)
        x = self.merge(x)

        return x

In [8]:
class Trainer:

    def __init__(self, corpus_root, lr=1e-4, batch_size=50, test_size=10000,
        eval_every=100000, corpus_kwargs=None, model_kwargs=None):

        self.corpus = Corpus(corpus_root, **(corpus_kwargs or {}))

        char_counts = self.corpus.char_counts()

        self.model = CharLM(char_counts, **(model_kwargs or {}))

        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        self.batch_size = batch_size

        self.eval_every = eval_every

        self.train_lines, self.val_lines = train_test_split(
            self.corpus.lines, test_size=test_size)

        if torch.cuda.is_available():
            self.model.cuda()

    def train(self, epochs=10):
        """Train for N epochs.
        """
        for epoch in range(epochs):
            self.train_epoch(epoch)

    def train_epoch(self, epoch):

        logger.info('Epoch %d' % epoch)

        lines_iter = tqdm(self.train_lines)

        batches = self.model.batch_iter(lines_iter, self.batch_size)

        batch_losses = []
        eval_n = 0
        for lines, yt in batches:

            self.model.train()

            self.optimizer.zero_grad()

            embeds = self.model.encode(lines)
            yp = self.model.predict(embeds)

            loss = F.nll_loss(yp, yt)
            loss.backward()

            self.optimizer.step()

            batch_losses.append(loss.item())

            n = math.floor(lines_iter.n / self.eval_every)

            if n > eval_n:
                self.log_metrics(batch_losses)
                eval_n = n

        self.log_metrics(batch_losses)

    def log_metrics(self, batch_losses, n=100):
        logger.info('Train loss: %f' % np.mean(batch_losses[-n:]))
        self.log_val_metrics()

    def log_val_metrics(self):

        self.model.eval()

        lines_iter = tqdm(self.val_lines)

        batches = self.model.batch_iter(lines_iter, self.batch_size)

        losses = []
        for lines, yt in batches:

            embeds = self.model.encode(lines)
            yp = self.model.predict(embeds)

            losses.append(F.nll_loss(yp, yt).item())

        logger.info('Val loss: %f' % np.mean(losses))

In [9]:
t = Trainer('../data/b13-texts.json/', eval_every=1000, test_size=100, corpus_kwargs=dict(skim=10000))

2018-12-10 01:11:48,436 | INFO : Parsing line corpus.
10000it [00:00, 165472.75it/s]
2018-12-10 01:11:48,499 | INFO : Gathering char counts.
100%|██████████| 10000/10000 [00:00<00:00, 153541.90it/s]


In [None]:
t.train()

2018-12-10 01:11:52,118 | INFO : Epoch 0
 10%|█         | 1000/9900 [00:28<03:55, 37.77it/s]2018-12-10 01:12:21,916 | INFO : Train loss: 4.418758

  0%|          | 0/100 [00:00<?, ?it/s][A
 50%|█████     | 50/100 [00:00<00:00, 124.67it/s][A
100%|██████████| 100/100 [00:00<00:00, 115.26it/s][A
[A2018-12-10 01:12:22,833 | INFO : Val loss: 4.376865
 20%|██        | 2000/9900 [00:57<03:43, 35.28it/s]2018-12-10 01:12:51,306 | INFO : Train loss: 4.349924

  0%|          | 0/100 [00:00<?, ?it/s][A
 50%|█████     | 50/100 [00:00<00:00, 129.98it/s][A
100%|██████████| 100/100 [00:00<00:00, 122.14it/s][A
[A2018-12-10 01:12:52,165 | INFO : Val loss: 4.056062
 30%|███       | 3000/9900 [01:27<03:20, 34.45it/s]2018-12-10 01:13:21,246 | INFO : Train loss: 4.085155

  0%|          | 0/100 [00:00<?, ?it/s][A
 50%|█████     | 50/100 [00:00<00:00, 128.71it/s][A
100%|██████████| 100/100 [00:00<00:00, 121.93it/s][A
[A2018-12-10 01:13:22,102 | INFO : Val loss: 3.224514
 40%|████      | 4000/9900