## ET-BERT embeddings

In [None]:
%cd ../models/ET-BERT/src

/pscratch/sd/k/kell/demystifying/ET-BERT/src


In [2]:
from finetuning.run_classifier import Classifier, read_dataset, load_or_initialize_parameters, count_labels_num, batch_loader
import torch
from collections import defaultdict
from argparse import Namespace
from uer.layers import *
from uer.encoders import *
from uer.utils.constants import *
from uer.utils import *
from uer.utils.optimizers import *
from uer.opts import finetune_opts
from uer.utils.config import load_hyperparam
import argparse
from tqdm import tqdm
import copy
import threading
import pickle
from torchsummary import summary

In [None]:
def get_embeddings(datafolder, batch_size=64, pretrained_model="/dev/shm/pretrained_model_etbert.bin", limit = 10**30, gpus=4):
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    finetune_opts(parser)
    args = []
    args += ["--train_path", "dummy"]
    args += ["--vocab_path", "../models/ET-BERT/src/models/encryptd_vocab.txt"]
    args += ["--dev_path", "../models/ET-BERT/cic/test_dataset.tsv"]
    args += ["--pretrained_model_path", "dummy"]
    args = parser.parse_args(args)
    args.tokenizer = "bert"
    args.pooling = "first"
    args.soft_targets = False
    args.topk = 1
    args.frozen = False
    args.soft_alpha = 0.5
    
    args = load_hyperparam(args)
    args.tokenizer = str2tokenizer[args.tokenizer](args)
    args.batch_size = batch_size
    args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def get_model(args, model_path, train_path):
        args.pretrained_model_path = model_path
        args.train_path = train_path
        args.labels_num = count_labels_num(args.train_path)
        model = Classifier(args)
        load_or_initialize_parameters(args, model)
        model = model.to(args.device)
        summary(model)
        return model

    etbert_frozen_model = get_model(args,pretrained_model, f"{datafolder}/train_dataset.tsv")
    data_etbert = read_dataset(args, f"{datafolder}/train_dataset.tsv")

    def encode(model, src_batch, seg_batch):
        emb = model.embedding(src_batch, seg_batch)
        emb = model.encoder(emb, seg_batch)
        emb = emb[:, 0, :]   # pooling = first
        return emb

    etbert_emb = defaultdict(list)
    print(f"Total: {int(len(data_etbert) / args.batch_size)}")
    
    etbert_models = {}
    for i in range(gpus):
        etbert_models[i] = copy.deepcopy(etbert_frozen_model)
        etbert_models[i].to(f"cuda:{i}")
    print()
    
    def encode_and_append(batch, model, result_list, i):
        src_batch, seg_batch = batch
        src_batch = src_batch.to(f"cuda:{i}")
        seg_batch = seg_batch.to(f"cuda:{i}")
        with torch.no_grad():
            result_list.append(encode(model, src_batch, seg_batch).cpu())
        del src_batch, seg_batch, batch
    
    batch_size = args.batch_size
    
    src = torch.LongTensor([example[0] for example in data_etbert])
    tgt = torch.LongTensor([example[1] for example in data_etbert])
    seg = torch.LongTensor([example[2] for example in data_etbert])
    
    loader = batch_loader(batch_size, src, tgt, seg, None)
    iterator = iter(loader)
    embeddings = []
    try:
        for k in tqdm(range(min(int(len(data_etbert) / args.batch_size / gpus)+1, limit))):
            etbert_emb = defaultdict(list)
            batches = [next(iterator) for i in range(gpus)]            
            threads = []
            for i in range(gpus):
                t = threading.Thread(target=encode_and_append, args=((batches[i][0], batches[i][2]), etbert_models[i], etbert_emb[i], i))
                t.start()
                threads.append(t)
            for t in threads:
                t.join()
            del batches
            embeddings.append(torch.cat([torch.cat(etbert_emb[i]) for i in etbert_emb]))
    except StopIteration:
        print("finished")
    except Exception as e:
        print(e)
        raise

    with open(f"{datafolder}/picked_file_record") as f:
        filenames = f.readlines()

    embeddings = torch.cat(embeddings)
    filenames = filenames[:embeddings.shape[0]]  # because of our batch selection algorithm, taken number of files should be integer after / gpus / batch_size, so we took first N files where N is resulting tensor shape
    print(embeddings.shape, len(filenames))
    return embeddings, filenames

In [None]:
get_embeddings("../models/ET-BERT/cic/")

Layer (type:depth-idx)                             Param #
├─WordPosSegEmbedding: 1-1                         --
|    └─Dropout: 2-1                                --
|    └─Embedding: 2-2                              46,083,840
|    └─Embedding: 2-3                              393,216
|    └─Embedding: 2-4                              2,304
|    └─LayerNorm: 2-5                              1,536
├─TransformerEncoder: 1-2                          --
|    └─ModuleList: 2-6                             --
|    |    └─TransformerLayer: 3-1                  7,087,872
|    |    └─TransformerLayer: 3-2                  7,087,872
|    |    └─TransformerLayer: 3-3                  7,087,872
|    |    └─TransformerLayer: 3-4                  7,087,872
|    |    └─TransformerLayer: 3-5                  7,087,872
|    |    └─TransformerLayer: 3-6                  7,087,872
|    |    └─TransformerLayer: 3-7                  7,087,872
|    |    └─TransformerLayer: 3-8                  7,087,872
| 

KeyboardInterrupt: 

In [None]:
def save_emb(label):
    emb = get_embeddings(f"../data/{label}", batch_size=512, gpus=4)
    with open(f"../data/{label}_emb.pkl", "bw") as f:
        pickle.dump(emb, f)

In [8]:
save_emb("cross")

FileNotFoundError: [Errno 2] No such file or directory: '/dev/shm/data/cross/train_dataset.tsv'

In [42]:
save_emb("caida")

Total: 1935



100%|██████████| 484/484 [06:14<00:00,  1.29it/s]


torch.Size([990722, 768]) 990722


In [6]:
save_emb("cicapt")

Total: 2408



100%|█████████▉| 602/603 [08:32<00:00,  1.18it/s]


finished
torch.Size([1232896, 768]) 1232896


In [7]:
save_emb("cicids")

Total: 856



100%|█████████▉| 214/215 [02:59<00:00,  1.19it/s]


finished
torch.Size([438272, 768]) 438272


In [8]:
save_emb("mawi")

Total: 1370



100%|█████████▉| 342/343 [04:49<00:00,  1.18it/s]


finished
torch.Size([700416, 768]) 700416


In [15]:
save_emb("etbert")

Total: 10



 91%|█████████ | 10/11 [00:00<00:00, 88.36it/s]

finished
torch.Size([10, 768]) 10





In [None]:
label = "synth"
emb = get_embeddings(f"../data/{label}", batch_size=1, gpus=1)
with open(f"../data/{label}_emb.pkl", "bw") as f:
    pickle.dump(emb, f)

Layer (type:depth-idx)                             Param #
├─WordPosSegEmbedding: 1-1                         --
|    └─Dropout: 2-1                                --
|    └─Embedding: 2-2                              46,083,840
|    └─Embedding: 2-3                              393,216
|    └─Embedding: 2-4                              2,304
|    └─LayerNorm: 2-5                              1,536
├─TransformerEncoder: 1-2                          --
|    └─ModuleList: 2-6                             --
|    |    └─TransformerLayer: 3-1                  7,087,872
|    |    └─TransformerLayer: 3-2                  7,087,872
|    |    └─TransformerLayer: 3-3                  7,087,872
|    |    └─TransformerLayer: 3-4                  7,087,872
|    |    └─TransformerLayer: 3-5                  7,087,872
|    |    └─TransformerLayer: 3-6                  7,087,872
|    |    └─TransformerLayer: 3-7                  7,087,872
|    |    └─TransformerLayer: 3-8                  7,087,872
| 

100%|█████████▉| 472/473 [00:08<00:00, 53.66it/s] 

finished
torch.Size([472, 768]) 472



