In [1]:
import os.path
import time
from itertools import product

import itables
import pandas as pd
import torch
from tqdm import tqdm

from codegen_sources.model.src.data.dictionary import Dictionary
from codegen_sources.model.src.model import build_model
from codegen_sources.model.src.utils import AttrDict

N_LAYERS_ENCODER = N_LAYERS_DECODER = [2, 4, 6]
EMB_DIM = [256, 512, 1024]

MODEL_PATH = r"/home/igor/PycharmProjects/CodeGen/training_artifacts/models/DOBF_var_shuffled.pth"
TMP_PATH = r"/home/igor/PycharmProjects/CodeGen/training_artifacts/models/tmp.pth"

REPEAT = 10

adding to path /home/igor/PycharmProjects/CodeGen


In [2]:
def gen_random_batch(batch_size):
    x1 = torch.randint(64000, size=(100, batch_size))
    len1 = torch.randint(x1.size(0), size=(batch_size,))
    langs1 = torch.ones_like(x1)
    x2 = torch.randint(64000, size=(10, batch_size))
    len2 = torch.randint(x2.size(0), size=(batch_size,))
    langs2 = torch.ones_like(x2)
    y = torch.randint(64000, size=(batch_size,))
    pred_mask = torch.zeros_like(x2, dtype=torch.bool)
    pred_mask[1, :] = True
    return x1, len1, langs1, x2, len2, langs2, y, pred_mask


@torch.no_grad()
def time_forward(encoder, decoder, langs1, langs2, len1, len2, pred_mask, spans, x1, x2, y):
    # encode source sentence
    total_start = time.perf_counter()
    start = time.perf_counter()
    for _ in range(REPEAT):
        enc1 = encoder(
            "fwd", x=x1, lengths=len1, langs=langs1, causal=False, spans=spans
        )
    time_enc = (time.perf_counter() - start) / REPEAT
    enc1 = enc1.transpose(0, 1)
    # decode target sentence
    start = time.perf_counter()
    for _ in range(REPEAT):
        dec2 = decoder(
            "fwd",
            x=x2,
            lengths=len2,
            langs=langs2,
            causal=True,
            src_enc=enc1,
            src_len=len1,
            spans=spans,
        )
    time_dec = (time.perf_counter() - start) / REPEAT
    # loss
    start = time.perf_counter()
    for _ in range(REPEAT):
        scores, loss = decoder(
            "predict", tensor=dec2, pred_mask=pred_mask, y=y, get_scores=True
        )
    time_pred = (time.perf_counter() - start) / REPEAT
    total = (time.perf_counter() - total_start) / REPEAT
    return {"time_enc": time_enc, "time_dec": time_dec, "time_pred": time_pred, "total_time": total}

def _reload_model(model_path, gpu=False):
    # reload model
    reloaded = torch.load(model_path, map_location="cpu")
    # change params of the reloaded model so that it will
    # relaod its own weights and not the MLM or DOBF pretrained model
    reloaded["params"]["reload_model"] = ",".join([model_path] * 2)
    reloaded["params"]["lgs_mapping"] = ""
    reloaded["params"]["reload_encoder_for_decoder"] = False
    reloaded_params = AttrDict(reloaded["params"])

    # build dictionary / update parameters
    dico = Dictionary(
        reloaded["dico_id2word"], reloaded["dico_word2id"], reloaded["dico_counts"]
    )

    # build model / reload weights (in the build_model method)
    return reloaded_params, dico, build_model(reloaded_params, dico, gpu)

In [3]:
batch_size = 16
params, dico, (encoder, decoder) = _reload_model(MODEL_PATH)
encoder = encoder[0]
decoder = decoder[0]
print("Encoder:", encoder.parameters)
print("Decoder:", decoder.parameters)
encoder.eval()
decoder.eval()
x1, len1, langs1, x2, len2, langs2, y, pred_mask = gen_random_batch(batch_size)
d = {}
d["n_enc"] = params.n_layers_encoder
d["n_dec"] = params.n_layers_decoder
d["emb_dim"] = params.emb_dim_encoder
d.update(time_forward(encoder, decoder, langs1, langs2, len1, len2, pred_mask, None, x1, x2, y))
torch.save((encoder.state_dict(), decoder.state_dict()), TMP_PATH)
d["size"] = os.path.getsize(TMP_PATH) / 1024 / 1024
d["enc_ps"] = sum([p.numel() for p in encoder.parameters() if p.requires_grad])
d["dec_ps"] = sum([p.numel() for p in decoder.parameters() if p.requires_grad])
d["sum_ps"] = d["enc_ps"] + d["dec_ps"]
d["emb_ps"] = encoder.embeddings.weight.numel()
params.reload_model = ''
del encoder, decoder
res = {'dobf': d}
for i, (n_layers_encoder, n_layers_decoder, emb_dim) in tqdm(
        list(enumerate(product(N_LAYERS_ENCODER, N_LAYERS_DECODER, EMB_DIM)))):
    d = {}
    d["n_enc"] = params.n_layers_encoder = n_layers_encoder
    d["n_dec"] = params.n_layers_decoder = n_layers_decoder
    d["emb_dim"] = params.emb_dim = params.emb_dim_encoder = params.emb_dim_decoder = emb_dim
    encoder, decoder = build_model(params, dico, gpu=False)
    encoder, decoder = encoder[0], decoder[0]
    d.update(time_forward(encoder, decoder, langs1, langs2, len1, len2, pred_mask, None, x1, x2, y))
    torch.save((encoder.state_dict(), decoder.state_dict()), TMP_PATH)
    d["size"] = os.path.getsize(TMP_PATH) / 1024 / 1024
    d["enc_ps"] = sum([p.numel() for p in encoder.parameters() if p.requires_grad])
    d["dec_ps"] = sum([p.numel() for p in decoder.parameters() if p.requires_grad])
    d["sum_ps"] = d["enc_ps"] + d["dec_ps"]
    d["emb_ps"] = encoder.embeddings.weight.numel()
    res[i] = d



Encoder: <bound method Module.parameters of TransformerModel(
  (position_embeddings): Embedding(2048, 1024)
  (lang_embeddings): Embedding(2, 1024)
  (embeddings): Embedding(64000, 1024, padding_idx=2)
  (layer_norm_emb): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (attentions): ModuleList(
    (0): MultiHeadAttention(
      (q_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (k_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (v_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (out_lin): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (1): MultiHeadAttention(
      (q_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (k_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (v_lin): Linear(in_features=1024, out_features=1024, bias=True)
      (out_lin): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (2): MultiHeadAttention(
      (q_lin): Linear(in_featur

100%|██████████| 27/27 [01:48<00:00,  4.03s/it]


In [4]:
res_df = pd.DataFrame.from_dict(res, orient='index')
display(res_df)

Unnamed: 0,n_enc,n_dec,emb_dim,time_enc,time_dec,time_pred,total_time,size,enc_ps,dec_ps,sum_ps,emb_ps
dobf,6,6,1024,0.54381,0.139082,0.012677,0.695569,1189.350755,143278592,168481280,311759872,65536000
0,2,2,256,0.014599,0.005066,0.004129,0.023795,143.588017,18552832,19080192,37633024,16384000
1,2,2,512,0.04218,0.012898,0.007396,0.062476,314.65833,40187392,42290688,82478080,32768000
2,2,2,1024,0.156485,0.047851,0.013473,0.21781,740.798955,92893696,101294592,194188288,65536000
3,2,4,256,0.015,0.009659,0.004276,0.028936,151.641198,18552832,21187072,39739904,16384000
4,2,4,512,0.042696,0.025528,0.007617,0.075843,346.74862,40187392,50698752,90886144,32768000
5,2,4,1024,0.157633,0.09494,0.013558,0.266133,868.963464,92893696,134887936,227781632,65536000
6,2,6,256,0.014363,0.013875,0.004173,0.032413,159.694383,18552832,23293952,41846784,16384000
7,2,6,512,0.042421,0.038838,0.007431,0.088692,378.838914,40187392,59106816,99294208,32768000
8,2,6,1024,0.155652,0.143203,0.013679,0.312535,997.127976,92893696,168481280,261374976,65536000


In [5]:
itables.show(res_df)

Unnamed: 0,n_enc,n_dec,emb_dim,time_enc,time_dec,time_pred,total_time,size,enc_ps,dec_ps,sum_ps,emb_ps
Loading... (need help?),,,,,,,,,,,,
