In [11]:
import onmt
from onmt.inputters.inputter import _load_vocab, _build_fields_vocab, get_fields, IterOnDevice
from onmt.inputters.corpus import ParallelCorpus
from onmt.inputters.dynamic_iterator import DynamicDatasetIter
from onmt.translate import GNMTGlobalScorer, Translator, TranslationBuilder
from onmt.utils.misc import set_random_seed

In [12]:
import yaml
import torch
import torch.nn as nn
from argparse import Namespace
from collections import defaultdict, Counter

In [13]:
# enable logging
from onmt.utils.logging import init_logger, logger
init_logger()

<RootLogger root (INFO)>

In [14]:
src_vocab_path = "vul_data/data.vocab.src"
tgt_vocab_path = "vul_data/data.vocab.tgt"

In [15]:
# initialize the frequency counter
counters = defaultdict(Counter)
# load source vocab
_src_vocab, _src_vocab_size = _load_vocab(
    src_vocab_path,
    'src',
    counters)
# load target vocab
_tgt_vocab, _tgt_vocab_size = _load_vocab(
    tgt_vocab_path,
    'tgt',
    counters)

[2022-06-15 08:44:15,035 INFO] Loading src vocabulary from vul_data/data.vocab.src
[2022-06-15 08:44:15,083 INFO] Loaded src vocab has 36442 tokens.
[2022-06-15 08:44:15,093 INFO] Loading tgt vocabulary from vul_data/data.vocab.tgt
[2022-06-15 08:44:15,100 INFO] Loaded tgt vocab has 5924 tokens.


In [16]:
# initialize fields
src_nfeats, tgt_nfeats = 0, 0 # do not support word features for now
fields = get_fields(
    'text', src_nfeats, tgt_nfeats)

In [17]:
# build fields vocab
share_vocab = False
vocab_size_multiple = 1
src_vocab_size = 30000
tgt_vocab_size = 30000
src_words_min_frequency = 1
tgt_words_min_frequency = 1
vocab_fields = _build_fields_vocab(
    fields, counters, 'text', share_vocab,
    vocab_size_multiple,
    src_vocab_size, src_words_min_frequency,
    tgt_vocab_size, tgt_words_min_frequency)

[2022-06-15 08:44:24,059 INFO]  * tgt vocab size: 5928.
[2022-06-15 08:44:24,085 INFO]  * src vocab size: 30002.


In [18]:
src_text_field = vocab_fields["src"].base_field
src_vocab = src_text_field.vocab
src_padding = src_vocab.stoi[src_text_field.pad_token]

tgt_text_field = vocab_fields['tgt'].base_field
tgt_vocab = tgt_text_field.vocab
tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]

In [19]:
emb_size = 100
rnn_size = 500
# Specify the core model.

encoder_embeddings = onmt.modules.Embeddings(emb_size, len(src_vocab),
                                             word_padding_idx=src_padding)

encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1,
                                   rnn_type="LSTM", bidirectional=True,
                                   embeddings=encoder_embeddings)

decoder_embeddings = onmt.modules.Embeddings(emb_size, len(tgt_vocab),
                                             word_padding_idx=tgt_padding)
decoder = onmt.decoders.decoder.InputFeedRNNDecoder(
    hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True, 
    rnn_type="LSTM", embeddings=decoder_embeddings)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = onmt.models.model.NMTModel(encoder, decoder)
model.to(device)

# Specify the tgt word generator and loss computation module
model.generator = nn.Sequential(
    nn.Linear(rnn_size, len(tgt_vocab)),
    nn.LogSoftmax(dim=-1)).to(device)

loss = onmt.utils.loss.NMTLossCompute(
    criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction="sum"),
    generator=model.generator)

In [20]:
lr = 1
torch_optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optim = onmt.utils.optimizers.Optimizer(
    torch_optimizer, learning_rate=lr, max_grad_norm=2)

In [21]:
src_train = "vul_data/random_fine_tune_train.src.txt"
tgt_train = "vul_data/random_fine_tune_train.tgt.txt"
src_val = "vul_data/random_fine_tune_valid.src.txt"
tgt_val = "vul_data/random_fine_tune_valid.tgt.txt"

# build the ParallelCorpus
corpus = ParallelCorpus("corpus", src_train, tgt_train)
valid = ParallelCorpus("valid", src_val, tgt_val)