Skip to content

Latest commit

 

History

History
327 lines (272 loc) · 12.3 KB

transcoder.md

File metadata and controls

327 lines (272 loc) · 12.3 KB

TransCoder

Pytorch implementation of TransCoder in Unsupervised Translation of Programming Languages Model

Release

pre-trained models

We provide TransCoder models. We used the validation set to select the best checkpoint for each language pair, and choose the model to use to compute the test scores.

Models used in TransCoder original paper are the following (directions selected using the validation set):

Note: if you really want the output of these models to be exactly right, you need to change the constant LAYER_NORM_EPSILON to be 1e-12 instead of 1e-5. If you don't, the result will be the same in more than 99% of the cases and only slightly different otherwise.

[Update:] Better model for translating between java and python (pretrained with our new model DOBF - 2021):

Its computational accuracy (CA@1) scores are 39.5% for Python -> Java (44.7% with beam size 10) and 49.2% for Java -> Python (52.5% with beam size 10).

You can use these models to translate functions with this script:

python -m codegen_sources.model.translate --src_lang python --tgt_lang java --model_path <model_path> --beam_size 1 < my_python_file_to_translate.py

Parallel Validation and Test Sets

We release our parallel validation and test datasets. It consists of parallel functions in Cpp, Java and Python. You can download the raw data and the binarized version to be used for model evaluation transcoder_test_set.zip. Note that this test set is released under the Creative Commons Attribution-ShareAlike 2.0 license. See https://creativecommons.org/licenses/by-sa/2.0/ for more information.

The format of each line in each file is <FUNCTION_ID> | <function>. The function are tokenized. You can detokenize them with the script preprocessing/detokenize.py. You can extract the function id and use it to find the corresponding test script in data/evaluation/geeks_for_geeks_successful_test_scripts/<language> if it exists. For instance, for the line COUNT_SET_BITS_IN_AN_INTEGER_3 | <function> in the file test.cpp.shuf.valid.tok, the corresponding test script can be found in data/evaluation/geeks_for_geeks_successful_test_scripts/cpp/COUNT_SET_BITS_IN_AN_INTEGER_3.cpp. If the script is missing, it means there was an issue with our automatically created tests for the corresponding function.

The code generated by your model can be tested by injecting it where the TO_FILL comment is in the test script.

Training

Dataset

Overview

For a training on NGPU gpu: Data you need to pretrain a model, with MLM:

  • training data (monolingual): source code in each language, train.python.[0..NGPU-1].pth (data is splitted accross GPU)
  • test / valid data (monolingual): source code in each language to test perplexity of model , ex: test.python.pth / valid.python.pth

Data you need to train AE and BT :

  • training data (monolingual functions): functions standalone in each language, ex: train.python_sa.[0..NGPU-1].pth
  • test / valid data (monolingual functions + gfg parallel functions):
    • monolingual functions to test perplexity of model, ex: test.python_sa.pth / valid.python_sa.pth
    • [parallel functions] to test the translation model (with BLEU and computation accuracy), ex: test.python_sa-cpp_sa.pth / valid.python_sa-cpp_sa.pth

All of these data should be contain in the same folder. The path is given as --data_path argument.

In our case we use NGPU=8

Get Training Data

First get raw data from Google BigQuery (see).

Then run the following command to get the monolingual data for MLM:

python -m codegen_sources.preprocessing.preprocess 
<DATASET_PATH>                                     # folder containing raw data i.e json.gz
--langs cpp java python                            # languages to prepocess
--mode=monolingual                                 # dataset mode
--local=True                                       # Run on your local machine if True. If False run on a cluster (requires submitit setup)
--bpe_mode=fast 
--train_splits=NGPU                                # nb of splits for training data - corresponds to the number of GPU you have

To get the monolingual functions data for DAE et BT change:

mode=monolingual_functions

Note that is your data is small enough to fit on a single GPU, then NGPU=1 and loading this single split on all GPU is the normal thing to do. Note also that if you run you training on multiple machine, each with NGPU GPUS, splitting in NGPU is fine as well. You will just have to precise --split_data_accross_gpu local in your training parameters. In our case, we add 4 machines of 8 GPU each, we set NPU=8 and --split_data_accross_gpu local.

Get Parallel Test and Validation Data

Simpy download the binarized data transcoder_test_set.zip and add them to the same folder and the data you preprocessed above.

Train

Train a MLM Model:

python codegen_sources/model/train.py 

## main parameters
--exp_name mlm \
--dump_path '<YOUR_DUMP_PATH>' \ 

## data / objectives
--data_path '<DATA_PATH>' \ 
--split_data_accross_gpu local \
--mlm_steps 'cpp,java,python' \
--add_eof_to_stream true \
--word_mask_keep_rand '0.8,0.1,0.1' \
--word_pred '0.15' \


## model
--encoder_only true \
--n_layers 6  \
--emb_dim 1024  \
--n_heads 8  \
--lgs 'cpp-java-python' \
--max_vocab 64000 \
--gelu_activation false \
--roberta_mode false \
--max_len 512 \

#optimization
--amp 2  \
--fp16 true  \
--batch_size 32 \
--bptt 512 \
--epoch_size 100000 \
--max_epoch 100000 \
--split_data_accross_gpu global \
--optimizer 'adam_inverse_sqrt,warmup_updates=10000,lr=0.0001,weight_decay=0.01' \
--save_periodic 0 \
--validation_metrics _valid_mlm_ppl \
--stopping_criterion '_valid_mlm_ppl,10' 

To train transcoder from a pretrained model (MLM or DOBF - for DOBF [see]):

python codegen_sources/model/train.py   

## main parameters
--exp_name transcoder \
--dump_path '<YOUR_DUMP_PATH>' \ 

## data / objectives
--data_path '<DATA_PATH>' \
--split_data_accross_gpu local \
--bt_steps 'python_sa-java_sa-python_sa,java_sa-python_sa-java_sa,python_sa-cpp_sa-python_sa,java_sa-cpp_sa-java_sa,cpp_sa-python_sa-cpp_sa,cpp_sa-java_sa-cpp_sa' \
--ae_steps 'python_sa,java_sa'  \
--lambda_ae '0:1,30000:0.1,100000:0'  \ 
--word_shuffle 3  \
--word_dropout '0.1' \ 
--word_blank '0.3'  \

## model  
--encoder_only False \
--n_layers 0  \
--n_layers_encoder 6  \
--n_layers_decoder 6 \
--emb_dim 1024  \
--n_heads 8  \
--lgs 'cpp_sa-java_sa-python_sa'  \
--max_vocab 64000 \
--gelu_activation false \
--roberta_mode false   \ 
--max_len 512 \

## model reloading
--reload_model '<PATH_TO_MLM_MODEL>,<PATH_TO_MLM_MODEL>'  \
--reload_encoder_for_decoder true \
--lgs_mapping 'cpp_sa:cpp,java_sa:java,python_sa:python'  \

## optimization
--amp 2  \
--fp16 true  \
--tokens_per_batch 3000  \
--group_by_size true \
--max_batch_size 128 \
--epoch_size 50000  \
--max_epoch 10000000  \
--split_data_accross_gpu global \
--optimizer 'adam_inverse_sqrt,warmup_updates=10000,lr=0.0001,weight_decay=0.01' \
--eval_bleu true \
--eval_computation true \
--has_sentence_ids "valid|para,test|para" \
--generate_hypothesis true \
--save_periodic 1 \
--validation_metrics 'valid_python_sa-java_sa_mt_comp_acc'  

Evaluate

Evaluation is done after each training epoch. But, if you want to evaluate a model without training it, run the same command as the training command and add these flags:

--reload_model '<TRANSCODER_MODEL_PATH>,<TRANSCODER_MODEL_PATH>' \
--eval_only True \
--reload_encoder_for_decoder false

For instance:

MODEL=<PATH_TO_MODEL>
python codegen_sources/model/train.py \
--exp_name transcoder_eval \
--dump_path '<DUMP_PATH>' \
--data_path '<DATASET_PATH>' \
--bt_steps 'python_sa-java_sa-python_sa,java_sa-python_sa-java_sa,python_sa-cpp_sa-python_sa,java_sa-cpp_sa-java_sa,cpp_sa-python_sa-cpp_sa,cpp_sa-java_sa-cpp_sa'    \
--encoder_only False \
--n_layers 0  \
--n_layers_encoder 6  \
--n_layers_decoder 6 \
--emb_dim 1024  \
--n_heads 8  \
--lgs 'cpp_sa-java_sa-python_sa'  \
--max_vocab 64000 \
--gelu_activation false \
--roberta_mode false  \
--amp 2  \
--fp16 true  \
--tokens_per_batch 3000  \
--max_batch_size 128 \
--eval_bleu true \
--eval_computation true \
--has_sentence_ids "valid|para,test|para" \
--generate_hypothesis true \
--save_periodic 1 \
--reload_model "$MODEL,$MODEL" \
--reload_encoder_for_decoder false \
--eval_only true \
--n_sentences_eval 1500

You do not need to have the training data in your data_path, only the validation and test sets.

Results

Our CA@1 results with the models we provide for beam size 1 (i.e. greedy decoding) and 10 (using length_penalty = 0.5 for beam size 10). The CA@N metrics may vary slightly due to timeouts. Our results for TransCoder are sligthly different from those of the original paper due to code and libraries updates. The model trained from DOBF was selected based on the validation score for Java -> Python.

Model/Task C++ -> Java C++ -> Python Java -> C++ Java -> Python Python -> C++ Python -> Java
Beam Size k=1 k=10 k=1 k=10 k=1 k=10 k=1 k=10 k=1 k=10 k=1 k=10
TransCoder_model_1 62.99 64.86 44.71 47.08 80.04 78.76 46.87 48.81 31.55 33.69 33.89 35.55
TransCoder_model_2 62.37 62.99 42.33 43.41 77.68 78.54 46.87 47.73 29.61 32.4 32.64 35.97
TransCoder from DOBF - - - - - - 49.24 52.7 - - 39.5 45.32

Train in multi GPU

To train a model in multi GPU replace python codegen_sources/model/train.py with:

export NGPU=2; python -m torch.distributed.launch --nproc_per_node=$NGPU codegen_sources/model/train.py

References

This repository contains code that was used to train and evaluate the TransCoder model. Our paper was published at NeurIPS 2020:

[1] B. Roziere*, M.A. Lachaux*, L. Chanussot, G. Lample Unsupervised Translation of Programming Languages.

* Equal Contribution

@article{roziere2020unsupervised,
  title={Unsupervised translation of programming languages},
  author={Roziere, Baptiste and Lachaux, Marie-Anne and Chanussot, Lowik and Lample, Guillaume},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}