In [3]:
from pathlib import Path
import json
from random import choice

import numpy as np
import torch
from bert import modeling, optimization
from transformers import BertConfig, BertForPreTraining, Trainer, TrainingArguments, InputFeatures, BatchEncoding
from sklearn.model_selection import train_test_split
import wandb

In [4]:
wandb.login()
%env WANDB_WATCH=all

env: WANDB_WATCH=all


In [5]:
data_root_path = Path('/home/maxkvant/data/pretraining_dataset/prepr/cleaned/')
repos = list(data_root_path.iterdir())
train_repos, val_repos = train_test_split(repos)

In [6]:
with open('/home/maxkvant/data/pretraining_dataset/20200621_Python_github_python_minus_ethpy150open_deduplicated_vocabulary.txt') as vocab_file:
    tokens = [token[1:-2] for token in vocab_file.readlines()]
token_ids = {token: token_id for token_id, token in enumerate(tokens)}

In [7]:
TOKEN_ID_MASKED = -100
MASKING_PROBABILITY = .15


def make_masked(input_ids):
    return


class TokenizedReposDataset(torch.utils.data.IterableDataset):
    def __init__(self, repo_paths, token_ids):
        self.file_paths = [
            file_path 
            for repo_path in repo_paths
            for file_path in repo_path.glob('**/*.json')
        ]
        self.token_ids = token_ids
        def file_len(file_path):
            try:
                return len(json.load(file_path.open()))
            except:
                return 0
        self.len = sum(map(file_len, self.file_paths))
        
    def __iter__(self):
        for file_path in self.file_paths:
            with file_path.open() as file:
                try:
                    file_lines = json.load(file)
                except:
                    continue
                file_lines = [
                    [self.token_ids[token] for token in line]
                    for line in file_lines
                ]
                last_i = len(file_lines) - 1
                
                for i, line in enumerate(file_lines):
                    next_label = np.random.rand() > .5 and i != last_i
                    try:
                        b_line = file_lines[i + 1] if next_label else choice(file_lines[:i] + file_lines[i + 1:])
                    except IndexError:
                        continue
                    type_ids = [0 for _ in line] + [1 for _ in b_line]
                    input_ids = line + b_line
                    masked_input = np.array(input_ids)
                    mask = np.random.binomial(1, MASKING_PROBABILITY, len(input_ids)).astype(np.bool)
                    masked_input[mask] = TOKEN_ID_MASKED
                    if len(input_ids) <= 512:
                        yield InputFeatures(input_ids, label=masked_input), next_label, type_ids
                    
    def __len__(self):
        return self.len

In [5]:
# with open('train.txt', 'wt') as val_ds_file:
#     for line_token_ids in TokenizedReposDataset(train_repos, token_ids):
#         val_ds_file.write(' '.join(map(str, line_token_ids)) + '\n')

In [6]:
# def read(path):
#     data = []
#     with open(path) as f:
#         for l in f.readlines():
#             input_ids = list(map(int, l.split()))
#             masked_input = np.array(input_ids)
#             mask = np.random.binomial(1, MASKING_PROBABILITY, len(input_ids)).astype(np.bool)
#             masked_input[mask] = TOKEN_ID_MASKED
#             data.append(InputFeatures(input_ids, label=masked_input))
#     return data

In [8]:
TOKEN_ID_PAD = token_ids['<pad>_']


def collate(data):
    batch_width = max(len(dp.input_ids) for dp, _, _ in data)
    inputs_ids, attention_masks, lang_masks, next_sentense_labels, token_type_ids = [], [], [], [], []
    for dp, next_label, type_ids in data:
        input_ids = dp.input_ids
        line_len = len(input_ids)
        pad_len = batch_width - line_len
        inputs_ids.append(input_ids + [TOKEN_ID_PAD] * pad_len)
        attention_masks.append([1] * line_len + [0] * pad_len)
        lang_masks.append(np.concatenate((dp.label, np.zeros(pad_len, dtype=np.long))))
        next_sentense_labels.append(next_label)
        token_type_ids.append(type_ids + [0] * pad_len)
    return BatchEncoding(data={
        'input_ids': torch.LongTensor(inputs_ids), 
        'attention_mask': torch.LongTensor(attention_masks), 
        'labels': torch.LongTensor(lang_masks),
        'next_sentence_label': torch.LongTensor(next_sentense_labels),
        'token_type_ids': torch.LongTensor(token_type_ids)
    })

In [9]:
model = BertForPreTraining(BertConfig(vocab_size=50000, return_dict=True))
# model = BertForPreTraining.from_pretrained('pytorch_bert_training_09_14/checkpoint-12500/')

In [None]:
train_args = TrainingArguments(
    output_dir='10_20', 
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    do_train=True,
    do_eval=True,
    evaluation_strategy='epoch',
    dataloader_num_workers=6,
    logging_steps=100,
    save_steps=250,
    save_total_limit=2,
    gradient_accumulation_steps=16,
    run_name='10_20'
)
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=TokenizedReposDataset(train_repos, token_ids),
    eval_dataset=TokenizedReposDataset(val_repos, token_ids),
    data_collator=collate,
)

In [None]:
trainer.train()