## Preparation

Using with Google Colaboratory

In [None]:
!pip install tensorboardX
!git clone https://github.com/NVIDIA/apex
!pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" apex/

In [None]:
from google.colab import drive
drive.mount('/gdrive')
!cp '/gdrive/My Drive/data.7z' ./
!cp '/gdrive/My Drive/address/simhei.ttf' /usr/share/fonts/
!7z x data.7z
!rm -f data.7z

### install ParlAI

In [None]:
!git clone https://github.com/facebookresearch/ParlAI.git ./ParlAI
!python ./ParlAI/setup.py develop

### install Quasi-hyperbolic optimizers

In [None]:
!pip install git+https://github.com/facebookresearch/qhoptim.git

### imports

In [1]:
%%file common.py
import os
import argparse
from functools import reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--rank", type=int, default=0)
args = parser.parse_known_args()[0]
def opt():pass
if torch.cuda.is_available():
  opt.dtype = torch.half
  opt.device = torch.device('cuda:{}'.format(args.local_rank))
  torch.cuda.set_device(args.local_rank)
  opt.cuda = True
  from apex import amp
else:
  opt.device = torch.device('cpu')
  opt.dtype = torch.float
  opt.cuda = False
  num_threads = torch.multiprocessing.cpu_count() - 1
  if num_threads > 1:
    torch.set_num_threads(num_threads)
  amp = None
print('Using device ' + str(opt.device))
print('Using default dtype ' + str(opt.dtype))

Overwriting common.py


In [1]:
%matplotlib inline
from common import *

Using device cpu
Using default dtype torch.float32


## data.py

In [None]:
%%file data.py
from common import *
from torch.utils.data import Dataset, DataLoader

dataLength = {'train': 4096, 'val': 256, 'test': 256}

class Data(Dataset):
    def __init__(self, path):
        super(Data, self).__init__()
        l = dataLength[path]
        self.lens = torch.randint(4, (l,)) + 1
        self.mask = torch.zeros((l, 5), dtype=torch.uint8)
        for i in range(l):
            self.mask[i, :self.lens[i]].fill_(1)
        self.data = torch.rand((l, 5)) * self.mask.float()
        self.count = l
    def __len__(self):
        return self.count
    # input, label, length, mask
    def __getitem__(self, ind):
        x = self.data[ind]
        return x, x.sum(), self.lens[ind], self.mask[ind]

newLoader = lambda path, *args, **kwargs: DataLoader(Data(path), *args, **kwargs)

## vocab.py

In [None]:
%%file vocab.py
import os
import torch
from torch.nn.utils.rnn import pad_sequence
vocabPath = './char.txt'

def getBatch(data):
  x = pad_sequence([torch.tensor([1] + [(vocabIndex[t] if t in vocabSet else 0) for t in s] + [0], dtype=torch.long) for s in data])
  l = [len(s) + 1 for s in data]
  mask = torch.ones_like(x)
  for i, t in enumerate(l):
    mask[t:, i].fill_(0)
  return x, l, mask

def initial(path):
  global vocab, vocabSet, vocabIndex
  with open(path, 'r', encoding='utf-8') as f:
    vocab = ['', ''] + f.read().split('\0')
  vocabSet = set(vocab)
  vocabIndex = {}
  for i, w in enumerate(vocab):
    vocabIndex[w] = i
  return vocab

vocab = []
if os.path.exists(vocabPath):
  initial(vocabPath)

In [None]:
%%file -a train.py
word2vecPath = './vectors.pth'
stateDictPath = './net.init.pth'
fontPath = '/usr/share/fonts/simhei.ttf'
vocabPath = './char.txt'
#fontPath = '/usr/share/fonts/wqy-microhei/wqy-microhei.ttc'
#fontPath = 'C:/Windows/Fonts/simhei.ttf'
from vocab import vocab, initial
initial(vocabPath)

## model.py

In [None]:
%%file model.py
from common import *

Zero = torch.tensor(0.)
maxLen = 5

class Model(nn.Module):
    def __init__(self, opt):
        super(Model, self).__init__()
        self.device = opt.device
        self.dtype = opt.dtype
        self.edim = opt.edim
        self.dropout = nn.Dropout(opt.dropout)
        self.to(dtype=opt.dtype, device=opt.device)
        self.f0 = nn.Linear(1, opt.edim, bias=True)
        self.act0 = nn.LeakyReLU(.1)
        self.norm = nn.BatchNorm1d(opt.edim * maxLen, affine=False)
        self.f1 = nn.Linear(opt.edim * maxLen, 1, bias=True)
        self.act1 = torch.tanh

    def forward(self, x, mask, *_):
        bsz, l = x.shape
        e = self.dropout(x).view(bsz, l, 1)
        mask = mask.to(e.dtype)
        x1 = self.act0(self.f0(e)) * mask.view(bsz, l, 1)
        x2 = self.norm(x1.view(bsz, -1)).view(bsz, l, -1) * mask.view(bsz, l, 1)
        return self.act1(self.f1(x2.view(bsz, -1)).squeeze(-1)), Zero.to(self.device), x1

predict = lambda x: x

## train.py

In [4]:
%%file train.py
import torch.optim as optim
import numpy as np
from common import *
from data import newLoader
from model import Model, predict
from option import option
if amp:
    from apex.optimizers import FusedAdam
getNelement = lambda model: sum(map(lambda p: p.nelement(), model.parameters()))
l1Reg = lambda acc, cur: acc + cur.abs().sum(dtype=torch.float)
l2Reg = lambda acc, cur: acc + (cur * cur).sum(dtype=torch.float)
nan = torch.tensor(float('nan'), device=opt.device)

opt.batchsize = 1
opt.epochs = 1
opt.maxgrad = 1. # max gradient
opt.dropout = 0
opt.sdt = 0.001 # initial learning rate
opt.sdt_decay_step = 10 # how often to reduce learning rate
opt.criterion = lambda y, out, mask: F.mse_loss(out, y) # criterion for evaluation
opt.loss = lambda opt, model, y, out, *args: F.mse_loss(out, y) # criterion for loss function
opt.newOptimizer = (lambda opt, params, _: FusedAdam(params, lr=opt.sdt)) if amp else lambda opt, params, eps: optim.Adam(params, lr=opt.sdt, amsgrad=True, eps=eps)
opt.writer = 0 # TensorBoard writer
opt.drawVars = 0
opt.reset_parameters = 0
opt.__dict__.update(option)

def initParameters(opt, model):
    for m in model.modules():
        if hasattr(m, 'bias') and isinstance(m.bias, torch.Tensor):
            nn.init.constant_(m.bias, 0)
        if isinstance(m, nn.PReLU):
            nn.init.constant_(next(m.parameters()), 1)
        if opt.reset_parameters:
            opt.reset_parameters()
    if hasattr(model, 'embedding') and isinstance(model.embedding, nn.Embedding):
        model.embedding.weight.data[2:] = torch.load(word2vecPath)

getParameters = (lambda opt, _: amp.master_params(opt.optimizer)) if amp else lambda _, model: model.parameters()
backward = lambda loss, _: loss.backward()
if torch.cuda.is_available():
    def backward(loss, opt):
        with amp.scale_loss(loss, opt.optimizer) as scaled_loss:
            scaled_loss.backward()

def trainStep(opt, model, x, y, length, mask):
    opt.optimizer.zero_grad()
    x = x.to(opt.device, non_blocking=True)
    mask = mask.to(opt.device, non_blocking=True)
    label = y.to(opt.device, dtype=torch.float, non_blocking=True)
    loss = opt.loss(opt, model, label, *model(x, mask))
    if torch.allclose(loss, nan, equal_nan=True):
        raise Exception('Loss returns NaN')
    backward(loss, opt)
    nn.utils.clip_grad_value_(getParameters(opt, model), opt.maxgrad)
    opt.optimizer.step()
    return float(loss)

def evaluateStep(opt, model, x, y, _, mask):
    mask = mask.to(opt.device, non_blocking=True)
    out, *others = model(x.to(opt.device, non_blocking=True), mask)
    pred = predict(out)
    if isinstance(pred, torch.Tensor):
        y = y.to(pred)
    missed = opt.criterion(y, pred, mask)
    return (float(missed.sum()), missed, pred, *others)

def evaluate(opt, model):
    model.eval()
    totalErr = 0
    count = 0
    for x, y, l, mask in newLoader('val', batch_size=opt.batchsize):
        count += int(l.sum())
        err, _, pred, _, *others = evaluateStep(opt, model, x, y, l, mask)
        totalErr += err
    if opt.drawVars:
        opt.drawVars(x[0], l[0], *tuple(v[0] for v in others))
        print(pred[0])
    return totalErr / count

def initTrain(opt, model, epoch=None):
    opt.optimizer = opt.newOptimizer(opt, model.parameters(), 1e-8)
    if opt.sdt_decay_step > 0:
        opt.scheduler = optim.lr_scheduler.StepLR(opt.optimizer, opt.sdt_decay_step, gamma=0.5)
    else:
        opt.scheduler = optim.lr_scheduler.StepLR(opt.optimizer, 1e6, gamma=1)
    if type(epoch) == int:
        state = torch.load('train.epoch{}.pth'.format(epoch), map_location='cpu')
        opt.optimizer.load_state_dict(state[0])
        opt.scheduler.load_state_dict(state[1])
    else:
        torch.manual_seed(args.rank)
        np.random.seed(args.rank)

def train(opt, model, init=True):
    if init:
        initParameters(opt, model)
        if type(init) == int:
            model.load_state_dict(torch.load('model.epoch{}.pth'.format(epoch), map_location='cpu'))
    model = model.to(opt.device) # need before constructing optimizers
    if init:
        initTrain(opt, model, init)
    if amp:
        model, opt.optimizer = amp.initialize(model, opt.optimizer, opt_level="O2", keep_batchnorm_fp32=False)
    for i in range(opt.scheduler.last_epoch, opt.epochs):
        opt.scheduler.step()
        count = 0
        totalLoss = 0
        model.train()
        opt.optimizer.zero_grad()
        for x, y, l, mask in newLoader('train', batch_size=opt.batchsize, shuffle=True):
            length = int(l.sum())
            count += length
            loss = trainStep(opt, model, x, y, length, mask)
            totalLoss += loss
        valErr = evaluate(opt, model)
        if opt.writer:
            logBoardStep(opt, model)
        print('Epoch #%i | train loss: %.4f | valid error: %.3f | learning rate: %.5f' %
          (opt.scheduler.last_epoch, totalLoss / count, valErr, opt.scheduler.get_lr()[0]))
        if i % 10 == 9:
            saveState(opt, model, opt.scheduler.last_epoch)
    return valErr

def saveState(opt, model, epoch):
    torch.save(model.state_dict(), 'model.epoch{}.pth'.format(epoch))
    torch.save((opt.optimizer.state_dict(), opt.scheduler.state_dict()), 'train.epoch{}.pth'.format(epoch))

def logBoardStep(opt, model):
    step = opt.scheduler.last_epoch
    for name, param in model.named_parameters():
        try:
            opt.writer.add_histogram(name, param.data, step)
        except:
            print(name, param)

if __name__ == '__main__':
    torch.manual_seed(args.rank)
    np.random.seed(args.rank)
    model = Model(opt).to(opt.device)
    print('Number of parameters: %i | valid error: %.3f' % (getNelement(model), evaluate(opt, model)))
    train(opt, model)
    torch.save(model.state_dict(), 'model.epoch{}.pth'.format(opt.scheduler.last_epoch))

Overwriting train.py


In [None]:
%%file option.py
option = dict(edim=16, epochs=3, maxgrad=1., sdt=1e-2, sdt_decay_step=1, batchsize=32)
try:
    from qhoptim.pyt import QHAdam
    option['newOptimizer'] = lambda opt, params, _: QHAdam(params, lr=opt.sdt, nus=(.7, .8), betas=(0.995, 0.999))
except ImportError: pass

## visualization

In [None]:
%%file -a option.py
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib as mpl
zhfont= mpl.font_manager.FontProperties(fname=fontPath)
columns = 2

def drawAttention(indices, l, _, att, *args):
  if len(att.shape) != 3:
    return
  heads = att.size(0)
  l = int(l)
  rows = (heads + columns - 1) // columns
  indices = indices[:l].tolist()
  ticks = np.arange(0, l)
  labels = [''] + [vocab[i] for i in indices]
  fig = plt.figure(figsize=(16, rows * 16 // columns))
  for t in range(heads):
    ax = fig.add_subplot(rows, columns, t + 1)
    data = att[t, :l, :l+1].detach().to(torch.float).cpu().numpy()
    cax = ax.matshow(data, interpolation='nearest', cmap='hot', vmin=0, vmax=1)
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.set_xticklabels(labels + ['NA'], fontproperties=zhfont)
    ax.set_yticklabels(labels, fontproperties=zhfont)
  return plt.show()

option['drawVars'] = drawAttention

## ParlAI agent

In [None]:
%%file my_agent.py
from common import *
from copy import deepcopy
import json
import numpy as np
from parlai.core.torch_agent import TorchAgent, Output
from parlai.core.logs import TensorboardLogger
from model import Model, predict
from train import opt, initParameters, nan

class MyAgent(TorchAgent):
    def __init__(self, optAgent, shared=None):
        super().__init__(optAgent, shared)
        if optAgent.get('numthreads', 1) > 1:
            torch.set_num_threads(1)
        optAgent['gradient_clip'] = opt.maxgrad
        self.opt = optAgent
        self.criterion = opt.criterion
        self.loss = opt.loss
        torch.manual_seed(args.rank)
        np.random.seed(args.rank)
        if optAgent['tensorboard_log'] is True:
            self.writer = TensorboardLogger(optAgent)
        if not shared:
            model = Model(opt)
            initParameters(opt, model)
            model = model.to(opt.device)
            model.train()
            self.model = model
            if optAgent.get('numthreads', 1) > 1:
                model.share_memory()
        else:
            self.model = share['model']
            self.dict = shared['dict']
        self.reset()

    def share(self):
        """Share internal states between parent and child instances."""
        return super().share()

    def reset(self):
        """Reset episode_done."""
        super().reset()
        self.episode_done = True
        self.reset_metrics()
        return self

    def save(self, path):
        """Save model, options, dict."""
        path = self.opt.get('model_file', None) if path is None else path
        if not path:
            return
        states = self.state_dict()
        if states:
            torch.save(states['model'], path + '.pth')
            del states['model']
            with open(path + '.states', 'wb') as write:
                torch.save(states, write)
        # Parlai expects options to also be saved
        with open(path + '.opt.json', 'w', encoding='utf-8') as handle:
            if hasattr(self, 'model_version'):
                self.opt['model_version'] = self.model_version()
            saved_opts = deepcopy(self.opt)
            if 'interactive_mode' in saved_opts:
                # We do not save the state of interactive mode, it is only decided
                # by scripts or command line.
                del saved_opts['interactive_mode']
            json.dump(self.opt, handle)
            # for convenience of working with jq, make sure there's a newline
            handle.write('\n')

        # force save the dict
        self.dict.save(path + '.dict.txt', sort=False)

    def load_state_dict(self, state_dict):
        """Load the state dict into model."""
        self.model.load_state_dict(state_dict)
        if self.use_cuda:
            self.model.cuda()

    def load(self, path):
        """Load model, options, dict."""
        optPath = path + '.opt.json'
        if os.path.isfile(optPath):
            with open(optPath, 'r', encoding='utf-8') as handle:
                self.opt = json.load(handle)
        dictPath = path + '.dict.txt'
        if os.path.isfile(dictPath) and hasattr(self, 'dict'):
            self.dict.load(dictPath)
        statePath = path + '.states'
        states = torch.load(statePath, map_location='cpu') if os.path.isfile(statePath) else {}
        modelPath = path + '.pth'
        if os.path.isfile(modelPath):
            states['model'] = torch.load(modelPath, map_location='cpu')
            self.load_state_dict(states['model'])
        if 'optimizer' in states and hasattr(self, 'optimizer'):
            self.optimizer.load_state_dict(states['optimizer'])
        return states

    def is_valid(self, obs):
        """Override from TorchAgent.
        Check if an observation has no tokens in it."""
        return len(obs.get('text_vec', [])) > 0

    def batchify(self, *args, **kwargs):
        """
        Create a batch of valid observations from an unchecked batch.

        A valid observation is one that passes the lambda provided to the
        function, which defaults to checking if the preprocessed 'text_vec'
        field is present which would have been set by this agent's 'vectorize'
        function.

        Returns a namedtuple Batch. See original definition above for in-depth
        explanation of each field.

        If you want to include additonal fields in the batch, you can subclass
        this function and return your own "Batch" namedtuple: copy the Batch
        namedtuple at the top of this class, and then add whatever additional
        fields that you want to be able to access. You can then call
        super().batchify(...) to set up the original fields and then set up the
        additional fields in your subclass and return that batch instead.

        :param obs_batch:
            List of vectorized observations

        :param sort:
            Default False, orders the observations by length of vectors. Set to
            true when using torch.nn.utils.rnn.pack_padded_sequence.  Uses the text
            vectors if available, otherwise uses the label vectors if available.
        """
        batch = super().batchify(*args, **kwargs)
        if not batch.valid_indices or not len(batch.valid_indices):
            return batch

        lengths = batch.text_lengths
        if lengths:
            batch.text_lengths = torch.tensor(lengths)
            bsz = lengths.shape[0]
            text_mask = torch.zeros((bsz, batch.text_vec.shape[1]), dtype=torch.uint8)
            for i in range(bsz):
                text_mask[i, :lengths[i]].fill_(1)
            batch.text_mask = text_mask.cuda() if self.use_cuda else text_mask
        return batch

    def train_step(self, batch):
        """Process batch of inputs and targets and train on them.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = True
        self.model.train()
        self.zero_grad()
        output = self.model(batch.text_vec, batch.text_mask)
        loss = self.loss(self, self.model, batch.label_vec, *output)
        if torch.allclose(loss, nan, equal_nan=True):
            raise Exception('Loss returns NaN')
        self.backward(loss)
        self.update_params()
        self.count += int(batch.text_lengths.sum())
        self.metrics['loss'] += float(loss)
        pred = predict(output[0])
        return Output(text=[self.dict.vec2txt(y) for y in pred])

    def eval_step(self, batch):
        """Process batch of inputs.

        If the batch includes labels, calculate validation metrics as well.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = False
        self.model.eval()
        output = self.model(batch.text_vec, batch.text_mask)
        if batch.label_vec is not None:
            # Interactive mode won't have a gold label
            missed = self.criterion(batch.label_vec, pred, batch.text_mask)
            self.metrics['error'] += float(missed.sum())
            self.eval_exs += batch.label_vec.shape[0]

        pred = predict(output[0])
        return Output(text=[self.dict.vec2txt(y) for y in pred])

    def report(self):
        """Return metrics calculated by the model."""
        metrics = super().report()
        metrics['loss.avg'] = metrics['loss'] / (self.count if self.count else 1)
        metrics['error.avg'] = metrics['error'] / (self.eval_exs if self.eval_exs else 1)
        return metrics

    def reset_metrics(self):
        """Reset metrics calculated by the model back to zero."""
        super().reset_metrics()
        self.metrics['loss'] = 0.
        self.metrics['error'] = 0.
        self.count = 0
        self.eval_exs = 0

    def receive_metrics(self, metrics_dict):
        """Update lr scheduler with validation loss."""
        super().receive_metrics(metrics_dict)

    @classmethod
    def add_cmdline_args(cls, argparser):
        """Add command-line arguments specifically for this agent."""
        super(MyAgent, cls).add_cmdline_args(argparser)

        agent = argparser.add_argument_group('Arguments')
        agent.add_argument('-esz', '--embeddingsize', type=int, default=100,
                           help='size of the token embeddings')
        agent.add_argument('-dr', '--dropout', type=float, default=0.0,
                           help='dropout rate')
        agent.add_argument('-rf', '--report-freq', type=float, default=0.001,
                           help='Report frequency of prediction during eval.')
        agent.add_argument(
            '--fp16', type='bool', default=True, help='Use fp16 computations.'
        )
        agent.add_argument(
            '--split-lines',
            type='bool',
            default=True,
            help='split the dialogue history on newlines and save in separate '
            'vectors',
        )
        MyAgent.dictionary_class().add_cmdline_args(argparser)
        return agent

In [None]:
%run -i train.py