### What in this notebook

- Running the code in `demo.ipynb` with the vulnerabilities dataset

### Loading the library

In [11]:
import onmt
from onmt.inputters.inputter import _load_vocab, _build_fields_vocab, get_fields, IterOnDevice
from onmt.inputters.corpus import ParallelCorpus
from onmt.inputters.dynamic_iterator import DynamicDatasetIter
from onmt.translate import GNMTGlobalScorer, Translator, TranslationBuilder
from onmt.utils.misc import set_random_seed

In [12]:
import yaml
import torch
import torch.nn as nn
from argparse import Namespace
from collections import defaultdict, Counter

In [13]:
# enable logging
from onmt.utils.logging import init_logger, logger
init_logger()

<RootLogger root (INFO)>

### Build fields

In [14]:
src_vocab_path = "vul_data/data.vocab.src"
tgt_vocab_path = "vul_data/data.vocab.tgt"

In [15]:
# initialize the frequency counter
counters = defaultdict(Counter)
# load source vocab
_src_vocab, _src_vocab_size = _load_vocab(
    src_vocab_path,
    'src',
    counters)
# load target vocab
_tgt_vocab, _tgt_vocab_size = _load_vocab(
    tgt_vocab_path,
    'tgt',
    counters)

[2022-06-15 08:44:15,035 INFO] Loading src vocabulary from vul_data/data.vocab.src
[2022-06-15 08:44:15,083 INFO] Loaded src vocab has 36442 tokens.
[2022-06-15 08:44:15,093 INFO] Loading tgt vocabulary from vul_data/data.vocab.tgt
[2022-06-15 08:44:15,100 INFO] Loaded tgt vocab has 5924 tokens.


In [16]:
# initialize fields
src_nfeats, tgt_nfeats = 0, 0 # do not support word features for now
fields = get_fields(
    'text', src_nfeats, tgt_nfeats)

In [17]:
# build fields vocab
share_vocab = False
vocab_size_multiple = 1
src_vocab_size = 30000
tgt_vocab_size = 30000
src_words_min_frequency = 1
tgt_words_min_frequency = 1
vocab_fields = _build_fields_vocab(
    fields, counters, 'text', share_vocab,
    vocab_size_multiple,
    src_vocab_size, src_words_min_frequency,
    tgt_vocab_size, tgt_words_min_frequency)

[2022-06-15 08:44:24,059 INFO]  * tgt vocab size: 5928.
[2022-06-15 08:44:24,085 INFO]  * src vocab size: 30002.


### Model and optimizer creation

In [18]:
src_text_field = vocab_fields["src"].base_field
src_vocab = src_text_field.vocab
src_padding = src_vocab.stoi[src_text_field.pad_token]

tgt_text_field = vocab_fields['tgt'].base_field
tgt_vocab = tgt_text_field.vocab
tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]

In [19]:
emb_size = 100
rnn_size = 500
# Specify the core model.

encoder_embeddings = onmt.modules.Embeddings(emb_size, len(src_vocab),
                                             word_padding_idx=src_padding)

encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1,
                                   rnn_type="LSTM", bidirectional=True,
                                   embeddings=encoder_embeddings)

decoder_embeddings = onmt.modules.Embeddings(emb_size, len(tgt_vocab),
                                             word_padding_idx=tgt_padding)
decoder = onmt.decoders.decoder.InputFeedRNNDecoder(
    hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True, 
    rnn_type="LSTM", embeddings=decoder_embeddings)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = onmt.models.model.NMTModel(encoder, decoder)
model.to(device)

# Specify the tgt word generator and loss computation module
model.generator = nn.Sequential(
    nn.Linear(rnn_size, len(tgt_vocab)),
    nn.LogSoftmax(dim=-1)).to(device)

loss = onmt.utils.loss.NMTLossCompute(
    criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction="sum"),
    generator=model.generator)

In [29]:
lr = 1e-3
torch_optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optim = onmt.utils.optimizers.Optimizer(
    torch_optimizer, learning_rate=lr, max_grad_norm=2)

### Create data iterator

In [31]:
src_train = "vul_data/random_fine_tune_train.src.txt"
tgt_train = "vul_data/random_fine_tune_train.tgt.txt"
src_val = "vul_data/random_fine_tune_valid.src.txt"
tgt_val = "vul_data/random_fine_tune_valid.tgt.txt"

# build the ParallelCorpus
corpus = ParallelCorpus("corpus", src_train, tgt_train)
valid = ParallelCorpus("valid", src_val, tgt_val)

In [32]:
# build the training iterator
train_iter = DynamicDatasetIter(
    corpora={"corpus": corpus},
    corpora_info={"corpus": {"weight": 1}},
    transforms={},
    fields=vocab_fields,
    is_train=True,
    batch_type="tokens",
    batch_size=4096,
    batch_size_multiple=1,
    data_type="text")

In [33]:
# make sure the iteration happens on GPU 0 (-1 for CPU, N for GPU N)
train_iter = iter(IterOnDevice(train_iter, 0))

In [34]:
# build the validation iterator
valid_iter = DynamicDatasetIter(
    corpora={"valid": valid},
    corpora_info={"valid": {"weight": 1}},
    transforms={},
    fields=vocab_fields,
    is_train=False,
    batch_type="sents",
    batch_size=8,
    batch_size_multiple=1,
    data_type="text")

In [35]:
valid_iter = IterOnDevice(valid_iter, 0)

### Training

In [36]:
report_manager = onmt.utils.ReportMgr(
    report_every=50, start_time=None, tensorboard_writer=None)

trainer = onmt.Trainer(model=model,
                       train_loss=loss,
                       valid_loss=loss,
                       optim=optim,
                       report_manager=report_manager,
                       dropout=[0.1])

trainer.train(train_iter=train_iter,
              train_steps=10000,
              valid_iter=valid_iter,
              valid_steps=2000)

[2022-06-15 09:16:53,188 INFO] Start training loop and validate every 2000 steps...
[2022-06-15 09:16:53,189 INFO] corpus's transforms: TransformPipe()
[2022-06-15 09:16:53,189 INFO] Weighted corpora loaded so far:
			* corpus: 1
[2022-06-15 09:16:55,016 INFO] Step 600/10000; acc:  18.71; ppl: 192.75; xent: 5.26; lr: 0.00100; 52783/6709 tok/s;      2 sec
[2022-06-15 09:16:58,372 INFO] Step 650/10000; acc:  19.85; ppl: 163.54; xent: 5.10; lr: 0.00100; 56003/4873 tok/s;      5 sec
[2022-06-15 09:17:01,921 INFO] Step 700/10000; acc:  19.13; ppl: 175.59; xent: 5.17; lr: 0.00100; 53477/4649 tok/s;      9 sec
[2022-06-15 09:17:04,183 INFO] Weighted corpora loaded so far:
			* corpus: 2
[2022-06-15 09:17:05,452 INFO] Step 750/10000; acc:  19.29; ppl: 172.78; xent: 5.15; lr: 0.00100; 52247/5676 tok/s;     12 sec
[2022-06-15 09:17:08,794 INFO] Step 800/10000; acc:  19.98; ppl: 173.20; xent: 5.15; lr: 0.00100; 56049/5426 tok/s;     16 sec
[2022-06-15 09:17:12,229 INFO] Step 850/10000; acc:  18.9

<onmt.utils.statistics.Statistics at 0x7f3e0c4708e0>