In [1]:
import torch

import pandas as pd
import numpy as np

from sklearn.metrics import confusion_matrix

In [2]:
dataset = pd.read_csv('data/Japan_dataset_octet_3.csv')
dataset['label'].value_counts()

label
0    1114187
1     779054
Name: count, dtype: int64

In [3]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = [int(x) for x in "{:024b}".format(self.X[idx])]
        y = self.Y[idx]

        return torch.tensor(x).float(), torch.tensor([y]).float()

In [4]:
X = dataset.integer.to_numpy()
Y = dataset.label.to_numpy()

In [11]:
ip_dataset = Dataset(X, Y)
dataloader = torch.utils.data.DataLoader(ip_dataset, batch_size=1024, shuffle=True)

In [7]:
model_in = 24
model_arch = [model_in, 256, 128, 64, 32]
model_out = 1

model = torch.nn.Sequential()
for i in range(1, len(model_arch)):
    model.append(torch.nn.Linear(model_arch[i - 1], model_arch[i]))
    model.append(torch.nn.ReLU())
model.append(torch.nn.Linear(model_arch[-1], model_out))

model

Sequential(
  (0): Linear(in_features=24, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ReLU()
  (4): Linear(in_features=128, out_features=64, bias=True)
  (5): ReLU()
  (6): Linear(in_features=64, out_features=32, bias=True)
  (7): ReLU()
  (8): Linear(in_features=32, out_features=1, bias=True)
)

In [15]:
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [9]:
def train_one_epoch(step):
   
    model.train()
    running_loss = 0.
    last_loss = 0.

    for i, data in enumerate(dataloader):

        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        if i % step == (step - 1):
            last_loss = running_loss / step
            print(f'batch: {i + 1} | loss: {last_loss}')
            running_loss = 0.
    return last_loss


def _metrics(tn, fp, fn, tp):
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    recall = tp / (tp + fn)
    precision = tp / (tp + fp)
    
    return accuracy, recall, precision

def get_model_metrics(thresh=0.5):
    model.eval()

    running_loss = []
    tp = tn = fp = fn = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs, labels
            
            logits = model(inputs)
            probs = torch.sigmoid(logits)
            preds = (probs > thresh).int()
            _tn, _fp, _fn, _tp = confusion_matrix(labels, preds).ravel()

            tn += _tn
            fp += _fp
            fn += _fn
            tp += _tp

            loss = loss_fn(logits, labels)
            running_loss.append(loss.item())
            
    avg_vloss = np.mean(running_loss)
    accuracy, recall, precision = _metrics(tn, fp, fn, tp)
    print(tn, fp, fn, tp)
    return avg_vloss, accuracy, recall, precision

In [None]:
#20 EPOCHS ALREADY DONE.
EPOCHS = 2

for epoch in range(EPOCHS):
    print(f'EPOCH : {epoch + 1}')
    
    avg_tloss = train_one_epoch(step=200)
    avg_vloss, accuracy, recall, precision = get_model_metrics()
    
    print(f'Acc {accuracy:.4f} | Precision {precision:.4f} | Recall {recall:.4f}')
    print()

EPOCH : 1
batch: 200 | loss: 0.02045377715257928
batch: 400 | loss: 0.019379164224956183
batch: 600 | loss: 0.019086903231218456
batch: 800 | loss: 0.018537037721835077
batch: 1000 | loss: 0.01852941380115226
batch: 1200 | loss: 0.01875365401385352
batch: 1400 | loss: 0.018552384825889022
batch: 1600 | loss: 0.018250142452307046
batch: 1800 | loss: 0.01820269276155159
1109295 4892 6418 772636
Acc 0.9940 | Precision 0.9937 | Recall 0.9918

EPOCH : 2
batch: 200 | loss: 0.018145487850997597
batch: 400 | loss: 0.017956386196892708
batch: 600 | loss: 0.017283236572984605
batch: 800 | loss: 0.01796381570631638
batch: 1000 | loss: 0.016818508156575263
batch: 1200 | loss: 0.01737298485590145
batch: 1400 | loss: 0.017333653066307308
batch: 1600 | loss: 0.017309259574394675
batch: 1800 | loss: 0.017254087077453732
1109579 4608 6213 772841
Acc 0.9943 | Precision 0.9941 | Recall 0.9920



In [17]:
torch.save(model.state_dict(), "saved_model/Japan_256_128_64_32_fp_4608_later.pth")