# Attention Architecture

### Attention model does not currently work with language model training

Uses conv1d - which means it expects 1 input. Did not implement for pretraining  
Also does not use masked-multi-attention for the head. This means the decoder can cheat

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.text import *

In [3]:
path = Path('data/composers/notewise/piano_solo/note_range38/sample_freq12')

In [4]:
bs=512

In [5]:
bptt=250

In [6]:
data = TextLMDataBunch.load(path, bs=bs, bptt=bptt)

In [7]:
t = data.train_ds[0][0]
t.text[:50], t.data

('xxbos wait25 wait25 wait25 wait25 wait25 wait25 wa',
 array([  2,  94,  94,  94, ...,   9,  53,   9, 109]))

In [8]:
vocab = data.train_ds.vocab
vocab_size = len(vocab.itos); vocab_size

110

In [9]:
# data.show_batch()

### Mask

In [10]:
n_ctx=512

In [11]:
pos_emb_mask = torch.zeros(1, 1, vocab_size)
pos_emb_mask[:, :, -n_ctx:] = -1e12

In [12]:
pos_emb_mask

tensor([[[-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
          -1.0000e+12, -1

### Transformer Arch
Paper: https://arxiv.org/abs/1706.03762  
Inspiration: https://github.com/jadore801120/attention-is-all-you-need-pytorch

In [13]:
# import transformer.Constants as Constants
# from dataset import TranslationDataset, paired_collate_fn
# from transformer.Models import Transformer
# from transformer.Optim import ScheduledOptim

In [14]:
import copy
import json
import math
import re
import collections

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


def swish(x):
    return x * torch.sigmoid(x)


ACT_FNS = {
    'relu': nn.ReLU,
    'swish': swish,
    'gelu': gelu
}


class LayerNorm(nn.Module):
    "Construct a layernorm module in the OpenAI style (epsilon inside the square root)."

    def __init__(self, n_state, e=1e-5):
        super(LayerNorm, self).__init__()
        self.g = nn.Parameter(torch.ones(n_state))
        self.b = nn.Parameter(torch.zeros(n_state))
        self.e = e

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.e)
        return self.g * x + self.b


class Conv1D(nn.Module):
    def __init__(self, nf, rf, nx):
        super(Conv1D, self).__init__()
        self.rf = rf
        self.nf = nf
        if rf == 1:  # faster 1x1 conv
            w = torch.empty(nx, nf)
            nn.init.normal_(w, std=0.02)
            self.w = Parameter(w)
            self.b = Parameter(torch.zeros(nf))
        else:  # was used to train LM
            raise NotImplementedError

    def forward(self, x):
        if self.rf == 1:
            size_out = x.size()[:-1] + (self.nf,)
            x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
            x = x.view(*size_out)
        else:
            raise NotImplementedError
        return x


class Attention(nn.Module):
    def __init__(self, nx, n_ctx, cfg, scale=False):
        super(Attention, self).__init__()
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % cfg.n_head == 0
        self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = cfg.n_head
        self.split_size = n_state
        self.scale = scale
        self.c_attn = Conv1D(n_state * 3, 1, nx)
        self.c_proj = Conv1D(n_state, 1, nx)
        self.attn_dropout = nn.Dropout(cfg.attn_pdrop)
        self.resid_dropout = nn.Dropout(cfg.resid_pdrop)

    def _attn(self, q, k, v):
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
        # w = w * self.b + -1e9 * (1 - self.b)  # TF implem method: mask_attn_weights
        # XD: self.b may be larger than w, so we need to crop it
        b = self.b[:, :, :w.size(-2), :w.size(-1)]
        w = w * b + -1e9 * (1 - b)

        w = nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)
        return torch.matmul(w, v)

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

    def forward(self, x):
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        a = self._attn(query, key, value)
        a = self.merge_heads(a)
        a = self.c_proj(a)
        a = self.resid_dropout(a)
        return a


class MLP(nn.Module):
    def __init__(self, n_state, cfg):  # in MLP: n_state=3072 (4 * n_embd)
        super(MLP, self).__init__()
        nx = cfg.n_embd
        self.c_fc = Conv1D(n_state, 1, nx)
        self.c_proj = Conv1D(nx, 1, n_state)
        self.act = ACT_FNS[cfg.afn]
        self.dropout = nn.Dropout(cfg.resid_pdrop)

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return self.dropout(h2)


class Block(nn.Module):
    def __init__(self, n_ctx, cfg, scale=False):
        super(Block, self).__init__()
        nx = cfg.n_embd
        self.attn = Attention(nx, n_ctx, cfg, scale)
        self.ln_1 = LayerNorm(nx)
        self.mlp = MLP(4 * nx, cfg)
        self.ln_2 = LayerNorm(nx)

    def forward(self, x):
        a = self.attn(x)
        n = self.ln_1(x + a)
        m = self.mlp(n)
        h = self.ln_2(n + m)
        return h


class TransformerModel(nn.Module):
    """ Transformer model """

    def __init__(self, cfg, vocab=40990, n_ctx=512):
        super(TransformerModel, self).__init__()
        self.vocab = vocab
        self.embed = nn.Embedding(vocab, cfg.n_embd)
        self.pos_embed = nn.Embedding(vocab, cfg.n_embd)
        self.drop = nn.Dropout(cfg.embd_pdrop)
        block = Block(n_ctx, cfg, scale=True)
        self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])

        nn.init.normal_(self.embed.weight, std=0.02)
        nn.init.normal_(self.pos_embed.weight, std=0.02)

    def forward(self, x):
        x = x.view(-1, x.size(-2), x.size(-1))
        e = self.embed(x)
        seq_length = bptt
        position_ids = torch.arange(x.shape[-1], dtype=torch.long, device=x.device)
        position_ids = position_ids.unsqueeze(0).expand_as(x)
        
        pos = self.pos_embed(position_ids)
        # Add the position information to the input embeddings
        hid = e + pos
        pdb.set_trace()
        for block in self.h:
            hid = block(hid)
        return hid


class LMHead(nn.Module):
    """ Language Model Head for the transformer """

    def __init__(self, model, cfg, trunc_and_reshape=True):
        super(LMHead, self).__init__()
        self.n_embd = cfg.n_embd
        embed_shape = model.embed.weight.shape
        self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
        self.decoder.weight = model.embed.weight # Tied weights
        self.trunc_and_reshape = trunc_and_reshape  # XD

    def forward(self, h):
        # Truncated Language modeling logits (we remove the last token)
        h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) \
            if self.trunc_and_reshape else h  # XD
        lm_logits = self.decoder(h_trunc)
        return lm_logits



# XD
class LMModel(nn.Module):
    """ Transformer with language model head only """
    def __init__(self, cfg, vocab=40990, n_ctx=512, return_probs=False):
        super(LMModel, self).__init__()
        self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
        self.lm_head = LMHead(self.transformer, cfg, trunc_and_reshape=False)
        self.return_probs = return_probs
        if self.return_probs:
            pos_emb_mask = torch.zeros(1, 1, vocab)
            pos_emb_mask[:, :, -n_ctx:] = -1e12
            self.register_buffer('pos_emb_mask', pos_emb_mask)


    def forward(self, x):
        h = self.transformer(x)
        lm_logits = self.lm_head(h)
        if self.return_probs:
            lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1)
        return lm_logits


class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


DEFAULT_CONFIG = dotdict({
#     'n_embd': 256,
    'n_embd': 768,
    'n_head': 12,
    'n_layer': 12,
    'embd_pdrop': 0.1,
    'attn_pdrop': 0.1,
    'resid_pdrop': 0.1,
    'afn': 'gelu',
    'clf_pdrop': 0.1})


In [15]:
lm_transformer = LMModel(DEFAULT_CONFIG, vocab=vocab_size)

In [16]:
lm_transformer.transformer.embed.weight.shape

torch.Size([110, 768])

In [17]:
lm_transformer = lm_transformer.cuda()

In [18]:
lm_transformer.reset = lambda: None

In [19]:
# learn = LanguageLearner(data, model, bptt, split_func=lm_split, **kwargs)

In [20]:
ob = data.one_batch(); ob[0].shape

torch.Size([512, 275])

In [21]:
import pdb

In [23]:
lm_transformer

LMModel(
  (transformer): TransformerModel(
    (embed): Embedding(110, 768)
    (pos_embed): Embedding(110, 768)
    (drop): Dropout(p=0.1)
    (h): ModuleList(
      (0): Block(
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1)
          (resid_dropout): Dropout(p=0.1)
        )
        (ln_1): LayerNorm()
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1)
        )
        (ln_2): LayerNorm()
      )
      (1): Block(
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1)
          (resid_dropout): Dropout(p=0.1)
        )
        (ln_1): LayerNorm()
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1)
        )
        (ln_2): LayerNorm()
      )
      (2): Block(
        (attn): Attention(
          (c_attn): Conv1D()


In [22]:
lm_transformer(ob[0].cuda())

> <ipython-input-14-7dc9d9e5c4e7>(180)forward()
-> for block in self.h:
(Pdb) next
> <ipython-input-14-7dc9d9e5c4e7>(181)forward()
-> hid = block(hid)
(Pdb) next
RuntimeError: CUDA error: device-side assert triggered
> <ipython-input-14-7dc9d9e5c4e7>(181)forward()
-> hid = block(hid)
(Pdb) hid.shape
torch.Size([1, 512, 275, 768])
(Pdb) quit


BdbQuit: 

In [36]:
%debug

> [0;32m<ipython-input-24-9a6223270e33>[0m(61)[0;36mforward[0;34m()[0m
[0;32m     59 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mrf[0m [0;34m==[0m [0;36m1[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m            [0msize_out[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m-[0m[0;36m1[0m[0;34m][0m [0;34m+[0m [0;34m([0m[0mself[0m[0;34m.[0m[0mnf[0m[0;34m,[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 61 [0;31m            [0mx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0maddmm[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mb[0m[0;34m,[0m [0mx[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0mx[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mw[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m            [0mx[0m [0;34m=[0m [0mx

In [21]:
lm_transformer

LMModel(
  (transformer): TransformerModel(
    (embed): Embedding(110, 768)
    (pos_embed): Embedding(110, 768)
    (drop): Dropout(p=0.1)
    (h): ModuleList(
      (0): Block(
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1)
          (resid_dropout): Dropout(p=0.1)
        )
        (ln_1): LayerNorm()
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1)
        )
        (ln_2): LayerNorm()
      )
      (1): Block(
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1)
          (resid_dropout): Dropout(p=0.1)
        )
        (ln_1): LayerNorm()
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1)
        )
        (ln_2): LayerNorm()
      )
      (2): Block(
        (attn): Attention(
          (c_attn): Conv1D()


### Create Language learner

In [22]:
learn = LanguageLearner(data, lm_transformer, bptt)

In [23]:
learn.fit_one_cycle(1)

epoch,train_loss,valid_loss,accuracy


RuntimeError: cublas runtime error : library not initialized at /home/ubuntu/pytorch/aten/src/THC/THCGeneral.cpp:266

In [70]:
%debug

> [0;32m<ipython-input-54-cd8ef95ac38e>[0m(173)[0;36mforward[0;34m()[0m
[0;32m    171 [0;31m        [0mseq_length[0m [0;34m=[0m [0mbptt[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    172 [0;31m        [0mposition_ids[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0marange[0m[0;34m([0m[0mseq_length[0m[0;34m,[0m [0mdtype[0m[0;34m=[0m[0mtorch[0m[0;34m.[0m[0mlong[0m[0;34m,[0m [0mdevice[0m[0;34m=[0m[0mx[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 173 [0;31m        [0mposition_ids[0m [0;34m=[0m [0mposition_ids[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    174 [0;31m[0;34m[0m[0m
[0m[0;32m    175 [0;31m        [0mp[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mpos_embed[0m[0;34m([0m[0mposition_ids[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> x.shape
torch.Size([1, 32, 493])
