In [3]:
import argparse
import math
import sys
from collections import OrderedDict

import torch
from torch.utils.data import DataLoader, IterableDataset

from fast_transformers.masking import LengthMask, TriangularCausalMask
from fast_transformers.builders import TransformerEncoderBuilder


In [4]:
class EpochStats(object):
    def __init__(self, metric_names=[], freq=1, out=sys.stdout):
        self._start = time.time()
        self._samples = 0
        self._loss = 0
        self._metrics = [0]*len(metric_names)
        self._metric_names = metric_names
        self._out = out
        self._freq = freq
        self._max_line = 0

    def update(self, n_samples, loss, metrics=[]):
        self._samples += n_samples
        self._loss += loss*n_samples
        for i, m in enumerate(metrics):
            self._metrics[i] += m*n_samples

    def _get_progress_text(self):
        time_per_sample = (time.time()-self._start) / self._samples
        loss = self._loss / self._samples
        metrics = [
            m/self._samples
            for m in self._metrics
        ]
        text = "Loss: {} ".format(loss)
        text += " ".join(
            "{}: {}".format(mn, m)
            for mn, m in zip(self._metric_names, metrics)
        )
        if self._out.isatty():
            to_add = " [{} sec/sample]".format(time_per_sample)
            if len(text) + len(to_add) > self._max_line:
                self._max_line = len(text) + len(to_add)
            text += " " * (self._max_line-len(text)-len(to_add)) + to_add
        else:
            text += " time: {}".format(time_per_sample)
        return text

    def progress(self):
        if self._samples < self._freq:
            return
        text = self._get_progress_text()
        if self._out.isatty():
            print("\r" + text, end="", file=self._out)
        else:
            print(text, file=self._out, flush=True)
        self._loss = 0
        self._samples = 0
        self._last_progress = 0
        for i in range(len(self._metrics)):
            self._metrics[i] = 0
        self._start = time.time()

    def finalize(self):
        self._freq = 1
        self.progress()
        if self._out.isatty():
            print("", file=self._out)


In [5]:
def load_model(saved_file, model, optimizer, device):
    data = torch.load(saved_file, map_location=device)
    model.load_state_dict(data["model_state"])
    optimizer.load_state_dict(data["optimizer_state"])
    epoch = data["epoch"]

    return epoch

In [6]:
def save_model(save_file, model, optimizer, epoch):
    torch.save(
        dict(
            model_state=model.state_dict(),
            optimizer_state=optimizer.state_dict(),
            epoch=epoch
        ),
        save_file.format(epoch)
    )

In [7]:
def print_transformer_arguments(args):
    print((
        "Transformer Config:\n"
        "    Attention type: {attention_type}\n"
        "    Number of layers: {n_layers}\n"
        "    Number of heads: {n_heads}\n"
        "    Key/Query/Value dimension: {d_query}\n"
        "    Transformer layer dropout: {dropout}\n"
        "    Softmax temperature: {softmax_temp}\n"
        "    Attention dropout: {attention_dropout}\n"
        "    Number of hashing planes: {bits}\n"
        "    Chunk Size: {chunk_size}\n"
        "    Rounds: {rounds}\n"
        "    Masked: {masked}"
    ).format(**vars(args)))

In [8]:
def add_transformer_arguments(parser):
    parser.add_argument(
        "--attention_type",
        type=str,
        choices=["full", "causal-linear", "reformer"],
        default="causal-linear",
        help="Attention model to be used"
    )
    parser.add_argument(
        "--n_layers",
        type=int,
        default=4,
        help="Number of self-attention layers"
    )
    parser.add_argument(
        "--n_heads",
        type=int,
        default=8,
        help="Number of attention heads"
    )
    parser.add_argument(
        "--d_query",
        type=int,
        default=32,
        help="Dimension of the query, key, and value embedding"
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=0.1,
        help="Dropout to be used for transformer layers"
    )
    parser.add_argument(
        "--softmax_temp",
        type=float,
        default=None,
        help=("Softmax temperature to be used for training "
              "(default: 1/sqrt(d_query))")
    )
    parser.add_argument(
        "--attention_dropout",
        type=float,
        default=0.1,
        help="Dropout to be used for attention layers"
    )
    parser.add_argument(
        "--bits",
        type=int,
        default=32,
        help="Number of planes to use for hashing for reformer"
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        default=32,
        help="Number of queries in each block for reformer"
    )
    parser.add_argument(
        "--rounds",
        type=int,
        default=4,
        help="Number of rounds of hashing for reformer"
    )
    parser.add_argument(
        "--unmasked_reformer",
        action="store_false",
        dest="masked",
        help="If set the query can attend to itsself for reformer"
    )

    return parser

In [9]:
def get_optimizer(params, args):
    if args.optimizer == "adam":
        return torch.optim.Adam(params, lr=args.lr)
    elif args.optimizer == "radam":
        return RAdam(params, lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise RuntimeError("Optimizer {} not available".format(args.optimizer))

In [10]:
def add_optimizer_arguments(parser):
    parser.add_argument(
        "--optimizer",
        choices=["radam", "adam"],
        default="radam",
        help="Choose the optimizer"
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=1e-3,
        help="Set the learning rate"
    )
    parser.add_argument(
        "--weight_decay",
        type=float,
        default=0.01,
        help="Set the weight decay"
    )

In [11]:
class CopyTask(IterableDataset):
    def __init__(self, max_sequence, n_classes):
        self._max_sequence = max_sequence
        self._n_classes = n_classes
        self._i = 0

    def __iter__(self):
        return self

    def __next__(self):
        # Make some local copies
        max_seq = self._max_sequence
        n_classes = self._n_classes

        # Generate the random sequence
        n = torch.randint(max_seq//4, (max_seq-1)//2, tuple())
        random_sequence = (torch.rand(n)*n_classes).long() + 1

        # Generate the input, target and loss mask
        x = torch.zeros(max_seq, dtype=torch.long)
        y = torch.zeros(max_seq, dtype=torch.long)
        mask = torch.zeros(max_seq)
        x[:n] = random_sequence
        x[n+1:2*n+1] = random_sequence
        y[:-1] = x[1:]
        mask[n-1:2*n] = 1

        return x, y, mask


In [12]:
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.0, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        self.d_model = d_model
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        pos_embedding =  self.pe[:, :x.size(1), :]
        pos_embedding = torch.repeat_interleave(pos_embedding, x.shape[0], dim=0)
        x =  torch.cat([x, pos_embedding], dim=2)
        return self.dropout(x)

In [13]:
class SequencePredictor(torch.nn.Module):
    def __init__(self, d_model, sequence_length, n_classes,
                 attention_type="full", n_layers=4, n_heads=4,
                 d_query=32, dropout=0.1, softmax_temp=None,
                 attention_dropout=0.1,
                 bits=32, rounds=4,
                 chunk_size=32, masked=True):
        super(SequencePredictor, self).__init__()

        self.pos_embedding = PositionalEncoding(
            d_model//2,
            max_len=sequence_length
        )
        self.value_embedding = torch.nn.Embedding(
            n_classes+1,
            d_model//2
        )
        self.builder_dict = OrderedDict({
            "attention_type": attention_type,
            "n_layers": n_layers,
            "n_heads": n_heads,
            "feed_forward_dimensions": n_heads*d_query*4,
            "query_dimensions": d_query,
            "value_dimensions": d_query,
            "dropout": dropout,
            "softmax_temp": softmax_temp,
            "attention_dropout": attention_dropout,
            "bits": bits,
            "rounds": rounds,
            "chunk_size": chunk_size,
            "masked": masked
        })

        self.transformer = TransformerEncoderBuilder.from_dictionary(
            self.builder_dict,
            strict=True
        ).get()

        hidden_size = n_heads*d_query
        self.predictor = torch.nn.Linear(
            hidden_size,
            n_classes+1
        )

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.value_embedding(x).transpose(1, 0)
        x = self.pos_embedding(x)
        triangular_mask = TriangularCausalMask(x.shape[1], device=x.device) 
        y_hat = self.transformer(x, attn_mask=triangular_mask)
        y_hat = self.predictor(y_hat)

        return y_hat

In [14]:
def loss(y, y_hat, loss_mask):
    y_hat = y_hat.transpose(1, 0).contiguous()
    L, N, C = y_hat.shape
    l = torch.nn.functional.cross_entropy(
        y_hat.view(L*N, C),
        y.contiguous().view(L*N),
        reduction="none"
    ).view(L, N)
    # this means longer sequences have higher weight but it sounds ok
    l = (loss_mask * l).mean() / loss_mask.mean()
    accuracy = ((y == y_hat.argmax(dim=-1)).float() * loss_mask).mean() / loss_mask.mean()

    return l, accuracy.item()

In [15]:
def train(model, optimizer, dataloader, device):
    model.train()
    stats = EpochStats(["accuracy"])
    for i, (x, y, m) in zip(range(100), dataloader):
        x = x.to(device).t()
        y = y.to(device).t()
        m = m.to(device).t()
        optimizer.zero_grad()
        y_hat = model(x)
        l, acc = loss(y, y_hat, m)
        l.backward()
        optimizer.step()
        stats.update(x.shape[1], l.item(), [acc])
        stats.progress()
    stats.finalize()

In [16]:
def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0
    total_acc = 0
    total_samples = 0
    with torch.no_grad():
        for i, (x, y, m) in zip(range(20), dataloader):
            x = x.to(device).t()
            y = y.to(device).t()
            m = m.to(device).t()
            y_hat = model(x)
            l, acc = loss(y, y_hat, m)
            total_loss += x.shape[1] * l.item()
            total_acc += x.shape[1] * acc
            total_samples += x.shape[1]
    print(
        "Testing =>",
        "Loss:",
        total_loss/total_samples,
        "Accuracy:",
        total_acc/total_samples
    )

    return total_loss/total_samples


In [18]:
def main(argv=None):
    parser = argparse.ArgumentParser(
        description="Train a transformer for a copy task"
    )

    add_optimizer_arguments(parser)
    add_transformer_arguments(parser)

    parser.add_argument(
        "--sequence_length",
        type=int,
        default=128,
        help="Set the maximum sequence length"
    )
    parser.add_argument(
        "--n_classes",
        type=int,
        default=10,
        help="Set the number of classes"
    )

    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="How many epochs to train for"
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=64,
        help="How many samples to use together"
    )
    parser.add_argument(
        "--reduce_lr_at",
        type=int,
        default=30,
        help="At this epoch divide the lr by 10"
    )

    parser.add_argument(
        "--save_to",
        default=None,
        help="Set a file to save the models to."
    )
    parser.add_argument(
        "--continue_from",
        default=None,
        help="Load the model from a file"
    )
    parser.add_argument(
        "--save_frequency",
        default=1,
        type=int,
        help="Save every that many epochs"
    )

    args = parser.parse_args(argv)
    print_transformer_arguments(args)

    # Make the dataset and the model
    train_set = CopyTask(args.sequence_length, args.n_classes)
    test_set = CopyTask(args.sequence_length, args.n_classes)
    model = SequencePredictor(
        args.d_query*args.n_heads, args.sequence_length, args.n_classes,
        attention_type=args.attention_type,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        d_query=args.d_query,
        dropout=args.dropout,
        softmax_temp=None,
        attention_dropout=args.attention_dropout,
        bits=args.bits,
        rounds=args.rounds,
        chunk_size=args.chunk_size,
        masked=args.masked
    )

    # Choose a device and move everything there
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Running on {}".format(device))
    model.to(device)
    # Start training
    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        pin_memory=device=="cuda"
    )
    test_loader = DataLoader(
        test_set,
        batch_size=args.batch_size,
        pin_memory=device=="cuda"
    )
    optimizer = get_optimizer(model.parameters(), args)
    start_epoch = 1
    if args.continue_from:
        start_epoch = load_model(
            args.continue_from,
            model,
            optimizer,
            device
        )
    lr_schedule = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda e: 1. if e < args.reduce_lr_at else 0.1
    )
    for e in range(start_epoch, args.epochs+1):
        train(model, optimizer, train_loader, device)
        evaluate(model, test_loader, device)
        if (e % args.save_frequency) == 0 and args.save_to:
            save_model(args.save_to, model, optimizer, e)
        lr_schedule.step()


if __name__ == "__main__":
    main()

usage: ipykernel_launcher.py [-h] [--optimizer {radam,adam}] [--lr LR]
                             [--weight_decay WEIGHT_DECAY]
                             [--attention_type {full,causal-linear,reformer}]
                             [--n_layers N_LAYERS] [--n_heads N_HEADS]
                             [--d_query D_QUERY] [--dropout DROPOUT]
                             [--softmax_temp SOFTMAX_TEMP]
                             [--attention_dropout ATTENTION_DROPOUT]
                             [--bits BITS] [--chunk_size CHUNK_SIZE]
                             [--rounds ROUNDS] [--unmasked_reformer]
                             [--sequence_length SEQUENCE_LENGTH]
                             [--n_classes N_CLASSES] [--epochs EPOCHS]
                             [--batch_size BATCH_SIZE]
                             [--reduce_lr_at REDUCE_LR_AT] [--save_to SAVE_TO]
                             [--continue_from CONTINUE_FROM]
                             [--save_frequency SAVE_FRE

SystemExit: 2