注意到 `dataset.py` 中直接从 json 中读取了数据，但数据集并不符合这种格式，所以需要将数据集转换为 json 格式。

```py
# 复现代码中读取数据集的部分
            all_raw_base, all_raw_a, all_raw_b, all_raw_res = json.load(open('%s/raw_data'%(total_raw_data_path)))
```

所以我们需要将数据集转换为 json 格式：

```
[
    ["base1", "base2", "base3"],
    ["a1", "a2", "a3"],
    ["b1", "b2", "b3"],
    ["res1", "res2", "res3"]
]
```




In [1]:
data_dir = 'RAW_DATA/fse2022'
out_file = 'RAW_DATA/raw_data'
# walk 递归找到目录下的所有 json 文件
import os

def get_all_json_files(path):
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith('metadata.json'):
                yield os.path.join(root, file)

all_json_files = list(get_all_json_files(data_dir))
print(len(all_json_files))

# 读取 json 文件
import json
from tqdm import tqdm

o_contents = []
a_contents = []
b_contents = []
r_contents = []
for file in tqdm(all_json_files):
    with open(file, 'r') as f:
        data = json.load(f)
        for chunk in data['conflicting_chunks']:
            if chunk['res_region'] is None:
                continue
            o_contents.append(chunk['base_contents'])
            a_contents.append(chunk['a_contents'])
            b_contents.append(chunk['b_contents'])
            r_contents.append(chunk['res_region'])
    
assert len(o_contents) == len(a_contents) == len(b_contents) == len(r_contents)
print(len(o_contents))

json_arr = [
    o_contents,
    a_contents,
    b_contents,
    r_contents
]

# 把 json_arr 写入文件
with open(out_file, 'w') as f:
    json.dump(json_arr, f)

48785


100%|██████████| 48785/48785 [00:18<00:00, 2667.80it/s]


151426


# 自己收集的数据集 .json 转化为 raw_data

In [22]:
data_dir = '/root/projects/conflictManager/edit_script_resolver/train_and_infer/data/processed_data/recollect_without_min_bundle_without_file_content'
out_file = 'RAW_DATA/graphQL_raw_data_sample_20'


# 1. 列出 data_dir 下所有 xx.json 文件
def get_all_json_files(path):
    import os
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith('.json'):
                from pathlib import Path
                basename = Path(file).stem
                idx = basename.split('.')[0]
                yield (os.path.join(root, file), idx)

tuples = tuple(get_all_json_files(data_dir))

o_contents = []
a_contents = []
b_contents = []
r_contents = []


# 2. 读取 json 文件
from tqdm import tqdm
import json
for file, idx in tqdm(tuples, dynamic_ncols=True, desc='Reading json files', leave=False, position=0):
    if (int(idx) >= 20): continue
    with open(file, 'r') as f:
        cfs = json.load(f)
        for cf in tqdm(cfs, dynamic_ncols=True, desc='Reading conflict chunks', leave=False, position=1):
            for chunk in cf['conflict_chunks']:
                o_contents.append(chunk['o_content'])
                a_contents.append(chunk['a_content'])
                b_contents.append(chunk['b_content'])
                r_contents.append(chunk['r_content'])
    assert len(o_contents) == len(a_contents) == len(b_contents) == len(r_contents)

print(len(o_contents))
print(len(a_contents))
print(len(b_contents))
print(len(r_contents))

json_arr = [
    o_contents,
    a_contents,
    b_contents,
    r_contents
]

# 把 json_arr 写入文件
with open(out_file, 'w') as f:
    json.dump(json_arr, f)

                                                                   

358446
358446
358446
358446


# 看看 token_len 分布

In [4]:
# 找到所有符合这个模式的文件

# data_path = 'RAW_DATA/raw_data'
data_path = 'RAW_DATA/graphQL_raw_data_sample_20'

# 内容是 all_raw_base, all_raw_a, all_raw_b, all_raw_res = json.load(open(data_path, 'r'))
# 统计所有 inputs 和 outputs 的长度分布

import numpy as np
import json
import pickle
import os
from tqdm import tqdm   
from collections import defaultdict
from transformers import RobertaTokenizer, T5Model, T5ForConditionalGeneration, AdamW

# 模型类型设定为 CodeT5 的小模型
model_type = 'Salesforce/codet5-small'
local_path = './codet5/codet5-small'

# 初始化对应的分词器
# tokenizer = RobertaTokenizer.from_pretrained(model_type)
tokenizer = RobertaTokenizer.from_pretrained(local_path)


# inputs_lens = defaultdict(int)
# outputs_lens = defaultdict(int)
res_lens = defaultdict(int)

all_raw_base, all_raw_a, all_raw_b, all_raw_res = json.load(open(data_path, 'r'))

print(len(all_raw_base))
for raw_res in tqdm(all_raw_res):
    
    raw_res = ' '.join(raw_res.split())
    # 对 res 进行分词
    # 利用分词器对各版本代码进行分词
    tokens_res = tokenizer.tokenize(raw_res)
    ids_res = tokenizer.convert_tokens_to_ids(tokens_res)
    # 统计长度
    res_lens[len(ids_res) <= 200] += 1

print(res_lens)

151426


100%|██████████| 151426/151426 [04:03<00:00, 620.89it/s] 

defaultdict(<class 'int'>, {True: 138641, False: 12785})



