## 预备工作

In [None]:
from IPython import get_ipython
import os
from pathlib import Path
script_dir = get_ipython().starting_dir
# 修改工作目录为上一级
os.chdir(Path(script_dir) / '..')
from collections import defaultdict
import json
from util.conflict_util import Conflict, conflict2file
from tqdm.notebook import tqdm
from typing import List, Dict, Any, Tuple
import re
work_dir = Path(os.getcwd())
print(work_dir)

class ConflictChunk:
    def __init__(self, m_start, m_end, a_content, b_content, 
                 o_content, r_content, label: str | None, chunk_idx):
        self.m_start = m_start
        self.m_end = m_end
        self.a_content: 'str' = a_content
        self.b_content: 'str' = b_content
        self.o_content: 'str' = o_content
        self.r_content: 'str' = r_content
        self.label = label
        self.chunk_idx = chunk_idx

    def to_dict(self):
        return {
            "m_start": self.m_start,
            "m_end": self.m_end,
            "a_content": self.a_content,
            "b_content": self.b_content,
            "o_content": self.o_content,
            "r_content": self.r_content,
            "label": self.label,
        }
    
    def getJSONstr(self):
        return json.dumps(self, default=lambda o: o.__dict__, indent=4)


class ConflictFile:
    def __init__(self, path, repo_url, file_a_content, file_b_content, file_o_content, file_r_content, file_m_content, commit_hash):
        self.path = path
        self.repo_url = repo_url
        self.file_a_content = file_a_content
        self.file_b_content = file_b_content
        self.file_o_content = file_o_content
        self.file_r_content = file_r_content
        self.file_m_content = file_m_content
        self.commit_hash = commit_hash
        self.conflict_chunks = []

    def add_conflict_chunk(self, conflict_chunk_obj):
        self.conflict_chunks.append(conflict_chunk_obj)

    def to_dict(self):
        return {
            "path": self.path,
            "repo_url": self.repo_url,
            "file_a_content": self.file_a_content,
            "file_b_content": self.file_b_content,
            "file_o_content": self.file_o_content,
            "file_r_content": self.file_r_content,
            "file_m_content": self.file_m_content,
            "conflict_chunks": [chunk.to_dict() for chunk in self.conflict_chunks],
        }
    
    def getJSONstr(self):
        return json.dumps(self, default=lambda o: o.__dict__, indent=4)
    
class ConflictFileCollector:
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
    
    @staticmethod
    def sample(output_dir, n, random_seed=0, label=None):
        cnt = 0
        # 从所有冲突文件中随机抽取 n 个 label 类型的 Conflict chunk
        # 读取 output_dir 中的所有 JSON 文件
        jsons = list(ConflictFileCollector.getAllJsonsUnder(output_dir))
        print(f"Found {len(jsons)} JSON files in {output_dir}")
        # 读取所有 JSON 文件中的 Conflict chunk
        for json_file in jsons:
            with open(json_file) as f:
                data = json.load(f)
            for conflict_file in data:
                for chunk in conflict_file['conflict_chunks']:
                    if label == None or chunk['label'] == label:
                        if cnt >= n:
                            return
                        cnt += 1
                        yield chunk


    def collect(self):
        '''
        返回一个迭代器，每次迭代返回一个ConflictFile对象
        '''
        raise NotImplementedError
        
    def collect_in_batches(self, batch_size=10000):
        batch = []
        for conflict_file in self.collect():
            if conflict_file is None:
                continue
            batch.append(conflict_file)
            if len(batch) >= batch_size:
                yield batch
                batch = []
        if batch:
            yield batch

    def collect_and_save(self, output_dir, batch_size=10000):
        output_dir = Path(output_dir)  # 确保 output_dir 是 Path 对象
        output_dir.mkdir(parents=True, exist_ok=True)  # 自动创建目录及其父目录
        for i, batch in enumerate(self.collect_in_batches(batch_size)):
            with open(output_dir / f"{i}.json", 'w') as f:
                print(f"Saving batch {i} to {output_dir / f'{i}.json'}")
                json.dump([json.loads(x.getJSONstr()) for x in batch], f)
    
    @staticmethod
    def preprocessContent(content: str):
        return '' if content.strip() == '' else re.sub(r'\s+', ' ', content.strip() + '\n')
    
    @staticmethod
    def getLabel(a: str, b: str, o: str, r: str):
        r_processed = ConflictFileCollector.preprocessContent(r)
        a_processed = ConflictFileCollector.preprocessContent(a)
        b_processed = ConflictFileCollector.preprocessContent(b)
        o_processed = ConflictFileCollector.preprocessContent(o)
        if a_processed == b_processed:
            return "same modification, formatting maybe different"
        if r_processed == a_processed:
            return "A"
        if r_processed == b_processed:
            return "B"
        if r_processed == o_processed:
            return "O"
        if r_processed == a_processed + b_processed:
            return "AB"
        if r_processed == b_processed + a_processed:
            return "BA"

        r_lines = set(r.split('\n'))
        a_lines = set(a.split('\n'))
        b_lines = set(b.split('\n'))
        o_lines = set(o.split('\n'))
        for rl in r_lines:
            if (rl not in a_lines) and (rl not in b_lines) and (rl not in o_lines) and not rl.isspace():
                return 'newline'
        return 'mixline'

    @staticmethod
    def getAllJsonsUnder(dirPath: str):
        for root, _, files in os.walk(dirPath):
            for file in files:
                if(file.endswith(".json")):
                    yield os.path.join(root, file)
    
    @staticmethod
    def list2str(l):
        if l == [] or l == ['']:
            return ''
        return '\n'.join(l) + '\n'


## 观察样本

In [None]:
chunks = list(ConflictFileCollector.sample(work_dir / 'data_collect_analysis' / 'output' / '100+stars_4GB-_multidev_org_lang', n=10, label='mixline'))

### 观察不能被 es 解决的 mixline

In [None]:
print(len(chunks4debug))
import requests
requests.post('http://localhost:3000/api/versions', json=chunks4debug[:8] + chunks4debug[15:20])

# chunk = chunks4debug[0]
# print(chunk)

## 收集数据

统一将数据格式化为 conflictMap
```json
{
    "path": , // 文件相对路径
    "repo_url": , // 仓库地址
    "file_a_content": , // 文件 A 内容
    "file_b_content": , // 文件 B 内容
    "file_o_content": , // 文件 base 内容
    "file_r_content": , // 文件 Resolved 内容
    "file_m_content": , // 文件 Merged 内容
    "commitHash": ,     // commit hash
    "conflict_chunks": [
        {
            "m_start": , // merge 起始行
            "m_end": , // merge 结束行
            "a_content": , // A 内容
            "b_content": , // B 内容
            "o_content": , // base 内容
            "r_content": , // resolved 内容
            "label": , // conflict 类型
            "chunk_idx": , // chunk 在文件中是第几个 chunk     // 有可能有的 chunk 没有 resolutioin
        }
    ]
    
}
```

In [None]:
data_dir = work_dir / "data" / "congra_dataset"
output_dir = work_dir / "data_collect_analysis" / "output" / "congra_dataset"

cc_cnt = 0
illegal_cnt = 0

class CONGRACollector(ConflictFileCollector):
    '''
    从 CONGRA 数据集中收集冲突文件
    '''
    def __init__(self, dataset_path):
        super().__init__(dataset_path)
    
    def collect(self):
        # 1. 获取所有冲突文件路径 如 data_dir / 'C_C++/git/conflict_files_0/regions' 下的所有文件名，比如 a.java.region
        # 2. 第一行之后 (168, 192, 171, 195) 对应(origin_conflict_start, origin_conflict_end, resolved_start, resolved_end)，获得 n 个 tuple
        # 3. 读取 C_C++/git/conflict_files_0/merged/a.java，逐行扫描得到冲突块，assert len(cc) == n
        # 4. 读取 C_C++/git/conflict_files_0/resolved/a.java，获取(resolved_start, resolved_end)对应的内容
        conflict_files_dirs = []
        for root, dirs, files in os.walk(self.dataset_path):
            for d in dirs:
                if d.startswith('conflict_files_'):
                    conflict_files_dirs.append(os.path.join(root, d))

        print(len(conflict_files_dirs), '个冲突文件')

        regions = []

        # Step 1: 获取所有冲突文件路径
        for conflict_dir in tqdm(conflict_files_dirs):
            regions_dir = os.path.join(conflict_dir, 'regions')
            if os.path.exists(regions_dir) and os.path.isdir(regions_dir):
                for filename in os.listdir(regions_dir):
                    regions.append(os.path.join(regions_dir, filename))

        conflict_data = []

        # Step 2: 解析冲突文件，获得 n 个 tuple
        for region_file in regions:
            with open(region_file, 'r') as file:
                lines = file.readlines()
                tuples = [tuple(map(int, line.strip()[1:-1].split(','))) for line in lines[1:]]
                conflict_data.append((region_file, tuples))

        print(len(conflict_data), '个冲突')
        # debug
        # conflict_data = conflict_data[:5]

        for region_file, tuples in tqdm(conflict_data):
            try:
                base_name = os.path.basename(region_file).replace('.region', '')
                merged_file_path = os.path.join(region_file.split('/regions/')[0], 'merged', base_name)
                resolved_file_path = os.path.join(region_file.split('/regions/')[0], 'resolved', base_name)
                with open(merged_file_path, 'r') as f:
                    m_content = f.read()
                with open(resolved_file_path, 'r') as f:
                    r_content = f.read()
                # 新建 ConflictFile 对象
                conflict_file = ConflictFile(merged_file_path, str(merged_file_path), '', '', '', '', '', '')

                # Step 3: 读取 merged 文件，逐行扫描得到冲突块
                with open(merged_file_path, 'r') as merged_file:
                    merged_lines = merged_file.readlines()
                    # 逐行扫描
                    for i, line in enumerate(merged_lines):
                        if line.startswith('<<<<<<< a'):
                            rec = i + 1
                            m_start = i
                        if line.startswith('||||||| base'):
                            a_content: str = ''.join(merged_lines[rec:i])
                            rec = i + 1
                        if line.startswith('======='):
                            o_content: str = ''.join(merged_lines[rec:i])
                            rec = i + 1
                        if line.startswith('>>>>>>> b'):
                            b_content: str = ''.join(merged_lines[rec:i])
                            m_end = i
                            cc = ConflictChunk(m_start, m_end, a_content, b_content, o_content, '', None, len(conflict_file.conflict_chunks)) # label 和 r_content 暂时不填
                            conflict_file.add_conflict_chunk(cc)
                    assert len(conflict_file.conflict_chunks) == len(tuples)

                # Step 4: 读取 resolved 文件，获取(resolved_start, resolved_end)对应的内容
                with open(resolved_file_path, 'r') as resolved_file:
                    resolved_lines = resolved_file.readlines()
                    for i,(_, _, resolved_start, resolved_end) in enumerate(tuples):
                        if resolved_end - 1 > len(resolved_lines) or resolved_start > resolved_end:
                            print('resolved_end 超出 resolved 文件长度')
                        conflict_file.conflict_chunks[i].r_content = ''.join(resolved_lines[resolved_start: resolved_end - 1])
                        conflict_file.conflict_chunks[i].label = ConflictFileCollector.getLabel(conflict_file.conflict_chunks[i].a_content, 
                                                                                            conflict_file.conflict_chunks[i].b_content, 
                                                                                            conflict_file.conflict_chunks[i].o_content, 
                                                                                            conflict_file.conflict_chunks[i].r_content)
                global cc_cnt
                cc_cnt += len(conflict_file.conflict_chunks)
                yield conflict_file
            except Exception as e:
                # print('Error:', e)
                global illegal_cnt
                illegal_cnt += 1
                continue

collector = CONGRACollector(data_dir)
collector.collect_and_save(output_dir)
print('illegal file: ', illegal_cnt)
print('conflict chunk: ', cc_cnt)

In [None]:
data_dir = "/root/projects/gitMergeScenario/collect_output/output/conflictFiles"
output_dir = work_dir / "data_collect_analysis" / "output" / "100+stars_4GB-_multidev_org_lang"

class GraphQLFilteredRepoCollector(ConflictFileCollector):
    '''
    100+ stars, non_fork, 10+devs, org, 4GB- repos on GitHub
    '''
    def __init__(self, dataset_path):
        super().__init__(dataset_path)
    
    def collect(self):
        # 1. 获取所有 json 文件名 /xxx/conflictFiles/hash/conflictFilesMetadata.json
        # 3. 读取 json 文件，构造 ConflictFile 对象

        metadata_jsonPaths = [path for path in self.getAllJsonsUnder(self.dataset_path)]
        if len(metadata_jsonPaths) == 0:
            raise FileNotFoundError("No metadata json files found in the dataset path")

        for jsonPath in tqdm(metadata_jsonPaths):
            # 提取路径
            basename = os.path.basename(jsonPath)
            if basename != 'conflictFilesMetadata.json':
                raise ValueError("conflictFilesMetadata.json file name error")
            dirname = os.path.dirname(jsonPath)

            ret = []
            # jsonData
            with open(jsonPath, 'r') as f:      # 好多数据都没收集
                try:
                    metadata_list = json.load(f)
                    for metadata in metadata_list:
                        repo_url = None                 # 还真没记录 repo_url 或者 author/repoName，只记录 repoName 了
                        path = metadata['filePath']
                        suffix = path.split('.')[-1]
                        conflictChunks = metadata['conflictChunks']
                        commit_hash = metadata['resolvedCommitHash']

                        a_content = '\n'.join(metadata['oursContent'])          # 不需要在最后 + '\n'，收集数据集是用的是 String.split('\n', -1) -1 代表尽量分割，所以 join 后不需要再加换行符
                        b_content = '\n'.join(metadata['theirsContent'])
                        base_content = '\n'.join(metadata['baseContent'])
                        merged_content = '\n'.join(metadata['mergedContent'])
                        r_content = '\n'.join(metadata['resolvedContent'])
                
                    # 构造 ConflictFile 对象
                    conflict_file = ConflictFile(path, repo_url, a_content, b_content, base_content, r_content, merged_content, commit_hash)

                    # 过滤 .min.xx .bundle.xx
                    if '.min.' in path or '.bundle.' in path:
                        continue
                    # 过滤内容只有很长的一行的文件，例如某种形式的 .min.js
                    for content in [metadata['oursContent'], metadata['theirsContent'], metadata['baseContent'], metadata['resolvedContent']]:
                        if len(content) == 1 and len(content[0]) > 2000:
                            continue


                    for chunk in conflictChunks:
                        if 'resolution' not in chunk or chunk['resolution'] == None:                          # gitMergeScenario 中 DeepMergeAligner 没有找到 resolution
                            continue
                        # m_start, m_end 和 chunk_idx 不太好拿，先忽略
                        # 最后得加 \n，因为 DeepMergeAligner 获取 resolution 是从代码行数组中提取出来的行，所以要加上换行符
                        cc = ConflictChunk(
                                chunk['startLine'], 
                                chunk['endLine'], 
                                self.list2str(chunk['ours']), 
                                self.list2str(chunk['theirs']), 
                                self.list2str(chunk['base']), 
                                self.list2str(chunk['resolution']), 
                                None, None)
                        cc.label = self.getLabel(cc.a_content, cc.b_content, cc.o_content, cc.r_content)
                        conflict_file.add_conflict_chunk(cc)
                    ret.append(conflict_file)
                except Exception as e:
                    print(f"Error reading {jsonPath}: {e} (type: {type(e).__name__})")
                    import traceback
                    traceback.print_exc()  # 打印完整堆栈信息
            for conflict_file in ret:
                yield conflict_file

collector = GraphQLFilteredRepoCollector(data_dir)
collector.collect_and_save(output_dir)

# todo: 为什么有很多 chunk 没有 mergedContent

# 这里收集的 file merged content 的冲突块范围是由 jgit 的 formatter 生成的，有一个问题是会排除 AB 中的公共行
# 另有一个脚本可以在这个基础上生成冲突块，但是不排除 AB 中的公共行（采用 git merge-file）

In [None]:
# 在已经输出 JSON 的情况下，临时处理一下 .min.js 和 .bundle.js 文件
# tmp content，某次收集后的补救措施

data_dir = Path("/Volumes/urine_bag/100+stars_4GB-_multidev_org_lang")
# output_dir = work_dir / "data_collect_analysis" / "output" / "100+stars"
output_dir = Path("/Volumes/urine_bag/recollect_without_min_bundle_without_file_content")

# ensure output_dir exists
output_dir.mkdir(parents=True, exist_ok=True)

def filter_cf(cf: dict) -> bool:
    if '.min.' in cf['path'] or '.bundle.' in cf['path']:
        return False
    for content in [cf['file_a_content'], cf['file_b_content'], cf['file_o_content'], cf['file_r_content']]:
        if len(content) > 2000 and content.count('\n') <= 1:
            return False
    return True
# 1. 读取所有 JSON 文件
# 2. 读取所有 JSON 文件中的 ConflictFile
# 3. 过滤 .min.js 和 .bundle.js 文件，过滤超长一行
# 4. 保存到 output_dir
jsons = list(data_dir.glob('*.json'))
for json_file in tqdm(jsons):
    with open(json_file) as f:
        data = json.load(f)
    
    # 只保留几个字段
    filtered_data = [{
        "path": cf['path'],
        "repo_url": cf['repo_url'],
        "commit_hash": cf['commit_hash'],
        "conflict_chunks": cf['conflict_chunks'],
    } for cf in data if filter_cf(cf)]
    with open(output_dir / json_file.name, 'w') as f:
        json.dump(filtered_data, f)


In [None]:
data_dir = work_dir / "data" / "2000repos"
output_dir = work_dir / "data_collect_analysis" / "output" / "2000repos"

class MergeNatureRepoCollector(ConflictFileCollector):
    '''
    2000 repos 数据集转化为 conflictMap
    '''
    def __init__(self, dataset_path):
        super().__init__(dataset_path)

    def collect(self):
        # 1. 获取所有 json 文件名 /xxx/repo_name/hash/relativePath/filename/metadata.json
        # 2. 提取对应目录下的 ours.xxx theirs.xxx base.xxx conflict.xxx resolve.xxx
        # 3. 读取 metadata.json, 获取 repo_url, path 以及 conflict chunks

        metadata_jsonPaths = [path for path in self.getAllJsonsUnder(self.dataset_path)]
        if len(metadata_jsonPaths) == 0:
            raise FileNotFoundError("No metadata json files found in the dataset path")
        for jsonPath in tqdm(metadata_jsonPaths):
            # 提取路径
            basename = os.path.basename(jsonPath)
            if basename != 'metadata.json':
                raise ValueError("metadata.json file name error")
 
            # jsonData
            with open(jsonPath, 'r') as f:      # 好多数据都没收集
                metadata = json.load(f)
                repo_url = None
                path = metadata['path']
                suffix = metadata['filetype']
                conflict_chunks = metadata['conflicting_chunks']
                commit_hash = metadata['commitID']
            dirname = os.path.dirname(jsonPath)
            # 读取 a, b, base, merged, resolved
            a_path = os.path.join(dirname, 'ours' + suffix)
            b_path = os.path.join(dirname, 'theirs' + suffix)
            base_path = os.path.join(dirname, 'base' + suffix)
            merged_path = os.path.join(dirname, 'conflict' + suffix)
            resolved_path = os.path.join(dirname, 'resolve' + suffix)

            # 读取文件内容
            try:
                with open(a_path, 'r') as f:
                    a_content = f.read()
                with open(b_path, 'r') as f:
                    b_content = f.read()
                with open (base_path, 'r') as f:
                    base_content = f.read()
                with open (merged_path, 'r') as f:
                    merged_content = f.read()
                with open (resolved_path, 'r') as f:
                    r_content = f.read()
            except Exception as e:
                # 有的文件不存在，直接跳过
                # print(jsonPath)
                # print(e)
                continue
            
            # 构造 ConflictFile 对象
            conflict_file = ConflictFile(path, repo_url, a_content, b_content, base_content, r_content, merged_content, commit_hash)
            for chunk in conflict_chunks:
                if 'resolve' not in chunk or chunk['resolve'] == None:
                    continue
                # m_start, m_end 和 chunk_idx 不太好拿，先忽略
                cc = ConflictChunk(-1, -1, chunk['a_contents'], chunk['b_contents'], 
                                    chunk['base_contents'], chunk['resolve'], None, None)
                cc.label = self.getLabel(cc.a_content, cc.b_content, cc.o_content, cc.r_content)
                conflict_file.add_conflict_chunk(cc)
            yield conflict_file


collector = MergeNatureRepoCollector(data_dir)
collector.collect_and_save(output_dir)

In [None]:
data_dir = work_dir / "data" / "top50"
output_dir = work_dir / "data_collect_analysis" / "output" / "top50"

class MergeNatureRepoTop50Collector(ConflictFileCollector):
    '''
    top50/2000 repos 数据集转化为 conflictMap
    '''
    def __init__(self, dataset_path):
        super().__init__(dataset_path)
    
    def collect(self):
        # 1. 获取所有 json 文件名 /.../repo_name/12345_a.java
        #    提取最后的 12345
        # 2. 获取对应的 12345_a.xxx, 12345_b.xxx, 12345_base.xxx, 12345_merged.xxx, 12345_resolved.xxx
        # 3. 读取 metadata.json, 获取 repo_url, path 以及 conflict chunks

        metadata_jsonPaths = [path for path in self.getAllJsonsUnder(self.dataset_path)]
        if len(metadata_jsonPaths) == 0:
            raise FileNotFoundError("No metadata json files found in the dataset path")
        for jsonPath in tqdm(metadata_jsonPaths):
            # 提取路径
            basename = os.path.basename(jsonPath)
            dirname = os.path.dirname(jsonPath)

            # jsonData
            with open(jsonPath, 'r') as f:      # 好多数据都没收集
                metadata = json.load(f)
                repo_url = None
                path = None
                suffix = metadata['filetype']
                conflict_chunks = metadata['conflicting_chunks']
                commit_hash = None
            
            # 读取 a, b, base, merged, resolved
            a_path = os.path.join(dirname, basename.replace('_metadata.json', '_a' + suffix))
            b_path = os.path.join(dirname, basename.replace('_metadata.json', '_b' + suffix))
            base_path = os.path.join(dirname, basename.replace('_metadata.json', '_base' + suffix))
            merged_path = os.path.join(dirname, basename.replace('_metadata.json', '_merged' + suffix))
            resolved_path = os.path.join(dirname, basename.replace('_metadata.json', '_resolved' + suffix))

            # 读取文件内容
            try:
                with open(a_path, 'r') as f:
                    a_content = f.read()
                with open(b_path, 'r') as f:
                    b_content = f.read()
                with open (base_path, 'r') as f:
                    base_content = f.read()
                with open (merged_path, 'r') as f:
                    merged_content = f.read()
                with open(resolved_path, 'r') as f:
                    resolved_content = f.read()
            except Exception as e:
                print(jsonPath)
                print(e)
                continue
            
            # 构造 ConflictFile 对象
            conflict_file = ConflictFile(path, repo_url, a_content, b_content, base_content, resolved_content, merged_content, commit_hash)
            for chunk in conflict_chunks:
                if 'resolve' not in chunk or chunk['resolve'] == None:
                    continue
                # m_start, m_end 和 chunk_idx 不太好拿，先忽略
                cc = ConflictChunk(-1, -1, chunk['a_contents'], chunk['b_contents'], 
                                    chunk['base_contents'], chunk['resolve'], None, None)
                cc.label = self.getLabel(cc.a_content, cc.b_content, cc.o_content, cc.r_content)
                conflict_file.add_conflict_chunk(cc)
            yield conflict_file

collector = MergeNatureRepoTop50Collector(data_dir)
collector.collect_and_save(output_dir)

In [None]:
# data_dir = work_dir / "data" / "mergebert_data" / "automated-analysis-data" / "TypeScript"
# output_dir = work_dir / "data_collect_analysis" / "output" / "mergebert_ts"

data_dir = work_dir / "data" / "mergebert_data" / "automated-analysis-data"
output_dir = work_dir / "data_collect_analysis" / "output" / "mergebert_all_lang"

class MergeBERTConflictFileCollector(ConflictFileCollector):
    '''
    MergeBERT 数据集转化为 conflictMap
    '''
    def __init__(self, dataset_path):
        super().__init__(dataset_path)
    
    def collect(self):
        # 1. 获取所有 json 文件名，如 /Users/foril/projects/conflict_resolve/my_work/dataset_collect_analysis_script/data/mergebert_data/automated-analysis-data/TypeScript/55743_metadata.json,
        #    提取最后的 12345
        # 2. 获取对应的 12345_a.xxx, 12345_b.xxx, 12345_base.xxx, 12345_merged.xxx, 12345_resolved.xxx
        # 3. 读取 metadata.json, 获取 repo_url, path 以及 conflict chunks
        chunk_cnt = 0
        chunk_no_r_cnt = 0

        metadata_jsonPaths = [path for path in self.getAllJsonsUnder(self.dataset_path)]
        if len(metadata_jsonPaths) == 0:
            raise FileNotFoundError("No metadata json files found in the dataset path")
        for jsonPath in tqdm(metadata_jsonPaths):
            # 提取路径
            basename = os.path.basename(jsonPath)
            dirname = os.path.dirname(jsonPath)

            # jsonData
            with open(jsonPath, 'r') as f:
                metadata = json.load(f)
                repo_url = metadata['repo']
                path = metadata['fname']
                suffix = path.split('.')[-1]
                conflict_chunks = metadata['conflicting_chunks']
                commit_hash = metadata['commitHash']
            
            # 读取 a, b, base, merged, resolved
            a_path = os.path.join(dirname, basename.replace('_metadata.json', '_a.' + suffix))
            b_path = os.path.join(dirname, basename.replace('_metadata.json', '_b.' + suffix))
            base_path = os.path.join(dirname, basename.replace('_metadata.json', '_base.' + suffix))
            merged_path = os.path.join(dirname, basename.replace('_metadata.json', '_merged.' + suffix))
            resolved_path = os.path.join(dirname, basename.replace('_metadata.json', '_resolved.' + suffix))

            # 读取文件内容
            with open(a_path, 'r') as f:
                a_content = f.read()
            with open(b_path, 'r') as f:
                b_content = f.read()
            with open (base_path, 'r') as f:
                base_content = f.read()
            with open (merged_path, 'r') as f:
                merged_content = f.read()
            with open(resolved_path, 'r') as f:
                resolved_content = f.read()
            
            # 构造 ConflictFile 对象
            conflict_file = ConflictFile(path, repo_url, a_content, b_content, base_content, resolved_content, merged_content, commit_hash)
            for chunk in conflict_chunks:
                chunk_cnt += 1
                if chunk['res_region'] == None:
                    chunk_no_r_cnt += 1
                    continue
                # m_start, m_end 和 chunk_idx 不太好拿，对 MergeBERT 数据集好像也不是很重要，先忽略吧
                cc = ConflictChunk(-1, -1, chunk['a_contents'], chunk['b_contents'], 
                                    chunk['base_contents'], chunk['res_region'], None, None)
                cc.mergebert_label = chunk.get('label', None) # type: ignore
                    # 'A',
                    #  'AB',
                    #  'B',
                    #  'BA',
                    #  'BASE',
                    #  None,
                    #  'OTHER',
                    #  'REM-BASE-A',
                    #  'REM-BASE-AB',
                    #  'REM-BASE-B',
                    #  'REM-BASE-BA',
                    #  'RES_EMPTY',
                    #  'RES_FILE_EMPTY'
                
                cc.label = self.getLabel(cc.a_content, cc.b_content, cc.o_content, cc.r_content)

                conflict_file.add_conflict_chunk(cc)
            yield conflict_file
        print(f"Total chunk count: {chunk_cnt}, chunk without r: {chunk_no_r_cnt}")

collector = MergeBERTConflictFileCollector(data_dir)
collector.collect_and_save(output_dir)

## 分析冲突块的类型分布

In [None]:
# 读取文件夹下所有 json 文件，统计 ConflictFile 下的 ConflictChunk 的 label 分布
dir2analyze = work_dir / "data_collect_analysis" / "output" / "100+stars_sample"
dir2analyze = work_dir / "data_collect_analysis" / "output" / "congra_dataset"
# dir2analyze = work_dir / "data_collect_analysis" / "output" / "100+_recollect"
# dir2analyze = work_dir / "data_collect_analysis" / "output" / "mergebert_ts"

# 输入存放 ConflictFiles 的目录，输出 类型分布 map，同时绘制饼图
def analyze_label_distribution(dir2analyze):
    label_cnt = defaultdict(int)
    # 获取所有 json 文件名
    jsonPaths = [path for path in ConflictFileCollector.getAllJsonsUnder(dir2analyze)]
    if len(jsonPaths) == 0:
        raise FileNotFoundError("No metadata json files found in the dataset path")
    for jsonPath in tqdm(jsonPaths, position=0, leave=True, dynamic_ncols=True):
        # jsonData
        with open(jsonPath, 'r') as f:
            try:
                for x in tqdm(json.load(f), position=1, leave=False, dynamic_ncols=True):
                    for chunk in x['conflict_chunks']:
                        ### tmp 
                        # 因为 bug，导致有的没有找到 resolution 的 chunk 也加入了，这里忽略
                        if 'label' not in chunk:
                            continue
                        ### tmp
                        label_cnt[chunk['label']] += 1
            except Exception as e:
                print(f"Error reading {jsonPath}: {e} (type: {type(e).__name__})")
                import traceback
                traceback.print_exc()
            
    import plotly.graph_objects as go
    # 创建饼图
    labels = list(label_cnt.keys())
    values = list(label_cnt.values())
    fig = go.Figure(data=[go.Pie(labels=labels, values=values)])
    # 设置布局
    fig.update_layout(title_text="各类型冲突占比", width=600, height=400)
    # 显示图形
    fig.show()
    from pprint import pprint
    pprint(label_cnt)
    return label_cnt


# 从文件中读取
label_cnt = analyze_label_distribution(dir2analyze)

## 分析一个文件中的冲突块数量分布

In [None]:
# 读取文件夹下所有 json 文件，统计 ConflictFile 下的 ConflictChunk 的数量分布
dir2analyze = work_dir / "data_collect_analysis" / "output" / "100+stars_sample"
# dir2analyze = work_dir / "data_collect_analysis" / "output" / "congra_dataset"
# dir2analyze = work_dir / "data_collect_analysis" / "output" / "100+_recollect"
# dir2analyze = work_dir / "data_collect_analysis" / "output" / "mergebert_ts"

# 输入存放 ConflictFiles 的目录，统计每个 ConflictFile 下的 ConflictChunk 数量，绘制直方图
def analyze_chunk_num_distribution(dir2analyze):
    num_cnt = defaultdict(int)
    # 获取所有 json 文件名
    jsonPaths = [path for path in ConflictFileCollector.getAllJsonsUnder(dir2analyze)]
    if len(jsonPaths) == 0:
        raise FileNotFoundError("No metadata json files found in the dataset path")
    for jsonPath in tqdm(jsonPaths, position=0, leave=True, dynamic_ncols=True):
        # jsonData
        with open(jsonPath, 'r') as f:
            try:
                for x in tqdm(json.load(f), position=1, leave=False, dynamic_ncols=True):
                    num_cnt[len(x['conflict_chunks'])] += 1
            except Exception as e:
                print(f"Error reading {jsonPath}: {e} (type: {type(e).__name__})")
                import traceback
                traceback.print_exc()
            
    import plotly.graph_objects as go
    x = list(num_cnt.keys())
    y = list(num_cnt.values())
    # 计算百分比
    total = sum(y)
    percentages = [f"{(value / total) * 100:.2f}%" for value in y]
    # 创建直方图并添加百分比标注
    fig = go.Figure(data=[go.Bar(x=x, y=y, text=percentages, textposition='outside')])
    # 设置布局
    fig.update_layout(title_text="冲突块数量分布", width=600, height=400)
    # 显示图形
    fig.show()


# 从文件中读取
analyze_chunk_num_distribution(dir2analyze)

## 提取指定语言数据集

In [None]:
# 读取文件夹下所有 json 文件，提取指定语言的文件中的合并冲突块
# dir2analyze = work_dir / "data_collect_analysis" / "output" / "mergebert_all_lang"
from math import ceil


dir2analyze = "/Volumes/Q1571825323/recollect_without_min_bundle_without_file_content"

codebert_conflict_files_output_dir = work_dir / "data_collect_analysis" / "output" / "codebert_conflict_files"
zero_shot_ext_conflict_files_output_dir = work_dir / "data_collect_analysis" / "output" / "zero_shot_conflict_files"

# 经过分析
codebert_lang_ext = {'.go', '.java', '.rb', '.js', '.py', '.php'}
zero_shot_ext = {'.h', '.mm', '.cs', '.swift', '.rs', '.c', '.ts', '.m', '.hpp', '.cpp'}

codebert_conflict_files = []
zero_shot_ext_conflict_files = []
# 获取所有 json 文件名
jsonPaths = [path for path in ConflictFileCollector.getAllJsonsUnder(dir2analyze)]
if len(jsonPaths) == 0:
    raise FileNotFoundError("No metadata json files found in the dataset path")
for jsonPath in tqdm(jsonPaths, position=0, leave=True, dynamic_ncols=True):
    # jsonData
    with open(jsonPath, 'r') as f:
        try:
            for x in tqdm(json.load(f), position=1, leave=False, dynamic_ncols=True):
                ext = os.path.splitext(x['path'])[1]
                if ext in codebert_lang_ext:
                    codebert_conflict_files.append(x)
                elif ext in zero_shot_ext:
                    zero_shot_ext_conflict_files.append(x)
                else:
                    assert False
        except Exception as e:
            print(f"Error reading {jsonPath}: {e} (type: {type(e).__name__})")
            import traceback
            traceback.print_exc()

# 分文件保存
os.makedirs(codebert_conflict_files_output_dir, exist_ok=True)
for i in tqdm(range(ceil(len(codebert_conflict_files) / 10000)), position=0, leave=True, dynamic_ncols=True):
    with open(codebert_conflict_files_output_dir / f"{i}.json", 'w') as f:
        json.dump(codebert_conflict_files[i * 10000: min((i + 1) * 10000, len(codebert_conflict_files))], f)

os.makedirs(zero_shot_ext_conflict_files_output_dir, exist_ok=True)
for i in tqdm(range(ceil(len(zero_shot_ext_conflict_files) / 10000)), position=0, leave=True, dynamic_ncols=True):
    with open(zero_shot_ext_conflict_files_output_dir / f"{i}.json", 'w') as f:
        json.dump(zero_shot_ext_conflict_files[i * 10000: min((i + 1) * 10000, len(zero_shot_ext_conflict_files))], f)

## 回溯分析解决上界

In [None]:
dirs = []
# dirs.append(work_dir / "data_collect_analysis" / "output" / "mergebert_ts")
# dirs.append(work_dir / "data_collect_analysis" / "output" / "mergebert_all_lang")
# dirs.append(work_dir / "data_collect_analysis" / "output" / "100+stars_4GB-_multidev_org_lang")
dirs.append(work_dir / "data_collect_analysis" / "output" / "100+stars_sample")
# dirs.append(work_dir / "data_collect_analysis" / "output" / "2000repos")
# dirs.append(work_dir / "data_collect_analysis" / "output" / "top50")



from util.edit_script import compute, SequenceDiff

class EditScriptLabel:
    def __init__(self, sd: SequenceDiff, _from: str, accept: bool):
        self.edit_script = sd
        self._from = _from
        self.accept = accept

def analyze_edit_script(dir2analyze):
    dataset_name = os.path.basename(dir2analyze)
    print(f'在 {dataset_name} 下统计')
    accept_mark_cnt = defaultdict(int)
    es_cnt = defaultdict(int)
    # cc_with_es_intersects = 0
    resolvable_cc_cnt = 0
    non_resolvable_cc_cnt = 0
    too_many_lines_cnt = 0
    too_many_es_cnt = 0
    label_cnt = defaultdict(int)
    label_resolvable_cnt = defaultdict(int)

    def cc_check(chunk: ConflictChunk) -> None:
        '''
        统计可以用编辑脚本解决的冲突，统计接受和拒绝的数量
        统计编辑脚本的数量，如果太多则跳
        最后比较时我希望转化成 token
        生成编辑脚本时，去除空行影响，缩进。。。去掉？
        '''
        nonlocal accept_mark_cnt
        nonlocal es_cnt
        nonlocal resolvable_cc_cnt
        nonlocal non_resolvable_cc_cnt
        nonlocal too_many_lines_cnt
        nonlocal too_many_es_cnt
        nonlocal label_resolvable_cnt

        def es_gen_str2list(content: str) -> List[str]:
            '''
            生成编辑脚本时的处理
            '''
            return [line.strip() for line in content.split('\n') if line.strip() != '']
            
        a_contents = es_gen_str2list(chunk.a_content)
        b_contents = es_gen_str2list(chunk.b_content)
        o_contents = es_gen_str2list(chunk.o_content)
        r_contents = es_gen_str2list(chunk.r_content)

        def compareInToken(a_ls: List[str], b_ls: List[str]) -> bool:
            '''
            最后比较的预处理，忽略空白符的影响
            '''
            def toUnifiedStr(ls: List[str]) -> str:
                return '' if ls == [] or ls == [''] else re.sub(r'\s+', ' ', '\n'.join(ls).strip() + '\n')
            a_processed = toUnifiedStr(a_ls)
            b_processed = toUnifiedStr(b_ls)
            # print(a_processed)
            # print(b_processed)
            # print(a_processed == b_processed)
            # print('-' * 20)
            return a_processed == b_processed

        def bt(generated, i, last_end, all_edit_scripts: List[EditScriptLabel]) -> bool:
            '''
            回溯法生成所有可能的解决方案，如果和 resolution 相同则加入结果集
            '''
            # exit
            if i == len(all_edit_scripts):
                whole_generated = generated + o_contents[last_end:]
                # 过滤 whole_generated 和 resolution 中的空行
                if compareInToken(whole_generated, r_contents):
                    # 可以使用组合 ES 的方式解决的冲突
                    return True
                return False

            # 不接受这个脚本
            all_edit_scripts[i].accept = False
            if bt(generated, i + 1, last_end, all_edit_scripts):
                return True

            # 如果当前脚本的起始位置比 last_end 还小，说明这个脚本和上一个脚本有冲突
            # 不能接受这个脚本，直接跳过
            if all_edit_scripts[i].edit_script.seq1Range.start < last_end:
                return False     # 因为是小于号，所以可以解决伪冲突

            # 接受这个脚本
            start = all_edit_scripts[i].edit_script.seq2Range.start
            end = all_edit_scripts[i].edit_script.seq2Range.end
            if all_edit_scripts[i]._from == 'ours':
                curr_content = a_contents[start:end]
            else:
                curr_content = b_contents[start:end]
            all_edit_scripts[i].accept = True
            if bt(generated
                    + o_contents[last_end:all_edit_scripts[i].edit_script.seq1Range.start]
                    + curr_content,
                    i + 1,
                    all_edit_scripts[i].edit_script.seq1Range.end,
                    all_edit_scripts
                ):
                return True


            # 有下一个脚本，且两者对应 base 的位置相同
            if (
                i + 1 < len(all_edit_scripts) and
                all_edit_scripts[i].edit_script.seq1Range == all_edit_scripts[i + 1].edit_script.seq1Range
            ):
                start = all_edit_scripts[i + 1].edit_script.seq2Range.start
                end = all_edit_scripts[i + 1].edit_script.seq2Range.end
                if all_edit_scripts[i + 1]._from == 'ours':
                    next_content = a_contents[start:end]
                else:
                    next_content = b_contents[start:end]

                # base 长度为 0 的情况，只需要加入另一种 concat（seq1Range 的长度为 0，代表双方在同一位置的插入）
                all_edit_scripts[i + 1].accept = True
                if bt(generated
                        + o_contents[last_end:all_edit_scripts[i].edit_script.seq1Range.start]
                        + next_content
                        + curr_content,
                    i + 2,
                    all_edit_scripts[i].edit_script.seq1Range.end,
                    all_edit_scripts
                ):
                    return True
                # base 长度不为 0 的情况，需要考虑两种 concat
                if len(all_edit_scripts[i].edit_script.seq1Range) > 0: 
                    all_edit_scripts[i + 1].accept = True
                    if bt(generated
                            + o_contents[last_end:all_edit_scripts[i].edit_script.seq1Range.start]
                            + curr_content
                            + next_content,
                            i + 2,
                            all_edit_scripts[i].edit_script.seq1Range.end,
                            all_edit_scripts
                    ):
                        return True


        # 开始收集数据集
            
        # 如果行数过大，直接跳过
        if any([len(content) > 200 for content in [a_contents, b_contents, o_contents, r_contents]]):
            too_many_lines_cnt += 1
            return
        
        
        from_ours = compute(o_contents, a_contents)
        from_theirs = compute(o_contents, b_contents)
        # 加入 _from 标记
        from_ours = [EditScriptLabel(sd, 'ours', False) for sd in from_ours]
        from_theirs = [EditScriptLabel(sd, 'theirs', False) for sd in from_theirs]
        all_edit_scripts = from_ours + from_theirs
        es_cnt[len(all_edit_scripts)] += 1              # 统计编辑脚本数量，这里已经过滤了行数过大
        
        
        # 限制脚本数量，避免计算量过大
        if len(all_edit_scripts) > 10:
            too_many_es_cnt += 1
            return
        

        all_edit_scripts.sort(key=lambda editScriptLabel: editScriptLabel.edit_script.seq1Range)

        kind = chunk.label
        label_cnt[kind] += 1
        # 如果是 newline 的冲突，直接跳过
        if kind == 'newline':
            non_resolvable_cc_cnt += 1
            return
        if kind == 'same modification, formatting maybe different':
            resolvable_cc_cnt += 1
            label_resolvable_cnt[kind] += 1
            return

        if bt([], 0, 0, all_edit_scripts):  # 这个冲突能解决
            resolvable_cc_cnt += 1
            label_resolvable_cnt[kind] += 1
            # 统计 accept_mark
            for i, es in enumerate(all_edit_scripts):
                accept_mark_cnt[es.accept] += 1
        else:
            non_resolvable_cc_cnt += 1



    # 开始统计数据集结果
    jsonPaths = [path for path in ConflictFileCollector.getAllJsonsUnder(dir2analyze)]
    if len(jsonPaths) == 0:
        raise FileNotFoundError("No metadata json files found in the dataset path")
    for jsonPath in tqdm(jsonPaths, desc="Processing files", position=0, leave=True, dynamic_ncols=True):
        # jsonData
        try:
            with open(jsonPath, 'r') as f:
                cfs = json.load(f)
        except Exception as e:
            print(f"Error reading {jsonPath}: {e} (type: {type(e).__name__})")
            import traceback
            traceback.print_exc()
        for cf in tqdm(cfs, desc=f"Process items", position=1, leave=False, dynamic_ncols=True):
            for cc in cf['conflict_chunks']:
                # all_cc_cnt += 1
                # label_cnt[cc['label']] += 1
                cc = ConflictChunk(cc['m_start'], cc['m_end'], cc['a_content'], cc['b_content'], cc['o_content'], cc['r_content'], cc['label'], cc['chunk_idx'])
                cc_check(cc)
    
    def print_res_to_file(file=os.sys.stdout):
        print(f'在 {dataset_name} 下统计结果:', file=file) 
        print(f'过滤过长后，共有 {resolvable_cc_cnt + non_resolvable_cc_cnt} 个冲突块，其中 {resolvable_cc_cnt} 个可以用编辑脚本解决，占比 {resolvable_cc_cnt / (resolvable_cc_cnt + non_resolvable_cc_cnt) * 100:.2f}%', file=file)
        # print(f'有 {cc_with_es_intersects} 个冲突块的编辑脚本有交集', file=file)
        print(f'有 {too_many_lines_cnt} 个冲突块的行数过大，无法处理', file=file)
        print(f'有 {too_many_es_cnt} 个冲突块的编辑脚本数量过大，无法处理', file=file)
        print(f'编辑脚本数量分布: {es_cnt}', file=file)
        print(f'接受标记分布: {accept_mark_cnt}', file=file)
        print(f'类型分布: {label_cnt}', file=file)
        print(f'可解决类型分布: {label_resolvable_cnt}', file=file)
        for k, v in label_cnt.items():
            print(f'{k}: {v}, 可解决: {label_resolvable_cnt[k]}，占比: {label_resolvable_cnt[k] / v * 100:.2f}%', file=file)

    # 新建文件夹
    os.makedirs(work_dir / 'data_collect_analysis' / 'bt_log', exist_ok=True)
    print_res_to_file(file=open(work_dir / 'data_collect_analysis' / 'bt_log' / f'{dataset_name}.log', 'w'))

for dir2analyze in dirs:
    analyze_edit_script(dir2analyze)


In [None]:
# 读取 log 文件，绘制结果
log_path = work_dir / "data_collect_analysis" / "bt_log" / "100+stars_sample.log"
out_dir = work_dir / "data_collect_analysis" / "bt_log"
# 类型分布: defaultdict(<class 'int'>, {'mixline': 4715, 'AB': 1485, 'B': 1590, 'A': 1969, 'newline': 3063, 'BA': 426, 'same modification, formatting maybe different': 68, 'O': 188})
# 可解决类型分布: defaultdict(<class 'int'>, {'AB': 1382, 'B': 1558, 'A': 1932, 'mixline': 3452, 'BA': 391, 'same modification, formatting maybe different': 67, 'O': 188})
dataset_name = log_path.stem.split('.')[0]

# 输入类型分布和可解决类型分布，绘制柱状图
def paint_bt_result(kind_counter, kind_resolvable, out_dir, dataset_name):
    import plotly.graph_objects as go
    fig = go.Figure()
    labels = list(kind_counter.keys())
    resolvable = [kind_resolvable[label] if label in kind_resolvable else 0 for label in labels]
    non_resolvable = [kind_counter[label] - (kind_resolvable[label] if label in kind_resolvable else 0) for label in labels]

    fig.add_trace(go.Bar(
        x=labels,
        y=resolvable,
        name='可解决'
    ))

    fig.add_trace(go.Bar(
        x=labels,
        y=non_resolvable,
        name='无法解决',
        base=resolvable
    ))

    fig.update_layout(
        barmode='stack',
        title=f'{dataset_name} 回溯上界统计',
        xaxis_title='冲突类型',
        yaxis_title='数量',
    )

    # 保存为 html 文件
    fig.write_html(out_dir / f'{dataset_name}_bt_result.html')

def read_bt_log(log_path):
    with open(log_path, 'r') as f:
        lines = f.readlines()
    for line in lines:
        if line.startswith('类型分布'):
            kind_counter = eval(line.split("类型分布: defaultdict(<class 'int'>, ")[1][:-2])            # 从字符串中提取 dict
        if line.startswith('可解决类型分布'):
            kind_resolvable = eval(line.split("可解决类型分布: defaultdict(<class 'int'>, ")[1][:-2])
    return kind_counter, kind_resolvable

kind_counter, kind_resolvable = read_bt_log(log_path)
paint_bt_result(kind_counter, kind_resolvable, out_dir, dataset_name)