In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataset.idDataModule import IdDataModule
from model.idTransformerModel import IdTransformerModel
import hydra
from omegaconf import OmegaConf
from tqdm.auto import trange, tqdm
import warnings
import torch
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore")

In [3]:
conf = OmegaConf.load('configs/config.yaml')
conf['model'] = OmegaConf.load('configs/model/transformer_encoder.yaml')
conf['dataset'] = OmegaConf.load('configs/dataset/java-small.yaml')
conf.dataset.usage_vocab_min_freq = 3
conf.dataset.target_vocab_min_freq = 2
def dm(cfg):
    datamodule = IdDataModule(cfg)
    datamodule.setup('test')
    return datamodule

dm = dm(conf)

Building train dataset...
Building val dataset...
Building test dataset...
Loading usage counter from C:\Users\Igor\IdeaProjects\ide-plugin\dataset\java-small\usage_vocab.json
Loading target counter from C:\Users\Igor\IdeaProjects\ide-plugin\dataset\java-small\target_vocab.json

Usage vocabulary size is 23418
Target vocabulary size is 14377


In [4]:
model = IdTransformerModel.load_from_checkpoint(r"checkpoints/11.03-transformer_encoder-java-small-epoch=02-val_accuracy=0.19.ckpt",
                                               map_location="cpu",
                                               dm=dm,
                                               cfg=conf.model)

In [5]:
vocab = dm.dataset.target_vocabulary
def itos(ids):
    ids.reshape(-1)
    return list(map(lambda x: vocab.itos[x], ids))

In [6]:
_ = model.train(False)

In [7]:
# torch.manual_seed(42)
iterator = iter(DataLoader(dm.dataset.test,
                                       batch_size=1,
                                       shuffle=True,
                                       collate_fn=dm.dataset.collate_fn))

In [295]:
ex = next(iterator)
ex


[torchtext.legacy.data.batch.Batch of size 1]
	[.target]:('[torch.LongTensor of size 6x1]', '[torch.LongTensor of size 1]')
	[.usages]:[torch.LongTensor of size 1x11x41]

In [296]:
itos(ex.target[0])

['<s>', 'rpc', 'server', '</s>', '<pad>', '<pad>']

In [305]:
sum_p = 0
for pred in model(ex.usages[0], batch_size=10, beam_width=40, topk=10):
    sum_p += pred[1].p
    print(f"{itos(pred[1].token_ids)}: {pred[1].p}")
print(sum_p)

['<s>', 'server', '</s>']: 0.043029683704450546
['<s>', 'port', '</s>']: 0.030609993522448243
['<s>', 'coordinator', '</s>']: 0.020024495759689228
['<s>', 'connector', '</s>']: 0.017832930840474857
['<s>', 'discovery', 'node', '</s>']: 0.017712834854850104
['<s>', 'host', '</s>']: 0.0168719493946595
['<s>', 'node', '</s>']: 0.014936942953890624
['<s>', 'server', 'node', '</s>']: 0.014028782476434826
['<s>', 'hostname', '</s>']: 0.01330560722522888
['<s>', 'server', 'port', '</s>']: 0.009655817536823354
0.19800903826895017


In [298]:
%timeit model(ex.usages[0], batch_size=None, beam_width=20)

255 ms ± 25.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [299]:
%timeit model(ex.usages[0], batch_size=10, beam_width=40)

98 ms ± 2.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [300]:
%timeit model(ex.usages[0], batch_size=20, beam_width=40)

151 ms ± 1.21 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [301]:
%timeit model(ex.usages[0], batch_size=30, beam_width=40)

119 ms ± 1.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [302]:
%timeit model(ex.usages[0], batch_size=40, beam_width=40)

144 ms ± 4.89 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
