注意到 `dataset.py` 中直接从 json 中读取了数据，但数据集并不符合这种格式，所以需要将数据集转换为 json 格式。

```py
# 复现代码中读取数据集的部分
            all_raw_base, all_raw_a, all_raw_b, all_raw_res = json.load(open('%s/raw_data'%(total_raw_data_path)))
```

所以我们需要将数据集转换为 json 格式：

```
[
    ["base1", "base2", "base3"],
    ["a1", "a2", "a3"],
    ["b1", "b2", "b3"],
    ["res1", "res2", "res3"]
]
```




In [1]:
data_dir = 'RAW_DATA/fse2022'
out_file = 'RAW_DATA/raw_data'
# walk 递归找到目录下的所有 json 文件
import os

def get_all_json_files(path):
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith('metadata.json'):
                yield os.path.join(root, file)

all_json_files = list(get_all_json_files(data_dir))
print(len(all_json_files))

# 读取 json 文件
import json
from tqdm import tqdm

o_contents = []
a_contents = []
b_contents = []
r_contents = []
for file in tqdm(all_json_files):
    with open(file, 'r') as f:
        data = json.load(f)
        for chunk in data['conflicting_chunks']:
            if chunk['res_region'] is None:
                continue
            o_contents.append(chunk['base_contents'])
            a_contents.append(chunk['a_contents'])
            b_contents.append(chunk['b_contents'])
            r_contents.append(chunk['res_region'])
    
assert len(o_contents) == len(a_contents) == len(b_contents) == len(r_contents)
print(len(o_contents))

json_arr = [
    o_contents,
    a_contents,
    b_contents,
    r_contents
]

# 把 json_arr 写入文件
with open(out_file, 'w') as f:
    json.dump(json_arr, f)

48785


100%|██████████| 48785/48785 [00:18<00:00, 2667.80it/s]


151426


# 自己收集的数据集 .json 转化为 raw_data

In [22]:
data_dir = '/root/projects/conflictManager/edit_script_resolver/train_and_infer/data/processed_data/recollect_without_min_bundle_without_file_content'
out_file = 'RAW_DATA/graphQL_raw_data_sample_20'


# 1. 列出 data_dir 下所有 xx.json 文件
def get_all_json_files(path):
    import os
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith('.json'):
                from pathlib import Path
                basename = Path(file).stem
                idx = basename.split('.')[0]
                yield (os.path.join(root, file), idx)

tuples = tuple(get_all_json_files(data_dir))

o_contents = []
a_contents = []
b_contents = []
r_contents = []


# 2. 读取 json 文件
from tqdm import tqdm
import json
for file, idx in tqdm(tuples, dynamic_ncols=True, desc='Reading json files', leave=False, position=0):
    if (int(idx) >= 20): continue
    with open(file, 'r') as f:
        cfs = json.load(f)
        for cf in tqdm(cfs, dynamic_ncols=True, desc='Reading conflict chunks', leave=False, position=1):
            for chunk in cf['conflict_chunks']:
                o_contents.append(chunk['o_content'])
                a_contents.append(chunk['a_content'])
                b_contents.append(chunk['b_content'])
                r_contents.append(chunk['r_content'])
    assert len(o_contents) == len(a_contents) == len(b_contents) == len(r_contents)

print(len(o_contents))
print(len(a_contents))
print(len(b_contents))
print(len(r_contents))

json_arr = [
    o_contents,
    a_contents,
    b_contents,
    r_contents
]

# 把 json_arr 写入文件
with open(out_file, 'w') as f:
    json.dump(json_arr, f)

                                                                   

358446
358446
358446
358446


# 看看 token_len 分布

In [4]:
# 找到所有符合这个模式的文件

# data_path = 'RAW_DATA/raw_data'
data_path = 'RAW_DATA/graphQL_raw_data_sample_20'

# 内容是 all_raw_base, all_raw_a, all_raw_b, all_raw_res = json.load(open(data_path, 'r'))
# 统计所有 inputs 和 outputs 的长度分布

import numpy as np
import json
import pickle
import os
from tqdm import tqdm   
from collections import defaultdict
from transformers import RobertaTokenizer, T5Model, T5ForConditionalGeneration, AdamW

# 模型类型设定为 CodeT5 的小模型
model_type = 'Salesforce/codet5-small'
local_path = './codet5/codet5-small'

# 初始化对应的分词器
# tokenizer = RobertaTokenizer.from_pretrained(model_type)
tokenizer = RobertaTokenizer.from_pretrained(local_path)


# inputs_lens = defaultdict(int)
# outputs_lens = defaultdict(int)
res_lens = defaultdict(int)

all_raw_base, all_raw_a, all_raw_b, all_raw_res = json.load(open(data_path, 'r'))

print(len(all_raw_base))
for raw_res in tqdm(all_raw_res):
    
    raw_res = ' '.join(raw_res.split())
    # 对 res 进行分词
    # 利用分词器对各版本代码进行分词
    tokens_res = tokenizer.tokenize(raw_res)
    ids_res = tokenizer.convert_tokens_to_ids(tokens_res)
    # 统计长度
    res_lens[len(ids_res) <= 200] += 1

print(res_lens)

151426


100%|██████████| 151426/151426 [04:03<00:00, 620.89it/s] 

defaultdict(<class 'int'>, {True: 138641, False: 12785})





# 看看 token 级别 的合并（input_txt） 会不会很诡异

In [12]:
import os
import subprocess
import tempfile
from transformers import RobertaTokenizer

# 初始化分词器并添加自定义的特殊 tokens
model_type = 'Salesforce/codet5-small'
local_path = './codet5/codet5-small'

tokenizer = RobertaTokenizer.from_pretrained(local_path)

# 自定义的特殊tokens，用于表示冲突时的括号分隔符
brackets_tokens = ['<lbra>', '<mbra>', '<rbra>']
succeed_num = tokenizer.add_tokens(brackets_tokens)
assert succeed_num == len(brackets_tokens), "Failed to add all special tokens."

class ConflictResolver:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.lbra_token = '<lbra>'
        self.rbra_token = '<rbra>'

    def clean_code(self, code_str):
        """
        清理代码字符串中的多余空格。
        """
        return ' '.join(code_str.split())

    def tokenize_code(self, code_str):
        """
        对代码字符串进行分词。
        """
        return self.tokenizer.tokenize(code_str)

    def execute_command(self, cmd):
        """
        执行shell命令的辅助函数。
        """
        result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        if result.returncode < 0:
            print("Error executing command:", cmd)
            print("Error message:", result.stderr)
            raise Exception("Command failed")
        return result.returncode, result.stdout

    def git_merge(self, tokens_base, tokens_a, tokens_b):
        """
        使用git merge-file命令对base、a、b三个版本进行三方合并，
        解析git产生的冲突标记并将结果转化为对应的token序列格式。
        """
        with tempfile.TemporaryDirectory() as merge_dir:
            base_path = os.path.join(merge_dir, 'base')
            a_path = os.path.join(merge_dir, 'a')
            b_path = os.path.join(merge_dir, 'b')
            merge_output_path = os.path.join(merge_dir, 'merge')

            # 将tokens写入临时文件中
            with open(base_path, 'w') as f:
                f.write('\n'.join(tokens_base))
            with open(a_path, 'w') as f:
                f.write('\n'.join(tokens_a))
            with open(b_path, 'w') as f:
                f.write('\n'.join(tokens_b))

            # 执行git merge-file命令
            merge_cmd = f'git merge-file -L a -L base -L b {a_path} {base_path} {b_path} --diff3 -p > {merge_output_path}'
            ret_code, ret_out = self.execute_command(merge_cmd)
            if ret_code > 0:
                print("%s conflicts occurred during merge." % ret_code)

            # 读取合并结果
            with open(merge_output_path, 'r') as f:
                merge_res = f.read().splitlines()
            merge_res = [x.strip() for x in merge_res if x.strip()]

            # 解析冲突标记行
            format_ids = [k for k, x in enumerate(merge_res) if x in ['<<<<<<< a', '||||||| base', '=======', '>>>>>>> b']]
            if len(format_ids) % 4 != 0:
                raise ValueError("Unexpected number of conflict markers.")

            final_tokens = []
            start = 0
            for k in range(0, len(format_ids), 4):
                assert (merge_res[format_ids[k]] == '<<<<<<< a' and 
                        merge_res[format_ids[k + 1]] == '||||||| base' and 
                        merge_res[format_ids[k + 2]] == '=======' and 
                        merge_res[format_ids[k + 3]] == '>>>>>>> b'), "Conflict markers mismatch."

                # 上下文部分
                context_tokens = merge_res[start:format_ids[k]]
                # 来自a版本的代码片段
                a_tokens = merge_res[format_ids[k] + 1:format_ids[k + 1]]
                # base版本的代码片段
                base_tokens = merge_res[format_ids[k + 1] + 1:format_ids[k + 2]]
                # 来自b版本的代码片段
                b_tokens = merge_res[format_ids[k + 2] + 1:format_ids[k + 3]]

                start = format_ids[k + 3] + 1

                # 添加到最终的token列表中，包含自定义的括号token和sep_token
                final_tokens += context_tokens + [self.lbra_token] + a_tokens + [self.tokenizer.sep_token] + base_tokens + [self.tokenizer.sep_token] + b_tokens + [self.rbra_token]

            # 处理剩余的尾部内容
            if start < len(merge_res):
                final_tokens += merge_res[start:]

            # 在最终序列的首尾加入bos和eos
            final_tokens = [self.tokenizer.bos_token] + final_tokens + [self.tokenizer.eos_token]

            return final_tokens

    def resolve_conflict(self, base, a, b):
        """
        主函数：处理输入的base、a、b代码并返回合并后的token序列。
        """
        # 清理代码
        base_clean = self.clean_code(base)
        a_clean = self.clean_code(a)
        b_clean = self.clean_code(b)

        # 分词
        tokens_base = self.tokenize_code(base_clean)
        tokens_a = self.tokenize_code(a_clean)
        tokens_b = self.tokenize_code(b_clean)

        # 执行合并
        merged_tokens = self.git_merge(tokens_base, tokens_a, tokens_b)

        return merged_tokens

def main():
    # 示例代码输入
    # base_code = """
    #     log_denominator_n = logsumexp(jf_k - ju_kn.T, b=jNN_k, axis=1)
    #     log_numerator_k_= logsumexp(-log_denominator_n - ju_kn, axis=1)
    #     return -1 * jN_k * (1.0 - npj.exp(jf_k + log_numerator_k))
    # jit_mbar_gradient = jax.jit(jax_mbar_gradient)
    # """

    # a_code = """
    #     return -1 * jNk * (1.0 - jnp.exp(f_k + log_numerateor_k)
    # jit_mbar_gradient= jax.jit(jax_mbar_gradient)
    # """

    # b_code = """
    #     return -1 * jNk * (1.0 - jpn.exp(f_k + log_numerator_k))
    # """
    
    base_code = """
        bool m_showFirstRun = false;
        bool m_checkssl     = true;
        bool m_vsModeActive = false;
    """

    a_code = """
        QScopedPointer<QWebSocketServer> m_webChannelServeer;
        uint32_t m_webChannelPort;
        QScopedPointer<QWebChannel> m_webChannel;
        bool m_showFirstRun = false;
        bool m_checkssl
        = true;
        bool m_vsModeActive = false;
    """

    b_code = """
    private:
        VsQuickView m_view;
        VsServerinfo m_currentStudio;
        QScopedPointer<JackTrip> m_jackTrip;
        QSharedPointer<QJackTrip> m_standardWindow;
        QScopedPointer<QNetworkAccessManager> m_networkAccessManager;
        QScopedPointer<VsAuth> m_auth;
        QScopedPointer<VsApi> m_api;
        QScopedPointer<VsDevice> m_devicePtr;
        QScopedPointer<VsWebSocket> m_studioSocketPtr;
        QScopedPointer<VsAudio> m_audioConfigPtr;
        QScopedPointer<QThread> m_audioConfigThread;
        QVector<VsServerInfoPointer> m_servers;
        QVector<VsServerinfo*> m_serverModel; //< qml doesn'tlike sm
        QMap<QString, bool> m_subscribedServers;
        QJsonObject m_regions;
        QJsonObject m_userMetadata;
        QJsonObject m_networkStats;
        QTimer m_startTimer;
        QTimer m_refreshTimer;
        QTimer m_heartbeatTimer;
        QTimer m_networkOutageTimer;
        QMutex m_refreshMutex;
        QUrl m_studioToJoin;
    """

    resolver = ConflictResolver(tokenizer)
    merged_tokens = resolver.resolve_conflict(base_code, a_code, b_code)

    # 打印每个token一行
    print("Merged Tokens:")
    for token in merged_tokens:
        print(token)

main()

1 conflicts occurred during merge.
Merged Tokens:
<s>
<lbra>
Q
Scoped
Pointer
<
Q
WebSocket
Server
>
Ġm
_
web
Channel
Serve
er
;
Ġuint
32
_
t
Ġm
_
web
Channel
Port
;
ĠQ
Scoped
Pointer
<
Q
Web
Channel
>
Ġm
_
web
Channel
;
Ġbool
</s>
bool
</s>
private
:
ĠV
s
Quick
View
<rbra>
Ġm
_
view
;
ĠV
s
Server
info
Ġm
_
current
St
udio
;
ĠQ
Scoped
Pointer
<
J
ack
Trip
>
Ġm
_
j
ack
Trip
;
ĠQ
Shared
Pointer
<
Q
J
ack
Trip
>
Ġm
_
standard
Window
;
ĠQ
Scoped
Pointer
<
Q
Network
Access
Manager
>
Ġm
_
network
Access
Manager
;
ĠQ
Scoped
Pointer
<
Vs
Auth
>
Ġm
_
auth
;
ĠQ
Scoped
Pointer
<
Vs
Api
>
Ġm
_
api
;
ĠQ
Scoped
Pointer
<
Vs
Device
>
Ġm
_
device
Ptr
;
ĠQ
Scoped
Pointer
<
Vs
WebSocket
>
Ġm
_
st
udio
Socket
Ptr
;
ĠQ
Scoped
Pointer
<
Vs
Audio
>
Ġm
_
audio
Config
Ptr
;
ĠQ
Scoped
Pointer
<
Q
Thread
>
Ġm
_
audio
Config
Thread
;
ĠQ
Vector
<
Vs
Server
Info
Pointer
>
Ġm
_
servers
;
ĠQ
Vector
<
Vs
Server
info
*
>
Ġm
_
server
Model
;
Ġ//
<
Ġq
ml
Ġdoesn
't
like
Ġsm
ĠQ
Map
<
Q
String
,
Ġbool
>
Ġm
_
subscribed
Ser

# infer

In [16]:
import os
import torch
import torch.nn as nn
from transformers import RobertaTokenizer, T5ForConditionalGeneration
from collections import namedtuple

# 定义 dotdict 类，允许通过属性访问字典内容
class dotdict(dict):
    def __getattr__(self, name):
        return self[name]

# 定义 MergeT5 类，继承自 nn.Module
class MergeT5(nn.Module):
    def __init__(self, args):
        super(MergeT5, self).__init__()
        # 使用预训练的 T5ForConditionalGeneration 模型
        self.t5 = T5ForConditionalGeneration.from_pretrained(args.model_type)
        # 调整词表大小，包含新增的自定义 tokens
        self.t5.resize_token_embeddings(len(args.tokenizer))
        self.embedding_dim = self.t5.config.hidden_size

    def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, **kwargs):
        return self.t5(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            **kwargs
        )

# 配置参数
args = dotdict({
    'model_type': './codet5/codet5-small',  # 模型路径
    'max_conflict_length': 500,
    'max_resolve_length': 200,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
})

# 初始化分词器并添加自定义的特殊 tokens
tokenizer = RobertaTokenizer.from_pretrained(args.model_type)
brackets_tokens = ['<lbra>', '<mbra>', '<rbra>']
succeed_num = tokenizer.add_tokens(brackets_tokens)
assert succeed_num == len(brackets_tokens), "Failed to add all special tokens."
args.tokenizer = tokenizer  # 将 tokenizer 添加到 args 中

# 初始化模型
model = MergeT5(args)
model.to(args.device)

# 加载训练好的模型权重
model_path = 'back/best_model.pt'  # 确保路径正确


if not os.path.exists(model_path):
    raise FileNotFoundError(f"Model file not found at {model_path}")
model.load_state_dict(torch.load(model_path, map_location=args.device))
model.eval()
print("模型加载完成！")


# 定义冲突解决函数
def resolve_conflict(model, tokenizer, base_code, a_code, b_code, args, beam_num=3):
    """
    生成代码合并冲突的解决方案。

    参数:
    - model: 已加载的 MergeT5 模型
    - tokenizer: 分词器
    - base_code: base 分支的代码字符串
    - a_code: a 分支的代码字符串
    - b_code: b 分支的代码字符串
    - args: 配置参数
    - beam_num: beam search 的宽度

    返回:
    - resolved_code: 解决冲突后的代码字符串
    """
    # 清理代码（移除多余空格）
    def clean_code(code_str):
        return ' '.join(code_str.split())

    base_clean = clean_code(base_code)
    a_clean = clean_code(a_code)
    b_clean = clean_code(b_code)

    # 分词
    tokens_base = tokenizer.tokenize(base_clean)
    tokens_a = tokenizer.tokenize(a_clean)
    tokens_b = tokenizer.tokenize(b_clean)

    # 构造合并后的 token 序列，包含特殊的括号 tokens
    # 格式：<s> context <lbra> a_code <sep> base_code <sep> b_code <rbra> </s>
    merged_tokens = [tokenizer.bos_token] + tokens_base + [tokenizer.sep_token, '<lbra>'] + tokens_a + [tokenizer.sep_token] + tokens_b + [tokenizer.sep_token, '<rbra>', tokenizer.eos_token]
    
    # 转换为 token ids
    input_ids = tokenizer.convert_tokens_to_ids(merged_tokens)
    input_ids = torch.tensor([input_ids]).to(args.device)

    # 创建 attention mask
    attention_mask = (input_ids != tokenizer.pad_token_id).long()

    # 生成输出
    with torch.no_grad():
        outputs = model.t5.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            num_beams=beam_num,
            max_length=args.max_resolve_length,
            early_stopping=True,
            no_repeat_ngram_size=2
        )

    # 解码输出
    resolved_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return resolved_code

# 示例代码输入
base_code = """
    bool m_showFirstRun = false;
    bool m_checkssl     = true;
    bool m_vsModeActive = false;
"""

a_code = """
    QScopedPointer<QWebSocketServer> m_webChannelServeer;
    uint32_t m_webChannelPort;
    QScopedPointer<QWebChannel> m_webChannel;
    bool m_showFirstRun = false;
    bool m_checkssl
    = true;
    bool m_vsModeActive = false;
"""

b_code = """
private:
    VsQuickView m_view;
    VsServerinfo m_currentStudio;
    QScopedPointer<JackTrip> m_jackTrip;
    QSharedPointer<QJackTrip> m_standardWindow;
    QScopedPointer<QNetworkAccessManager> m_networkAccessManager;
    QScopedPointer<VsAuth> m_auth;
    QScopedPointer<VsApi> m_api;
    QScopedPointer<VsDevice> m_devicePtr;
    QScopedPointer<VsWebSocket> m_studioSocketPtr;
    QScopedPointer<VsAudio> m_audioConfigPtr;
    QScopedPointer<QThread> m_audioConfigThread;
    QVector<VsServerInfoPointer> m_servers;
    QVector<VsServerinfo*> m_serverModel; //< qml doesn'tlike sm
    QMap<QString, bool> m_subscribedServers;
    QJsonObject m_regions;
    QJsonObject m_userMetadata;
    QJsonObject m_networkStats;
    QTimer m_startTimer;
    QTimer m_refreshTimer;
    QTimer m_heartbeatTimer;
    QTimer m_networkOutageTimer;
    QMutex m_refreshMutex;
    QUrl m_studioToJoin;
"""

# 生成解决方案
resolved_code = resolve_conflict(model, tokenizer, base_code, a_code, b_code, args, beam_num=3)

# 输出结果
print("Resolved Code:")
print(resolved_code)

  model.load_state_dict(torch.load(model_path, map_location=args.device))


模型加载完成！
Resolved Code:
bool m_showFirstRun = false; bool m2checkssl = true; QScopedPointer<QWebSocketServer> m _webChannelServeer; uint32_t m_%sLastRunTime; VsServerinfo mCurrentEntryPoint; SvgData m0; switch (m_currentStudio) { case VSvgView: V s_view.getViewPointByChannel; break; case VK_USER_DO_NOT_USE_OR_YOU_WILL_BE_FIRED: return; } } /** * This method is neccessary to join a stream. * * @param stream the stream to start from. */ public static QStream Stream(VsQuickView stream) : base(stream); } public void start() { super.start(); end(); } @Override protected void onJoin(){ super._onJoin(); long startTime = 0; long lastFrameSize = System.currentTimeMillis(); V
