In [1]:
import torch
from torch import nn
from torch.functional import F

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from tqdm.notebook import tqdm

from trace_process import *

In [2]:
network_out_path = f"{DEFAULT_PREFIX}/{TEST_PATH}/{NETWORK_OUT}"
memory_path = f"{DEFAULT_PREFIX}/{TEST_PATH}/{MEMORY}"
network_out_raw_path = f"../{network_out_path}"
memory_raw_path = f"../{memory_path}"

SEQ_L = 30
TIME_DELTA = 500 * NANO_TO_MICRO

In [3]:
packets_df = pd.read_csv(f"{network_out_path}.csv", index_col=[0])
trace = packets_df.values

In [4]:
flows_indexes, flow_sizes = get_flows_index(trace, TIME_DELTA)

In [5]:
packets = trace[:, 1]

In [6]:
start = np.argwhere(flows_indexes[:, 0] > SEQ_L).min()

In [7]:
class TraceDataset(torch.utils.data.IterableDataset):
    EM_THRESHOLD = 195810
    END = 500000

    def __init__(self, packet_trace, flows_indexes, flow_sizes, start):
        self.packet_trace = packet_trace
        self.flows_indexes = flows_indexes
        self.flow_sizes = flow_sizes
        self.start = start
    
    def __iter__(self):
        idx = start
        end = len(self.flows_indexes[start:, 0])
        while idx < end - 1:
            target = self.flows_indexes[idx, 0]
            x = self.packet_trace[target - SEQ_L:target]
            y = (self.flow_sizes[idx + 1] > self.EM_THRESHOLD) * 1.
            yield x[None, :].astype(np.float32), y
            idx += 1
    
    def __len__(self):
        return len(self.flows_indexes) - start

In [8]:
class CNNModel(nn.Module):
    def __init__(self, filters=(100, 50, 25)):
        super().__init__()
        self.filters = filters

        self.conv1 = nn.Conv1d(1, filters[0], 3)
        self.bn1 = nn.BatchNorm1d(filters[0])
        self.conv2 = nn.Conv1d(filters[0], filters[1], 3)
        self.bn2 = nn.BatchNorm1d(filters[1])
        self.conv3 = nn.Conv1d(filters[1], filters[2], 3)
        self.bn3 = nn.BatchNorm1d(filters[2])
        self.linear = nn.Linear(filters[2] * 2, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool1d(x, 2)
        x = self.bn1(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool1d(x, 2)
        x = self.bn2(x)

        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool1d(x, 2)
        x = self.bn3(x)
        

        x = x.view(-1, self.filters[-1] * 2)

        x = self.linear(x)
        x = F.sigmoid(x)

        return x


In [9]:
N_EPOCHS = 10
BATCH_SIZE = 64

In [10]:
train_dataset = TraceDataset(packets, flows_indexes, flow_sizes, start)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)

In [11]:
criterion = nn.BCELoss()
model = CNNModel()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.005, momentum=0.5)

In [12]:
if torch.cuda.is_available():
  model = model.cuda()
  criterion = criterion.cuda()
  device = torch.device("cuda:0")
else:
  device = torch.device("cpu")

In [13]:

def get_loss_and_correct(model, batch, criterion, device):
    data, target = batch
    data, target = data.to(device, dtype=torch.float), target.to(device, dtype=torch.float)
    output = model(data)
    output = torch.squeeze(output)
    
    loss = criterion(output, target)

    pred = torch.round(output)
    true_num = pred.eq(target.data.view_as(pred)).sum()

    return loss, true_num

def step(loss, optimizer):
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

In [14]:
train_losses = []
train_accuracies = []

pbar = tqdm(range(N_EPOCHS))

for i in pbar:
  total_train_loss = 0.0
  total_train_correct = 0.0

  model.train()

  for batch in tqdm(train_dataloader, leave=False):
    loss, correct = get_loss_and_correct(model, batch, criterion, device)
    step(loss, optimizer)
    total_train_loss += loss.item()
    total_train_correct += correct.item()

  mean_train_loss = total_train_loss / len(train_dataset)
  train_accuracy = total_train_correct / len(train_dataset)

  train_losses.append(mean_train_loss)

  train_accuracies.append(train_accuracy)

  pbar.set_postfix({'train_loss': mean_train_loss, 'train_accuracy': train_accuracy})

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/31294 [00:00<?, ?it/s]



  0%|          | 0/31294 [00:00<?, ?it/s]

  0%|          | 0/31294 [00:00<?, ?it/s]

  0%|          | 0/31294 [00:00<?, ?it/s]

  0%|          | 0/31294 [00:00<?, ?it/s]

  0%|          | 0/31294 [00:00<?, ?it/s]

  0%|          | 0/31294 [00:00<?, ?it/s]

  0%|          | 0/31294 [00:00<?, ?it/s]

  0%|          | 0/31294 [00:00<?, ?it/s]

  0%|          | 0/31294 [00:00<?, ?it/s]