In [6]:
from tqdm import tqdm
from typing import List, Dict
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from flowprintOptimal.sekigo.core.flowRepresentation import FlowRepresentation,PacketFlowRepressentation
from flowprintOptimal.sekigo.dataAnalysis.vNATDataFrameProcessor import VNATDataFrameProcessor
from flowprintOptimal.sekigo.core.flowConfig import FlowConfig
import random
from flowprintOptimal.sekigo.flowUtils.flowDatasets import PacketFlowDataset
from torch.utils.data import Dataset,DataLoader
from torchsampler import ImbalancedDatasetSampler
from sklearn.model_selection import train_test_split
from flowprintOptimal.sekigo.flowUtils.commons import normalizePacketRep
import os
from joblib import Parallel, delayed
from flowprintOptimal.sekigo.flowUtils.commons import saveFlows,loadFlows
from flowprintOptimal.sekigo.dataAnalysis.dataFrameProcessor import UTMobileNetProcessor
from flowprintOptimal.sekigo.flowUtils.dataGetter import getTrainTestOOD
from sklearn.metrics import confusion_matrix
import json
from flowprintOptimal.sekigo.modeling.trainers import NNClassificationTrainer
from flowprintOptimal.sekigo.modeling.neuralNetworks import LSTMNetwork,TransformerGenerator,CNNNetwork1D
from flowprintOptimal.sekigo.modeling.loggers import Logger
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
from flowprintOptimal.sekigo.earlyClassification.DQL.core import MemoryElement,Rewarder,State
from flowprintOptimal.sekigo.earlyClassification.DQL.memoryFiller import MemoryFillerV2
from flowprintOptimal.sekigo.earlyClassification.DQL.datasets import MemoryDataset
from flowprintOptimal.sekigo.earlyClassification.DQL.trainers import EarlyClassificationtrainer
from flowprintOptimal.sekigo.utils.documentor import Documenter
from flowprintOptimal.sekigo.utils.evaluations import Evaluator,EarlyEvaluation,EarlyEvaluationV2
import warnings
from flowprintOptimal.sekigo.flowUtils.commons import dropPacketFromPacketRep
from copy import deepcopy
from flowprintOptimal.sekigo.flowUtils.flowDatasets import BaseFlowDataset
warnings.filterwarnings('ignore')
from focal_loss.focal_loss import FocalLoss

In [7]:
class Decider(nn.Module):
    def __init__(self, in_dim) -> None:
        super().__init__()
        self.linear = nn.Sequential(nn.Linear(in_dim, in_dim//2), nn.ReLU(), nn.Linear(in_dim//2, in_dim//4), nn.ReLU(), nn.Linear(in_dim//4, 2))
    def forward(self,X):
        return self.linear(X)

In [3]:
configs = dict(
    name = "VNAT_no_sample_no_ood_samples",
    description = "VNAT with OOD detection (no balancers used), no ood samples generated",
    
    common_config = dict(
        max_length = 15
    ),
    
    full_model_kwargs = dict(
        lstm_hidden_size = 256,
        layers= 2, lstm_input_size = 3
    ),

    early_model_kwargs = dict(
        lstm_input_size= 3,lstm_hidden_size= 256,layers = 2        
    ),
    
    data_config = dict(
        dataset_name = "unibs",
        subsampleConfig = None,#dict(max_gap = 20, min_gap = 5),
        max_flow_length = 80, # in seconds  ( each flow sample cannot excede this length)
        test_size = .2,
        ood_classes = ["Skype"],
        do_balance = False

    ),

    rewarder_config = dict(
        l = .5
    ),

    dataset_config = dict(
        aug = [0,.2]
    ),

    memory_fillter_config = dict(
        ood_config = dict(ood_aug = [.6,.9], ood_prob = .2),
        min_length = 5,
        use_balancer = False
    ),
    full_trainer_config = dict(
        use_sampler = False
    ),
    early_trainer_config = dict(
        use_sampler = False  # this is for giving more weight to wait samples
    )


)

In [5]:
train_flows,test_flows,ood_flows = getTrainTestOOD(**configs["data_config"], packet_limit= configs["common_config"]["max_length"])

full class distrubation
BROWSERS    42855
P2P         21514
OTHER        5218
MAIL         4521
Skype        1280
Name: count, dtype: int64
using no sampling
filtering max_flow_length = 80
balancing
keep_number = 1758
post num packet filter class distrubation
BROWSERS    1746
MAIL        1717
OTHER       1711
P2P         1682
Skype        894
Name: count, dtype: int64
------------------------------
train class distrubation
BROWSERS    1386
MAIL        1380
OTHER       1370
P2P         1348
Name: count, dtype: int64
test class distrubation
BROWSERS    360
OTHER       341
MAIL        337
P2P         334
Name: count, dtype: int64


In [8]:
train_dataset = PacketFlowDataset(flows= train_flows,label_to_index= None,aug= configs["dataset_config"]["aug"])
test_dataset = PacketFlowDataset(flows= test_flows,label_to_index= train_dataset.label_to_index)
ood_dataset = PacketFlowDataset(flows= ood_flows, label_to_index= None) if (ood_flows != None and len(ood_flows) != 0) else None

In [9]:
num_labels = len(train_dataset.label_to_index)
configs["full_model_kwargs"]["output_dim"] = num_labels 
configs["early_model_kwargs"]["output_dim"] = num_labels
configs["rewarder_config"]["num_labels"] = num_labels
configs["rewarder_config"]["max_length"] = configs["common_config"]["max_length"]

In [10]:
rewarder = Rewarder(**configs["rewarder_config"])
memory_filler = MemoryFillerV2(dataset= train_dataset,rewarder= rewarder, min_length= configs["memory_fillter_config"]["min_length"],
                              max_length= rewarder.max_length,ood_config= configs["memory_fillter_config"]["ood_config"], use_balancer= configs["memory_fillter_config"]["use_balancer"]
                              )

In [11]:
memory = memory_filler.processDataset()
print(len(memory))

241296


In [12]:
class EarlyClassificationTrainerV2:
    def __init__(self, decider ,predictor: LSTMNetwork, train_dataset: BaseFlowDataset, memory_dataset: MemoryDataset, hint_loss_alpha: float,
                q_loss_alpha: float, hint_loss_gap: float, test_dataset: BaseFlowDataset, ood_dataset: BaseFlowDataset,
                logger: Logger, model_replacement_steps: int, device: str):
        self.device = device
        self.predictor = predictor.to(device)
        self.decider = decider.to(self.device)
        self.lag_decider = deepcopy(decider).to(device)
        self.lag_decider.eval()

        self.hint_loss_alpha = hint_loss_alpha
        self.q_loss_alpha = q_loss_alpha
        self.hint_loss_gap = hint_loss_gap

        self.train_dataset = train_dataset
        self.memory_dataset = memory_dataset
        self.test_dataset = test_dataset
        self.ood_dataset = ood_dataset
        self.logger = logger

        self.best = dict(
            score = 0,
            predictor = deepcopy(self.predictor),
            decider = deepcopy(self.decider)
        )

        self.evaluator = EarlyEvaluationV2(min_steps= memory_dataset.min_length, device= device,model= self.predictor, decider= self.decider)
        self.mse_loss_function = nn.MSELoss(reduction= "none")
        self.model_replacement_steps = model_replacement_steps

        self.logger.setMetricReportSteps(metric_name= "test_eval_f1", step_size= 1)
        self.logger.setMetricReportSteps(metric_name= "train_eval_f1", step_size= 1)
        self.logger.setMetricReportSteps(metric_name= "train_eval_time", step_size= 1)
        self.logger.setMetricReportSteps(metric_name= "test_eval_time", step_size= 1)
        self.logger.setMetricReportSteps(metric_name= "ood_eval", step_size= 1)
        self.logger.setMetricReportSteps(metric_name= "ood_eval_time", step_size= 1)
        self.logger.setMetricReportSteps(metric_name= "incorrect_ood_test", step_size= 1)
        self.logger.setMetricReportSteps(metric_name= "incorrect_ood_train", step_size= 1)
        
        self.cross_entropy_loss = nn.CrossEntropyLoss(reduction= "none")
        self.focal_loss = FocalLoss(gamma= .7, reduction= "none")
        self.softmax = nn.Softmax(dim= -1)
 

    def __refreshLagModel(self):
        self.lag_decider = deepcopy(self.decider)
        self.lag_decider.eval()
    


    def eval(self,dataset : BaseFlowDataset):
        metrices = self.evaluator.getMetrices(dataset= dataset,ood_dataset= None)
        return metrices["macro_f1"],metrices["time"],metrices["incorrect_ood"]
    

    def evalTrain(self):
        f1,average_time,incorrect_ood = self.eval(dataset= self.train_dataset)
        self.logger.addMetric(metric_name= "train_eval_f1", value= f1)
        self.logger.addMetric(metric_name= "train_eval_time", value= average_time)
        self.logger.addMetric(metric_name= "incorrect_ood_train", value = incorrect_ood)

    def evalTest(self):
        f1,average_time,incorrect_ood = self.eval(dataset= self.test_dataset)

        if f1 >= self.best["score"]:
            self.best["score"] = f1
            self.best["predictor"] = deepcopy(self.predictor)
            self.best["decider"] = deepcopy(self.decider)
        
        self.logger.addMetric(metric_name= "test_eval_f1", value= f1)
        self.logger.addMetric(metric_name= "test_eval_time", value= average_time)
        self.logger.addMetric(metric_name= "incorrect_ood_test", value= incorrect_ood)

    def evalOOD(self):
        metrices = self.evaluator.getMetrices(ood_dataset= self.ood_dataset, dataset= None)
        self.logger.addMetric(metric_name= "ood_eval", value= metrices["ood_accuracy"])
        self.logger.addMetric(metric_name= "ood_eval_time", value= metrices["ood_time"])


    
    def trainStep(self, steps, batch: dict, lam: float, predictor_optimizer, decider_optimizer):
        state,next_state,action,reward,is_terminal = batch["state"].to(self.device), batch["next_state"].to(self.device),\
                                                    batch["action"].to(self.device), batch["reward"].to(self.device),batch["is_terminal"].to(self.device)
        
        label, state_length = batch["label"].to(self.device), batch["state_length"].to(self.device)

        predictor_state_output,predictor_state_features = self.predictor(state)
        predicted_values = self.decider(predictor_state_features)
        
        with torch.no_grad():
            next_state_max_actions_model = torch.argmax(self.decider(self.predictor(next_state)[1]),dim = -1,keepdim= True)
            next_state_values_lag_model = self.lag_decider(self.predictor(next_state)[1])
            next_state_values_for_max_action = torch.gather(input= next_state_values_lag_model, dim= 1, index= next_state_max_actions_model) # (BS,1)
            next_state_values_for_max_action = next_state_values_for_max_action*(~(is_terminal.unsqueeze(-1)))


            # calculate reward
            predicted_classification = torch.argmax(predictor_state_output, dim= -1)
            correct_classification = (predicted_classification == label)  # this covers ood as well
            do_predict = (action == 1)
            reward[(correct_classification & do_predict) ] = 1
            reward[((~correct_classification) & do_predict)] = -1
            reward[~do_predict] = -.1
            


            target = reward + lam*(next_state_values_for_max_action.squeeze()) # (BS)
        
        
        predicted_values_for_taken_action = torch.gather(input= predicted_values, dim= 1,index= action.unsqueeze(-1)).squeeze() # (BS)
        q_loss = self.mse_loss_function(target, predicted_values_for_taken_action).mean()
        
        #predicted_classification = torch.argmax(predictor_state_output, dim= -1)
        #correct_classification = (predicted_classification == label)
        #decider_labels = torch.zeros_like(label)
        #decider_labels[correct_classification] = 1
        #q_loss = self.focal_loss(self.softmax(predicted_values), decider_labels).mean()
        #q_loss = self.cross_entropy_loss(predicted_values,decider_labels).mean()

        ood_mask = torch.ones_like(label).float()
        ood_mask[label == -1] = 0
        label[label == -1] = 0 # dummy label
        cross_entropy_loss = (self.cross_entropy_loss(predictor_state_output,label)*(state_length/self.memory_dataset.max_length)*ood_mask).mean()

        loss = q_loss + cross_entropy_loss

        self.logger.addMetric(metric_name= "q_loss", value= q_loss.item())
        self.logger.addMetric(metric_name= "cross_entropy_loss", value= cross_entropy_loss.item())

        decider_optimizer.zero_grad()
        predictor_optimizer.zero_grad()
        loss.backward()
        predictor_optimizer.step()
        decider_optimizer.step()


        if steps%self.model_replacement_steps == 0:
            self.hint_memory = []
            self.__refreshLagModel()
        

    def train(self,epochs : int,batch_size = 64,lr = .001,lam = .99):
        # TODO add batch_sampler
        """
        Can stress enough how important the shuffle == True is in the Dataloader
        """
       
        train_dataloader = DataLoader(dataset= self.memory_dataset, collate_fn= self.memory_dataset.collateFn, batch_size= batch_size,drop_last= True, shuffle= True)
        predictor_optimizer = torch.optim.Adam(params= self.predictor.parameters(), lr = lr)
        decider_optimizer = torch.optim.Adam(params= self.decider.parameters(), lr = lr)
        steps = 0

        for epoch in range(epochs):
            for batch in train_dataloader:
                self.trainStep(steps = steps,batch= batch,lam= lam, predictor_optimizer= predictor_optimizer, decider_optimizer= decider_optimizer)
                steps += 1            

                if steps%1000 == 0:
                    self.evalTest()
                    if self.ood_dataset != None:
                        self.evalOOD()
                if steps%2000 == 0:
                    self.evalTrain()

In [13]:
memory_dataset = MemoryDataset(memories= memory,num_classes= len(train_dataset.label_to_index),
                               min_length= memory_filler.min_length,max_length= memory_filler.max_length)
predictor = LSTMNetwork(**configs["early_model_kwargs"])
logger = Logger(verbose= True)
logger.default_step_size = 1000
decider = Decider(in_dim= configs["early_model_kwargs"]["lstm_hidden_size"])

In [14]:
ddq_model = EarlyClassificationTrainerV2(decider= decider,predictor= predictor,train_dataset = train_dataset,test_dataset= test_dataset,memory_dataset= memory_dataset,hint_loss_gap= .05,
                                       ood_dataset= ood_dataset,hint_loss_alpha= 0,q_loss_alpha= 1,
                                       logger= logger,device=device,model_replacement_steps= 500)

In [15]:
ddq_model.train(epochs= 20, batch_size= 128, lr= 3e-4, lam= .99)

 ---- 1000 metric q_loss = 0.41060874742269515
 ---- 1000 metric cross_entropy_loss = 0.6940028650164605
 ---- 1 metric test_eval_f1 = 0.42347371375009146
 ---- 1 metric test_eval_time = 6.330903790087463
 ---- 1 metric incorrect_ood_test = 0.021137026239067054
 ---- 1 metric ood_eval = 0.0
 ---- 1 metric ood_eval_time = 5.10178970917226
 ---- 2000 metric q_loss = 0.39067640374600887
 ---- 2000 metric cross_entropy_loss = 0.607127954185009
 ---- 2 metric test_eval_f1 = 0.5721010407426271
 ---- 2 metric test_eval_time = 7.004373177842566
 ---- 2 metric incorrect_ood_test = 0.0036443148688046646
 ---- 2 metric ood_eval = 0.0
 ---- 2 metric ood_eval_time = 10.980984340044742
 ---- 1 metric train_eval_f1 = 0.5342392939869054
 ---- 1 metric train_eval_time = 7.435083880379286
 ---- 1 metric incorrect_ood_train = 0.027169948942377828
 ---- 3000 metric q_loss = 0.38608128279447557
 ---- 3000 metric cross_entropy_loss = 0.5768830538988113
 ---- 3 metric test_eval_f1 = 0.7144037764397771
 ---- 

KeyboardInterrupt: 

In [None]:
evaluator = EarlyEvaluationV2(min_steps= 5, device = device, model = predictor, decider = decider)

In [None]:
evaluator.getMetrices(dataset= test_dataset)

{'micro_f1': 0.7592301355500637,
 'macro_f1': 0.7644449466867931,
 'accuracy': 0.7592301355500637,
 'cm': array([[1299,   39,   65,  869,  276],
        [  33, 3005,   25,   36,   27],
        [  62,   22, 1320,   79,   25],
        [ 415,   34,   46, 2172,  493],
        [ 150,   22,   15,  482, 2342]]),
 'incorrect_ood': 0.0,
 'time': 5.061184752490077}