# Global settings

In [None]:
# References:
# This source code file refers to:
# https://github.com/ICL-ml4csec/VulBERTa
# https://towardsdatascience.com/text-classification-with-bert-in-pytorch-887965e5820f
# https://huggingface.co/docs/transformers/model_doc/roberta
# https://colab.research.google.com/github/dpressel/dlss-tutorial/blob/master/1_pretrained_vectors.ipynb


In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import random
import torch
import numpy as np
import shutil

def write_to_file(text, path, mode='a'): # 'a': append; 'w': overwrite
    with open(path, mode) as f:
        f.write(text)

def mkdir_if_not_exist(directory):
    if not directory: return
    if not os.path.exists(directory):
        os.mkdir(directory)

def remove_file_if_exist(path):
    if not path: return
    if os.path.exists(path):
        try:
            os.remove(path)
        except:
            shutil.rmtree(path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('using', device)

# The following randomization refers to: https://github.com/ICL-ml4csec/VulBERTa/blob/main/Finetuning_VulBERTa-MLP.ipynb
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
os.environ['WANDB_DISABLED'] = 'true'
os.environ['WANDB_MODE'] = 'dryrun'

# -------------------------------------- start

DATASET_NAME = 'qemu'
# DATASET_MASKING = 'masked_'
DATASET_MASKING = ''

codeTF_check_point = 'checkpoint-6681'
msgTF_check_point = 'roberta_large_qemu_0.976_ep12.pt'

# -------------------------------------- end

root_directory = '/root/autodl-tmp'
dataset_directory = f'{root_directory}/output_dataset_1/{DATASET_MASKING}{DATASET_NAME}'
init_train_path = f'{dataset_directory}/train.json'
init_val_path = f'{dataset_directory}/val.json'
init_test_path = f'{dataset_directory}/test.json'
intermediate_directory = f'{root_directory}/intermediate/{DATASET_MASKING}{DATASET_NAME}'
mkdir_if_not_exist(f'{root_directory}/intermediate')
mkdir_if_not_exist(intermediate_directory)

finetuned_ct_model_path = f'{root_directory}/codeTF_check_point/{DATASET_MASKING}{DATASET_NAME}/{codeTF_check_point}'
intermediate_ct_train_path = f'{intermediate_directory}/ct_train.txt'
intermediate_ct_val_path = f'{intermediate_directory}/ct_val.txt'
intermediate_ct_test_path = f'{intermediate_directory}/ct_test.txt'

finetuned_mt_model_path = f'{root_directory}/msgTF_check_point/{DATASET_MASKING}{DATASET_NAME}/{msgTF_check_point}'
intermediate_mt_train_path = f'{intermediate_directory}/mt_train.txt'
intermediate_mt_val_path = f'{intermediate_directory}/mt_val.txt'
intermediate_mt_test_path = f'{intermediate_directory}/mt_test.txt'


using cuda


# CodeTransformer

In [2]:
from tqdm import tqdm
import sys
import pandas as pd
import numpy as np
import csv
import pickle
import re
import torch
import sklearn
import random
import clang
from clang import *
from clang import cindex
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
from tokenizers.implementations import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from torch.utils.data import Dataset, DataLoader, IterableDataset
from transformers import RobertaConfig
from transformers import RobertaForMaskedLM, RobertaForSequenceClassification
from transformers import RobertaTokenizerFast
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import LineByLineTextDataset
from transformers.modeling_outputs import SequenceClassifierOutput
from tokenizers.pre_tokenizers import PreTokenizer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers import NormalizedString,PreTokenizedString
from typing import List
from tokenizers import Tokenizer
from tokenizers import normalizers,decoders
from tokenizers.normalizers import StripAccents, unicode_normalizer_from_str, Replace
from tokenizers.processors import TemplateProcessing
from tokenizers import processors,pre_tokenizers
from tokenizers.models import BPE

# definitions
class MyTokenizer:
    cidx = cindex.Index.create()

    def clang_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        ## Tokkenize using clang
        tok = []
        tu = self.cidx.parse('tmp.c',
                       args=[''],  
                       unsaved_files=[('tmp.c', str(normalized_string.original))],  
                       options=0)
        for t in tu.get_tokens(extent=tu.cursor.extent):
            spelling = t.spelling.strip()
            if spelling == '': continue
            ## Keyword no need
            ## Punctuations no need
            ## Literal all to BPE
            #spelling = spelling.replace(' ', '')
            tok.append(NormalizedString(spelling))
        return(tok)

    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.clang_split)

def process_encodings(encodings):
    input_ids=[]
    attention_mask=[]
    for enc in encodings:
        input_ids.append(enc.ids)
        attention_mask.append(enc.attention_mask)
    return {'input_ids':input_ids, 'attention_mask':attention_mask}

class MyCustomDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
        assert len(self.encodings['input_ids']) == len(self.encodings['attention_mask']) ==  len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

# ------------------------------------------------------------------------------
# tokenize and load dataset
vocab, merges = BPE.read_file(vocab="./tokenizer/drapgh-vocab.json", merges="./tokenizer/drapgh-merges.txt")
my_tokenizer = Tokenizer(BPE(vocab, merges, unk_token="<unk>"))

my_tokenizer.normalizer = normalizers.Sequence([StripAccents(), Replace(" ", "Ä")])
my_tokenizer.pre_tokenizer = PreTokenizer.custom(MyTokenizer())
my_tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
my_tokenizer.post_processor = TemplateProcessing(
    single="<s> $A </s>",
    special_tokens=[
    ("<s>",0),
    ("<pad>",1),
    ("</s>",2),
    ("<unk>",3),
    ("<mask>",4)
    ]
)

my_tokenizer.enable_truncation(max_length=1024)
my_tokenizer.enable_padding(direction='right', pad_id=1, pad_type_id=0, pad_token='<pad>', length=None, pad_to_multiple_of=None)

m1 = pd.read_json(init_train_path)
m2 = pd.read_json(init_val_path)

train_encodings = my_tokenizer.encode_batch(m1.commit_patch)
train_encodings = process_encodings(train_encodings)

val_encodings = my_tokenizer.encode_batch(m2.commit_patch)
val_encodings = process_encodings(val_encodings)

train_dataset = MyCustomDataset(train_encodings, m1.label.tolist())
val_dataset = MyCustomDataset(val_encodings, m2.label.tolist())

train_loader = DataLoader(train_dataset, batch_size=128)
val_loader = DataLoader(val_dataset, batch_size=128)

# ------------------------------------------------------------------------------
# generate intermediate data by CodeTransformer
model = RobertaForSequenceClassification.from_pretrained(finetuned_ct_model_path)
model.to(device)

def generate_ct_intermediate_dataset(data_loader, intermediate_data_path):
    # model.eval()
    with torch.no_grad():
        for batch in tqdm(data_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask)
            probs = torch.nn.functional.softmax(outputs[0], dim=1).tolist()
            assert(len(probs) == len(labels))
            for i in range(len(probs)):
                prob = probs[i]
                label = int(labels[i])
                content = '\t'.join([str(i) for i in prob + [label]]) + '\n'
                write_to_file(content, intermediate_data_path)

remove_file_if_exist(intermediate_ct_train_path)
remove_file_if_exist(intermediate_ct_val_path)

print('Generating codeTF intermediate dataset:')
generate_ct_intermediate_dataset(train_loader, intermediate_ct_train_path)
del train_loader
generate_ct_intermediate_dataset(val_loader, intermediate_ct_val_path)
del val_loader

# ------------------------------------------------------------------------------
# evaluation
print('\nEvaluation:')
model = RobertaForSequenceClassification.from_pretrained(finetuned_ct_model_path)

test_loader = DataLoader(val_dataset, batch_size=128)

def softmax_accuracy(probs,all_labels):
    def getClass(x):
        return(x.index(max(x)))

    all_labels = all_labels.tolist()
    probs = pd.Series(probs.tolist())
    all_predicted = probs.apply(getClass)
    all_predicted.reset_index(drop=True, inplace=True)
    vc = pd.value_counts(all_predicted == all_labels)
    try:
        acc = vc[1]/len(all_labels)
    except:
        if(vc.index[0]==False):
            acc = 0
        else:
            acc = 1
    return(acc,all_predicted)

model.to(device)

all_pred=[]
all_labels=[]
all_probs=[]
model.eval()
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        acc_val,pred = softmax_accuracy(torch.nn.functional.softmax(outputs[1],dim=1),labels)
        all_pred += pred.tolist()
        all_labels += labels.tolist()
        all_probs += outputs[1].tolist()

confusion = sklearn.metrics.confusion_matrix(y_true=all_labels, y_pred=all_pred)
print('Confusion matrix: \n',confusion)

tn, fp, fn, tp = confusion.ravel()
print('\nTP:',tp)
print('FP:',fp)
print('TN:',tn)
print('FN:',fn)

probs2=[]
for x in all_probs:
    probs2.append(x[1])

## Performance measure
print('\nAccuracy: '+ str(sklearn.metrics.accuracy_score(y_true=all_labels, y_pred=all_pred)))
print('Precision: '+ str(sklearn.metrics.precision_score(y_true=all_labels, y_pred=all_pred)))
print('Recall: '+ str(sklearn.metrics.recall_score(y_true=all_labels, y_pred=all_pred)))
print('F-measure: '+ str(sklearn.metrics.f1_score(y_true=all_labels, y_pred=all_pred)))
print('Precision-Recall AUC: '+ str(sklearn.metrics.average_precision_score(y_true=all_labels, y_score=probs2)))
print('AUC: '+ str(sklearn.metrics.roc_auc_score(y_true=all_labels, y_score=probs2)))
print('MCC: '+ str(sklearn.metrics.matthews_corrcoef(y_true=all_labels, y_pred=all_pred)))



Generating codeTF intermediate dataset:


100%|██████████| 70/70 [02:07<00:00,  1.81s/it]
100%|██████████| 24/24 [00:42<00:00,  1.78s/it]



Evaluation:


100%|██████████| 24/24 [00:42<00:00,  1.78s/it]

Confusion matrix: 
 [[1267  471]
 [ 564  667]]

TP: 667
FP: 471
TN: 1267
FN: 564

Accuracy: 0.6513977770293028
Precision: 0.5861159929701231
Recall: 0.5418359057676686
F-measure: 0.5631067961165049
Precision-Recall AUC: 0.6056356912068408
AUC: 0.6834732584303275
MCC: 0.2744372153612606





# MsgTransformer

In [3]:
import pandas as pd
import numpy as np
import torch
from transformers import BertTokenizer
from torch import nn
from transformers import BertModel
from transformers import RobertaModel, RobertaTokenizerFast
from torch.optim import Adam
from tqdm import tqdm
from sklearn import metrics
from torch.nn.parallel import DistributedDataParallel
import os
import random

# definitions
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

BERT_CONFIG = 'roberta-large'
labels = {0:0, 1:1}
BATCH_SIZE = 128
tokenizer = RobertaTokenizerFast.from_pretrained(BERT_CONFIG)

class Dataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.labels = [labels[label] for label in df['label']]
        self.texts = [tokenizer(text, padding='max_length', max_length=512, truncation=True,
                                return_tensors="pt") for text in df['commit_message']]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):
        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_texts, batch_y

class BertClassifier(nn.Module):
    def __init__(self, dropout=0.5):
        super(BertClassifier, self).__init__()

        self.bert = RobertaModel.from_pretrained(BERT_CONFIG)
        self.dropout = nn.Dropout(dropout)
        if BERT_CONFIG == 'roberta-large':
            self.linear = nn.Linear(1024, len(labels))
        else:
            self.linear = nn.Linear(768, len(labels))
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        _, pooled_output = self.bert(input_ids=input_id, attention_mask=mask, return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        # final_layer = self.relu(linear_output) # IMPO CHANGE
        return linear_output

    def check_parameters(self):
        print('The number of Bert parameters:', self.bert.num_parameters())

import torch.nn.functional as F

class ParallelConv(nn.Module):

    def __init__(self, input_dims, filters, dropout=0.5):
        super().__init__()
        convs = []        
        self.output_dims = sum([t[1] for t in filters])
        for (filter_length, output_dims) in filters:
            pad = filter_length//2
            conv = nn.Sequential(
                nn.Conv1d(input_dims, output_dims, filter_length, padding=pad),
                nn.ReLU()
            )
            convs.append(conv)
        # Add the module so its managed correctly
        self.convs = nn.ModuleList(convs)
        self.conv_drop = nn.Dropout(dropout)

    def forward(self, input_bct):
        mots = []
        for conv in self.convs:
            # In Conv1d, data BxCxT, max over time
            conv_out = conv(input_bct)
            mot, _ = conv_out.max(2)
            mots.append(mot)
        mots = torch.cat(mots, 1)
        return self.conv_drop(mots)

class ConvClassifier(nn.Module):

    def __init__(self, embed_dims,
                 filters=[(2, 100), (3, 100), (4, 100)],
                 dropout=0.5, hidden_units=[]):
        super().__init__()
        self.bert = RobertaModel.from_pretrained(BERT_CONFIG)
        self.dropout = nn.Dropout(dropout)
        self.convs = ParallelConv(embed_dims, filters, dropout)
        
        input_units = self.convs.output_dims
        output_units = self.convs.output_dims
        sequence = []
        for h in hidden_units:
            sequence.append(self.dropout(nn.Linear(input_units, h)))
            input_units = h
            output_units = h
            
        sequence.append(nn.Linear(output_units, 2))
        self.outputs = nn.Sequential(*sequence)

    def forward(self, input_id, mask):
        x, pooled_output = self.bert(input_ids=input_id, attention_mask=mask, return_dict=False)
        embed = self.dropout(x)
        embed = embed.transpose(1, 2).contiguous()
        hidden = self.convs(embed)
        linear = self.outputs(hidden)
        return F.log_softmax(linear, dim=-1)


# ------------------------------------------------------------------------------
# generate intermediate data by MsgTransformer
embed_dim = 1024
model = ConvClassifier(embed_dim)
model.load_state_dict(torch.load(finetuned_mt_model_path))
model.to(device)

def generate_mt_intermediate_dataset(input_data, intermediate_data_path):
    data_loader = torch.utils.data.DataLoader(Dataset(input_data), batch_size=BATCH_SIZE)
    
    model.eval()
    with torch.no_grad():
        for texts, labels in tqdm(data_loader):
            labels = labels.to(device)
            masks = texts['attention_mask'].to(device)
            input_ids = texts['input_ids'].squeeze(1).to(device)
            outputs = model(input_ids, masks)

            probs = torch.nn.functional.softmax(outputs, dim=1).tolist()
            assert(len(probs) == len(labels))
            for i in range(len(probs)):
                prob = probs[i]
                label = int(labels[i])
                content = '\t'.join([str(i) for i in prob + [label]]) + '\n'
                write_to_file(content, intermediate_data_path)

df_train = pd.read_json(init_train_path)
df_val = pd.read_json(init_val_path)

remove_file_if_exist(intermediate_mt_train_path)
remove_file_if_exist(intermediate_mt_val_path)

generate_mt_intermediate_dataset(df_train, intermediate_mt_train_path)
generate_mt_intermediate_dataset(df_val, intermediate_mt_val_path)

# ------------------------------------------------------------------------------
# evaluation
print('\nEvaluation:')

seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def evaluate(model, test_data):
    test = Dataset(test_data)
    test_dataloader = torch.utils.data.DataLoader(test, batch_size=BATCH_SIZE)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    total_acc_test = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    model.eval()
    with torch.no_grad():
        for test_input, test_label in test_dataloader:
            test_label = test_label.to(device)
            mask = test_input['attention_mask'].to(device)
            input_id = test_input['input_ids'].squeeze(1).to(device)

            output = model(input_id, mask)

            acc = (output.argmax(dim=1) == test_label).sum().item()
            total_acc_test += acc

            test_label = test_label.data.cpu().numpy()
            predic = output.argmax(dim=1).data.cpu().numpy()
            labels_all = np.append(labels_all, test_label)
            predict_all = np.append(predict_all, predic)

    report = metrics.classification_report(labels_all, predict_all, target_names=['benign', 'vulnerable'], digits=4)
    confusion = metrics.confusion_matrix(labels_all, predict_all)
    print(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')
    print(report)
    print(confusion)

embed_dim = 1024
model = ConvClassifier(embed_dim)
model.to(device)
model.load_state_dict(torch.load(finetuned_mt_model_path))
evaluate(model, df_val)



Some weights of the model checkpoint at roberta-large were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 70/70 [02:32<00:00,  2.17s/it]
100%|██████████| 24/24 [00:50<00:00,  2.12s/it]



Evaluation:


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Test Accuracy:  0.976
              precision    recall  f1-score   support

      benign     0.9749    0.9839    0.9794      1738
  vulnerable     0.9770    0.9643    0.9706      1231

    accuracy                         0.9757      2969
   macro avg     0.9759    0.9741    0.9750      2969
weighted avg     0.9758    0.9757    0.9757      2969

[[1710   28]
 [  44 1187]]


# Combine everything into intermediate dataset

In [4]:
intermediate_train_path = f'{intermediate_directory}/train.txt'
intermediate_val_path = f'{intermediate_directory}/val.txt'

def generate_intermediate_dataset(intermediate_mt_data_path, intermediate_ct_data_path, intermediate_data_path):
    with open(intermediate_mt_data_path) as f:
        mt_data_list = f.read().split('\n')
    
    with open(intermediate_ct_data_path) as f:
        ct_data_list = f.read().split('\n')
    
    mt_data_list = mt_data_list[:-1] if not mt_data_list[-1] else mt_data_list
    ct_data_list = ct_data_list[:-1] if not ct_data_list[-1] else ct_data_list

    assert(len(mt_data_list) == len(ct_data_list))
    
    for i in range(len(mt_data_list)):
        mt_data = mt_data_list[i].split('\t')
        ct_data = ct_data_list[i].split('\t')
        assert(mt_data[2] == ct_data[2])
        label = mt_data[2]
        content = '\t'.join(mt_data[:2] + ct_data[:2] + [label])
        content = content + '\n' if i < len(mt_data_list) - 1 else content
        write_to_file(content, intermediate_data_path)

remove_file_if_exist(intermediate_train_path)
remove_file_if_exist(intermediate_val_path)

generate_intermediate_dataset(intermediate_mt_train_path, intermediate_ct_train_path, intermediate_train_path)
generate_intermediate_dataset(intermediate_mt_val_path, intermediate_ct_val_path, intermediate_val_path)


# Ensemble learning

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn import metrics

EPOCHS = 150
LR = 1e-6
BATCH_SIZE = 4
intermediate_train_path = f'{intermediate_directory}/train.txt'
intermediate_val_path = f'{intermediate_directory}/val.txt'
MODEL_SAVE_PATH = f'{root_directory}/ensemble_model/{DATASET_MASKING}{DATASET_NAME}'
mkdir_if_not_exist(f'{root_directory}/ensemble_model')
remove_file_if_exist(MODEL_SAVE_PATH)
mkdir_if_not_exist(MODEL_SAVE_PATH)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

class MyDataset(Dataset):
    def __init__(self, path):
        with open(path) as f:
            data_list = f.read().split('\n')
        self.labels = [ int(data.split('\t')[-1]) for data in data_list ]
        self.inputs = [ [float(v) for v in data.split('\t')[:-1]] for data in data_list ]
        assert(len(self.labels) == len(self.inputs))

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = self.inputs[idx]
        y = self.labels[idx]
        return x[0], x[1], x[2], x[3], y

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(input_dim, 20)
        self.out = nn.Linear(20, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.out(x)
        return x

def train(model, train_dataset, val_dataset):
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)

    for epoch_num in range(EPOCHS):
        model.train()
        total_acc_train = 0
        total_loss_train = 0
        for x1, x2, x3, x4, y in tqdm(train_dataloader):
            x = torch.transpose(torch.stack([x1, x2, x3, x4]), 0, 1).float().to(device)
            y = y.to(device)
            y_pred = model(x)

            loss = criterion(y_pred, y)
            total_loss_train += loss.item()

            acc = (y_pred.argmax(dim=1) == y).sum().item()
            total_acc_train += acc

            model.zero_grad()
            loss.backward()
            optimizer.step()

        total_acc_val = 0
        total_loss_val = 0
        model.eval()
        with torch.no_grad():
            for x1, x2, x3, x4, y in val_dataloader:
                x = torch.transpose(torch.stack([x1, x2, x3, x4]), 0, 1).float().to(device)
                y = y.to(device)
                y_pred = model(x)

                loss = criterion(y_pred, y)
                total_loss_val += loss.item()

                acc = (y_pred.argmax(dim=1) == y).sum().item()
                total_acc_val += acc

        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_dataset): .4f} \
            | Train Accuracy: {total_acc_train / len(train_dataset): .4f} \
            | Val Loss: {total_loss_val / len(val_dataset): .4f} \
            | Val Accuracy: {total_acc_val / len(val_dataset): .4f}')

        val_acc = f'{total_acc_val / len(val_dataset):.4f}'
        torch.save(model.state_dict(), f'{MODEL_SAVE_PATH}/ensemble2_cnn_{val_acc}_epoch{epoch_num + 1}.pt')

train_dataset = MyDataset(intermediate_train_path)
val_dataset = MyDataset(intermediate_val_path)

model = MLP(4, 2)
train(model, train_dataset, val_dataset)


100%|██████████| 2227/2227 [00:03<00:00, 626.13it/s]


Epochs: 1 | Train Loss:  0.1421             | Train Accuracy:  0.8946             | Val Loss:  0.1453             | Val Accuracy:  0.8164


100%|██████████| 2227/2227 [00:03<00:00, 598.17it/s]


Epochs: 2 | Train Loss:  0.1410             | Train Accuracy:  0.8949             | Val Loss:  0.1442             | Val Accuracy:  0.8218


100%|██████████| 2227/2227 [00:03<00:00, 619.31it/s]


Epochs: 3 | Train Loss:  0.1395             | Train Accuracy:  0.9064             | Val Loss:  0.1431             | Val Accuracy:  0.8265


100%|██████████| 2227/2227 [00:03<00:00, 600.42it/s]


Epochs: 4 | Train Loss:  0.1386             | Train Accuracy:  0.8988             | Val Loss:  0.1420             | Val Accuracy:  0.8286


100%|██████████| 2227/2227 [00:03<00:00, 625.17it/s]


Epochs: 5 | Train Loss:  0.1371             | Train Accuracy:  0.9067             | Val Loss:  0.1410             | Val Accuracy:  0.8356


100%|██████████| 2227/2227 [00:04<00:00, 544.15it/s]


Epochs: 6 | Train Loss:  0.1359             | Train Accuracy:  0.9083             | Val Loss:  0.1399             | Val Accuracy:  0.8390


100%|██████████| 2227/2227 [00:03<00:00, 590.84it/s]


Epochs: 7 | Train Loss:  0.1348             | Train Accuracy:  0.9075             | Val Loss:  0.1388             | Val Accuracy:  0.8447


100%|██████████| 2227/2227 [00:03<00:00, 558.45it/s]


Epochs: 8 | Train Loss:  0.1335             | Train Accuracy:  0.9120             | Val Loss:  0.1377             | Val Accuracy:  0.8474


100%|██████████| 2227/2227 [00:04<00:00, 545.32it/s]


Epochs: 9 | Train Loss:  0.1325             | Train Accuracy:  0.9098             | Val Loss:  0.1366             | Val Accuracy:  0.8515


100%|██████████| 2227/2227 [00:04<00:00, 555.61it/s]


Epochs: 10 | Train Loss:  0.1312             | Train Accuracy:  0.9122             | Val Loss:  0.1355             | Val Accuracy:  0.8555


100%|██████████| 2227/2227 [00:03<00:00, 615.68it/s]


Epochs: 11 | Train Loss:  0.1299             | Train Accuracy:  0.9178             | Val Loss:  0.1344             | Val Accuracy:  0.8606


100%|██████████| 2227/2227 [00:03<00:00, 617.62it/s]


Epochs: 12 | Train Loss:  0.1282             | Train Accuracy:  0.9236             | Val Loss:  0.1333             | Val Accuracy:  0.8636


100%|██████████| 2227/2227 [00:03<00:00, 647.55it/s]


Epochs: 13 | Train Loss:  0.1272             | Train Accuracy:  0.9258             | Val Loss:  0.1322             | Val Accuracy:  0.8680


100%|██████████| 2227/2227 [00:03<00:00, 568.58it/s]


Epochs: 14 | Train Loss:  0.1258             | Train Accuracy:  0.9268             | Val Loss:  0.1311             | Val Accuracy:  0.8713


100%|██████████| 2227/2227 [00:03<00:00, 573.21it/s]


Epochs: 15 | Train Loss:  0.1249             | Train Accuracy:  0.9270             | Val Loss:  0.1300             | Val Accuracy:  0.8754


100%|██████████| 2227/2227 [00:04<00:00, 554.62it/s]


Epochs: 16 | Train Loss:  0.1237             | Train Accuracy:  0.9287             | Val Loss:  0.1288             | Val Accuracy:  0.8794


100%|██████████| 2227/2227 [00:03<00:00, 566.48it/s]


Epochs: 17 | Train Loss:  0.1224             | Train Accuracy:  0.9311             | Val Loss:  0.1277             | Val Accuracy:  0.8835


100%|██████████| 2227/2227 [00:04<00:00, 548.22it/s]


Epochs: 18 | Train Loss:  0.1211             | Train Accuracy:  0.9345             | Val Loss:  0.1265             | Val Accuracy:  0.8872


100%|██████████| 2227/2227 [00:03<00:00, 566.83it/s]


Epochs: 19 | Train Loss:  0.1196             | Train Accuracy:  0.9395             | Val Loss:  0.1254             | Val Accuracy:  0.8922


100%|██████████| 2227/2227 [00:03<00:00, 617.66it/s]


Epochs: 20 | Train Loss:  0.1180             | Train Accuracy:  0.9424             | Val Loss:  0.1242             | Val Accuracy:  0.8976


100%|██████████| 2227/2227 [00:03<00:00, 598.58it/s]


Epochs: 21 | Train Loss:  0.1169             | Train Accuracy:  0.9451             | Val Loss:  0.1231             | Val Accuracy:  0.9037


100%|██████████| 2227/2227 [00:03<00:00, 564.00it/s]


Epochs: 22 | Train Loss:  0.1157             | Train Accuracy:  0.9442             | Val Loss:  0.1219             | Val Accuracy:  0.9084


100%|██████████| 2227/2227 [00:03<00:00, 622.86it/s]


Epochs: 23 | Train Loss:  0.1146             | Train Accuracy:  0.9437             | Val Loss:  0.1208             | Val Accuracy:  0.9148


100%|██████████| 2227/2227 [00:03<00:00, 605.52it/s]


Epochs: 24 | Train Loss:  0.1132             | Train Accuracy:  0.9495             | Val Loss:  0.1196             | Val Accuracy:  0.9229


100%|██████████| 2227/2227 [00:03<00:00, 618.83it/s]


Epochs: 25 | Train Loss:  0.1120             | Train Accuracy:  0.9462             | Val Loss:  0.1184             | Val Accuracy:  0.9286


100%|██████████| 2227/2227 [00:03<00:00, 588.49it/s]


Epochs: 26 | Train Loss:  0.1104             | Train Accuracy:  0.9547             | Val Loss:  0.1173             | Val Accuracy:  0.9330


100%|██████████| 2227/2227 [00:03<00:00, 625.30it/s]


Epochs: 27 | Train Loss:  0.1091             | Train Accuracy:  0.9540             | Val Loss:  0.1161             | Val Accuracy:  0.9427


100%|██████████| 2227/2227 [00:03<00:00, 565.32it/s]


Epochs: 28 | Train Loss:  0.1080             | Train Accuracy:  0.9539             | Val Loss:  0.1149             | Val Accuracy:  0.9475


100%|██████████| 2227/2227 [00:03<00:00, 590.65it/s]


Epochs: 29 | Train Loss:  0.1066             | Train Accuracy:  0.9613             | Val Loss:  0.1138             | Val Accuracy:  0.9535


100%|██████████| 2227/2227 [00:03<00:00, 600.88it/s]


Epochs: 30 | Train Loss:  0.1053             | Train Accuracy:  0.9611             | Val Loss:  0.1126             | Val Accuracy:  0.9586


100%|██████████| 2227/2227 [00:03<00:00, 638.19it/s]


Epochs: 31 | Train Loss:  0.1042             | Train Accuracy:  0.9571             | Val Loss:  0.1114             | Val Accuracy:  0.9616


100%|██████████| 2227/2227 [00:03<00:00, 631.98it/s]


Epochs: 32 | Train Loss:  0.1026             | Train Accuracy:  0.9620             | Val Loss:  0.1103             | Val Accuracy:  0.9660


100%|██████████| 2227/2227 [00:03<00:00, 627.48it/s]


Epochs: 33 | Train Loss:  0.1016             | Train Accuracy:  0.9611             | Val Loss:  0.1091             | Val Accuracy:  0.9697


100%|██████████| 2227/2227 [00:03<00:00, 623.36it/s]


Epochs: 34 | Train Loss:  0.1000             | Train Accuracy:  0.9644             | Val Loss:  0.1079             | Val Accuracy:  0.9714


100%|██████████| 2227/2227 [00:03<00:00, 582.42it/s]


Epochs: 35 | Train Loss:  0.0986             | Train Accuracy:  0.9662             | Val Loss:  0.1068             | Val Accuracy:  0.9724


100%|██████████| 2227/2227 [00:03<00:00, 568.22it/s]


Epochs: 36 | Train Loss:  0.0975             | Train Accuracy:  0.9664             | Val Loss:  0.1056             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:03<00:00, 612.55it/s]


Epochs: 37 | Train Loss:  0.0961             | Train Accuracy:  0.9696             | Val Loss:  0.1044             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:03<00:00, 607.29it/s]


Epochs: 38 | Train Loss:  0.0946             | Train Accuracy:  0.9728             | Val Loss:  0.1033             | Val Accuracy:  0.9741


100%|██████████| 2227/2227 [00:03<00:00, 609.24it/s]


Epochs: 39 | Train Loss:  0.0936             | Train Accuracy:  0.9705             | Val Loss:  0.1021             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:03<00:00, 601.81it/s]


Epochs: 40 | Train Loss:  0.0924             | Train Accuracy:  0.9719             | Val Loss:  0.1010             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:03<00:00, 593.63it/s]


Epochs: 41 | Train Loss:  0.0910             | Train Accuracy:  0.9736             | Val Loss:  0.0998             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:03<00:00, 604.35it/s]


Epochs: 42 | Train Loss:  0.0897             | Train Accuracy:  0.9751             | Val Loss:  0.0987             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:03<00:00, 614.35it/s]


Epochs: 43 | Train Loss:  0.0884             | Train Accuracy:  0.9746             | Val Loss:  0.0975             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:04<00:00, 535.94it/s]


Epochs: 44 | Train Loss:  0.0874             | Train Accuracy:  0.9751             | Val Loss:  0.0964             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:04<00:00, 554.50it/s]


Epochs: 45 | Train Loss:  0.0859             | Train Accuracy:  0.9777             | Val Loss:  0.0953             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:03<00:00, 602.48it/s]


Epochs: 46 | Train Loss:  0.0847             | Train Accuracy:  0.9801             | Val Loss:  0.0941             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:03<00:00, 581.63it/s]


Epochs: 47 | Train Loss:  0.0835             | Train Accuracy:  0.9777             | Val Loss:  0.0930             | Val Accuracy:  0.9741


100%|██████████| 2227/2227 [00:03<00:00, 576.28it/s]


Epochs: 48 | Train Loss:  0.0819             | Train Accuracy:  0.9815             | Val Loss:  0.0919             | Val Accuracy:  0.9741


100%|██████████| 2227/2227 [00:03<00:00, 560.61it/s]


Epochs: 49 | Train Loss:  0.0811             | Train Accuracy:  0.9804             | Val Loss:  0.0908             | Val Accuracy:  0.9741


100%|██████████| 2227/2227 [00:03<00:00, 562.29it/s]


Epochs: 50 | Train Loss:  0.0797             | Train Accuracy:  0.9811             | Val Loss:  0.0897             | Val Accuracy:  0.9741


100%|██████████| 2227/2227 [00:03<00:00, 559.05it/s]


Epochs: 51 | Train Loss:  0.0785             | Train Accuracy:  0.9811             | Val Loss:  0.0886             | Val Accuracy:  0.9744


100%|██████████| 2227/2227 [00:03<00:00, 560.56it/s]


Epochs: 52 | Train Loss:  0.0773             | Train Accuracy:  0.9832             | Val Loss:  0.0875             | Val Accuracy:  0.9747


100%|██████████| 2227/2227 [00:03<00:00, 586.83it/s]


Epochs: 53 | Train Loss:  0.0758             | Train Accuracy:  0.9847             | Val Loss:  0.0864             | Val Accuracy:  0.9747


100%|██████████| 2227/2227 [00:03<00:00, 610.85it/s]


Epochs: 54 | Train Loss:  0.0748             | Train Accuracy:  0.9864             | Val Loss:  0.0853             | Val Accuracy:  0.9747


100%|██████████| 2227/2227 [00:03<00:00, 589.05it/s]


Epochs: 55 | Train Loss:  0.0736             | Train Accuracy:  0.9845             | Val Loss:  0.0842             | Val Accuracy:  0.9747


100%|██████████| 2227/2227 [00:04<00:00, 549.33it/s]


Epochs: 56 | Train Loss:  0.0723             | Train Accuracy:  0.9847             | Val Loss:  0.0831             | Val Accuracy:  0.9747


100%|██████████| 2227/2227 [00:04<00:00, 547.20it/s]


Epochs: 57 | Train Loss:  0.0713             | Train Accuracy:  0.9874             | Val Loss:  0.0821             | Val Accuracy:  0.9747


100%|██████████| 2227/2227 [00:03<00:00, 580.14it/s]


Epochs: 58 | Train Loss:  0.0697             | Train Accuracy:  0.9879             | Val Loss:  0.0810             | Val Accuracy:  0.9747


100%|██████████| 2227/2227 [00:03<00:00, 576.88it/s]


Epochs: 59 | Train Loss:  0.0691             | Train Accuracy:  0.9846             | Val Loss:  0.0799             | Val Accuracy:  0.9747


100%|██████████| 2227/2227 [00:03<00:00, 560.07it/s]


Epochs: 60 | Train Loss:  0.0679             | Train Accuracy:  0.9842             | Val Loss:  0.0789             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 605.10it/s]


Epochs: 61 | Train Loss:  0.0664             | Train Accuracy:  0.9873             | Val Loss:  0.0778             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 556.88it/s]


Epochs: 62 | Train Loss:  0.0651             | Train Accuracy:  0.9890             | Val Loss:  0.0768             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 610.16it/s]


Epochs: 63 | Train Loss:  0.0641             | Train Accuracy:  0.9885             | Val Loss:  0.0758             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 645.62it/s]


Epochs: 64 | Train Loss:  0.0631             | Train Accuracy:  0.9889             | Val Loss:  0.0747             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 661.20it/s]


Epochs: 65 | Train Loss:  0.0619             | Train Accuracy:  0.9885             | Val Loss:  0.0737             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 637.20it/s]


Epochs: 66 | Train Loss:  0.0607             | Train Accuracy:  0.9906             | Val Loss:  0.0727             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 603.49it/s]


Epochs: 67 | Train Loss:  0.0594             | Train Accuracy:  0.9900             | Val Loss:  0.0717             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 611.33it/s]


Epochs: 68 | Train Loss:  0.0589             | Train Accuracy:  0.9898             | Val Loss:  0.0707             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 647.90it/s]


Epochs: 69 | Train Loss:  0.0573             | Train Accuracy:  0.9914             | Val Loss:  0.0697             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 614.20it/s]


Epochs: 70 | Train Loss:  0.0564             | Train Accuracy:  0.9919             | Val Loss:  0.0687             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 629.05it/s]


Epochs: 71 | Train Loss:  0.0552             | Train Accuracy:  0.9914             | Val Loss:  0.0678             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 590.26it/s]


Epochs: 72 | Train Loss:  0.0540             | Train Accuracy:  0.9924             | Val Loss:  0.0668             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 571.20it/s]


Epochs: 73 | Train Loss:  0.0533             | Train Accuracy:  0.9914             | Val Loss:  0.0659             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 593.69it/s]


Epochs: 74 | Train Loss:  0.0523             | Train Accuracy:  0.9916             | Val Loss:  0.0649             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 561.78it/s]


Epochs: 75 | Train Loss:  0.0513             | Train Accuracy:  0.9918             | Val Loss:  0.0640             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:04<00:00, 547.05it/s]


Epochs: 76 | Train Loss:  0.0502             | Train Accuracy:  0.9926             | Val Loss:  0.0631             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 634.09it/s]


Epochs: 77 | Train Loss:  0.0492             | Train Accuracy:  0.9928             | Val Loss:  0.0622             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:04<00:00, 545.20it/s]


Epochs: 78 | Train Loss:  0.0480             | Train Accuracy:  0.9932             | Val Loss:  0.0613             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 560.80it/s]


Epochs: 79 | Train Loss:  0.0472             | Train Accuracy:  0.9927             | Val Loss:  0.0604             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 583.66it/s]


Epochs: 80 | Train Loss:  0.0463             | Train Accuracy:  0.9940             | Val Loss:  0.0595             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 603.28it/s]


Epochs: 81 | Train Loss:  0.0452             | Train Accuracy:  0.9945             | Val Loss:  0.0587             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 636.29it/s]


Epochs: 82 | Train Loss:  0.0444             | Train Accuracy:  0.9940             | Val Loss:  0.0578             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 644.97it/s]


Epochs: 83 | Train Loss:  0.0439             | Train Accuracy:  0.9930             | Val Loss:  0.0570             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 613.58it/s]


Epochs: 84 | Train Loss:  0.0423             | Train Accuracy:  0.9954             | Val Loss:  0.0561             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 627.43it/s]


Epochs: 85 | Train Loss:  0.0415             | Train Accuracy:  0.9960             | Val Loss:  0.0553             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 627.81it/s]


Epochs: 86 | Train Loss:  0.0408             | Train Accuracy:  0.9951             | Val Loss:  0.0545             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 589.90it/s]


Epochs: 87 | Train Loss:  0.0399             | Train Accuracy:  0.9961             | Val Loss:  0.0537             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 599.81it/s]


Epochs: 88 | Train Loss:  0.0389             | Train Accuracy:  0.9969             | Val Loss:  0.0530             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 617.18it/s]


Epochs: 89 | Train Loss:  0.0380             | Train Accuracy:  0.9974             | Val Loss:  0.0522             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 609.02it/s]


Epochs: 90 | Train Loss:  0.0372             | Train Accuracy:  0.9970             | Val Loss:  0.0514             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 591.02it/s]


Epochs: 91 | Train Loss:  0.0366             | Train Accuracy:  0.9973             | Val Loss:  0.0507             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 603.09it/s]


Epochs: 92 | Train Loss:  0.0358             | Train Accuracy:  0.9970             | Val Loss:  0.0500             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 599.11it/s]


Epochs: 93 | Train Loss:  0.0348             | Train Accuracy:  0.9970             | Val Loss:  0.0493             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 606.29it/s]


Epochs: 94 | Train Loss:  0.0340             | Train Accuracy:  0.9983             | Val Loss:  0.0486             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 641.16it/s]


Epochs: 95 | Train Loss:  0.0337             | Train Accuracy:  0.9973             | Val Loss:  0.0479             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 624.64it/s]


Epochs: 96 | Train Loss:  0.0325             | Train Accuracy:  0.9988             | Val Loss:  0.0472             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 604.81it/s]


Epochs: 97 | Train Loss:  0.0319             | Train Accuracy:  0.9980             | Val Loss:  0.0465             | Val Accuracy:  0.9761


100%|██████████| 2227/2227 [00:03<00:00, 601.48it/s]


Epochs: 98 | Train Loss:  0.0315             | Train Accuracy:  0.9976             | Val Loss:  0.0459             | Val Accuracy:  0.9761


100%|██████████| 2227/2227 [00:03<00:00, 593.59it/s]


Epochs: 99 | Train Loss:  0.0303             | Train Accuracy:  0.9985             | Val Loss:  0.0453             | Val Accuracy:  0.9761


100%|██████████| 2227/2227 [00:03<00:00, 621.07it/s]


Epochs: 100 | Train Loss:  0.0299             | Train Accuracy:  0.9989             | Val Loss:  0.0447             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 615.49it/s]


Epochs: 101 | Train Loss:  0.0291             | Train Accuracy:  0.9991             | Val Loss:  0.0441             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 596.53it/s]


Epochs: 102 | Train Loss:  0.0282             | Train Accuracy:  0.9989             | Val Loss:  0.0435             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 600.83it/s]


Epochs: 103 | Train Loss:  0.0275             | Train Accuracy:  0.9993             | Val Loss:  0.0429             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 606.76it/s]


Epochs: 104 | Train Loss:  0.0268             | Train Accuracy:  0.9991             | Val Loss:  0.0423             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 619.00it/s]


Epochs: 105 | Train Loss:  0.0267             | Train Accuracy:  0.9988             | Val Loss:  0.0418             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 605.02it/s]


Epochs: 106 | Train Loss:  0.0258             | Train Accuracy:  0.9994             | Val Loss:  0.0412             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 603.91it/s]


Epochs: 107 | Train Loss:  0.0251             | Train Accuracy:  0.9992             | Val Loss:  0.0407             | Val Accuracy:  0.9757


100%|██████████| 2227/2227 [00:03<00:00, 619.57it/s]


Epochs: 108 | Train Loss:  0.0246             | Train Accuracy:  0.9993             | Val Loss:  0.0402             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 646.35it/s]


Epochs: 109 | Train Loss:  0.0238             | Train Accuracy:  0.9991             | Val Loss:  0.0397             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 609.45it/s]


Epochs: 110 | Train Loss:  0.0233             | Train Accuracy:  0.9994             | Val Loss:  0.0392             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 627.89it/s]


Epochs: 111 | Train Loss:  0.0227             | Train Accuracy:  0.9997             | Val Loss:  0.0388             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 597.43it/s]


Epochs: 112 | Train Loss:  0.0223             | Train Accuracy:  0.9994             | Val Loss:  0.0383             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 610.01it/s]


Epochs: 113 | Train Loss:  0.0218             | Train Accuracy:  0.9992             | Val Loss:  0.0379             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 598.92it/s]


Epochs: 114 | Train Loss:  0.0210             | Train Accuracy:  0.9990             | Val Loss:  0.0374             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 630.09it/s]


Epochs: 115 | Train Loss:  0.0207             | Train Accuracy:  0.9997             | Val Loss:  0.0370             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 584.46it/s]


Epochs: 116 | Train Loss:  0.0199             | Train Accuracy:  0.9993             | Val Loss:  0.0366             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 609.58it/s]


Epochs: 117 | Train Loss:  0.0198             | Train Accuracy:  0.9994             | Val Loss:  0.0362             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 626.67it/s]


Epochs: 118 | Train Loss:  0.0190             | Train Accuracy:  0.9997             | Val Loss:  0.0358             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 573.96it/s]


Epochs: 119 | Train Loss:  0.0187             | Train Accuracy:  0.9993             | Val Loss:  0.0354             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:04<00:00, 553.93it/s]


Epochs: 120 | Train Loss:  0.0185             | Train Accuracy:  0.9996             | Val Loss:  0.0351             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 639.76it/s]


Epochs: 121 | Train Loss:  0.0178             | Train Accuracy:  0.9999             | Val Loss:  0.0347             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 625.75it/s]


Epochs: 122 | Train Loss:  0.0174             | Train Accuracy:  0.9998             | Val Loss:  0.0344             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 654.28it/s]


Epochs: 123 | Train Loss:  0.0170             | Train Accuracy:  0.9996             | Val Loss:  0.0340             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 643.34it/s]


Epochs: 124 | Train Loss:  0.0166             | Train Accuracy:  0.9998             | Val Loss:  0.0337             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 598.99it/s]


Epochs: 125 | Train Loss:  0.0164             | Train Accuracy:  0.9999             | Val Loss:  0.0334             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 632.73it/s]


Epochs: 126 | Train Loss:  0.0157             | Train Accuracy:  0.9999             | Val Loss:  0.0331             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 653.62it/s]


Epochs: 127 | Train Loss:  0.0154             | Train Accuracy:  0.9996             | Val Loss:  0.0328             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 664.62it/s]


Epochs: 128 | Train Loss:  0.0150             | Train Accuracy:  0.9994             | Val Loss:  0.0325             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 656.76it/s]


Epochs: 129 | Train Loss:  0.0147             | Train Accuracy:  0.9997             | Val Loss:  0.0322             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 626.98it/s]


Epochs: 130 | Train Loss:  0.0143             | Train Accuracy:  0.9999             | Val Loss:  0.0319             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 627.48it/s]


Epochs: 131 | Train Loss:  0.0140             | Train Accuracy:  0.9996             | Val Loss:  0.0317             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 622.79it/s]


Epochs: 132 | Train Loss:  0.0135             | Train Accuracy:  0.9997             | Val Loss:  0.0314             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 651.15it/s]


Epochs: 133 | Train Loss:  0.0131             | Train Accuracy:  0.9999             | Val Loss:  0.0312             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 661.55it/s]


Epochs: 134 | Train Loss:  0.0129             | Train Accuracy:  0.9998             | Val Loss:  0.0310             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 640.97it/s]


Epochs: 135 | Train Loss:  0.0124             | Train Accuracy:  0.9998             | Val Loss:  0.0307             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 637.71it/s]


Epochs: 136 | Train Loss:  0.0122             | Train Accuracy:  0.9997             | Val Loss:  0.0305             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 650.29it/s]


Epochs: 137 | Train Loss:  0.0120             | Train Accuracy:  0.9999             | Val Loss:  0.0303             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 636.04it/s]


Epochs: 138 | Train Loss:  0.0116             | Train Accuracy:  0.9996             | Val Loss:  0.0301             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 635.90it/s]


Epochs: 139 | Train Loss:  0.0114             | Train Accuracy:  0.9998             | Val Loss:  0.0299             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 613.90it/s]


Epochs: 140 | Train Loss:  0.0110             | Train Accuracy:  0.9999             | Val Loss:  0.0297             | Val Accuracy:  0.9751


100%|██████████| 2227/2227 [00:03<00:00, 568.28it/s]


Epochs: 141 | Train Loss:  0.0109             | Train Accuracy:  0.9999             | Val Loss:  0.0296             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:04<00:00, 552.08it/s]


Epochs: 142 | Train Loss:  0.0105             | Train Accuracy:  0.9998             | Val Loss:  0.0294             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 655.42it/s]


Epochs: 143 | Train Loss:  0.0101             | Train Accuracy:  0.9999             | Val Loss:  0.0292             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 642.68it/s]


Epochs: 144 | Train Loss:  0.0099             | Train Accuracy:  0.9998             | Val Loss:  0.0291             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 635.63it/s]


Epochs: 145 | Train Loss:  0.0097             | Train Accuracy:  0.9999             | Val Loss:  0.0290             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 600.12it/s]


Epochs: 146 | Train Loss:  0.0095             | Train Accuracy:  0.9998             | Val Loss:  0.0288             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 589.64it/s]


Epochs: 147 | Train Loss:  0.0094             | Train Accuracy:  0.9996             | Val Loss:  0.0287             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 598.23it/s]


Epochs: 148 | Train Loss:  0.0090             | Train Accuracy:  0.9999             | Val Loss:  0.0286             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 623.20it/s]


Epochs: 149 | Train Loss:  0.0088             | Train Accuracy:  0.9998             | Val Loss:  0.0284             | Val Accuracy:  0.9754


100%|██████████| 2227/2227 [00:03<00:00, 641.20it/s]


Epochs: 150 | Train Loss:  0.0085             | Train Accuracy:  0.9997             | Val Loss:  0.0283             | Val Accuracy:  0.9754


In [12]:
def evaluate(model, test_dataset):
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
    model = model.to(device)
    
    total_acc_test = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    model.eval()
    with torch.no_grad():
        for x1, x2, x3, x4, y in test_dataloader:
            x = torch.transpose(torch.stack([x1, x2, x3, x4]), 0, 1).float().to(device)
            y = y.to(device)
            y_pred = model(x)
            
            acc = (y_pred.argmax(dim=1) == y).sum().item()
            total_acc_test += acc
            
            y = y.data.cpu().numpy()
            predic = y_pred.argmax(dim=1).data.cpu().numpy()
            labels_all = np.append(labels_all, y)
            predict_all = np.append(predict_all, predic)

    report = metrics.classification_report(labels_all, predict_all, target_names=['benign', 'vulnerable'], digits=4)
    confusion = metrics.confusion_matrix(labels_all, predict_all)
    print(f'Test Accuracy: {total_acc_test / len(test_dataset): .4f}')
    print(report)
    print(confusion)

mkdir_if_not_exist(MODEL_SAVE_PATH)
model = MLP(4, 2)
saved_model_name = 'ensemble2_cnn_0.9761_epoch99.pt'
model.load_state_dict(torch.load(f'{MODEL_SAVE_PATH}/{saved_model_name}'))
evaluate(model, val_dataset)



Test Accuracy:  0.9761
              precision    recall  f1-score   support

      benign     0.9738    0.9856    0.9797      1738
  vulnerable     0.9793    0.9626    0.9709      1231

    accuracy                         0.9761      2969
   macro avg     0.9766    0.9741    0.9753      2969
weighted avg     0.9761    0.9761    0.9761      2969

[[1713   25]
 [  46 1185]]
