In [237]:
import math
import pandas as pd
import numpy as np
import copy
from einops import rearrange
from typing import List, Dict, Union
from argparse import Namespace

import torch
import torch.nn as nn
from torch import einsum
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

from operations.data import generate_dataset
from operations.data import generate_dataloader
from operations.embeds import Embedding
from operations.model import NewGELU
from operations.utils import generate_splits
from operations.utils import preprocess
from operations.utils import CutMix, Mixup

In [7]:
# read in data
data = pd.read_csv('data/creditcard.csv')

# generate split indices
sup_train_indices, val_indices, test_indices, ssl_train_indices = generate_splits(data.shape[0])

# preprocess data
df_proc, y_proc, no_num, no_cat, cats = preprocess(data.drop(columns=['Class']), data.Class, 0)

In [None]:
# generate train/val/test sets
train_df, train_y = df_proc.iloc[sup_train_indices], y_proc.iloc[sup_train_indices]
val_df, val_y = df_proc.iloc[val_indices], y_proc.iloc[val_indices]
test_df, test_y = df_proc.iloc[test_indices], y_proc.iloc[test_indices]

In [None]:
# where to store our train/val/test sets
config.train_csv_path = 'data/train/target/train_targets.csv'
config.val_csv_path = 'data/val/target/val_targets.csv'
config.test_csv_path = 'data/test/target/test_targets.csv'
config.train_y_csv = 'data/train/label/train_labels.csv'
config.val_y_csv = 'data/val/label/val_labels.csv'
config.test_y_csv = 'data/test/label/test_labels.csv'

# save the preprocessed data
train_df.to_csv(config.train_csv_path, index=False)
train_y.to_csv(config.train_y_csv, index=False)

val_df.to_csv(config.val_csv_path, index=False)
val_y.to_csv(config.val_y_csv, index=False)

test_df.to_csv(config.test_csv_path, index=False)
test_y.to_csv(config.test_y_csv, index=False)

In [9]:
# dataloader reads in files using their designated paths
train_dataset, val_dataset, test_dataset = generate_dataset(
                                            train_csv_path = config.train_csv_path,
                                            val_csv_path = config.val_csv_path,
                                            test_csv_path = config.test_csv_path,
                                            train_y_csv_path = config.train_y_csv,
                                            val_y_csv_path = config.val_y_csv,
                                            test_y_csv_path = config.test_y_csv)

# dictionary containing data paths that will be passed to the generate_dataloader class
data_paths = {
    "train_csv_path": config.train_csv_path,
    "val_csv_path": config.val_csv_path,
    "test_csv_path": config.test_csv_path,
    "train_y_csv_path": config.train_y_csv,
    "val_y_csv_path": config.val_y_csv,
    "test_y_csv_path": config.test_y_csv
}

# prepare our train, validation, and test loaders
train_loader, validation_loader, test_loader = generate_dataloader(train_bs=16, 
                                                                   val_bs=16, 
                                                                   num_workers=0, 
                                                                   data_paths=data_paths,
                                                                  )

In [108]:
# initial configuration
config.n_embd = 10
config.no_num = no_num
config.no_cat = no_cat
config.cats = cats
config.n_head = 2
config.resid_pdrop = 0.8
config.prob_cutmix = 0.3 # used in paper
config.mixup_alpha = 0.2 # used in paper
config.d_k = config.n_embd // config.n_head
config.scale = config.n_head ** -0.5

## Self Supervised Pre-Training
<p align="center">
    <img width="500" height="350" src="media/media.jpg">
</p>

SAINT implements contrastive pre-training, where the distance between two views of the same point is minimized while maximizing the distance between two different points. This strategy is coupled with denoising to perform pre-training on datasets with varied volumes of labeled data.

In [33]:
x = next(iter(train_loader))[0] # (16, 31)

The CutMix regularization strategy is used to augment samples in the input space, and mixup for samples in the embedding space. Specifically, mixup generates convex combinations of pairs of examples and their labels to regularize the NN to favor simple linear behaviour in-between training examples.

In [41]:
cut_mix = CutMix(config.prob_cutmix)
mix_up = Mixup(config.mixup_alpha)

Continous and categorical features are projected into the higher dimensional embedding space before being passed through the transformer blocks. A seperate single fully-connected layer with a ReLU nonlinearity is used for each continous feature to project the 1-dimensional input into d-dimensional space.

In [142]:
em_1 = Embedding(config.n_embd, config.no_num , config.no_cat, config.cats) # +1 to account for addition of <cls> token
em_2 = Embedding(config.n_embd, config.no_num , config.no_cat, config.cats)

# embed batch
pi = em_1(x)
# embed cutmixed batch
pi_prime_em = em_2(cut_mix(x))
# mixup embedded cutmixed batch
pi_prime = mix_up(pi_prime_em)

## SAINT Architecture

<p align="center">
    <img width="255" height="200" src="media/saint_block.jpg">
</p>
Each layer has two attention blocks: one self-attention block, and one intersample attention block. The former is identical to the transformer block proposed by Vaswani et al., where the model takes in a sequence of feature embeddings and ouputs contextual representations of the same dimension. The latter uses intersample attention in lieu of self-attention, that being the only difference in architecture between the two blocks.

In [219]:
class SelfAttention(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.to_qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = nn.MultiheadAttention(config.n_embd, num_heads=config.n_head)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj  = nn.Linear(4 * config.n_embd, config.n_embd),
            act     = NewGELU(),
            dropout = nn.Dropout(config.resid_pdrop),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward

    def forward(self, x):
        q, k, v = to_qkv(ln_1(x)).chunk(3, dim=-1)
        attn_output, attn_mask = self.attn(q, k, v)
        
        x = x + attn_output
        x = x + self.mlpf(self.ln_2(x))
        return x
    
self_attn_block = SelfAttention(config)
z_almost = self_attn_block(pi_prime)

Intersample attention computes attention over samples rather than features.
<p align="center">
    <img width="1000" height="1000" src="media/intersample_attention.jpg">
</p>


In [236]:
def attention(query, key, value, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    # query: bs, n, embed_dim
    # key: bs, n, embed_dim
    # value: bs, n, embed_dim
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)            # bs , n , n
    output = torch.matmul(p_attn, value)    # bs, n , embed_dim
    return output, p_attn


def intersample(query , key , value,dropout=None):
    "Calculate the intersample of a given query batch" 
    #x , bs , n , d 
    b, h, n , d = query.shape
    #print(query.shape,key.shape, value.shape )
    query , key , value = query.reshape(1, b, h, n*d), \
                            key.reshape(1, b, h, n*d), \
                                value.reshape(1, b, h, n*d)

    output, _ = attention(query, key ,value)  #1 , b, n*d
    output = output.squeeze(0) #b, n*d
    output = output.reshape(b, h, n, d) #b,n,d

    return output


class MultiHeadedIntersampleAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedIntersampleAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.to_qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.d_k = d_model // h
        self.h = h
        self.linears = nn.ModuleList([copy.deepcopy(
                                nn.Linear(d_model, d_model)) for _ in range(4)]
                                    )
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
     
    def forward(self, x):
        "Implements Figure 2"
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        nbatches = q.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (q, k, v))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x = intersample(query, key, value, 
                                 dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) # bs , n , d_model
        return self.linears[-1](x)  # bs , n , d_model
    
intersample_attn_block = MultiHeadedIntersampleAttention(config.n_head, config.n_embd)
z1 = intersample_attn_block(z_almost) + a_almost