In [77]:
!pip install transformers datasets sentencepiece sacremoses

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m880.6/880.6 KB[0m [31m70.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=895260 sha256=b410c6b898343ab7495455a1ed999f2d67d3a560d1f63c1b023aa1d7d089cdf1
  Stored in directory: /root/.cache/pip/wheels/82/ab/9b/c15899bf659ba74f623ac776e861cf2eb8608c1825ddec66a4
Successfully built

# Setup

## Imports

In [1]:
import numpy as np
from datasets import load_dataset
import torch
import torch.nn as nn
from torch import nn, Tensor
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from typing import Optional
import pandas as pd
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt
import time
from IPython.display import clear_output
from tqdm import tqdm
from transformers import BertTokenizerFast
import gc
import math

  from .autonotebook import tqdm as notebook_tqdm


## Constants

In [2]:
RANDOM_STATE = 42
BATCH_SIZE = 32

BERT_MODEL = 'bert-base-uncased'
HF_MODEL_HUB = 'huggingface/pytorch-transformers'

## Hardware configuration

In [3]:
gc.collect()
torch.cuda.empty_cache()

[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]

['NVIDIA GeForce GTX 1660 Ti with Max-Q Design']

In [4]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
device = torch.device("cpu")
device

device(type='cpu')

## Functions

In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [6]:
def tokenize(texts, tokenizer):
    res = tokenizer(
        texts, 
        return_tensors="pt",
        padding='max_length',
        max_length=512,
        truncation=True
    )
    return res['input_ids'], res['attention_mask']

In [7]:
def train(device, model, iterator, optimizer, criterion, clip, train_history=None, valid_history=None, n_step=100):
    model.train()
    epoch_loss = 0
    history = []
    for i, batch in enumerate(iterator):
        X = batch[0].to(device)
        y = batch[1].to(device)
        optimizer.zero_grad()

        output = model(X).view(-1)
        loss = criterion(output, y)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()

        history.append(loss.cpu().data.numpy())
        if (i+1)%n_step==0:
            fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))

            clear_output(True)
            ax[0].plot(history, label='train loss')
            ax[0].set_xlabel('Batch')
            ax[0].set_title('Train loss')
            if train_history is not None:
                ax[1].plot(train_history, label='general train history')
                ax[1].set_xlabel('Epoch')
            if valid_history is not None:
                ax[1].plot(valid_history, label='general valid history')
            plt.legend()
            plt.show()
    return epoch_loss / len(iterator)

In [8]:
def evaluate(device, model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            X = batch[0].to(device)
            y = batch[1].to(device)
            output = model(X).view(-1)
            loss = criterion(output, y)
            epoch_loss += loss.item()
    return epoch_loss / len(iterator)

In [9]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [10]:
def predict(device, model, iterator):
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for i, batch in tqdm(enumerate(iterator)):
            X = batch[0].to(device)
            y = batch[1].to(device)
            output = model(X).view(-1)
            y_pred += output.cpu().numpy().tolist()
            y_true += y.cpu().numpy().tolist()
    return y_true, y_pred

In [9]:
# def train_transformer(device, model, iterator, optimizer, criterion, clip, train_history=None, valid_history=None, n_step=100):
#     model.train()
#     epoch_loss = 0
#     history = []
#     for i, batch in enumerate(iterator):
#         X = batch[0]
#         X_mask = batch[1]
#         y = batch[2]
#         optimizer.zero_grad()

#         output = model(X, X_mask).view(-1)
#         loss = criterion(output, y)
#         loss.backward()

#         torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
#         optimizer.step()
#         epoch_loss += loss.item()

#         history.append(loss.cpu().data.numpy())
#         if (i+1)%n_step==0:
#             fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))

#             clear_output(True)
#             ax[0].plot(history, label='train loss')
#             ax[0].set_xlabel('Batch')
#             ax[0].set_title('Train loss')
#             if train_history is not None:
#                 ax[1].plot(train_history, label='general train history')
#                 ax[1].set_xlabel('Epoch')
#             if valid_history is not None:
#                 ax[1].plot(valid_history, label='general valid history')
#             plt.legend()
#             plt.show()
#     return epoch_loss / len(iterator)

In [11]:
# def evaluate_transformer(model, iterator, criterion):
#     model.eval()
#     epoch_loss = 0
#     with torch.no_grad():
#         for i, batch in enumerate(iterator):
#             X = batch[0]
#             X_mask = batch[1]
#             y = batch[2]
#             output = model(X, X_mask).view(-1)
#             loss = criterion(output, y)
#             epoch_loss += loss.item()
#     return epoch_loss / len(iterator)

In [14]:
# def predict_transformer(model, iterator):
#     model.eval()
#     y_true = []
#     y_pred = []
#     with torch.no_grad():
#         for i, batch in enumerate(iterator):
#             X = batch[0]
#             X_mask = batch[1]
#             y = batch[2]
#             output = model(X, X_mask).view(-1)
#             y_pred += output.cpu().numpy().tolist()
#             y_true += y.cpu().numpy().tolist()
#     return y_true, y_pred

# Data

## Load

In [11]:
IMDB_DATASET = load_dataset('imdb')

Found cached dataset imdb (C:/Users/yaram/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
100%|██████████| 3/3 [00:00<00:00, 18.71it/s]


In [12]:
IMDB_DATASET

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

In [13]:
IMDB_DATASET['train'][0]

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

In [14]:
df_train = pd.DataFrame(IMDB_DATASET['train']).copy()
df_train, df_val = train_test_split(
    df_train, test_size=0.2, 
    random_state=RANDOM_STATE, 
    stratify=df_train['label']
)

df_test = pd.DataFrame(IMDB_DATASET['test']).copy()

## Data Preprocessing

In [15]:
tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL)

In [16]:
%%time
df_train_inputs, df_train_mask = tokenize(list(df_train['text']), tokenizer)
df_val_inputs, df_val_mask = tokenize(list(df_val['text']), tokenizer)
df_test_inputs, df_test_mask = tokenize(list(df_test['text']), tokenizer)

CPU times: total: 2min 28s
Wall time: 43 s


### ToTensor

In [17]:
%%time

# convert the data to torch tensors
train_labels = torch.tensor(df_train['label'].to_numpy(), dtype=torch.float32)
valid_labels = torch.tensor(df_val['label'].to_numpy(), dtype=torch.float32)
test_labels = torch.tensor(df_test['label'].to_numpy(), dtype=torch.float32)

# create TensorDataset
# train_dataset = TensorDataset(df_train_inputs, df_train_mask, train_labels)
# valid_dataset = TensorDataset(df_val_inputs, df_val_mask, valid_labels)
# test_dataset = TensorDataset(df_test_inputs, df_test_mask, test_labels)
train_dataset = TensorDataset(df_train_inputs, train_labels)
valid_dataset = TensorDataset(df_val_inputs, valid_labels)
test_dataset = TensorDataset(df_test_inputs, test_labels)

# create dataloader
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

loaders = {
    "train": train_dataloader,
    "val": valid_dataloader,
}

CPU times: total: 0 ns
Wall time: 48.2 ms


# Transformer implementation

In [18]:
bert_model = torch.hub.load(HF_MODEL_HUB, 'model', BERT_MODEL)

Using cache found in C:\Users\yaram/.cache\torch\hub\huggingface_pytorch-transformers_main
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [23]:
bert_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

## Transformer Modules

In [34]:
class PositionalEmbedding(nn.Module):
    def __init__(self, max_seq_len, embed_model_dim):
        super(PositionalEmbedding, self).__init__()
        self.embed_dim = embed_model_dim

        pe = torch.zeros(max_seq_len, self.embed_dim)
        for pos in range(max_seq_len):
            for i in range(0, self.embed_dim, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/self.embed_dim)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/self.embed_dim)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x * math.sqrt(self.embed_dim)
        seq_len = x.size(1)
        x = x + torch.autograd.Variable(self.pe[:,:seq_len], requires_grad=False)
        return x

In [22]:
class ScaleDotProductAttention(nn.Module):
    """
    Compute scale dot product attention

    Query : given sentence that we focused on (decoder)
    Key : every sentence to check relationship with Qeury(encoder)
    Value : every sentence same with Key (encoder)
    """
    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None):
        batch_size, head, length, d_tensor = k.size()
        k_t = k.transpose(2, 3)
        score = (q @ k_t) / torch.sqrt(torch.tensor(d_tensor))
        if mask is not None:
            score = score.masked_fill(mask == 0, -10000)
        score = self.softmax(score)
        v = score @ v
        return v, score

In [23]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_size, hidden_size, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_heads = n_heads
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(self.input_size, self.hidden_size)
        self.w_k = nn.Linear(self.input_size, self.hidden_size)
        self.w_v = nn.Linear(self.input_size, self.hidden_size)
        self.w_concat = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, q, k, v, mask=None):
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        q, k, v = self.split(q), self.split(k), self.split(v)
        out, attention = self.attention(q, k, v, mask=mask)
        out = self.concat(out)
        out = self.w_concat(out)
        return out

    def split(self, tensor):
        """
        split tensor by number of head

        :param tensor: [batch_size, length, d_model]
        :return: [batch_size, head, length, d_tensor]
        """
        batch_size, length, d_model = tensor.size()

        d_tensor = d_model // self.n_heads
        tensor = tensor.view(batch_size, length, self.n_heads, d_tensor).transpose(1, 2)
        # it is similar with group convolution (split by number of heads)
        return tensor

    def concat(self, tensor):
        """
        inverse function of self.split(tensor : torch.Tensor)

        :param tensor: [batch_size, head, length, d_tensor]
        :return: [batch_size, length, d_model]
        """
        batch_size, head, length, d_tensor = tensor.size()
        d_model = head * d_tensor

        tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
        return tensor

In [24]:
class EncoderLayer(nn.Module):
    def __init__(self, input_size, hidden_size, n_heads, drop_prob=0.1):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(input_size, hidden_size, n_heads)
        self.norm1 = nn.LayerNorm(input_size)
        self.dropout = nn.Dropout(p=drop_prob)

        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, input_size)
        self.relu = nn.ReLU()

        self.norm2 = nn.LayerNorm(input_size)

    def forward(self, x, src_mask):
        _x = x

        x = self.attention(q=x, k=x, v=x, mask=src_mask)
        x = self.dropout(x)
        
        x = self.norm1(x + _x)
        _x = x

        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)

        x = self.norm2(x + _x)
        return x

In [25]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, n_heads, n_layers, drop_prob=0.1):
        super().__init__()
        encoder_layers = []
        for _ in range(n_layers):
            layer = EncoderLayer(
                input_size=input_size,
                hidden_size=hidden_size,
                n_heads=n_heads,
                drop_prob=drop_prob
            )
            encoder_layers.append(layer)
        self.layers = nn.ModuleList(encoder_layers)

    def forward(self, x, src_mask):
        for layer in self.layers:
            x = layer(x, src_mask)
        return x

In [26]:
class BinaryClassificationTransformerModel(nn.Module):
    def __init__(self, ntoken: int, model_size: int = 128, n_heads: int = 4, 
    nlayers: int = 1, dropout: float = 0.1, maxlen: int = 512):
        super().__init__()
        self.model_size = model_size
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEmbedding(maxlen, model_size)
        self.emb = nn.Embedding(ntoken, model_size)
        self.transformer_encoder = Encoder(
            input_size=self.model_size, 
            hidden_size=self.model_size, 
            n_heads=n_heads, 
            n_layers=nlayers, 
            drop_prob=dropout
        )
        self.decoder = nn.Linear(model_size, 1)  # Bin classifier
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.emb.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        src = self.emb(src)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        pooled = output.mean(dim=1)
        output = self.decoder(pooled)
        return torch.sigmoid(output)

## Base

In [27]:
hidden_dim = 128
# hidden_dim = 768

model = BinaryClassificationTransformerModel(
    ntoken=30522, 
    model_size=hidden_dim, 
    n_heads=4, 
    # d_hid=4*hidden_dim,
    nlayers=6
)
# model.emb = bert_model.embeddings.word_embeddings
model = model.to(device)

In [28]:
count_parameters(model)

4504449

In [29]:
count_parameters(model.emb)

3906816

In [30]:
train_history = []
valid_history = []

N_EPOCHS = 10
CLIP = 1
learning_rate = 0.0001

best_valid_loss = float('inf')
early_stopping_counter = 0
early_stopping_criteria = 3
lr_on_plateau_update = 0.2
min_lr = 1e-6

# criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
%%time
st = time.time()
for epoch in range(N_EPOCHS):
    start_time = time.time()
    # train_loss = train_transformer(model, train_dataloader, optimizer, criterion, CLIP, train_history, valid_history)
    # valid_loss = evaluate_transformer(model, valid_dataloader, criterion)
    train_loss = train(device, model, train_dataloader, optimizer, criterion, CLIP, train_history, valid_history)
    valid_loss = evaluate(device, model, valid_dataloader, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'transformer.pt')
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        learning_rate = max(min_lr, learning_rate * lr_on_plateau_update)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        print(f"LR on pleataeu update. New LR: {learning_rate}")
    train_history.append(train_loss)
    valid_history.append(valid_loss)
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    if early_stopping_counter >= early_stopping_criteria:
        print(f"Early stopping, reached limit: {early_stopping_criteria}")
        break
tt_m, tt_s = epoch_time(st, time.time())

print(f'Total training time: {tt_m}m {tt_s}s')
model.load_state_dict(torch.load('transformer.pt'));

In [None]:
y_true, y_pred = predict(device, model, test_dataloader)
y_pred = list(map(round, y_pred))

acc = accuracy_score(y_true, y_pred)
f1_micro = f1_score(y_true, y_pred, average='micro')
f1_macro = f1_score(y_true, y_pred, average='macro')

print(f"Accuracy: {acc}\nF1(micro): {f1_micro}\nF1(macro): {f1_macro}")

807it [22:59,  1.71s/it]

Accuracy: 0.5
F1(micro): 0.5
F1(macro): 0.3333333333333333





## Reflection layer

In [None]:
class ReflectionLayer(nn.Module):
    def __init__(self, inut_size, drop_prob=0.1):
        super(ReflectionLayer, self).__init__()
        self.linear = nn.Linear(input_size, input_size)
        self.sigm = nn.Sigmoid()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear(x)
        gate = self.sigm(x)
        x = x * gate
        return x

In [None]:
class EncoderReflectionLayer(nn.Module):
    def __init__(self, input_size, hidden_size, n_heads, drop_prob=0.1):
        super(EncoderReflectionLayer, self).__init__()
        self.attention = MultiHeadAttention(input_size, hidden_size, n_heads)
        self.reflection = ReflectionLayer(input_size)
        self.norm1 = nn.LayerNorm(input_size)
        self.dropout = nn.Dropout(p=drop_prob)

        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, input_size)
        self.relu = nn.ReLU()

        self.norm2 = nn.LayerNorm(input_size)

    def forward(self, x, src_mask):
        _x = x

        x = self.reflection(x)
        x = self.attention(q=x, k=x, v=x, mask=src_mask)
        x = self.dropout(x)
        
        x = self.norm1(x + _x)
        _x = x

        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)

        x = self.norm2(x + _x)
        return x

In [None]:
class ReflectionEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, n_heads, n_layers, drop_prob=0.1):
        super().__init__()
        encoder_layers = []
        for _ in range(n_layers):
            layer = EncoderReflectionLayer(
                input_size=input_size,
                hidden_size=hidden_size,
                n_heads=n_heads,
                drop_prob=drop_prob
            )
            encoder_layers.append(layer)
        self.layers = nn.ModuleList(encoder_layers)

    def forward(self, x, src_mask):
        for layer in self.layers:
            x = layer(x, src_mask)
        return x

In [None]:
class BinaryClassificationReflectionTransformerModel(nn.Module):
    def __init__(self, ntoken: int, model_size: int = 128, n_heads: int = 4, 
    nlayers: int = 1, dropout: float = 0.1, maxlen: int = 512):
        super().__init__()
        self.model_size = model_size
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEmbedding(maxlen, model_size)
        self.emb = nn.Embedding(ntoken, model_size)
        self.transformer_encoder = ReflectionEncoder(
            input_size=self.model_size, 
            hidden_size=self.model_size, 
            n_heads=n_heads, 
            n_layers=nlayers, 
            drop_prob=dropout
        )
        self.decoder = nn.Linear(model_size, 1)  # Bin classifier
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.emb.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        src = self.emb(src)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        pooled = output.mean(dim=1)
        output = self.decoder(pooled)
        return torch.sigmoid(output)

In [None]:
hidden_dim = 128
# hidden_dim = 768

model = BinaryClassificationReflectionTransformerModel(
    ntoken=30522, 
    model_size=hidden_dim, 
    n_heads=4, 
    # d_hid=4*hidden_dim,
    nlayers=6
)
# model.emb = bert_model.embeddings.word_embeddings
model = model.to(device)

In [None]:
count_parameters(model)

4504449

In [None]:
count_parameters(model.emb)

3906816

In [None]:
train_history = []
valid_history = []

N_EPOCHS = 10
CLIP = 1
learning_rate = 0.0001

best_valid_loss = float('inf')
early_stopping_counter = 0
early_stopping_criteria = 3
lr_on_plateau_update = 0.2
min_lr = 1e-6

# criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
%%time
st = time.time()
for epoch in range(N_EPOCHS):
    start_time = time.time()
    # train_loss = train_transformer(model, train_dataloader, optimizer, criterion, CLIP, train_history, valid_history)
    # valid_loss = evaluate_transformer(model, valid_dataloader, criterion)
    train_loss = train(device, model, train_dataloader, optimizer, criterion, CLIP, train_history, valid_history)
    valid_loss = evaluate(device, model, valid_dataloader, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'transformer.pt')
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        learning_rate = max(min_lr, learning_rate * lr_on_plateau_update)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        print(f"LR on pleataeu update. New LR: {learning_rate}")
    train_history.append(train_loss)
    valid_history.append(valid_loss)
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    if early_stopping_counter >= early_stopping_criteria:
        print(f"Early stopping, reached limit: {early_stopping_criteria}")
        break
tt_m, tt_s = epoch_time(st, time.time())

print(f'Total training time: {tt_m}m {tt_s}s')
model.load_state_dict(torch.load('transformer.pt'));

In [None]:
y_true, y_pred = predict(device, model, test_dataloader)
y_pred = list(map(round, y_pred))

acc = accuracy_score(y_true, y_pred)
f1_micro = f1_score(y_true, y_pred, average='micro')
f1_macro = f1_score(y_true, y_pred, average='macro')

print(f"Accuracy: {acc}\nF1(micro): {f1_micro}\nF1(macro): {f1_macro}")

807it [22:59,  1.71s/it]

Accuracy: 0.5
F1(micro): 0.5
F1(macro): 0.3333333333333333





## HyperCube Layer

In [130]:
class HyperCubeLayer(nn.Module):
    __constants__ = ['in_features', 'out_sqrt_features']
    in_features: int
    out_sqrt_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_sqrt_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        hc_input_size = np.sqrt(in_features)
        assert hc_input_size % 1 == 0
        self.hc_input_size = hc_input_size = int(hc_input_size)
        self.in_features = in_features
        self.out_sqrt_features = out_sqrt_features  # No. of output features = out_sqrt_features * sqrt(in_features)
        self.weight = nn.Parameter(torch.empty((out_sqrt_features*hc_input_size, hc_input_size), **factory_kwargs))
        if bias:
            self.bias = nn.Parameter(torch.empty((out_sqrt_features*hc_input_size,), **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def extra_repr(self) -> str:
        return 'in_features={}, hc_input_size={}, out_sqrt_features={}, bias={}'.format(
            self.in_features, self.hc_input_size, self.out_sqrt_features, self.bias is not None
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view((*x.shape[:-1], self.hc_input_size, self.hc_input_size))
        # For the 1st hc should probably transpose, but I'm not sure
        x = x.expand((self.out_sqrt_features, *x.shape))
        x = x.movedim(0,-3)
        x = x.flatten(start_dim=-3, end_dim=-2)
        x = x * self.weight
        x = torch.sum(x, axis=-1) + self.bias
        return x

In [131]:
class HyperCubeBlock(nn.Module):
    def __init__(self, input_size, out_sqrt_features=None):
        if out_sqrt_features is None:
            out_sqrt_features = input_size
        super(HyperCubeBlock, self).__init__()
        self.hc_layers_1 = HyperCubeLayer(input_size, int(np.sqrt(input_size)))  # TODO: fix
        self.hc_layers_2 = HyperCubeLayer(input_size, out_sqrt_features)
            
    def forward(self, x):
        x = self.hc_layers_1(x)
        # x = x.transpose(1,2)  # !Check if needed
        x = self.hc_layers_2(x)
        return x

In [132]:
class MultiHeadAttentionHC(nn.Module):
    def __init__(self, input_size, hidden_size, n_heads):
        super(MultiHeadAttentionHC, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_heads = n_heads
        self.attention = ScaleDotProductAttention()
        self.hidden_size = self.input_size  # TMP second param is also input_size instead of hidden_size
        # self.w_q = HyperCubeBlock(self.input_size, self.hidden_size)
        # self.w_k = HyperCubeBlock(self.input_size, self.hidden_size)
        # self.w_v = HyperCubeBlock(self.input_size, self.hidden_size)
        self.w_q = HyperCubeBlock(self.input_size, int(np.sqrt(self.hidden_size)))
        self.w_k = HyperCubeBlock(self.input_size, int(np.sqrt(self.hidden_size)))
        self.w_v = HyperCubeBlock(self.input_size, int(np.sqrt(self.hidden_size)))
        self.w_concat = HyperCubeBlock(self.hidden_size, int(np.sqrt(self.hidden_size)))

    def forward(self, q, k, v, mask=None):
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        q, k, v = self.split(q), self.split(k), self.split(v)
        out, attention = self.attention(q, k, v, mask=mask)
        out = self.concat(out)
        out = self.w_concat(out)
        return out

    def split(self, tensor):
        """
        split tensor by number of head

        :param tensor: [batch_size, length, d_model]
        :return: [batch_size, head, length, d_tensor]
        """
        batch_size, length, d_model = tensor.size()

        d_tensor = d_model // self.n_heads
        tensor = tensor.view(batch_size, length, self.n_heads, d_tensor).transpose(1, 2)
        # it is similar with group convolution (split by number of heads)
        return tensor

    def concat(self, tensor):
        """
        inverse function of self.split(tensor : torch.Tensor)

        :param tensor: [batch_size, head, length, d_tensor]
        :return: [batch_size, length, d_model]
        """
        batch_size, head, length, d_tensor = tensor.size()
        d_model = head * d_tensor

        tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
        return tensor

In [133]:
class EncoderLayerHC(nn.Module):
    def __init__(self, input_size, hidden_size, n_heads, drop_prob=0.1):
        super(EncoderLayerHC, self).__init__()
        self.attention = MultiHeadAttentionHC(input_size, hidden_size, n_heads)
        self.norm1 = nn.LayerNorm(input_size)
        self.dropout = nn.Dropout(p=drop_prob)

        hidden_size = input_size  # TMP second param is also input_size instead of hidden_size
        # self.linear1 = HyperCubeBlock(input_size, hidden_size)
        # self.linear2 = HyperCubeBlock(hidden_size, input_size)
        self.linear1 = HyperCubeBlock(input_size, int(np.sqrt(hidden_size)))
        self.linear2 = HyperCubeBlock(hidden_size, int(np.sqrt(input_size)))
        self.relu = nn.ReLU()

        self.norm2 = nn.LayerNorm(input_size)

    def forward(self, x, src_mask):
        _x = x

        x = self.attention(q=x, k=x, v=x, mask=src_mask)
        x = self.dropout(x)
        
        x = self.norm1(x + _x)
        _x = x

        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)

        x = self.norm2(x + _x)
        return x

In [134]:
class EncoderHC(nn.Module):
    def __init__(self, input_size, hidden_size, n_heads, n_layers, drop_prob=0.1):
        super().__init__()
        encoder_layers = []
        for _ in range(n_layers):
            layer = EncoderLayerHC(
                input_size=input_size,
                hidden_size=hidden_size,
                n_heads=n_heads,
                drop_prob=drop_prob
            )
            encoder_layers.append(layer)
        self.layers = nn.ModuleList(encoder_layers)

    def forward(self, x, src_mask):
        for layer in self.layers:
            x = layer(x, src_mask)
        return x

In [135]:
class BinaryClassificationTransformerModelHC(nn.Module):
    def __init__(self, ntoken: int, model_size: int = 128, n_heads: int = 4, 
    nlayers: int = 1, dropout: float = 0.1, maxlen: int = 512):
        super().__init__()
        self.model_size = model_size
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEmbedding(maxlen, model_size)
        self.emb = nn.Embedding(ntoken, model_size)
        self.transformer_encoder = EncoderHC(
            input_size=self.model_size, 
            hidden_size=self.model_size, 
            n_heads=n_heads, 
            n_layers=nlayers, 
            drop_prob=dropout
        )
        self.decoder = nn.Linear(model_size, 1)  # Bin classifier
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.emb.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        src = self.emb(src)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        pooled = output.mean(dim=1)
        output = self.decoder(pooled)
        return torch.sigmoid(output)

In [136]:
hidden_dim = 256
# hidden_dim = 768

model = BinaryClassificationTransformerModelHC(
    ntoken=30522, 
    model_size=hidden_dim, 
    n_heads=4, 
    # d_hid=4*hidden_dim,
    nlayers=6
)
# model.emb = bert_model.embeddings.word_embeddings
model = model.to(device)

In [137]:
count_parameters(model)

39998465

In [138]:
count_parameters(model.emb)

31254528

In [None]:
train_history = []
valid_history = []

N_EPOCHS = 10
CLIP = 1
learning_rate = 0.0001

best_valid_loss = float('inf')
early_stopping_counter = 0
early_stopping_criteria = 3
lr_on_plateau_update = 0.2
min_lr = 1e-6

# criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
%%time
st = time.time()
for epoch in range(N_EPOCHS):
    start_time = time.time()
    # train_loss = train_transformer(model, train_dataloader, optimizer, criterion, CLIP, train_history, valid_history)
    # valid_loss = evaluate_transformer(model, valid_dataloader, criterion)
    train_loss = train(device, model, train_dataloader, optimizer, criterion, CLIP, train_history, valid_history)
    valid_loss = evaluate(device, model, valid_dataloader, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'transformer.pt')
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        learning_rate = max(min_lr, learning_rate * lr_on_plateau_update)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        print(f"LR on pleataeu update. New LR: {learning_rate}")
    train_history.append(train_loss)
    valid_history.append(valid_loss)
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    if early_stopping_counter >= early_stopping_criteria:
        print(f"Early stopping, reached limit: {early_stopping_criteria}")
        break
tt_m, tt_s = epoch_time(st, time.time())

print(f'Total training time: {tt_m}m {tt_s}s')
model.load_state_dict(torch.load('transformer.pt'));

In [None]:
y_true, y_pred = predict(device, model, test_dataloader)
y_pred = list(map(round, y_pred))

acc = accuracy_score(y_true, y_pred)
f1_micro = f1_score(y_true, y_pred, average='micro')
f1_macro = f1_score(y_true, y_pred, average='macro')

print(f"Accuracy: {acc}\nF1(micro): {f1_micro}\nF1(macro): {f1_macro}")

807it [22:59,  1.71s/it]

Accuracy: 0.5
F1(micro): 0.5
F1(macro): 0.3333333333333333





## HyperCube Layer V2

Optimised computation.

In [5]:
class HyperCubeLayer(nn.Module):
    __constants__ = ['in_features', 'out_sqrt_features']
    in_features: int
    out_sqrt_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_sqrt_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        hc_input_size = np.sqrt(in_features)
        assert hc_input_size % 1 == 0
        self.hc_input_size = hc_input_size = int(hc_input_size)
        self.in_features = in_features
        self.out_sqrt_features = out_sqrt_features  # No. of output features = out_sqrt_features * sqrt(in_features)
        self.weight = nn.Parameter(torch.empty((out_sqrt_features, hc_input_size, hc_input_size), **factory_kwargs))
        if bias:
            self.bias = nn.Parameter(torch.empty((out_sqrt_features,), **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def extra_repr(self) -> str:
        return 'in_features={}, hc_input_size={}, out_sqrt_features={}, bias={}'.format(
            self.in_features, self.hc_input_size, self.out_sqrt_features, self.bias is not None
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view((*x.shape[:-1], self.hc_input_size, self.hc_input_size))
        x = (x.movedim(1,2) @ self.weight).movedim(2,1) + self.bias
        x = x.flatten(start_dim=-2)
        return x

In [1]:
from transformers import AutoImageProcessor, ViTModel

model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

  from .autonotebook import tqdm as notebook_tqdm
Downloading: 100%|██████████| 502/502 [00:00<00:00, 167kB/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Downloading: 100%|██████████| 346M/346M [00:56<00:00, 6.17MB/s]   


In [2]:
model

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
      

In [5]:
count_parameters(model.embeddings)

742656

In [15]:
from transformers import MobileViTForImageClassification

In [20]:
model_ckpt = "apple/mobilevit-small"
mmodel = MobileViTForImageClassification.from_pretrained(model_ckpt)

Downloading: 100%|██████████| 70.0k/70.0k [00:00<00:00, 121kB/s] 
Downloading: 100%|██████████| 22.5M/22.5M [00:06<00:00, 3.38MB/s]


In [21]:
mmodel

MobileViTForImageClassification(
  (mobilevit): MobileViTModel(
    (conv_stem): MobileViTConvLayer(
      (convolution): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (normalization): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): SiLUActivation()
    )
    (encoder): MobileViTEncoder(
      (layer): ModuleList(
        (0): MobileViTMobileNetLayer(
          (layer): ModuleList(
            (0): MobileViTInvertedResidual(
              (expand_1x1): MobileViTConvLayer(
                (convolution): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (activation): SiLUActivation()
              )
              (conv_3x3): MobileViTConvLayer(
                (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)

In [22]:
count_parameters(mmodel)

5578632

# 