# Training TranX using the mined GPT3 samples

## Pytorch imports

In [13]:
from src.RawDataLoaders import CodeSearchNet_RawDataLoader, CoNaLa_RawDataLoader, Django_RawDataLoader
from src.models_and_transforms.text_transforms import Numericalise_Transform, Rename_Transform, Denumericalise_Transform
from src.models_and_transforms.complex_transforms import *

## TranX imports

In [1]:
import sys
sys.path.insert(0, "src/external_repos/tranX")

In [20]:
import argparse
import json
import os
import pickle
import sys
import nltk
nltk.download('punkt')

import numpy as np

from asdl.hypothesis import *
from asdl.lang.py3.py3_transition_system import python_ast_to_asdl_ast, asdl_ast_to_python_ast, Python3TransitionSystem
from asdl.transition_system import *
from components.action_info import get_action_infos
from components.dataset import Example
from components.vocab import Vocab, VocabEntry
from datasets.conala.evaluator import ConalaEvaluator
from datasets.conala.util import *
from datasets.conala.dataset import *

assert astor.__version__ == '0.7.1'

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


## Loading the GPT3 samples

In [22]:
conala_raw_dataloader = CoNaLa_RawDataLoader()
gpt3_mined_samples = conala_raw_dataloader.get_samples('mined_GPT3')
gpt3_mined_samples = Rename_Transform([('description','intent'),('code','snippet')])(gpt3_mined_samples)

In [26]:
gpt3_mined_samples[30]

{'intent': 'get the index of elements in a list `a` that are not equal to zero',
 'snippet': '[i for i, e in enumerate(a) if e != 0]',
 'question_id': 30}

In [24]:
for i, sample in enumerate(gpt3_mined_samples):
    sample['question_id'] = i

In [27]:
json.dump(gpt3_mined_samples, open('datasets/CoNaLa/conala-corpus/CoNaLa_gpt3.json', 'w'))

In [31]:
preprocess_conala_dataset(train_file='datasets/CoNaLa/conala-corpus/conala-train.json',
                              test_file='datasets/CoNaLa/conala-corpus/conala-test.json',
                              mined_data_file='datasets/CoNaLa/conala-corpus/CoNaLa_gpt3.json',
                              api_data_file='',
                              grammar_file='src/external_repos/tranX/asdl/lang/py3/py3_asdl.simplified.txt',
                              src_freq=3, code_freq=3,
                              vocab_size=20000,
                              num_mined=50000,
                              out_dir='datasets/CoNaLa/conala-corpus/')

process gold training data...


Skipped due to exceptions: 4


use mined data:  50000
from file:  datasets/CoNaLa/conala-corpus/CoNaLa_gpt3.json


Skipped due to exceptions: 74
52101 training instances
200 dev instances


process testing data...


Skipped due to exceptions: 0
500 testing instances


number of word types: 5939, number of word types w/ frequency > 1: 3314
number of singletons:  2625
number of words not included: 3382
total token count:  489746
unk token count:  4139
number of word types: 26510, number of word types w/ frequency > 1: 9208
number of singletons:  17302
number of words not included: 14578
total token count:  322997
unk token count:  18364
number of word types: 25687, number of word types w/ frequency > 1: 8180
number of singletons:  17507
number of words not included: 15147
total token count:  769603
unk token count:  18474


generated vocabulary Vocab(source Vocabulary[size=2561]words, primitive Vocabulary[size=5426]words, code Vocabulary[size=4857]words)
Max action len: 221
Avg action len: 21
Actions larger than 100: 14


## Training TranX

In [36]:
from exp import *

In [41]:
arg_parser = init_arg_parser()
args = arg_parser.parse_args('''--cuda \
--seed ${seed} \
--mode train \
--batch_size ${batch_size} \
--evaluator conala_evaluator \
--asdl_file asdl/lang/py3/py3_asdl.simplified.txt \
--transition_system python3 \
--train_file ${train_file} \
--dev_file ${dev_file} \
--vocab ${vocab} \
--lstm ${lstm} \
--no_parent_field_type_embed \
--no_parent_production_embed \
--hidden_size ${hidden_size} \
--embed_size ${embed_size} \
--action_embed_size ${action_embed_size} \
--field_embed_size ${field_embed_size} \
--type_embed_size ${type_embed_size} \
--dropout ${dropout} \
--patience 5 \
--max_num_trial 5 \
--glorot_init \
--lr ${lr} \
--lr_decay ${lr_decay} \
--lr_decay_after_epoch ${lr_decay_after_epoch} \
--max_epoch ${max_epoch} \
--beam_size ${beam_size} \
--log_every 50 \
--save_to saved_models/conala/${model_name} 2>&1 | tee logs/conala/${model_name}.log'''.split())

# seed the RNG
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
np.random.seed(int(args.seed * 13 / 7))

usage: ipykernel_launcher.py [-h] [--seed SEED] [--cuda]
                             [--lang {python,lambda_dcs,wikisql,prolog,python3}]
                             [--asdl_file ASDL_FILE] --mode
                             {train,test,interactive,train_paraphrase_identifier,train_reconstructor,rerank}
                             [--parser PARSER]
                             [--transition_system TRANSITION_SYSTEM]
                             [--evaluator EVALUATOR] [--lstm {lstm}]
                             [--embed_size EMBED_SIZE]
                             [--action_embed_size ACTION_EMBED_SIZE]
                             [--field_embed_size FIELD_EMBED_SIZE]
                             [--type_embed_size TYPE_EMBED_SIZE]
                             [--hidden_size HIDDEN_SIZE]
                             [--ptrnet_hidden_dim PTRNET_HIDDEN_DIM]
                             [--att_vec_size ATT_VEC_SIZE]
                             [--no_query_vec_to_action_map]
       

Traceback (most recent call last):
  File "/usr/lib/python3.6/argparse.py", line 2303, in _get_value
    result = type_func(arg_string)
ValueError: invalid literal for int() with base 10: '${seed}'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.6/argparse.py", line 1775, in parse_known_args
    namespace, args = self._parse_known_args(args, namespace)
  File "/usr/lib/python3.6/argparse.py", line 1981, in _parse_known_args
    start_index = consume_optional(start_index)
  File "/usr/lib/python3.6/argparse.py", line 1921, in consume_optional
    take_action(action, args, option_string)
  File "/usr/lib/python3.6/argparse.py", line 1833, in take_action
    argument_values = self._get_values(action, argument_strings)
  File "/usr/lib/python3.6/argparse.py", line 2274, in _get_values
    value = self._get_value(action, arg_string)
  File "/usr/lib/python3.6/argparse.py", line 2316, in _get_value
    raise Ar

TypeError: object of type 'NoneType' has no len()