In [1]:
import torch
import torch.nn as nn
import pandas as pd

from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Using Attention
Actually the data is not continuous so it should not use attention. But I dont care.

### Prepare the dataset

In [3]:
train_dataset = pd.read_csv('data/train_dataset.csv')
test_dataset = pd.read_csv('data/test_dataset.csv')

data_columns = ["V"+str(i) for i in range(1,29)]+["Amount"]
label_column ="Class"

X_train = train_dataset[data_columns]
X_test  = test_dataset[data_columns]

y_train = train_dataset[label_column]
y_test  = test_dataset[label_column]

x_train_tensor = torch.from_numpy(X_train.values).to(device)
y_train_tensor = torch.from_numpy(y_train.values).to(device)

x_test_tensor = torch.from_numpy(X_test.values).to(device)
y_test_tensor = torch.from_numpy(y_test.values).to(device)

Train_tensor = TensorDataset(x_train_tensor, y_train_tensor)
Test_tensor = TensorDataset(x_test_tensor, y_test_tensor)

Train_dataset = DataLoader(Train_tensor, batch_size=512, shuffle=True)
Test_dataset = DataLoader(Test_tensor, batch_size=512, shuffle=True)

### Simple attention block
This is the simple structure of attention block. There is not multi-heads or mask. If you want that, simply using torch.view() and torch.maskfilled

In [14]:
class AttentionBlock(nn.Module):
    def __init__(self, num_heads):
        super(AttentionBlock, self).__init__()

        self.num_heads = num_heads
        
        self.W_Q = nn.Linear(num_heads, num_heads)
        self.W_K = nn.Linear(num_heads, num_heads)
        self.W_V = nn.Linear(num_heads, num_heads)
        
    def forward(self, query, key, value):
        
        # softmax (QK^T/sqrt(d_k))V
        Q = self.W_Q(query)
        K = self.W_K(key)
        V = self.W_V(value)
        
        sm = torch.matmul(Q, K.transpose(-1,-2)) / (self.num_heads ** 0.5)
        
        out = torch.matmul(torch.softmax(sm, dim=-1),V)
        
        return out
        
        
        

### An encoder block in transformer architecture

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, num_heads, num_layers):
        super(EncoderBlock, self).__init__()
        
        self.num_heads = num_heads
        self.num_layers = num_layers
        
        self.attention_blocks = AttentionBlock(num_heads)
        
        self.ffn = nn.Sequential(
            nn.Linear(num_heads, num_heads),
            nn.ReLU(),
            nn.Linear(num_heads, num_heads)
        )
        self.ln1 = nn.LayerNorm(num_heads)
        self.ln2 = nn.LayerNorm(num_heads)
        
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)
        
    def forward(self, x):
        
        temp = x
        x = self.attention_blocks(x,x,x)
        x += temp
        x = self.ln1(x)
        x = self.dropout1(x)
        
        temp = x
        
        x = self.ffn(x)
        x += temp
        x = self.ln2(x)
        x = self.dropout2(x)
        
        return x

### BERT-like architecture + Classification

In [18]:
class AttentionClassification(nn.Module):
    def __init__(self,input_dim, num_heads, num_layers):
        super(AttentionClassification, self).__init__()
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.input = nn.Linear(input_dim, num_heads)
        self.encoder_blocks = nn.ModuleList([EncoderBlock(num_heads, num_layers) for _ in range(num_layers)])
        self.fc_out = nn.Linear(num_heads, 1)
        
    def forward(self, x):
        x = self.inpu(x)
        # A super mini BERT
        for encoder in self.encoder_blocks:
            x = encoder(x)
        x = self.fc_out(x)
        return x

### Parameter

In [19]:
net = AttentionClassification(29,32, 4).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
epochs = 100

### Training

In [20]:
for epoch in range(epochs):
    for x, y in tqdm(Train_dataset):
        optimizer.zero_grad()
        output = net(x.float())

        loss = criterion(output, y.float().view(-1, 1))
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        correct = 0
        total = 0
        for x, y in tqdm(Test_dataset):
            output = net(x.float())
            predicted = (output >= 0.5).float()
            total += y.size(0)
            correct += (predicted == y.float().view(-1, 1)).sum().item()
        print(f"Epoch {epoch} Accuracy: {correct*100/total:.4f}%")

100%|██████████| 889/889 [00:20<00:00, 44.16it/s]
100%|██████████| 223/223 [00:02<00:00, 83.17it/s]


Epoch 0 Accuracy: 93.5969%


100%|██████████| 889/889 [00:19<00:00, 46.25it/s]
100%|██████████| 223/223 [00:02<00:00, 83.76it/s]


Epoch 1 Accuracy: 95.6492%


100%|██████████| 889/889 [00:19<00:00, 44.74it/s]
100%|██████████| 223/223 [00:02<00:00, 82.68it/s]


Epoch 2 Accuracy: 96.0053%


100%|██████████| 889/889 [00:18<00:00, 46.96it/s]
100%|██████████| 223/223 [00:02<00:00, 85.56it/s]


Epoch 3 Accuracy: 96.1891%


100%|██████████| 889/889 [00:19<00:00, 46.72it/s]
100%|██████████| 223/223 [00:02<00:00, 95.23it/s]


Epoch 4 Accuracy: 96.3377%


100%|██████████| 889/889 [00:18<00:00, 47.74it/s]
100%|██████████| 223/223 [00:02<00:00, 89.37it/s]


Epoch 5 Accuracy: 96.4784%


100%|██████████| 889/889 [00:19<00:00, 45.75it/s]
100%|██████████| 223/223 [00:02<00:00, 88.59it/s]


Epoch 6 Accuracy: 96.6235%


100%|██████████| 889/889 [00:20<00:00, 44.01it/s]
100%|██████████| 223/223 [00:02<00:00, 87.69it/s]


Epoch 7 Accuracy: 96.6490%


100%|██████████| 889/889 [00:19<00:00, 45.61it/s]
100%|██████████| 223/223 [00:02<00:00, 91.36it/s]


Epoch 8 Accuracy: 96.7642%


100%|██████████| 889/889 [00:18<00:00, 47.14it/s]
100%|██████████| 223/223 [00:02<00:00, 89.78it/s]


Epoch 9 Accuracy: 96.8116%


100%|██████████| 889/889 [00:18<00:00, 47.49it/s]
100%|██████████| 223/223 [00:02<00:00, 91.31it/s]


Epoch 10 Accuracy: 96.8450%


100%|██████████| 889/889 [00:18<00:00, 47.05it/s]
100%|██████████| 223/223 [00:02<00:00, 84.31it/s]


Epoch 11 Accuracy: 96.8741%


100%|██████████| 889/889 [00:18<00:00, 47.23it/s]
100%|██████████| 223/223 [00:02<00:00, 89.16it/s]


Epoch 12 Accuracy: 96.9444%


100%|██████████| 889/889 [00:18<00:00, 48.06it/s]
100%|██████████| 223/223 [00:02<00:00, 90.26it/s]


Epoch 13 Accuracy: 97.0499%


100%|██████████| 889/889 [00:18<00:00, 47.93it/s]
100%|██████████| 223/223 [00:02<00:00, 90.16it/s]


Epoch 14 Accuracy: 97.2020%


100%|██████████| 889/889 [00:18<00:00, 47.69it/s]
100%|██████████| 223/223 [00:02<00:00, 89.16it/s]


Epoch 15 Accuracy: 97.3251%


100%|██████████| 889/889 [00:18<00:00, 47.74it/s]
100%|██████████| 223/223 [00:02<00:00, 84.53it/s]


Epoch 16 Accuracy: 97.4016%


100%|██████████| 889/889 [00:18<00:00, 47.17it/s]
100%|██████████| 223/223 [00:02<00:00, 82.52it/s]


Epoch 17 Accuracy: 97.5098%


100%|██████████| 889/889 [00:19<00:00, 46.23it/s]
100%|██████████| 223/223 [00:02<00:00, 86.45it/s]


Epoch 18 Accuracy: 97.6654%


100%|██████████| 889/889 [00:18<00:00, 47.38it/s]
100%|██████████| 223/223 [00:02<00:00, 88.39it/s]


Epoch 19 Accuracy: 97.7349%


100%|██████████| 889/889 [00:18<00:00, 46.84it/s]
100%|██████████| 223/223 [00:02<00:00, 85.69it/s]


Epoch 20 Accuracy: 97.7463%


100%|██████████| 889/889 [00:18<00:00, 47.36it/s]
100%|██████████| 223/223 [00:02<00:00, 85.06it/s]


Epoch 21 Accuracy: 97.9178%


100%|██████████| 889/889 [00:18<00:00, 48.69it/s]
100%|██████████| 223/223 [00:02<00:00, 88.23it/s]


Epoch 22 Accuracy: 97.8826%


100%|██████████| 889/889 [00:17<00:00, 49.56it/s]
100%|██████████| 223/223 [00:02<00:00, 88.31it/s]


Epoch 23 Accuracy: 98.0488%


100%|██████████| 889/889 [00:17<00:00, 49.52it/s]
100%|██████████| 223/223 [00:02<00:00, 94.71it/s]


Epoch 24 Accuracy: 97.8914%


100%|██████████| 889/889 [00:17<00:00, 49.54it/s]
100%|██████████| 223/223 [00:02<00:00, 93.32it/s]


Epoch 25 Accuracy: 98.0945%


100%|██████████| 889/889 [00:17<00:00, 49.95it/s]
100%|██████████| 223/223 [00:02<00:00, 89.08it/s]


Epoch 26 Accuracy: 98.3082%


100%|██████████| 889/889 [00:17<00:00, 49.69it/s]
100%|██████████| 223/223 [00:02<00:00, 92.94it/s]


Epoch 27 Accuracy: 98.2563%


100%|██████████| 889/889 [00:18<00:00, 49.34it/s]
100%|██████████| 223/223 [00:02<00:00, 93.53it/s]


Epoch 28 Accuracy: 98.3539%


100%|██████████| 889/889 [00:17<00:00, 49.57it/s]
100%|██████████| 223/223 [00:02<00:00, 92.07it/s]


Epoch 29 Accuracy: 98.4366%


100%|██████████| 889/889 [00:17<00:00, 49.94it/s]
100%|██████████| 223/223 [00:02<00:00, 90.31it/s]


Epoch 30 Accuracy: 98.4058%


100%|██████████| 889/889 [00:18<00:00, 48.73it/s]
100%|██████████| 223/223 [00:02<00:00, 78.57it/s]


Epoch 31 Accuracy: 98.5043%


100%|██████████| 889/889 [00:20<00:00, 43.69it/s]
100%|██████████| 223/223 [00:02<00:00, 82.94it/s]


Epoch 32 Accuracy: 98.6028%


100%|██████████| 889/889 [00:18<00:00, 47.25it/s]
100%|██████████| 223/223 [00:02<00:00, 86.47it/s]


Epoch 33 Accuracy: 98.5597%


100%|██████████| 889/889 [00:18<00:00, 47.45it/s]
100%|██████████| 223/223 [00:02<00:00, 87.31it/s]


Epoch 34 Accuracy: 98.6406%


100%|██████████| 889/889 [00:18<00:00, 47.40it/s]
100%|██████████| 223/223 [00:02<00:00, 90.18it/s]


Epoch 35 Accuracy: 98.6538%


100%|██████████| 889/889 [00:18<00:00, 47.64it/s]
100%|██████████| 223/223 [00:02<00:00, 88.84it/s]


Epoch 36 Accuracy: 98.6819%


100%|██████████| 889/889 [00:18<00:00, 47.83it/s]
100%|██████████| 223/223 [00:02<00:00, 88.05it/s]


Epoch 37 Accuracy: 98.7655%


100%|██████████| 889/889 [00:19<00:00, 44.86it/s]
100%|██████████| 223/223 [00:02<00:00, 85.13it/s]


Epoch 38 Accuracy: 98.7057%


100%|██████████| 889/889 [00:19<00:00, 46.03it/s]
100%|██████████| 223/223 [00:02<00:00, 80.42it/s]


Epoch 39 Accuracy: 98.7382%


100%|██████████| 889/889 [00:19<00:00, 45.43it/s]
100%|██████████| 223/223 [00:02<00:00, 80.13it/s]


Epoch 40 Accuracy: 98.7549%


100%|██████████| 889/889 [00:19<00:00, 45.89it/s]
100%|██████████| 223/223 [00:02<00:00, 88.04it/s]


Epoch 41 Accuracy: 98.7699%


100%|██████████| 889/889 [00:19<00:00, 46.00it/s]
100%|██████████| 223/223 [00:02<00:00, 80.17it/s]


Epoch 42 Accuracy: 98.7769%


100%|██████████| 889/889 [00:20<00:00, 43.55it/s]
100%|██████████| 223/223 [00:02<00:00, 79.90it/s]


Epoch 43 Accuracy: 98.8534%


100%|██████████| 889/889 [00:19<00:00, 45.94it/s]
100%|██████████| 223/223 [00:02<00:00, 78.76it/s]


Epoch 44 Accuracy: 98.8947%


100%|██████████| 889/889 [00:19<00:00, 44.99it/s]
100%|██████████| 223/223 [00:02<00:00, 83.52it/s]


Epoch 45 Accuracy: 98.8217%


100%|██████████| 889/889 [00:19<00:00, 46.30it/s]
100%|██████████| 223/223 [00:02<00:00, 88.81it/s]


Epoch 46 Accuracy: 98.9316%


100%|██████████| 889/889 [00:20<00:00, 43.53it/s]
100%|██████████| 223/223 [00:02<00:00, 85.80it/s]


Epoch 47 Accuracy: 98.8806%


100%|██████████| 889/889 [00:20<00:00, 44.09it/s]
100%|██████████| 223/223 [00:02<00:00, 82.44it/s]


Epoch 48 Accuracy: 98.8973%


100%|██████████| 889/889 [00:19<00:00, 44.87it/s]
100%|██████████| 223/223 [00:02<00:00, 76.79it/s]


Epoch 49 Accuracy: 98.8569%


100%|██████████| 889/889 [00:20<00:00, 44.16it/s]
100%|██████████| 223/223 [00:02<00:00, 88.88it/s]


Epoch 50 Accuracy: 98.8833%


100%|██████████| 889/889 [00:19<00:00, 45.67it/s]
100%|██████████| 223/223 [00:02<00:00, 86.81it/s]


Epoch 51 Accuracy: 99.0205%


100%|██████████| 889/889 [00:20<00:00, 43.63it/s]
100%|██████████| 223/223 [00:02<00:00, 81.40it/s]


Epoch 52 Accuracy: 98.9176%


100%|██████████| 889/889 [00:20<00:00, 43.45it/s]
100%|██████████| 223/223 [00:02<00:00, 79.95it/s]


Epoch 53 Accuracy: 98.9255%


100%|██████████| 889/889 [00:20<00:00, 43.05it/s]
100%|██████████| 223/223 [00:02<00:00, 76.63it/s]


Epoch 54 Accuracy: 98.9431%


100%|██████████| 889/889 [00:19<00:00, 46.23it/s]
100%|██████████| 223/223 [00:02<00:00, 84.98it/s]


Epoch 55 Accuracy: 99.0284%


100%|██████████| 889/889 [00:20<00:00, 44.00it/s]
100%|██████████| 223/223 [00:02<00:00, 83.17it/s]


Epoch 56 Accuracy: 99.0266%


100%|██████████| 889/889 [00:19<00:00, 45.72it/s]
100%|██████████| 223/223 [00:02<00:00, 86.54it/s]


Epoch 57 Accuracy: 99.0635%


100%|██████████| 889/889 [00:19<00:00, 45.07it/s]
100%|██████████| 223/223 [00:02<00:00, 87.64it/s]


Epoch 58 Accuracy: 98.9747%


100%|██████████| 889/889 [00:20<00:00, 43.87it/s]
100%|██████████| 223/223 [00:02<00:00, 85.42it/s]


Epoch 59 Accuracy: 99.0407%


 57%|█████▋    | 508/889 [00:11<00:08, 43.64it/s]


KeyboardInterrupt: 