In [1]:
import math
import pandas as pd
import numpy as np
import copy
from einops import rearrange
from typing import List, Dict, Union
from argparse import Namespace

import torch
import torch.nn as nn
from torch import einsum
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

from operations.data import generate_dataset
from operations.data import generate_dataloader
from operations.embeds import Embedding
from operations.model import NewGELU
from operations.utils import generate_splits
from operations.utils import preprocess
from operations.utils import CutMix, Mixup

from sklearn.base import TransformerMixin
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import StandardScaler

In [2]:
# create dictionary for configuration settings
config = Namespace()

# where to store our train/val/test sets
config.train_csv_path = 'data/train/target/train_targets.csv'
config.train_y_csv_path = 'data/train/label/train_labels.csv'

config.val_csv_path = 'data/val/target/val_targets.csv'
config.val_y_csv_path = 'data/val/label/val_labels.csv'

config.test_csv_path = 'data/test/target/test_targets.csv'
config.test_y_csv_path = 'data/test/label/test_labels.csv'

In [3]:
# read in data
data = pd.read_csv('data/creditcard.csv')

# generate split indices
sup_train_indices, val_indices, test_indices, ssl_train_indices = generate_splits(data.shape[0])

# preprocess data
df_proc, y_proc, no_num, no_cat, cats = preprocess(data.drop(columns=['Class']), data.Class, 0)

In [4]:
# generate train/val/test sets
train_df, train_y = df_proc.iloc[sup_train_indices], y_proc.iloc[sup_train_indices]
val_df, val_y = df_proc.iloc[val_indices], y_proc.iloc[val_indices]
test_df, test_y = df_proc.iloc[test_indices], y_proc.iloc[test_indices]

In [5]:
# dataloader reads in files using their designated paths
train_dataset, val_dataset, test_dataset = generate_dataset(
                                            train_csv_path = config.train_csv_path,
                                            val_csv_path = config.val_csv_path,
                                            test_csv_path = config.test_csv_path,
                                            train_y_csv_path = config.train_y_csv_path,
                                            val_y_csv_path = config.val_y_csv_path,
                                            test_y_csv_path = config.test_y_csv_path)


# prepare our train, validation, and test loaders
train_loader, validation_loader, test_loader = generate_dataloader(train_bs=16, 
                                                                   val_bs=16, 
                                                                   num_workers=0, 
                                                                   data_paths=vars(config),
                                                                  )

In [33]:
config.n_embd = 10
config.no_num = no_num
config.no_cat = no_cat
config.cats = cats
config.n_head = 2
config.resid_pdrop = 0.8
config.prob_cutmix = 0.3 # used in paper
config.mixup_alpha = 0.2 # used in paper
config.d_k = config.n_embd // config.n_head
config.scale = config.n_head ** -0.5
config.d_v = 10
config.dim_head = 16
config.inner_dim = config.n_head * config.dim_head
config.d_model = no_num + no_cat
config.mask = None
config.alpha = 1.0
config.attn_pdrop = 0.1

In [8]:
x, y = next(iter(train_loader)) # (16, 31)

In [9]:
class Xi_Pi(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.cut_mix = CutMix(config)
        self.mix_up = Mixup(config)
        
        self.em_1 = Embedding(config.n_embd, config.no_num, config.no_cat, config.cats)
        self.em_2 = Embedding(config.n_embd, config.no_num, config.no_cat, config.cats)
        
    def forward(self, x):
        # embed batch
        pi = self.em_1(x)
        # embed cutmixed batch
        pi_prime_em = self.em_2(self.cut_mix(x))
        # mixup embedded cutmixed batch
        pi_prime = self.mix_up(pi_prime_em)
        
        return pi, pi_prime

In [65]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        # linear projections of the queries, keys and values h times to d_k, d_k, and d_v respectively
        self.to_qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.w_qs = nn.Linear(config.n_embd, config.n_head * config.d_k)
        self.w_ks = nn.Linear(config.n_embd, config.n_head * config.d_k)
        self.w_vs = nn.Linear(config.n_embd, config.n_head * config.n_embd)
        
        # initialize weights with values drawn from the normal distribution
        nn.init.normal_(self.to_qkv.weight, mean=0, std=np.sqrt(2.0 / (config.d_model + config.n_embd)))
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (config.d_model + config.d_k)))
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (config.d_model + config.d_k)))
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (config.d_model + config.n_embd)))
        
        # linear projection after attention computation
        self.fc = nn.Linear(config.n_head * config.n_embd, config.n_embd) 
        # Xavier initialization for fc layer
        nn.init.xavier_normal_(self.fc.weight)
        
        # regularization
        self.layer_norm = nn.LayerNorm(config.n_embd)
        self.dropout = nn.Dropout(p=config.attn_pdrop)
        
    def forward(self, x, mask=None):
        # query, key, values for all heads in a batch
        q, k, v = self.to_qkv(x).chunk(3, dim=-1) # (B, T, C)
        
        # residual connection
        residual = q
        q = rearrange(self.w_qs(q), 'b l (head k) -> head b l k', head=config.n_head) # (hs, B, T, C//hs)
        k = rearrange(self.w_ks(k), 'b l (head k) -> head b l k', head=config.n_head) # (hs, B, T, C//hs)
        v = rearrange(self.w_vs(v), 'b l (head k) -> head b l k', head=config.n_head) # (hs, B, T, C//hs) 

        # compute attention 
        attn = torch.einsum('h b l k, h b t k -> h b l t', [q, k]) / np.sqrt(q.shape[-1]) # (hs, B, T, T)
        if mask is not None:
            attn = attn.masked_fill(mask[None], -np.inf)    
        attn = torch.softmax(attn, dim=3)  # (hs, B, T, T)

        # compute output
        output = torch.einsum('h b l t, h b t v -> h b l v', [attn, v]) # (hs, B, T, C)
        output = rearrange(output, 'head b l v -> b l (head v)') # (B, T, C)

        # apply dropout to linearly projected output
        output = self.dropout(self.fc(output))
        # apply layer normalization to sum of output and residual connection
        output = self.layer_norm(output + residual)
        return output, attn

In [191]:
class IntersampleAttention(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        
        # linear projections of the queries, keys and values h times to d_k, d_k, and d_v respectively
        self.to_qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.w_qs = nn.Linear(config.n_embd, config.n_head * config.d_k)
        self.w_ks = nn.Linear(config.n_embd, config.n_head * config.d_k)
        self.w_vs = nn.Linear(config.n_embd, config.n_head * config.n_embd)
        
        # initialize weights with values drawn from the normal distribution
        nn.init.normal_(self.to_qkv.weight, mean=0, std=np.sqrt(2.0 / (config.d_model + config.n_embd)))
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (config.d_model + config.d_k)))
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (config.d_model + config.d_k)))
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (config.d_model + config.n_embd)))
        
        # linear projection after attention computation
        self.fc = nn.Linear(config.n_embd, config.n_embd) 
        # Xavier initialization for fc layer
        nn.init.xavier_normal_(self.fc.weight)
        
        # regularization
        self.layer_norm = nn.LayerNorm(config.n_embd)
        self.dropout = nn.Dropout(p=config.attn_pdrop)
        
    def forward(self, x):
        q, k, v = self.to_qkv(pi_attended).chunk(3, dim=-1)
        
        # residual connection
        residual = q
        q = rearrange(w_qs(q), 'b w (d h) -> () b h (w d)', h=config.n_head) # (1, B, h, (T * C) // 2) 
        k = rearrange(w_ks(k), 'b w (d h) -> () b h (w d)', h=config.n_head) # (1, B, h, (T * C) // 2) 
        v = rearrange(w_vs(v), 'b w (d h) -> () b h (w d)', h=config.n_head) # (1, B, h, (T * C) // 2) 

        # compute attention
        attn = torch.einsum('h b l k, h b t k -> h b l t', [q, k]) / np.sqrt(q.shape[-1]) # (1, B, h, h)
        attn = torch.softmax(attn, dim=3) # (1, B, h, h)
        
        # compute output
        output = torch.einsum('h b l t, h b t v -> h b l v', [attn, v]) # (1, B, h, T * C)
        output = output.view(16, 31, 10)
        
        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)
        return output, attn

In [210]:
class FeedForward(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        
        self.feed_forward = nn.ModuleDict(dict(
            proj_1 = nn.Linear(config.n_embd, 20),
            proj_2 = nn.Linear(20, config.n_embd),
            dropout = nn.Dropout(0.1),
            activation = NewGELU()
            ))

        m = self.feed_forward

        self.mlpf = lambda x: m.proj_2(m.dropout(m.activation(m.proj_1(x))))
        
    def forward(self, x):
        return self.mlpf(x)

In [222]:
class SaintPipeline(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.layer_norm = nn.LayerNorm(config.n_embd)
        self.multihead_attention = MultiHeadAttention(config)
        self.FF1 = FeedForward(config)
        self.MISA = IntersampleAttention(config)
        self.FF2 = FeedForward(config)
        
    def forward(self, x):

        # compute multi-head attention
        z1 = self.layer_norm(x_attn) + x_attn
        z2 = self.layer_norm(self.FF1(z1)) + z1
        z2_attn, z2_attn_mask = self.MISA(z2)
        z3 = self.layer_norm(z2_attn) + z2
        r = self.layer_norm(self.FF2(z3)) + z3
        return r
    
sp = SaintPipeline(config)