In [1]:
!pip install datasets &> /dev/null
!pip install transformers &> /dev/null

In [2]:
import gc
import math
import numpy as np
import random

import datasets
import tokenizers 
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers


# Dataset & Tokenizer

In [None]:
raw_dataset = datasets.load_dataset('imdb', split=['train', 'test[:5000]'])

Downloading:   0%|          | 0.00/1.92k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text (download: 80.23 MiB, generated: 127.02 MiB, post-processed: Unknown size, total: 207.25 MiB) to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a...


Downloading:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
print(tokenizer.vocab_size)

28996


In [None]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

train_dataset = raw_dataset[0].map(tokenize_function, batched=True)
test_dataset = raw_dataset[0].map(tokenize_function, batched=True)

  0%|          | 0/25 [00:00<?, ?ba/s]

  0%|          | 0/25 [00:00<?, ?ba/s]

# Models

In [3]:
class MHA(nn.Module):
    """Heart of https://arxiv.org/pdf/2006.16236.pdf
    """

    def __init__(self, d_model, n_heads, use_cos, kernel, 
                 dropout, denom_eps, bias):
        super(MHA, self).__init__()
        assert d_model % n_heads == 0, 'd_model must be a multiple of n_heads'
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = self.d_model // self.n_heads
        self.denom_eps = denom_eps

        if kernel == 'relu':
            self.kernel = self.relu_kernel
        elif kernel == 'elu':
            self.kernel = self.elu_kernel
        else:
            raise NotImplementedError(
                "The only options for 'kernel' are 'relu and 'elu'.")
            
        if use_cos:
            self.attention_func = self.cos_linear_attention
        else:
            self.attention_func = self.linear_attention

        self.w_qkv = nn.Linear(d_model, 3 * d_model, bias=bias)
        self.w_o = nn.Linear(d_model, d_model, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def apply_mask(self, x, mask):
        # x -> [batch_size, seq_len, d_model]
        # mask -> [batch_size, seq_len, 1] or None
        if not mask is None:
            #x.masked_fill_(~mask, 0)
            x = x.masked_fill(mask, 0)
        return x

    def split_heads(self, x):
        batch_size, seq_len = x.shape[:2]
        # x -> [batch_size, seq_len, d_model]
        x = x.view(batch_size, seq_len, self.n_heads, self.d_head)
        # x -> [batch_size, seq_len, n_heads, d_head]
        return x

    def join_heads(self, x):
        batch_size, seq_len = x.shape[:2]
        # x -> [batch_size, seq_len, n_heads, d_head]
        x = x.view(batch_size, seq_len, self.d_model).contiguous()
        # x -> [batch_size, seq_len, d_model]
        return x

    def elu_kernel(self, x):
        return F.elu(x) + 1

    def relu_kernel(self, x):
        return F.relu(x)

    def linear_attention(self, q, k, v, weights=None):
        # stolen from 
        # https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/kernel_attention.py
        # q, k, v -> [batch_size, seq_len, n_heads, d_head]
        kv = torch.einsum('bsnx,bsnz->bnxz', k, v)
        # kv -> [batch_size, n_heads, d_head, d_head]
        # add dropout here
        denominator = 1.0 / (torch.einsum('bsnd,bnd->bsn', q, k.sum(axis=1)) + self.denom_eps)
        # denominator -> [batch_size, seq_len, n_heads]

        output = torch.einsum('bsnx,bnxz,bsn->bsnz', q, kv, denominator).contiguous()
        # output -> [batch_size, seq_len, n_heads, d_head]

        return output

    def cos_linear_attention(self, q, k, v, weights):
        # q, k, v -> [batch_size, seq_len, n_heads, d_head]
        cos, sin = weights
        # cos, sin -> [batch_size, seq_len]
        q_cos = torch.einsum('bsnd,bs->bsnd', q, cos)
        q_sin = torch.einsum('bsnd,bs->bsnd', q, sin)
        k_cos = torch.einsum('bsnd,bs->bsnd', k, cos)
        k_sin = torch.einsum('bsnd,bs->bsnd', k, sin)
        # q_cos, q_sin, k_cos, k_sin -> [batch_size, seq_len, n_heads, d_head]

        kv_cos = torch.einsum('bsnx,bsnz->bnxz', k_cos, v)
        # kv_cos -> [batch_size, n_heads, d_head, d_head]
        qkv_cos = torch.einsum('bsnx,bnxz->bsnz', q_cos, kv_cos)
        # qkv_cos -> [batch_size, seq_len, n_heads, d_head]

        kv_sin = torch.einsum('bsnx,bsnz->bnxz', k_sin, v)
        # kv_sin -> [batch_size, n_heads, d_head, d_head]
        qkv_sin = torch.einsum('bsnx,bnxz->bsnz', q_sin, kv_sin)
        # qkv_sin -> [batch_size, seq_len, n_heads, d_head]

        # denominator
        denominator = 1.0 / (torch.einsum('bsnd,bnd->bsn', q_cos, k_cos.sum(axis=1)) 
            + torch.einsum('bsnd,bnd->bsn', q_sin, k_sin.sum(axis=1))
            + self.denom_eps)
        # denominator -> [batch_size, seq_len, n_heads]
        
        output = torch.einsum('bsnz,bsn->bsnz', qkv_cos + qkv_sin, denominator).contiguous()
        # output -> [batch_size, seq_len, n_heads, d_head]
        return output        

    def forward(self, x, mask, weights):
        # x -> [batch_size, seq_len, d_model]
        # mask -> [batch_size, seq_len, 1] or None
        q, k, v = torch.chunk(self.w_qkv(x), 3, -1) 
        # q, k, v -> [batch_size, seq_len, d_model]

        q = self.split_heads(self.kernel(q))
        k = self.split_heads(self.apply_mask(self.kernel(k), mask))
        v = self.split_heads(v)
        # q, k, v -> [batch_size, seq_len, n_heads, d_head]

        x = self.attention_func(q, k, v, weights)
        # x -> [batch_size, seq_len, n_heads, d_head]
        x = self.join_heads(x)
        x = self.dropout(self.w_o(x))
        # x -> [batch_size, seq_len, d_model]

        return x


class FFN(nn.Module):

    def __init__(self, d_model, ffn_ratio, dropout, bias):
        super(FFN, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(d_model, ffn_ratio * d_model, bias=bias),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_ratio * d_model, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x -> [batch_size, seq_len, d_model]
        x = self.layers(x)
        # x -> [batch_size, seq_len, d_model]
        return x


class MHA_block(nn.Module):
    """
    Implements the Pre-LN Architecture as suggested here:
    https://arxiv.org/pdf/2002.04745.pdf
    """

    def __init__(self, d_model, n_heads, use_cos, kernel, dropout, 
                 ffn_ratio, ln_eps, denom_eps, bias):

        super(MHA_block, self).__init__()
        self.ln1 = nn.LayerNorm(d_model, eps=ln_eps)
        self.ln2 = nn.LayerNorm(d_model, eps=ln_eps)

        self.mha = MHA(
            d_model, n_heads, use_cos, kernel, dropout, denom_eps, bias)  
        self.ffn = FFN(d_model, ffn_ratio, dropout, bias)

    def forward(self, x, mask, weights):

        # x -> [batch_size, seq_len, d_model]
        fx = self.mha(self.ln1(x), mask, weights)
        x = x + fx

        fx = self.ffn(self.ln2(x))
        x = x + fx

        return x

class MHA_block_rezero(nn.Module):
    """
    Implements the ReZero Architecture as suggested here:
    https://arxiv.org/pdf/2003.04887.pdf
    https://github.com/majumderb/rezero
    """

    def __init__(self, d_model, n_heads, use_cos, kernel, dropout, 
                 ffn_ratio, ln_eps, denom_eps, bias):

        super(MHA_block_rezero, self).__init__()

        self.mha = MHA(
            d_model, n_heads, use_cos, kernel, dropout, denom_eps, bias)  
        self.ffn = FFN(d_model, ffn_ratio, dropout, bias)

        self.alpha = nn.Parameter(torch.Tensor([0]))
        

    def forward(self, x, mask, weights):

        # x -> [batch_size, seq_len, d_model]
        fx = self.alpha * self.mha(x, mask, weights)
        x = x + fx

        fx = self.alpha * self.ffn(x)
        x = x + fx

        return x


class Positional_embeddings(nn.Module):
    """
    Stolen from:
    https://github.com/lucidrains/linear-attention-transformer/blob/master/linear_attention_transformer/linear_attention_transformer.py
    """
    def __init__(self, d_model, max_len):
        super(Positional_embeddings, self).__init__()
        inv_freq = 1. / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        position = torch.arange(0, max_len, dtype=torch.float)
        sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
        self.register_buffer('emb', emb)

    def forward(self, x):
        return self.emb[None, :x.shape[1], :].to(x).detach()

class Segment_embeddings(nn.Module):
    """
    Stolen from:
    https://github.com/dreamgonfly/BERT-pytorch/blob/master/bert/train/model/embeddings.py
    """
    def __init__(self, d_model):
        super(Segment_embeddings, self).__init__()
        self.emb = nn.Embedding(2, d_model)

    def forward(self, x):
        return self.emb(x)


class Kernel_transformer(nn.Module):

    def __init__(self, d_model, n_heads, use_cos, kernel, dropout, 
                 ffn_ratio, n_layers, n_emb, tie_emb, ln_eps, denom_eps, 
                 bias, rezero, for_clf, n_classes=None, max_len=1024, xavier=True):
        super(Kernel_transformer, self).__init__()

        self.for_clf = for_clf
        self.use_cos = use_cos
        self.n_classes = n_classes

        self.emb_in = nn.Embedding(n_emb, d_model)
        if self.for_clf:
            if self.n_classes == 2: self.n_classes = 1
            self.emb_out = nn.Linear(d_model, self.n_classes)
        else:
            self.emb_out = nn.Linear(d_model, n_emb)
        # Tie input & output embeddings as in https://arxiv.org/abs/1608.05859
        if not self.for_clf and tie_emb:
            self.emb_out.weight = self.emb_in.weight

        #if not self.use_cos:
        self.emb_pos = Positional_embeddings(d_model, max_len)
        #self.emb_seg = Segment_embeddings(d_model)

        Block_class = MHA_block_rezero if rezero else MHA_block

        self.mha_blocks = nn.ModuleList(
            [Block_class(
                d_model, n_heads, use_cos, kernel, dropout, 
                ffn_ratio, ln_eps, denom_eps, bias
                ) for _ in range(n_layers)]
        )

        if self.for_clf and  self.n_classes == 1:
            self.loss_fn = F.binary_cross_entropy_with_logits
        else:
            self.loss_fn = F.cross_entropy
        
        # Trick to get model device. Stolen from:
        # https://stackoverflow.com/questions/58926054/how-to-get-the-device-type-of-a-pytorch-module-conveniently
        self.dummy_param = nn.Parameter(torch.empty(0))

        if xavier:
            self.init_xavier_uniform()

    def dev(self):
        """Returns the device where the model is stored"""
        return self.dummy_param.device

    def init_xavier_uniform(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)

    def get_mask(self, lens, max_len):
        # lens -> [batch_size]
        mask = torch.arange(max_len)[None, :].to(lens) < lens[:, None]
        # mask -> [batch_size, max_len]
        return mask

    def get_cos_weights(self, lengths, max_len=None):
        # lengths -> [batch_size]
        if max_len is None:
            max_len = lengths.max()
        # For each sample x in the batch, calculate M(x) = len(x)
        M = lengths
        # M -> [batch_size]
        idxs = math.pi / 2 * torch.arange(max_len).to(lengths)
        # idxs -> [max_len]
        idxs = torch.outer(1.0 / M, idxs)#[..., None, None]
        # idxs -> [batch_size, max_len]

        cos = torch.cos(idxs).detach()
        sin = torch.sin(idxs).detach()
        # cos, sin -> [batch_size, max_len]

        return cos, sin

    def forward(self, input_ids, labels, lengths, attention_mask=None):
        # input_idxs -> [batch_size, seq_len]
        # labels -> [batch_size]
        # attention_mask -> [batch_size, max_len] or None
        # lengths -> [batch_size]
        # weights ->  (tuple 2 X [batch_size, seq_len]) or None

        """
        if lengths is None:
            lengths = torch.full(
                [input_ids.shape[0]], input_ids.shape[0], device=input_ids.device)
        """
        # if not lengths is None:   
        input_ids = input_ids[:, :lengths.max()]

        if not attention_mask is None:
            # if not lengths is None:
            attention_mask = attention_mask[:, :lengths.max()]
            attention_mask = torch.logical_not(attention_mask[..., None].bool())
        # attention_mask -> [batch_size, max_len, 1] or None
            
        if self.use_cos:
            cos_weights = self.get_cos_weights(lengths, lengths.max())
        else:
            cos_weights = None

        x = self.emb_in(input_ids)
        if not self.use_cos:
            x += self.emb_pos(x)
        # x += self.emb_seg(seg_idxs)

        for block in self.mha_blocks:
            x = block(x, attention_mask, cos_weights)

        if self.for_clf: x = x[:, 0, :]
        x = self.emb_out(x)
        if self.n_classes == 1: x = x.squeeze(-1)
        loss = self.loss_fn(x, labels)
        return loss, x





In [4]:
class Baseline_transformer(nn.Module):

    def __init__(self, **kwargs):
        super(Baseline_transformer, self).__init__()

        self.for_clf = kwargs['for_clf']
        self.n_classes = kwargs['n_classes']

        self.emb_in = nn.Embedding(kwargs['n_emb'], kwargs['d_model'])
        if self.for_clf:
            if self.n_classes == 2: self.n_classes = 1
            self.emb_out = nn.Linear(kwargs['d_model'], self.n_classes)
        else:
            self.emb_out = nn.Linear(kwargs['d_model'], kwargs['n_emb'])
        # Tie input & output embeddings as in https://arxiv.org/abs/1608.05859
        if not self.for_clf and kwargs['tie_emb']:
            self.emb_out.weight = self.emb_in.weight

        self.emb_pos = Positional_embeddings(kwargs['d_model'], kwargs['max_len'])

        self.mha_blocks = nn.ModuleList([])
        for _ in range(kwargs['n_layers']):
            block = nn.TransformerEncoderLayer(
                d_model=kwargs['d_model'], 
                nhead=kwargs['n_heads'], 
                dim_feedforward=kwargs['d_model'] * kwargs['ffn_ratio'], 
                dropout=kwargs['dropout'], 
                activation='gelu', 
                layer_norm_eps=kwargs['ln_eps'], 
                batch_first=True, 
                #norm_first=True,
            )
            self.mha_blocks.append(block)

        if self.for_clf and  self.n_classes == 1:
            self.loss_fn = F.binary_cross_entropy_with_logits
        else:
            self.loss_fn = F.cross_entropy

        if kwargs['xavier']:
            self.init_xavier_uniform()

    def init_xavier_uniform(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)

    def forward(self, input_ids, labels, attention_mask=None, lengths=None):
        # input_idxs -> [batch_size, seq_len]
        # labels -> [batch_size]
        # attention_mask -> [batch_size, max_len] or None
        # lengths -> [batch_size]
        # weights ->  (tuple 2 X [batch_size, seq_len]) or None

        if lengths is None:
            lengths = torch.full(
                [input_ids.shape[0]], input_ids.shape[0], device=input_ids.device)
            
        input_ids = input_ids[:, :lengths.max()]

        if not attention_mask is None:
            attention_mask = ~attention_mask[:, :lengths.max()]
            # attention_mask = attention_mask[..., None]
        # attention_mask -> [batch_size, max_len, 1] or None

        x = self.emb_in(input_ids)
        x += self.emb_pos(x)

        for block in self.mha_blocks:
            x = block(x, src_key_padding_mask=attention_mask)

        if self.for_clf: x = x[:, 0, :]
        x = self.emb_out(x)
        if self.n_classes == 1: x = x.squeeze(-1)
        loss = self.loss_fn(x, labels)
        return loss, x


In [None]:
model_args = {
    'd_model': 16,
    'n_heads': 2,
    'use_cos': True,
    'kernel': 'relu',
    'dropout': 0.2,
    'ffn_ratio': 4,
    'n_layers': 2,
    'n_emb': 1000,
    'tie_emb': True,
    'ln_eps': 1e-5,
    'denom_eps': 1e-6,
    'bias': False,
    'rezero': True,
    'for_clf': True,
    'n_classes': 1,
    'max_len': 1024,
    'xavier': True,
}

seq_len = 420
batch_size = 4



In [None]:
model = Kernel_transformer(**model_args)


In [None]:
input_ids = torch.randint(0, model_args['n_emb'], [batch_size, seq_len])
labels = torch.randint(0, 2, [batch_size]).float()
lengths = torch.randint(1, seq_len, [batch_size])
attention_mask = model.get_mask(lengths, seq_len)
with torch.no_grad():
    output = model(input_ids=input_ids, labels=labels, 
                  attention_mask=attention_mask, lengths=lengths)

In [None]:
output

(tensor(0.5428, grad_fn=<BinaryCrossEntropyWithLogitsBackward>),
 tensor([2.8745, 1.0506, 1.4905, 2.0364], grad_fn=<SqueezeBackward1>))

In [None]:
baseline_model = Baseline_transformer(**model_args)

In [None]:
input_ids = torch.randint(0, model_args['n_emb'], [batch_size, seq_len])
labels = torch.randint(0, 2, [batch_size]).float()
lengths = torch.randint(1, seq_len, [batch_size])
attention_mask = model.get_mask(lengths, seq_len)
with torch.no_grad():
    output = baseline_model(input_ids=input_ids, labels=labels, 
                  attention_mask=attention_mask, lengths=lengths)

In [None]:
output

(tensor(0.3837), tensor([-0.5415, -0.5875, -3.3127, -2.8392]))

# Dataset & Tokenizer



In [5]:
raw_dataset = datasets.load_dataset('imdb', split=['train', 'test[:5000]'])

Downloading:   0%|          | 0.00/1.92k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text (download: 80.23 MiB, generated: 127.02 MiB, post-processed: Unknown size, total: 207.25 MiB) to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a...


Downloading:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
print(tokenizer.vocab_size)

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

28996


In [7]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

In [8]:
train_dataset = raw_dataset[0].map(tokenize_function, batched=True)
test_dataset = raw_dataset[0].map(tokenize_function, batched=True)

  0%|          | 0/25 [00:00<?, ?ba/s]

  0%|          | 0/25 [00:00<?, ?ba/s]

# Training

In [9]:
def set_seed(seed_val):
    """Sets seed for reproducibility.
    Args:
      seed_val: (int) Seed for rng.
    """
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
    transformers.set_seed(seed_val)


def get_device():
    """Returns Cuda device if it is available.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        print(f'Available GPU: {torch.cuda.get_device_name(0)}.')
    else:
        print('GPU unavailable.')

    return device


def padding_collator(batch):
    """Dynamically pads a batch, creates mask and masks out padded labels.
    Given batch_size sequences of different lengths, computes
    k = log2(batch_size). All sequences are padded/truncated to the length 
    of the kth longest sequence. Additionally, masks are constructed based 
    on the new dynamic lengths and padded tokens are masked out within labels.
    Code Based on:
    https://gist.github.com/pommedeterresautee/1a334b665710bec9bb65965f662c94c8
    https://huggingface.co/transformers/_modules/transformers/data/data_collator.html#default_data_collator
    Args:
      batch: (list) Each element is a dict {'input_ids': tokens} where tokens
        is a list of token indices coreesponding to a single sample.
    Returns:
      batch: (dict) Contains 3 elements: 
        input_ids: (torch.tensor) Padded input indices.
        attention_mask: (torch.tensor) Mask that takes the value 1 for valid
          tokens and 0 for padded ones.
        labels: (torch.tensor) Same as input_ids but padded tokens are replaced
          with -100.
    """
    lengths = [sum(sample['attention_mask']) for sample in batch]
    max_len = max(lengths)
    input_ids = [sample['input_ids'][:max_len] for sample in batch]
    input_ids = torch.tensor(input_ids).long()
    attention_mask = [sample['attention_mask'][:max_len] for sample in batch]
    attention_mask = torch.tensor(attention_mask).long()
    labels = [sample['label'] for sample in batch]
    labels = torch.tensor(labels).float()
    lengths = torch.tensor(lengths).long()

    batch = {
        'input_ids': input_ids, 
        'attention_mask': attention_mask, 
        'lengths': lengths,
        'labels': labels
    }
    return batch

def free_memory():
    """(Maybe) prevents Cuda running out of memory
    """
    gc.collect()
    torch.cuda.empty_cache()


class Garbage_collector_callback(transformers.TrainerCallback):
    """Custom callback that (maybe) prevents Cuda running out of memory.
    I have absolutely no idea if this actually helps. However, Cuda on Colab
    is prone to memory leaks, especially in case of Ctrl + C interrupts. 
    After using this callback the issue kinda disappeared. Code based on 
    https://huggingface.co/transformers/main_classes/callback.html
    """

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Called every time the Trainer logs data.
        """
        res_before = torch.cuda.memory_reserved(0)
        free_memory()
        res_after = torch.cuda.memory_reserved(0)
        freed = res_before - res_after
        print(f'Freed {freed}.')

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

In [11]:
seed_val = 42
set_seed(seed_val)
device = get_device()


Available GPU: Tesla K80.


In [22]:
model_args = {
    'd_model': 384,
    'n_heads': 6,
    'use_cos': False,
    'kernel': 'elu',
    'dropout': 0.2,
    'ffn_ratio': 4,
    'n_layers': 5,
    'n_emb': tokenizer.vocab_size,
    'tie_emb': True,
    'ln_eps': 1e-5,
    'denom_eps': 1e-5,
    'bias': False,
    'rezero': True,
    'for_clf': True,
    'n_classes': 1,
    'max_len': 512,
    'xavier': True,
}


training_args = {
    # Dirs
    'output_dir': 'results',          
    'logging_dir': 'logs',            
    'num_train_epochs': 10,              
    'per_device_train_batch_size': 8,  
    'per_device_eval_batch_size': 8,   
    # Strategies
    'evaluation_strategy': 'no',     
    'logging_strategy': 'steps',
    'save_strategy': 'epoch',
    # steps
    'logging_steps': int(1e3),               
    'eval_steps': int(1e3),       
    'warmup_steps': 300,                
    'learning_rate': 2e-4,
    'log_level': 'info',
    'seed': seed_val,
    'disable_tqdm': False,
    # Optimizations
    'group_by_length': True, 
}

training_args_ = transformers.TrainingArguments(**training_args)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [None]:
dat = iter(train_dataset)
a = next(dat)
b = next(dat)
c = next(dat)
d = next(dat)

In [None]:
test = padding_collator([a, b, c, d])

In [None]:
test

{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[  101,   139, 16071,  ...,     0,     0,     0],
         [  101,  3341, 18984,  ...,  1137,  1113,   102],
         [  101,   139, 11071,  ...,     0,     0,     0],
         [  101,  1188,  1110,  ...,     0,     0,     0]]),
 'labels': tensor([1., 1., 1., 1.]),
 'lengths': tensor([211, 512, 201, 154])}

In [None]:
test['input_ids'].shape

torch.Size([4, 512])

In [None]:
a = np.array(train_dataset['label'])

In [None]:
a.sum()

12500

## Kernel Transformer

In [23]:
free_memory()

In [24]:
kernel_model = Kernel_transformer(**model_args).to(device)
print(f'{count_parameters(kernel_model)} params')

kernel_trainer_args = {
        'model': kernel_model,
        'args': training_args_,
        'train_dataset': train_dataset,
        'eval_dataset': test_dataset,
        'data_collator': padding_collator,
        'callbacks': [Garbage_collector_callback],
}

kernel_trainer = transformers.Trainer(**kernel_trainer_args)

19984134 params


In [None]:
kernel_trainer.train()

The following columns in the training set  don't have a corresponding argument in `Kernel_transformer.forward` and have been ignored: token_type_ids, text.


In [None]:
[b.alpha for b in kernel_model.mha_blocks]

[tensor([1.], device='cuda:0'),
 tensor([1.], device='cuda:0'),
 tensor([1.], device='cuda:0'),
 tensor([1.], device='cuda:0'),
 tensor([1.], device='cuda:0')]

with cos, without rezero\
1000	0.950200\
2000	0.629000\
3000	0.512600\
4000	0.443600\
5000	0.419500\
6000	0.399400\
7000	0.356900\
8000	0.337900\


with cos, with rezero\
1000	0.604900\
2000	0.451500\
3000	0.388500\
4000	0.361400\
5000	0.350500\
6000	0.346500\


## Baseline Model

In [None]:
free_memory()

In [None]:
baseline_model = Baseline_transformer(**model_args).to(device)
print(f'{count_parameters(baseline_model)} params')

baseline_trainer_args = {
        'model': baseline_model,
        'args': training_args_,
        'train_dataset': train_dataset,
        'eval_dataset': test_dataset,
        'data_collator': padding_collator,
        'callbacks': [Garbage_collector_callback],
}

baseline_trainer = transformers.Trainer(**baseline_trainer_args)

20007169 params


In [None]:
baseline_trainer.train()

The following columns in the training set  don't have a corresponding argument in `Baseline_transformer.forward` and have been ignored: token_type_ids, text.
***** Running training *****
  Num examples = 25000
  Num Epochs = 10
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 31250


Step,Training Loss
1000,0.7469
2000,0.7068
3000,0.7026
4000,0.7063


Freed 1549795328.
Freed 1558183936.
Freed 2353004544.


Saving model checkpoint to results/checkpoint-3125
Trainer.model is not a `PreTrainedModel`, only saving its state dict.


Freed 2042626048.


KeyboardInterrupt: ignored