In [None]:
import json
from pprint import pprint
from pathlib import Path

import pandas as pd
pd.set_option('max_colwidth',300)

In [None]:
with open('data/python/final/jsonl/train/python_train_0.jsonl', 'r') as f:
    sample_file = f.readlines()
pprint(json.loads(sample_file[0]))

In [None]:
# files = sorted(Path('data/').glob('**/*.jsonl'))
files = sorted(Path('data/python/').glob('**/*.jsonl'))

columns_long_list = ['repo', 'path', 'url', 'code', 
                     'code_tokens', 'docstring', 'docstring_tokens', 
                     'language', 'partition']

def jsonl_list_to_dataframe(file_list, columns=columns_long_list):
    """Load a list of jsonl.gz files into a pandas DataFrame."""
    return pd.concat([pd.read_json(f, 
                                   orient='records', 
                                #    compression='gzip',
                                   lines=True)[columns] 
                      for f in file_list], sort=False)

df = jsonl_list_to_dataframe(files, columns_long_list)
df.head(1)

In [None]:
# focus on a python file first
print(df.language.value_counts())

# What happens next?
# 1. filter out code_tokens that start with #(comments)
# 2. concatenate all the tokens into a code string
# done with further processing

columns_short_list = ['code_tokens']
code = jsonl_list_to_dataframe(files, columns_short_list)

In [None]:
code['filtered_code_tokens'] = [[token for token in row if len(token) > 0 and token[0] != '#']
                                for row in code['code_tokens']]
code.head(1)

In [None]:
code['code_string'] = [' '.join(row) for row in code['filtered_code_tokens']]
code.head(1)

In [None]:
# introduce the local language model to do the "multi mask filling"

from transformers import RobertaTokenizer, RobertaForMaskedLM

tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base-mlm")
model = RobertaForMaskedLM.from_pretrained("microsoft/codebert-base-mlm")

print(tokenizer.mask_token_id)
print(tokenizer.vocab_size)

In [None]:
code = "if ( <mask> is not None ) <mask> ( x > 1 )"  # simulate the multi mask scenario
token_ids = tokenizer.encode(code, return_tensors='pt')
masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero()
masked_pos = [mask.item() for mask in masked_position]
# masked_pos  # [3, 8]

In [None]:
import torch
import torch.nn.functional as F

with torch.no_grad():
    output = model(token_ids)
output.logits.shape  # output[0].shape is torch.Size([1, 15, 50265])

In [None]:
last_hidden_state = output[0].squeeze()

list_of_list = []  # multiple guessings for each masked token
for mask_index in masked_pos:
    mask_hidden_state = last_hidden_state[mask_index]
    top_values, top_indices = torch.topk(mask_hidden_state, k=5, dim=0)
    top_prob = F.softmax(top_values, dim=0)
    top_words = [tokenizer.decode(i.item()).strip() for i in top_indices]
    list_of_list.append((top_words, top_indices.tolist(), top_prob.tolist()))

list_of_list