# Training of SST Model

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext as tt

from tqdm import tqdm
from pytorch_extras import RAdam, SingleCycleScheduler
from pytorch_transformers import GPT2Model, GPT2Tokenizer
from deps.torch_train_test_loop.torch_train_test_loop import LoopComponent, TrainTestLoop

from heinsen_routing import Routing

In [2]:
DEVICE = 'cuda:0'

## Load pretrained transformer and tokenizer

In [3]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large', do_lower_case=False)
lang_model = GPT2Model.from_pretrained('gpt2-large', output_hidden_states=True, output_attentions=False)
lang_model.cuda(device=DEVICE)
lang_model.eval()
print('Pretrained transformer loaded.')

Pretrained transformer loaded.


In [4]:
def tokenized_texts_to_embs(tokenized_texts, pad_token=tokenizer.eos_token):
    tokenized_texts = [[*tok_seq, tokenizer.eos_token] for tok_seq in tokenized_texts]
    lengths = [len(tok_seq) for tok_seq in tokenized_texts]

    max_length = max(lengths)
    input_toks = [t + [pad_token] * (max_length - l) for t, l in zip(tokenized_texts, lengths)]

    input_ids = [tokenizer.convert_tokens_to_ids(tok_seq) for tok_seq in input_toks]
    input_ids = torch.tensor(input_ids).to(device=DEVICE)

    mask = [[1.0] * length + [0.0] * (max_length - length) for length in lengths]
    mask = torch.tensor(mask).to(device=DEVICE)  # [batch sz, num toks]
    
    with torch.no_grad():
        outputs = lang_model(input_ids=input_ids)
        embs = torch.stack(outputs[-1], -2)  # [batch sz, n toks, n layers, d emb]

    return mask, embs

## Prepare datasets

In [5]:
fine_grained = False  # set to False for binary task
train_on_subtrees = True  # set to False to train only on root sentences

In [6]:
tt.datasets.SST.download(root='.data')  # download if not present

TEXT = tt.data.RawField(preprocessing=tokenizer.tokenize, postprocessing=tokenized_texts_to_embs, is_target=False)
LABEL = tt.data.LabelField()

class SSTFilter():

    def __init__(self, remove_dupes=False, remove_neutral=False, min_n_toks=0):
        self.remove_dupes, self.remove_neutral, self.min_n_toks = (remove_dupes, remove_neutral, min_n_toks)
        self.prev_seen = {}

    def __call__(self, sample):
        if self.remove_neutral and (sample.label == 'neutral'):
            return False
        if len(sample.text) < self.min_n_toks:
            return False
        hashable = ''.join(sample.text)
        if self.remove_dupes and (hashable in self.prev_seen):
            return False
        self.prev_seen[hashable] = True
        return True

trn_ds = tt.datasets.SST(
    '.data/sst/trees/train.txt', TEXT, LABEL, fine_grained=fine_grained, subtrees=train_on_subtrees,
    filter_pred=SSTFilter(remove_neutral=False if fine_grained else True, remove_dupes=True))

val_ds = tt.datasets.SST(
    '.data/sst/trees/dev.txt', TEXT, LABEL, fine_grained=fine_grained, subtrees=False,
    filter_pred=SSTFilter(remove_neutral=False if fine_grained else True, remove_dupes=False))

tst_ds = tt.datasets.SST(
    '.data/sst/trees/test.txt', TEXT, LABEL, fine_grained=fine_grained, subtrees=False,
    filter_pred=SSTFilter(remove_neutral=False if fine_grained else True, remove_dupes=False))

LABEL.build_vocab(trn_ds)

print('Datasets ready.\nLabels:', *enumerate(LABEL.vocab.stoi))
print('Number of samples: {:,} trn, {:,} val, {:,} tst.'.format(len(trn_ds), len(val_ds), len(tst_ds)))

Datasets ready.
Labels: (0, 'positive') (1, 'negative')
Number of samples: 77,616 trn, 872 val, 1,821 tst.


## Define model

In [7]:
class SequenceRoutingClassifier(nn.Module):

    def __init__(self, d_depth, d_emb, d_caps, n_caps, **kwargs):
        super().__init__()
        zipped_cap_args = zip(d_caps[1:], n_caps[1:], d_caps[:-1], n_caps[:-1])
        self.depth_emb = nn.Parameter(torch.zeros(d_depth, d_emb))
        self.routings = nn.Sequential(
            Routing(1, d_out=d_caps[0], n_out=n_caps[0], d_inp=d_emb, **kwargs),
            *[Routing(1, *cap_args, **kwargs) for cap_args in zipped_cap_args],
        )

    def forward(self, mask, embs):
        embs = embs + self.depth_emb  # [bs, n toks, d_depth, d_emb]

        a = torch.log(mask / (1.0 - mask))  # -inf to inf (PyTorch handles this nicely)
        a = a.unsqueeze(-1).expand(-1, -1, embs.shape[-2]).contiguous()  # [bs, n toks, d_depth]
        a = a.view(a.shape[0], -1)  # [bs, n toks * d_depth]
        mu = embs.view(embs.shape[0], -1, 1, embs.shape[-1])  # [bs, n toks * d_depth, 1, d_depth]

        for routing in self.routings:
            a, mu, sig2 = routing(a, mu)

        return a, mu, sig2

## Training Pipeline

In [8]:
class LoopMain(LoopComponent):

    def __init__(self, n_classes, device, min_prob=0.0, pct_warmup=0.1, mixup=(0.2, 0.2)):
        self.n_classes, self.device, self.min_prob, self.pct_warmup = (n_classes, device, min_prob, pct_warmup)
        self.mixup_dist = torch.distributions.Beta(torch.tensor(mixup[0]), torch.tensor(mixup[1]))

        self.onehot = torch.eye(self.n_classes, device=self.device)
        self.smooth = (self.onehot * (1.0 - (self.min_prob * self.n_classes))) + self.min_prob

    def on_train_begin(self, loop):
        n_iters = len(loop.train_data) * loop.n_epochs
        loop.optimizer = RAdam([{ 'params': loop.model.parameters(), 'lr': 5e-4 }])
        loop.scheduler = SingleCycleScheduler(loop.optimizer, loop.n_optim_steps, frac=self.pct_warmup, min_lr=1e-5)
        
    def on_grads_reset(self, loop):
        loop.model.zero_grad()

    def on_forward_pass(self, loop):
        model, batch = (loop.model, loop.batch)
        mask, embs = batch.text

        if loop.is_training:
            r = self.mixup_dist.sample([len(mask)]).to(device=mask.device)
            idx = torch.randperm(len(mask))
            mask = mask.lerp(mask[idx], r[:, None])
            embs = embs.lerp(embs[idx], r[:, None, None, None])
            target_probs = self.smooth[batch.label]
            target_probs = target_probs.lerp(target_probs[idx], r[:, None])
        else:
            target_probs = self.onehot[batch.label]

        # Classify inputs.
        pred_scores, _, _ = model(mask, embs)
        _, pred_ids = pred_scores.max(-1)
        accuracy = (pred_ids == batch.label).float().mean()

        # Save results as loop attrs.
        loop.pred_scores, loop.target_probs, loop.accuracy = (pred_scores, target_probs, accuracy)

    def on_loss_compute(self, loop):
        losses = -loop.target_probs * F.log_softmax(loop.pred_scores, dim=-1)  # smooth cross entropy
        loop.loss = losses.sum(dim=-1).mean()  # sum across classes, then mean of batch

    def on_backward_pass(self, loop):
        loop.loss.backward()

    def on_optim_step(self, loop):
        loop.optimizer.step()
        loop.scheduler.step()

In [9]:
class LoopStats(LoopComponent):

    def __init__(self):
        self.data = []

    def on_batch_end(self, loop):
        self.data.append({
            'n_samples': len(loop.batch),
            'n_toks': loop.batch.text[1].shape[1],
            'epoch_desc': loop.epoch_desc,
            'epoch_num': loop.epoch_num,
            'epoch_frac': loop.epoch_num + loop.batch_num / loop.n_batches,
            'accuracy': loop.accuracy.item(),
            'loss': loop.loss.item(),
            'lr': loop.optimizer.param_groups[0]['lr'],
            'momentum': loop.optimizer.param_groups[0]['betas'][0],
        })

    def plot(self, item_name='loss', epoch_desc='train', groupby='epoch_frac', **kwargs):
        df = pd.DataFrame(self.data)
        df = df[df.epoch_desc == epoch_desc]
        df[item_name] = df[item_name] * df.n_samples
        series = df.groupby(groupby)[item_name].sum() / df.groupby(groupby).n_samples.sum()
        series.plot(label=f"{epoch_desc}_{item_name}", **kwargs)

In [10]:
class LoopProgressBar(LoopComponent):

    def __init__(self, item_names=['loss', 'accuracy'], show_train_stats=False):
        self.item_names, self.show_train_stats = (item_names, show_train_stats)

    def on_epoch_begin(self, loop):
        self.total, self.count = ({ name: 0.0 for name in self.item_names }, 0)
        self.pbar = tqdm(total=loop.n_batches, desc=f"{loop.epoch_desc} epoch {loop.epoch_num}")

    def on_batch_end(self, loop):
        n = len(loop.batch)
        self.count += n
        for name in self.item_names:
            self.total[name] += getattr(loop, name).item() * n
        self.pbar.update(1)

        if (not loop.is_training) or (loop.is_training and self.show_train_stats):
            self.pbar.set_postfix(self.mean)

    def on_epoch_end(self, loop):
        self.pbar.close()

    def on_train_end(self, loop):
        self.pbar.close()  # in case of early stop

    @property
    def mean(self): return { f'mean_{name}': self.total[name] / self.count for name in self.item_names }

## Initialize and train model

In [11]:
# Seed RNG for replicability.
torch.manual_seed(1)

# Make iterators for each split, with random shuffling.
trn_itr, val_itr, tst_itr = tt.data.BucketIterator.splits(
    (trn_ds, val_ds, tst_ds),
    shuffle=True,
    batch_size=4,
    device='cuda:0')

# Initialize model.
model = SequenceRoutingClassifier(
    d_depth=lang_model.config.n_layer + 1,
    d_emb=lang_model.config.hidden_size,
    d_caps=[2, 2, 2],
    n_caps=[64, 64, len(LABEL.vocab)],
)
model = model.cuda(device=DEVICE)

print('Total number of parameters: {:,}.\n'.format(sum(int(np.prod(p.shape)) for p in model.parameters())))

Total number of parameters: 245,248.



In [None]:
loop_components = [
    LoopMain(len(LABEL.vocab), DEVICE, pct_warmup=0.1, mixup=(0.8, 0.8)),
    LoopStats(),
    LoopProgressBar(show_train_stats=True),
]
loop = TrainTestLoop(model, loop_components, trn_itr, val_itr)
loop.train(n_epochs=5)

train epoch 0: 100%|██████████| 19404/19404 [45:47<00:00,  7.06it/s, mean_loss=0.427, mean_accuracy=0.743] 
valid epoch 0: 100%|██████████| 218/218 [00:20<00:00, 10.74it/s, mean_loss=0.229, mean_accuracy=0.914]
train epoch 1: 100%|██████████| 19404/19404 [45:55<00:00,  7.04it/s, mean_loss=0.367, mean_accuracy=0.768] 
valid epoch 1: 100%|██████████| 218/218 [00:20<00:00, 10.71it/s, mean_loss=0.2, mean_accuracy=0.924]  
train epoch 2: 100%|██████████| 19404/19404 [46:01<00:00,  7.03it/s, mean_loss=0.343, mean_accuracy=0.776] 
valid epoch 2: 100%|██████████| 218/218 [00:20<00:00, 10.60it/s, mean_loss=0.198, mean_accuracy=0.922]
train epoch 3:   2%|▏         | 462/19404 [01:04<43:17,  7.29it/s, mean_loss=0.317, mean_accuracy=0.79]   

In [None]:
loop.components[-1].pbar.close()

# Visualize training stats

In [None]:
import pandas as pd
fig, axes = plt.subplots(ncols=2, figsize=(16, 5))
stats = loop.components[-2]
for axis, item_name in zip(axes, ['loss', 'accuracy']):
    #stats.plot(item_name, 'train', ax=axis, legend=True, alpha=0.5, color='red')
    stats.plot(item_name, 'valid', ax=axis, legend=True, alpha=0.5, color='black')

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(16, 5))
stats = loop.components[-2]
for axis, item_name in zip(axes, ['loss', 'accuracy']):
    stats.plot(item_name, 'train', groupby='epoch_num', ax=axis, legend=True, alpha=0.5, color='red')
    stats.plot(item_name, 'valid', groupby='epoch_num', ax=axis, legend=True, alpha=0.5, color='black')

## Test

In [None]:
loop.test(tst_itr)