In [1]:
import os.path
import time
from itertools import product

import itables
import pandas as pd
import torch
from tqdm import tqdm

from codegen_sources.model.deobfuscate import _reload_model
from codegen_sources.model.src.model import build_model
from iren.inference.mute import mute_stdout_stderr

N_LAYERS_ENCODER = N_LAYERS_DECODER = [2, 4, 6]
EMB_DIM = [256, 512, 1024]

MODEL_PATH = r"/home/igor/PycharmProjects/CodeGen/training_artifacts/models/DOBF_var_shuffled.pth"
TMP_PATH = r"/home/igor/PycharmProjects/CodeGen/training_artifacts/models/tmp.pth"

REPEAT = 10

adding to path /home/igor/PycharmProjects/CodeGen


In [2]:
def gen_random_batch(batch_size):
    x1 = torch.randint(64000, size=(100, batch_size))
    len1 = torch.randint(x1.size(0), size=(batch_size,))
    langs1 = torch.ones_like(x1)
    x2 = torch.randint(64000, size=(10, batch_size))
    len2 = torch.randint(x2.size(0), size=(batch_size,))
    langs2 = torch.ones_like(x2)
    y = torch.randint(64000, size=(batch_size,))
    pred_mask = torch.zeros_like(x2, dtype=torch.bool)
    pred_mask[1, :] = True
    return x1, len1, langs1, x2, len2, langs2, y, pred_mask


@torch.no_grad()
def time_forward(encoder, decoder, langs1, langs2, len1, len2, pred_mask, spans, x1, x2, y):
    # encode source sentence
    total_start = time.perf_counter()
    start = time.perf_counter()
    for _ in range(REPEAT):
        enc1 = encoder(
            "fwd", x=x1, lengths=len1, langs=langs1, causal=False, spans=spans
        )
    time_enc = (time.perf_counter() - start) / REPEAT
    enc1 = enc1.transpose(0, 1)
    # decode target sentence
    start = time.perf_counter()
    for _ in range(REPEAT):
        dec2 = decoder(
            "fwd",
            x=x2,
            lengths=len2,
            langs=langs2,
            causal=True,
            src_enc=enc1,
            src_len=len1,
            spans=spans,
        )
    time_dec = (time.perf_counter() - start) / REPEAT
    # loss
    start = time.perf_counter()
    for _ in range(REPEAT):
        scores, loss = decoder(
            "predict", tensor=dec2, pred_mask=pred_mask, y=y, get_scores=True
        )
    time_pred = (time.perf_counter() - start) / REPEAT
    total = (time.perf_counter() - total_start) / REPEAT
    return {"time_enc": time_enc, "time_dec": time_dec, "time_pred": time_pred, "total_time": total}

In [3]:
batch_size = 16
params, dico, (encoder, decoder) = _reload_model(MODEL_PATH)
encoder = encoder[0]
decoder = decoder[0]
print("Encoder:", encoder.parameters)
print("Decoder:", decoder.parameters)
encoder.eval()
decoder.eval()
x1, len1, langs1, x2, len2, langs2, y, pred_mask = gen_random_batch(batch_size)
d = {}
d["n_enc"] = params.n_layers_encoder
d["n_dec"] = params.n_layers_decoder
d["emb_dim"] = params.emb_dim_encoder
d.update(time_forward(encoder, decoder, langs1, langs2, len1, len2, pred_mask, None, x1, x2, y))
torch.save((encoder.state_dict(), decoder.state_dict()), TMP_PATH)
d["size"] = os.path.getsize(TMP_PATH) / 1024 / 1024
d["enc_ps"] = sum([p.numel() for p in encoder.parameters() if p.requires_grad])
d["dec_ps"] = sum([p.numel() for p in decoder.parameters() if p.requires_grad])
d["sum_ps"] = d["enc_ps"] + d["dec_ps"]
d["emb_ps"] = encoder.embeddings.weight.numel()
params.reload_model = ''
del encoder, decoder
res = {'dobf': d}
for i, (n_layers_encoder, n_layers_decoder, emb_dim) in tqdm(
        list(enumerate(product(N_LAYERS_ENCODER, N_LAYERS_DECODER, EMB_DIM)))):
    d = {}
    d["n_enc"] = params.n_layers_encoder = n_layers_encoder
    d["n_dec"] = params.n_layers_decoder = n_layers_decoder
    d["emb_dim"] = params.emb_dim = params.emb_dim_encoder = params.emb_dim_decoder = emb_dim
    with mute_stdout_stderr():
        encoder, decoder = build_model(params, dico, gpu=False)
    encoder, decoder = encoder[0], decoder[0]
    d.update(time_forward(encoder, decoder, langs1, langs2, len1, len2, pred_mask, None, x1, x2, y))
    torch.save((encoder.state_dict(), decoder.state_dict()), TMP_PATH)
    d["size"] = os.path.getsize(TMP_PATH) / 1024 / 1024
    d["enc_ps"] = sum([p.numel() for p in encoder.parameters() if p.requires_grad])
    d["dec_ps"] = sum([p.numel() for p in decoder.parameters() if p.requires_grad])
    d["sum_ps"] = d["enc_ps"] + d["dec_ps"]
    d["emb_ps"] = encoder.embeddings.weight.numel()
    res[i] = d


INFO - 04/25/22 13:17:40 - 0:00:07 - Reloading encoder from /home/igor/PycharmProjects/CodeGen/training_artifacts/models/DOBF_var_shuffled.pth ...
INFO - 04/25/22 13:17:42 - 0:00:09 - Reloading decoders from /home/igor/PycharmProjects/CodeGen/training_artifacts/models/DOBF_var_shuffled.pth ...
INFO - 04/25/22 13:17:43 - 0:00:10 - Number of parameters (encoder): 143278592
INFO - 04/25/22 13:17:43 - 0:00:10 - Number of parameters (decoders): 168481280
INFO - 04/25/22 13:17:43 - 0:00:10 - Number of decoders: 1



Encoder: <bound method Module.parameters of TransformerModel(
  (position_embeddings): Embedding(2048, 1024)
  (lang_embeddings): Embedding(2, 1024)
  (embeddings): Embedding(64000, 1024, padding_idx=2)
  (layer_norm_emb): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (attentions): ModuleList(
    (0): MultiHeadAttention(
      (q_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (k_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (v_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (out_lin): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (1): MultiHeadAttention(
      (q_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (k_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (v_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (out_lin): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (2): MultiHeadAttention(
      (q_lin): Linear(in_featur

  0%|          | 0/27 [00:00<?, ?it/s]INFO - 04/25/22 13:17:52 - 0:00:19 - Number of parameters (encoder): 18552832
INFO - 04/25/22 13:17:52 - 0:00:19 - Number of parameters (decoders): 19080192
INFO - 04/25/22 13:17:52 - 0:00:19 - Number of decoders: 1

  4%|▎         | 1/27 [00:01<00:29,  1.15s/it]INFO - 04/25/22 13:17:53 - 0:00:21 - Number of parameters (encoder): 40187392
INFO - 04/25/22 13:17:53 - 0:00:21 - Number of parameters (decoders): 42290688
INFO - 04/25/22 13:17:53 - 0:00:21 - Number of decoders: 1

  7%|▋         | 2/27 [00:03<00:41,  1.67s/it]INFO - 04/25/22 13:17:57 - 0:00:24 - Number of parameters (encoder): 92893696
INFO - 04/25/22 13:17:57 - 0:00:24 - Number of parameters (decoders): 101294592
INFO - 04/25/22 13:17:57 - 0:00:24 - Number of decoders: 1

 11%|█         | 3/27 [00:08<01:21,  3.40s/it]INFO - 04/25/22 13:18:00 - 0:00:28 - Number of parameters (encoder): 18552832
INFO - 04/25/22 13:18:00 - 0:00:28 - Number of parameters (decoders): 21187072
INFO - 04/25/22

Unnamed: 0,n_enc,n_dec,emb_dim,time_enc,time_dec,time_pred,total_time,size,enc_ps,dec_ps,sum_ps,emb_ps
Loading... (need help?),,,,,,,,,,,,
