## Preparation

Using with Google Colaboratory

In [None]:
from google.colab import drive
drive.mount('/gdrive')
!cp '/gdrive/My Drive/data.7z' ./
!cp '/gdrive/My Drive/address/simhei.ttf' /usr/share/fonts/
!7z x data.7z
!rm -f data.7z

## data.py

In [None]:
%%file data.py
import torch
from torch.utils.data import Dataset, DataLoader

dataLength = {'train': 65536, 'val': 256, 'test': 256}

class Data(Dataset):
    def __init__(self, path):
        super(Data, self).__init__()
        l = dataLength[path]
        self.lens = torch.randint(4, (l,)) + 1
        self.mask = torch.zeros((l, 5), dtype=torch.uint8)
        for i in range(l):
            self.mask[i, :self.lens[i]].fill_(1)
        self.data = torch.rand((l, 5)) * self.mask.float()
        self.count = l
    def __len__(self):
        return self.count
    # input, label, length, mask
    def __getitem__(self, ind):
        x = self.data[ind]
        return x, x.sum(), self.lens[ind], self.mask[ind]

newLoader = lambda path, *args, **kwargs: DataLoader(Data(path), *args, **kwargs)

### imports

In [None]:
%%file train.py
import os
import argparse
from functools import reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--rank", type=int, default=0)
args = parser.parse_known_args()[0]
def opt():pass
if torch.cuda.is_available():
  opt.dtype = torch.half
  opt.device = torch.device('cuda:{}'.format(args.local_rank))
  torch.cuda.set_device(args.local_rank)
  opt.cuda = True
else:
  opt.device = torch.device('cpu')
  opt.dtype = torch.float
  opt.cuda = False
  num_threads = torch.multiprocessing.cpu_count() - 1
  if num_threads > 1:
    torch.set_num_threads(num_threads)
print('Using device ' + str(opt.device))
print('Using default dtype ' + str(opt.dtype))

## vocab.py

In [None]:
%%file vocab.py
import os
import torch
from torch.nn.utils.rnn import pad_sequence
vocabPath = './char.txt'

def getBatch(data):
  x = pad_sequence([torch.tensor([1] + [(vocabIndex[t] if t in vocabSet else 0) for t in s] + [0], dtype=torch.long) for s in data])
  l = [len(s) + 1 for s in data]
  mask = torch.ones_like(x)
  for i, t in enumerate(l):
    mask[t:, i].fill_(0)
  return x, l, mask

def initial(path):
  global vocab, vocabSet, vocabIndex
  with open(path, 'r', encoding='utf-8') as f:
    vocab = ['', ''] + f.read().split('\0')
  vocabSet = set(vocab)
  vocabIndex = {}
  for i, w in enumerate(vocab):
    vocabIndex[w] = i
  return vocab

vocab = []
if os.path.exists(vocabPath):
  initial(vocabPath)

In [None]:
%%file -a train.py
word2vecPath = './vectors.pth'
stateDictPath = './net.init.pth'
fontPath = '/usr/share/fonts/simhei.ttf'
vocabPath = './char.txt'
#fontPath = '/usr/share/fonts/wqy-microhei/wqy-microhei.ttc'
#fontPath = 'C:/Windows/Fonts/simhei.ttf'
from vocab import vocab, initial
initial(vocabPath)

## model.py

In [None]:
%%file model.py
import torch
import torch.nn as nn
import torch.nn.functional as F

Zero = torch.tensor(0.)
maxLen = 5

class Model(nn.Module):
    def __init__(self, opt):
        super(Model, self).__init__()
        self.device = opt.device
        self.dtype = opt.dtype
        self.edim = opt.edim
        self.dropout = nn.Dropout(opt.dropout)
        self.to(dtype=opt.dtype, device=opt.device)
        self.f0 = nn.Linear(1, opt.edim, bias=True)
        self.act0 = nn.LeakyReLU(.1)
        self.norm = nn.BatchNorm1d(opt.edim * maxLen, affine=False)
        self.f1 = nn.Linear(opt.edim * maxLen, 1, bias=True)
        self.act1 = torch.tanh

    def forward(self, x, mask, *_):
        bsz, l = x.shape
        mask = mask.to(self.dtype)
        e = self.dropout(x).view(bsz, l, 1)
        x1 = self.act0(self.f0(e)) * mask.view(bsz, l, 1)
        x2 = self.norm(x1.view(bsz, -1)).view(bsz, l, -1) * mask.view(bsz, l, 1)
        return self.act1(self.f1(x2.view(bsz, -1)).squeeze(-1)), Zero, x1

predict = lambda x: x

## train.py

In [None]:
%%file -a train.py
import torch.optim as optim
import numpy as np
from data import newLoader
from model import Model, predict
from option import option
torch.manual_seed(args.rank)
np.random.seed(args.rank)
getNelement = lambda model: sum(map(lambda p: p.nelement(), model.parameters()))
l1Reg = lambda acc, cur: acc + cur.abs().sum(dtype=torch.float)
l2Reg = lambda acc, cur: acc + (cur * cur).sum(dtype=torch.float)
nan = torch.tensor(float('nan'), device=opt.device)

opt.batchsize = 1
opt.epochs = 1
opt.maxgrad = 1. # max gradient
opt.dropout = 0
opt.sdt = 0.001 # initial learning rate
opt.sdt_decay_step = 10 # how often to reduce learning rate
opt.criterion = lambda y, out, mask: F.mse_loss(out, y) # criterion for evaluation
opt.loss = lambda opt, model, y, out, *args: F.mse_loss(out, y) # criterion for loss function
opt.newOptimizer = lambda opt, params, eps: optim.Adam(params, lr=opt.sdt, amsgrad=True, eps=eps)
opt.writer = 0 # TensorBoard writer
opt.drawVars = 0
opt.reset_parameters = 0
opt.__dict__.update(option)

def initParameters(opt, model):
    for m in model.modules():
        if hasattr(m, 'bias') and isinstance(m.bias, torch.Tensor):
            nn.init.constant_(m.bias, 0)
        if isinstance(m, nn.PReLU):
            nn.init.constant_(next(m.parameters()), 1)
        if opt.reset_parameters:
            opt.reset_parameters()
    if hasattr(model, 'embedding') and isinstance(model.embedding, nn.Embedding):
        model.embedding.weight.data[2:] = torch.load(word2vecPath)

def trainStep(opt, model, x, y, length, mask):
    opt.optimizer.zero_grad()
    x = x.to(opt.device, non_blocking=True)
    mask = mask.to(opt.device, non_blocking=True)
    label = y.to(opt.device, dtype=torch.float, non_blocking=True)
    loss = opt.loss(opt, model, label, *model(x, mask))
    if torch.allclose(loss, nan, equal_nan=True):
        raise Exception('Loss returns NaN')
    loss.backward()
    nn.utils.clip_grad_value_(model.parameters(), opt.maxgrad)
    opt.optimizer.step()
    return float(loss)

def evaluateStep(opt, model, x, y, _, mask):
    out, *others = model(x, mask)
    pred = predict(out)
    missed = opt.criterion(y, pred, mask)
    return (float(missed.sum()), missed, pred, *others)

def evaluate(opt, model):
    model.eval()
    totalErr = 0
    count = 0
    for x, y, l, mask in newLoader('val', batch_size=opt.batchsize):
        count += int(l.sum())
        err, _, pred, _, *others = evaluateStep(opt, model, x, y, l, mask)
        totalErr += err
    if opt.drawVars:
        opt.drawVars(x[0], l[0], *tuple(v[0] for v in others))
        print(pred[0])
    return totalErr / count

def initTrain(opt, model, epoch=None):
    eps = 1e-4 if opt.dtype == torch.float16 else 1e-8
    opt.optimizer = opt.newOptimizer(opt, model.parameters(), eps)
    if opt.sdt_decay_step > 0:
        opt.scheduler = optim.lr_scheduler.StepLR(opt.optimizer, opt.sdt_decay_step, gamma=0.5)
    else:
        opt.scheduler = optim.lr_scheduler.StepLR(opt.optimizer, 1e6, gamma=1)
    if type(epoch) == int:
        state = torch.load('train.epoch{}.pth'.format(epoch), map_location='cpu')
        opt.optimizer.load_state_dict(state[0])
        opt.scheduler.load_state_dict(state[1])
    else:
        torch.manual_seed(args.rank)
        np.random.seed(args.rank)

def train(opt, model, init=True):
    if init:
        initParameters(opt, model)
        if type(init) == int:
            model.load_state_dict(torch.load('model.epoch{}.pth'.format(epoch), map_location='cpu'))
            model.to(device=opt.device, dtype=opt.dtype) # need before constructing optimizers
        initTrain(opt, model, init)
    else:
        model.to(device=opt.device, dtype=opt.dtype)
    for i in range(opt.scheduler.last_epoch, opt.epochs):
        opt.scheduler.step()
        count = 0
        totalLoss = 0
        model.train()
        model.zero_grad()
        for x, y, l, mask in newLoader('train', batch_size=opt.batchsize, shuffle=True):
            length = int(l.sum())
            count += length
            loss = trainStep(opt, model, x, y, length, mask)
            totalLoss += loss
        valErr = evaluate(opt, model)
        if opt.writer:
            logBoardStep(opt, model)
        print('Epoch #%i | train loss: %.4f | valid error: %.3f | learning rate: %.5f' %
          (opt.scheduler.last_epoch, totalLoss / count, valErr, opt.scheduler.get_lr()[0]))
        if i % 10 == 9:
            saveState(opt, model, opt.scheduler.last_epoch)
    return valErr

def saveState(opt, model, epoch):
    torch.save(model.state_dict(), 'model.epoch{}.pth'.format(epoch))
    torch.save((opt.optimizer.state_dict(), opt.scheduler.state_dict()), 'train.epoch{}.pth'.format(epoch))

def logBoardStep(opt, model):
    step = opt.scheduler.last_epoch
    for name, param in model.named_parameters():
        try:
            opt.writer.add_histogram(name, param.data, step)
        except:
            print(name, param)

torch.manual_seed(args.rank)
np.random.seed(args.rank)
model = Model(opt)
print('Number of parameters: %i | valid error: %.3f' % (getNelement(model), evaluate(opt, model)))
train(opt, model)
torch.save(model.state_dict(), 'model.epoch{}.pth'.format(opt.scheduler.last_epoch))

In [None]:
%%file option.py
option = dict(edim=16, epochs=10, maxgrad=1., sdt=1e-2, sdt_decay_step=3, batchsize=32)

## visualization

In [None]:
%%file -a option.py
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib as mpl
zhfont= mpl.font_manager.FontProperties(fname=fontPath)
columns = 2

def drawAttention(indices, l, _, att, *args):
  if len(att.shape) != 3:
    return
  heads = att.size(0)
  l = int(l)
  rows = (heads + columns - 1) // columns
  indices = indices[:l].tolist()
  ticks = np.arange(0, l)
  labels = [''] + [vocab[i] for i in indices]
  fig = plt.figure(figsize=(16, rows * 16 // columns))
  for t in range(heads):
    ax = fig.add_subplot(rows, columns, t + 1)
    data = att[t, :l, :l+1].detach().to(torch.float).cpu().numpy()
    cax = ax.matshow(data, interpolation='nearest', cmap='hot', vmin=0, vmax=1)
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.set_xticklabels(labels + ['NA'], fontproperties=zhfont)
    ax.set_yticklabels(labels, fontproperties=zhfont)
  return plt.show()

option['drawVars'] = drawAttention

In [None]:
%run -i train.py