In [163]:
## 模型构建
import sentencepiece as spm

corpus = "corpus.txt"
model_prefix = "unigram_model"
model_type = "unigram"
vocab_size = 15
character_coverage = 0.9995

spm.SentencePieceTrainer.Train(
    input='train.raw.en,train.raw.zh',
    model_prefix='colacc',
    model_type='unigram',
    vocab_size=22,
    character_coverage=1,
)


In [167]:
## 模型编码
spm_model = spm.SentencePieceProcessor()
spm_model.load("colacc.model")

with open('train.en', 'w') as out_f:
    with open('train.raw.en', 'r') as in_f:
            for line in in_f:
                line = line.strip()
                tok = spm_model.encode(line, out_type=str)
                print(' '.join(tok), file=out_f)

In [171]:
# binarize
import pathlib as Path
binpath = './data-bin'

!python -m fairseq_cli.preprocess \
    --source-lang {src_lang}\
    --target-lang {tgt_lang}\
    --trainpref {prefix/'train'}\
    --destdir {binpath}\
    --joined-dictionary\
    --workers 2

2023-06-29 17:46:25 | INFO | fairseq_cli.preprocess | Namespace(no_progress_bar=False, log_interval=100, log_format=None, log_file=None, aim_repo=None, aim_run_hash=None, tensorboard_logdir=None, wandb_project=None, azureml_logging=False, seed=1, cpu=False, tpu=False, bf16=False, memory_efficient_bf16=False, fp16=False, memory_efficient_fp16=False, fp16_no_flatten_grads=False, fp16_init_scale=128, fp16_scale_window=None, fp16_scale_tolerance=0.0, on_cpu_convert_precision=False, min_loss_scale=0.0001, threshold_loss_scale=None, amp=False, amp_batch_retries=2, amp_init_scale=128, amp_scale_window=None, user_dir=None, empty_cache_freq=0, all_gather_list_size=16384, model_parallel_size=1, quantization_config_path=None, profile=False, reset_logging=False, suppress_crashes=False, use_plasma_view=False, plasma_path='/tmp/plasma', criterion='cross_entropy', tokenizer=None, bpe=None, optimizer=None, lr_scheduler='fixed', scoring='bleu', task='translation', source_lang='en', target_lang='zh', tr

In [172]:
## 构建task

from fairseq.tasks.translation import TranslationConfig, TranslationTask


## setup task
task_cfg = TranslationConfig(
    data="./data-bin",
    source_lang="en",
    target_lang="zh",
    train_subset="train",
    required_seq_len_multiple=8,
    dataset_impl="mmap",
    upsample_primary=1,
)
task = TranslationTask.setup_task(task_cfg)


2023-06-29 17:55:45 | INFO | fairseq.tasks.translation | [en] dictionary: 8000 types
2023-06-29 17:55:45 | INFO | fairseq.tasks.translation | [zh] dictionary: 8000 types


In [175]:
task.load_dataset(split="train", epoch=1, combine=True) # combine if you have back-translation data.

sample = task.dataset("train")[2]
pprint.pprint(sample)
pprint.pprint(
    "Source: " + \
    task.source_dictionary.string(
        sample['source'],
        config.post_process,
    )
)
pprint.pprint(
    "Target: " + \
    task.target_dictionary.string(
        sample['target'],
        config.post_process,
    )
)

2023-06-29 17:59:53 | INFO | fairseq.data.data_utils | loaded 390,112 examples from: ./data-bin/train.en-zh.en
2023-06-29 17:59:53 | INFO | fairseq.data.data_utils | loaded 390,112 examples from: ./data-bin/train.en-zh.zh
2023-06-29 17:59:53 | INFO | fairseq.tasks.translation | ./data-bin train en-zh 390112 examples


{'id': 2,
 'source': tensor([  19,   59,  306,    5, 1105,  491,   32, 1228,  156,   31,  138,  161,
          43,  389,    4,   11,   19,  282,   12,  657,   88,   16,   25,   40,
           9,  418,    5,  583,  123,    5,  283,  213,    6,   80,   64,   19,
         179,   12,  349,    9,  231, 1960,    7,    2]),
 'target': tensor([ 630, 3318, 3079,   62, 2568, 1431,  185,   36,  832,  175,  846, 1631,
           8, 2005, 1231,    4,  945,  449, 1118,  678,  112,   41, 1408, 2311,
           8,  142, 2240,   10,    2])}
('Source: i have been blown away by this conference , and i want to thank all '
 'of you for the many nice comments about what i had to say the other night .')
'Target: 這個研討會給我留下了極為深刻的印象 , 我想感謝大家對我之前演講的好評 。'


In [177]:
def load_data_iterator(task, split, epoch=1, max_tokens=20, num_workers=1, cached=True):
    batch_iterator = task.get_batch_iterator(
        dataset=task.dataset(split),
        max_tokens=20,
        max_sentences=None,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            max_tokens,
        ),
        ignore_invalid_inputs=True,
        seed=seed,
        num_workers=num_workers,
        epoch=epoch,
        disable_iterator_cache=not cached,
        # Set this to False to speed up. However, if set to False, changing max_tokens beyond 
        # first call of this method has no effect. 
    )
    return batch_iterator

demo_epoch_obj = load_data_iterator(task, "train", epoch=1, max_tokens=20, num_workers=1, cached=False)
demo_iter = demo_epoch_obj.next_epoch_itr(shuffle=True)
sample = next(demo_iter)
sample



{'id': tensor([3965]),
 'nsentences': 1,
 'ntokens': 9,
 'net_input': {'src_tokens': tensor([[  1,   1,   1,  31,  48,  13,   5,   6, 121,   6, 145, 209, 403, 292,
             7,   2]]),
  'src_lengths': tensor([13]),
  'prev_output_tokens': tensor([[   2,    5, 1056,  104, 3421,  449,    8, 3448,  163,    1,    1,    1,
              1,    1,    1,    1]])},
 'target': tensor([[   5, 1056,  104, 3421,  449,    8, 3448,  163,    2,    1,    1,    1,
             1,    1,    1,    1]])}

In [180]:
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
len(src_dict)

8000