In [1]:
import torch
import torch.nn as nn
import pandas as pd

from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Using VAE for self-supervised task

The main idea of this approach is training unsupervised VAE by generating data similar to original. Then the output of VAE's encoder in latent space will go through a classification layer.

### Prepare the dataset

In [3]:
train_dataset = pd.read_csv('data/train_dataset.csv')
test_dataset = pd.read_csv('data/test_dataset.csv')

data_columns = ["V"+str(i) for i in range(1,29)]+["Amount"]
label_column ="Class"

X_train = train_dataset[data_columns]
X_test  = test_dataset[data_columns]

y_train = train_dataset[label_column]
y_test  = test_dataset[label_column]

In [5]:
x_train_tensor = torch.from_numpy(X_train.values).to(device)
y_train_tensor = torch.from_numpy(y_train.values).to(device)

x_test_tensor = torch.from_numpy(X_test.values).to(device)
y_test_tensor = torch.from_numpy(y_test.values).to(device)

Train_tensor = TensorDataset(x_train_tensor, y_train_tensor)
Test_tensor = TensorDataset(x_test_tensor, y_test_tensor)

Train_dataset = DataLoader(Train_tensor, batch_size=512, shuffle=True)
Test_dataset = DataLoader(Test_tensor, batch_size=512, shuffle=True)

## The Models

### VAE
This is the simplest version of VAE. You can add more layers, CNNs if you like.

In [50]:
class VAE(nn.Module):
    def __init__(self,input_dim=29, hidden_dim=20, latent_dim=6):
        super(VAE, self).__init__()
        
        # Encoder
        self.ln1_encoder = nn.Linear(input_dim, hidden_dim)
        
        self.mu_encoder = nn.Linear(hidden_dim, latent_dim)
        self.logvar_encoder = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.ln1_decoder = nn.Linear(latent_dim, hidden_dim)
        self.ln2_decoder = nn.Linear(hidden_dim, input_dim)
        
        self.activation = nn.ReLU()
        
    def encode(self, x):
        h1 = self.activation(self.ln1_encoder(x))
        return self.mu_encoder(h1), self.logvar_encoder(h1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decoder(self, z):
        h1 = self.activation(self.ln1_decoder(z))
        return torch.sigmoid(self.ln2_decoder(h1))
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

### ClassificationModel

In [51]:
class Classification(nn.Module):
    def __init__(self, input_dim=6):
        super(Classification, self).__init__()
        self.ln1 = nn.Linear(input_dim, input_dim)
        self.output = nn.Linear(input_dim, 1)
        self.activation = nn.ReLU()
    def forward(self, x):
        x = self.activation(self.ln1(x))
        return self.output(x)
                            

### The combination of 2 models

In [62]:
class VAE_Classification(nn.Module):
    def __init__(self, vae, classifier):
        super(VAE_Classification, self).__init__()
        self.vae = vae
        self.classifier = classifier
    
    def forward(self, x,):
    
        mu, logvar = self.vae.encode(x)
        return self.classifier(self.vae.reparameterize(mu,logvar))

### Loss funtion for VAE
$$
L_{loss} = L_{recon} + L_{D_{KL}}
$$

In [53]:
def VAEloss(x, x_hat, mu, logvar):
    BCE = nn.MSELoss()(x_hat, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),dim=1)
    return BCE + KLD.mean()

## Parameter

In [56]:
epochs = 100

#opitmizer = torch.optim.Adam(net.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

## Training

### Training VAE

In [58]:
vae = VAE().to(device)
opitmizer = torch.optim.Adam(vae.parameters(), lr=0.001)

In [60]:
for epoch in range(epochs):
    losses = 0
    for x, y in tqdm(Train_dataset):
        opitmizer.zero_grad()
        x_hat, mu, log_var = vae(x.float())
        loss = VAEloss(x.float(), x_hat, mu, log_var)
        # loss = criterion(output, y.float().view(-1, 1))
        # loss += kl
        loss.backward()
        opitmizer.step()
        losses += loss.item()
    print(f"Epoch {epoch} Loss: {losses/len(Train_dataset)}")
        

100%|██████████| 889/889 [00:05<00:00, 148.84it/s]


Epoch 0 Loss: 12572.512447052517


100%|██████████| 889/889 [00:05<00:00, 152.21it/s]


Epoch 1 Loss: 11067.14281485078


100%|██████████| 889/889 [00:05<00:00, 150.49it/s]


Epoch 2 Loss: 10773.97712657744


100%|██████████| 889/889 [00:05<00:00, 157.33it/s]


Epoch 3 Loss: 10671.828011854963


100%|██████████| 889/889 [00:05<00:00, 154.98it/s]


Epoch 4 Loss: 10596.135566977468


100%|██████████| 889/889 [00:05<00:00, 152.58it/s]


Epoch 5 Loss: 10552.543262048299


100%|██████████| 889/889 [00:05<00:00, 156.48it/s]


Epoch 6 Loss: 10518.025304393102


100%|██████████| 889/889 [00:05<00:00, 156.83it/s]


Epoch 7 Loss: 10488.944529821956


100%|██████████| 889/889 [00:06<00:00, 142.05it/s]


Epoch 8 Loss: 10469.018691999438


100%|██████████| 889/889 [00:05<00:00, 152.42it/s]


Epoch 9 Loss: 10452.478661724901


100%|██████████| 889/889 [00:05<00:00, 157.40it/s]


Epoch 10 Loss: 10437.78784426849


100%|██████████| 889/889 [00:06<00:00, 142.42it/s]


Epoch 11 Loss: 10424.472018024115


100%|██████████| 889/889 [00:06<00:00, 145.50it/s]


Epoch 12 Loss: 10413.438234344243


100%|██████████| 889/889 [00:06<00:00, 142.45it/s]


Epoch 13 Loss: 10404.08506090059


100%|██████████| 889/889 [00:05<00:00, 159.24it/s]


Epoch 14 Loss: 10397.575882421435


100%|██████████| 889/889 [00:05<00:00, 163.15it/s]


Epoch 15 Loss: 10391.569141503796


100%|██████████| 889/889 [00:06<00:00, 145.75it/s]


Epoch 16 Loss: 11137.38223908535


100%|██████████| 889/889 [00:06<00:00, 144.38it/s]


Epoch 17 Loss: 10389.319407339708


100%|██████████| 889/889 [00:06<00:00, 146.13it/s]


Epoch 18 Loss: 10386.802798636987


100%|██████████| 889/889 [00:05<00:00, 155.91it/s]


Epoch 19 Loss: 10384.307454060918


100%|██████████| 889/889 [00:05<00:00, 150.44it/s]


Epoch 20 Loss: 10382.214459825822


100%|██████████| 889/889 [00:05<00:00, 154.88it/s]


Epoch 21 Loss: 10380.211808057684


100%|██████████| 889/889 [00:05<00:00, 155.14it/s]


Epoch 22 Loss: 10378.54729737152


100%|██████████| 889/889 [00:06<00:00, 146.83it/s]


Epoch 23 Loss: 10376.798030617267


100%|██████████| 889/889 [00:05<00:00, 148.64it/s]


Epoch 24 Loss: 10375.065038623101


100%|██████████| 889/889 [00:05<00:00, 156.33it/s]


Epoch 25 Loss: 10373.253930966148


100%|██████████| 889/889 [00:06<00:00, 146.87it/s]


Epoch 26 Loss: 10371.766560918166


100%|██████████| 889/889 [00:06<00:00, 144.66it/s]


Epoch 27 Loss: 10369.963699117689


100%|██████████| 889/889 [00:06<00:00, 143.60it/s]


Epoch 28 Loss: 10368.264131046119


100%|██████████| 889/889 [00:06<00:00, 138.44it/s]


Epoch 29 Loss: 10366.225570558563


100%|██████████| 889/889 [00:05<00:00, 150.31it/s]


Epoch 30 Loss: 10364.533209166726


100%|██████████| 889/889 [00:05<00:00, 155.95it/s]


Epoch 31 Loss: 10362.772967893174


100%|██████████| 889/889 [00:06<00:00, 147.57it/s]


Epoch 32 Loss: 10360.69153806946


100%|██████████| 889/889 [00:06<00:00, 140.56it/s]


Epoch 33 Loss: 10358.967175855947


100%|██████████| 889/889 [00:06<00:00, 127.76it/s]


Epoch 34 Loss: 10357.538600033395


100%|██████████| 889/889 [00:06<00:00, 127.81it/s]


Epoch 35 Loss: 10356.026034892611


100%|██████████| 889/889 [00:06<00:00, 141.64it/s]


Epoch 36 Loss: 10354.800808712387


100%|██████████| 889/889 [00:06<00:00, 140.26it/s]


Epoch 37 Loss: 10353.600199816332


100%|██████████| 889/889 [00:06<00:00, 144.18it/s]


Epoch 38 Loss: 10352.368257434617


100%|██████████| 889/889 [00:06<00:00, 130.57it/s]


Epoch 39 Loss: 10350.96480200717


100%|██████████| 889/889 [00:06<00:00, 133.99it/s]


Epoch 40 Loss: 10349.185469980315


100%|██████████| 889/889 [00:05<00:00, 149.51it/s]


Epoch 41 Loss: 10348.547780160292


100%|██████████| 889/889 [00:06<00:00, 140.93it/s]


Epoch 42 Loss: 10344.763587840094


100%|██████████| 889/889 [00:06<00:00, 147.58it/s]


Epoch 43 Loss: 10341.255033855632


100%|██████████| 889/889 [00:06<00:00, 136.22it/s]


Epoch 44 Loss: 10338.981691375493


100%|██████████| 889/889 [00:05<00:00, 155.17it/s]


Epoch 45 Loss: 10335.926121783605


100%|██████████| 889/889 [00:05<00:00, 154.75it/s]


Epoch 46 Loss: 10337.892494639342


100%|██████████| 889/889 [00:05<00:00, 158.47it/s]


Epoch 47 Loss: 10332.61294456113


100%|██████████| 889/889 [00:05<00:00, 154.03it/s]


Epoch 48 Loss: 10329.16519065488


100%|██████████| 889/889 [00:06<00:00, 141.75it/s]


Epoch 49 Loss: 10324.591334957642


100%|██████████| 889/889 [00:06<00:00, 139.23it/s]


Epoch 50 Loss: 10320.347910002462


100%|██████████| 889/889 [00:06<00:00, 127.07it/s]


Epoch 51 Loss: 10307.980547292429


100%|██████████| 889/889 [00:06<00:00, 140.90it/s]


Epoch 52 Loss: 10303.08632307192


100%|██████████| 889/889 [00:07<00:00, 120.99it/s]


Epoch 53 Loss: 10300.67237811094


100%|██████████| 889/889 [00:07<00:00, 124.58it/s]


Epoch 54 Loss: 10297.642930742055


100%|██████████| 889/889 [00:06<00:00, 135.05it/s]


Epoch 55 Loss: 10295.8247191147


100%|██████████| 889/889 [00:06<00:00, 142.93it/s]


Epoch 56 Loss: 10293.703189811235


100%|██████████| 889/889 [00:06<00:00, 143.55it/s]


Epoch 57 Loss: 10290.191926936866


100%|██████████| 889/889 [00:05<00:00, 157.30it/s]


Epoch 58 Loss: 10287.570426194285


100%|██████████| 889/889 [00:06<00:00, 147.43it/s]


Epoch 59 Loss: 10285.778314820023


100%|██████████| 889/889 [00:07<00:00, 126.06it/s]


Epoch 60 Loss: 10284.384909527911


100%|██████████| 889/889 [00:06<00:00, 132.46it/s]


Epoch 61 Loss: 10283.047275950858


100%|██████████| 889/889 [00:07<00:00, 126.82it/s]


Epoch 62 Loss: 10281.669743918728


100%|██████████| 889/889 [00:06<00:00, 146.25it/s]


Epoch 63 Loss: 10280.327079781531


100%|██████████| 889/889 [00:06<00:00, 139.11it/s]


Epoch 64 Loss: 10279.787674001687


100%|██████████| 889/889 [00:06<00:00, 130.94it/s]


Epoch 65 Loss: 10278.980357801955


100%|██████████| 889/889 [00:05<00:00, 149.51it/s]


Epoch 66 Loss: 10278.393761753901


100%|██████████| 889/889 [00:05<00:00, 152.02it/s]


Epoch 67 Loss: 10276.67065182526


100%|██████████| 889/889 [00:05<00:00, 149.00it/s]


Epoch 68 Loss: 10276.030789730385


100%|██████████| 889/889 [00:05<00:00, 156.11it/s]


Epoch 69 Loss: 10275.509828788492


100%|██████████| 889/889 [00:05<00:00, 151.52it/s]


Epoch 70 Loss: 10274.863284545487


100%|██████████| 889/889 [00:06<00:00, 129.84it/s]


Epoch 71 Loss: 10273.637442658535


100%|██████████| 889/889 [00:07<00:00, 115.79it/s]


Epoch 72 Loss: 10272.599951556349


100%|██████████| 889/889 [00:07<00:00, 118.81it/s]


Epoch 73 Loss: 10271.550664809476


100%|██████████| 889/889 [00:07<00:00, 121.81it/s]


Epoch 74 Loss: 10270.689997429521


100%|██████████| 889/889 [00:06<00:00, 128.28it/s]


Epoch 75 Loss: 10270.135409343364


100%|██████████| 889/889 [00:06<00:00, 143.93it/s]


Epoch 76 Loss: 10269.27444811586


100%|██████████| 889/889 [00:05<00:00, 151.81it/s]


Epoch 77 Loss: 10269.736302859603


100%|██████████| 889/889 [00:06<00:00, 145.37it/s]


Epoch 78 Loss: 10268.076063123945


100%|██████████| 889/889 [00:05<00:00, 154.33it/s]


Epoch 79 Loss: 10268.276754407165


100%|██████████| 889/889 [00:05<00:00, 149.03it/s]


Epoch 80 Loss: 10267.415494388884


100%|██████████| 889/889 [00:05<00:00, 158.31it/s]


Epoch 81 Loss: 10266.90777855649


100%|██████████| 889/889 [00:05<00:00, 148.51it/s]


Epoch 82 Loss: 10266.409569653402


100%|██████████| 889/889 [00:05<00:00, 152.81it/s]


Epoch 83 Loss: 10265.8906859665


100%|██████████| 889/889 [00:05<00:00, 158.38it/s]


Epoch 84 Loss: 10265.885849290811


100%|██████████| 889/889 [00:05<00:00, 151.39it/s]


Epoch 85 Loss: 10266.065780107565


100%|██████████| 889/889 [00:06<00:00, 146.92it/s]


Epoch 86 Loss: 10265.500049981545


100%|██████████| 889/889 [00:05<00:00, 153.65it/s]


Epoch 87 Loss: 10265.458548272287


100%|██████████| 889/889 [00:05<00:00, 152.53it/s]


Epoch 88 Loss: 10264.193211627355


100%|██████████| 889/889 [00:06<00:00, 134.65it/s]


Epoch 89 Loss: 10266.415379046857


100%|██████████| 889/889 [00:06<00:00, 132.22it/s]


Epoch 90 Loss: 10264.280527189962


100%|██████████| 889/889 [00:06<00:00, 134.71it/s]


Epoch 91 Loss: 10265.023131569003


100%|██████████| 889/889 [00:06<00:00, 143.71it/s]


Epoch 92 Loss: 10263.237587000844


100%|██████████| 889/889 [00:06<00:00, 140.44it/s]


Epoch 93 Loss: 10262.453826389377


100%|██████████| 889/889 [00:05<00:00, 158.50it/s]


Epoch 94 Loss: 10262.451121893455


100%|██████████| 889/889 [00:06<00:00, 144.05it/s]


Epoch 95 Loss: 10261.865153086334


100%|██████████| 889/889 [00:05<00:00, 151.48it/s]


Epoch 96 Loss: 10261.397753137304


100%|██████████| 889/889 [00:06<00:00, 142.92it/s]


Epoch 97 Loss: 10374.760232485587


100%|██████████| 889/889 [00:05<00:00, 148.44it/s]


Epoch 98 Loss: 10264.955586728416


100%|██████████| 889/889 [00:06<00:00, 144.86it/s]

Epoch 99 Loss: 10264.386522668554





### Training model

In [63]:
net = VAE_Classification(vae, Classification()).to(device)
opitmizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [65]:
for epoch in range(epochs):
    for x, y in tqdm(Train_dataset):
        opitmizer.zero_grad()
        output = net(x.float())

        loss = criterion(output, y.float().view(-1, 1))
        loss.backward()
        opitmizer.step()

    with torch.no_grad():
        correct = 0
        total = 0
        for x, y in tqdm(Test_dataset):
            output = net(x.float())
            predicted = (output >= 0.5).float()
            total += y.size(0)
            correct += (predicted == y.float().view(-1, 1)).sum().item()
        print(f"Epoch {epoch} Accuracy: {correct*100/total:.4f}%")

100%|██████████| 889/889 [00:06<00:00, 145.36it/s]
100%|██████████| 223/223 [00:00<00:00, 253.94it/s]


Epoch 0 Accuracy: 96.7633%


100%|██████████| 889/889 [00:05<00:00, 172.29it/s]
100%|██████████| 223/223 [00:01<00:00, 181.03it/s]


Epoch 1 Accuracy: 97.9240%


100%|██████████| 889/889 [00:05<00:00, 175.08it/s]
100%|██████████| 223/223 [00:01<00:00, 220.27it/s]


Epoch 2 Accuracy: 98.6309%


100%|██████████| 889/889 [00:05<00:00, 177.56it/s]
100%|██████████| 223/223 [00:00<00:00, 224.71it/s]


Epoch 3 Accuracy: 99.0213%


100%|██████████| 889/889 [00:05<00:00, 176.68it/s]
100%|██████████| 223/223 [00:01<00:00, 219.70it/s]


Epoch 4 Accuracy: 99.0882%


100%|██████████| 889/889 [00:05<00:00, 175.45it/s]
100%|██████████| 223/223 [00:00<00:00, 262.40it/s]


Epoch 5 Accuracy: 99.3379%


100%|██████████| 889/889 [00:05<00:00, 164.27it/s]
100%|██████████| 223/223 [00:01<00:00, 182.35it/s]


Epoch 6 Accuracy: 99.3361%


100%|██████████| 889/889 [00:06<00:00, 136.97it/s]
100%|██████████| 223/223 [00:01<00:00, 185.30it/s]


Epoch 7 Accuracy: 99.4715%


100%|██████████| 889/889 [00:06<00:00, 133.07it/s]
100%|██████████| 223/223 [00:01<00:00, 220.70it/s]


Epoch 8 Accuracy: 99.6078%


100%|██████████| 889/889 [00:06<00:00, 140.42it/s]
100%|██████████| 223/223 [00:01<00:00, 173.82it/s]


Epoch 9 Accuracy: 99.5990%


100%|██████████| 889/889 [00:05<00:00, 155.28it/s]
100%|██████████| 223/223 [00:01<00:00, 210.66it/s]


Epoch 10 Accuracy: 99.6931%


100%|██████████| 889/889 [00:05<00:00, 174.71it/s]
100%|██████████| 223/223 [00:01<00:00, 211.69it/s]


Epoch 11 Accuracy: 99.7309%


100%|██████████| 889/889 [00:05<00:00, 166.01it/s]
100%|██████████| 223/223 [00:00<00:00, 230.66it/s]


Epoch 12 Accuracy: 99.6922%


100%|██████████| 889/889 [00:05<00:00, 169.71it/s]
100%|██████████| 223/223 [00:00<00:00, 239.72it/s]


Epoch 13 Accuracy: 99.7257%


100%|██████████| 889/889 [00:05<00:00, 166.57it/s]
100%|██████████| 223/223 [00:00<00:00, 237.03it/s]


Epoch 14 Accuracy: 99.7380%


100%|██████████| 889/889 [00:05<00:00, 174.59it/s]
100%|██████████| 223/223 [00:01<00:00, 222.80it/s]


Epoch 15 Accuracy: 99.7441%


100%|██████████| 889/889 [00:05<00:00, 173.10it/s]
100%|██████████| 223/223 [00:01<00:00, 211.58it/s]


Epoch 16 Accuracy: 99.7705%


100%|██████████| 889/889 [00:05<00:00, 171.76it/s]
100%|██████████| 223/223 [00:00<00:00, 228.65it/s]


Epoch 17 Accuracy: 99.7819%


100%|██████████| 889/889 [00:05<00:00, 168.66it/s]
100%|██████████| 223/223 [00:01<00:00, 206.83it/s]


Epoch 18 Accuracy: 99.7802%


100%|██████████| 889/889 [00:05<00:00, 168.08it/s]
100%|██████████| 223/223 [00:00<00:00, 245.92it/s]


Epoch 19 Accuracy: 99.7160%


100%|██████████| 889/889 [00:05<00:00, 176.27it/s]
100%|██████████| 223/223 [00:01<00:00, 214.80it/s]


Epoch 20 Accuracy: 99.7257%


100%|██████████| 889/889 [00:05<00:00, 172.92it/s]
100%|██████████| 223/223 [00:01<00:00, 220.56it/s]


Epoch 21 Accuracy: 99.7767%


100%|██████████| 889/889 [00:05<00:00, 175.71it/s]
100%|██████████| 223/223 [00:00<00:00, 252.98it/s]


Epoch 22 Accuracy: 99.8039%


100%|██████████| 889/889 [00:05<00:00, 164.71it/s]
100%|██████████| 223/223 [00:00<00:00, 238.05it/s]


Epoch 23 Accuracy: 99.7951%


100%|██████████| 889/889 [00:05<00:00, 165.66it/s]
100%|██████████| 223/223 [00:01<00:00, 216.13it/s]


Epoch 24 Accuracy: 99.7547%


100%|██████████| 889/889 [00:06<00:00, 138.84it/s]
100%|██████████| 223/223 [00:00<00:00, 227.13it/s]


Epoch 25 Accuracy: 99.7907%


100%|██████████| 889/889 [00:05<00:00, 149.56it/s]
100%|██████████| 223/223 [00:00<00:00, 237.49it/s]


Epoch 26 Accuracy: 99.7371%


100%|██████████| 889/889 [00:05<00:00, 149.32it/s]
100%|██████████| 223/223 [00:01<00:00, 198.26it/s]


Epoch 27 Accuracy: 99.8329%


100%|██████████| 889/889 [00:06<00:00, 145.35it/s]
100%|██████████| 223/223 [00:00<00:00, 250.94it/s]


Epoch 28 Accuracy: 99.8013%


100%|██████████| 889/889 [00:05<00:00, 168.53it/s]
100%|██████████| 223/223 [00:01<00:00, 218.86it/s]


Epoch 29 Accuracy: 99.7687%


100%|██████████| 889/889 [00:05<00:00, 172.67it/s]
100%|██████████| 223/223 [00:01<00:00, 219.33it/s]


Epoch 30 Accuracy: 99.8136%


100%|██████████| 889/889 [00:05<00:00, 167.08it/s]
100%|██████████| 223/223 [00:00<00:00, 247.03it/s]


Epoch 31 Accuracy: 99.8259%


 27%|██▋       | 240/889 [00:01<00:04, 141.47it/s]


KeyboardInterrupt: 

In [None]:
# 99.83 ok?