In [3]:
import pandas as pd
import numpy as np
import json
import os

In [5]:
df_cat = pd.read_excel('./result/train_cat1_pred.xlsx')
df_acc = pd.read_csv('./result/score_train_JudgeAccusation.csv')
df_res = pd.read_csv('./result/score_train_JudgeResult.csv')
df_rsn = pd.read_csv('./result/score_train_JudgeReason.csv')

In [112]:
# config
TRAIN_DIR = './wenshu_ms_dataset/train/'
TEST_DIR = './wenshu_ms_dataset/test/'
DEV_DIR = './wenshu_ms_dataset/dev/'
RESULT_DIR = './result/'

TRAIN_FILES = sorted(os.listdir(TRAIN_DIR), key=lambda x:int(x[:-5]))
TEST_FILES = sorted(os.listdir(TEST_DIR), key=lambda x:int(x[:-5]))
DEV_FILES = sorted(os.listdir(DEV_DIR), key=lambda x:int(x[:-5]))


In [3]:
import numpy as np
import math, json

def cal_F1_k(ranked_list: list, k: int) -> float:
    """
    计算F1 score
    k: 截断到第k个，从1计数
    """
    count, N_D = sum(ranked_list[:k]), sum(ranked_list)  # 统计list中label为1数量
    p_k = count / k
    r_k = (count / N_D) if N_D != 0 else 0
    return (2 * p_k * r_k / (p_k + r_k)) if p_k + r_k != 0 else 0
def cal_DCG_k(ranked_list: list, k: int, penalty=-1) -> float:
    """
    计算DCG
    """
    value = 0
    for i in range(k):
        value += (1 / math.log(i + 2, 2)) if ranked_list[i] else (penalty / math.log(i + 2, 2))  # i从0开始
    return value
class LegalRerankingEvaluator():
    """
    This class evaluates a SentenceTransformer model for the task of re-ranking.
    """
    def __init__(self,
                 name: str = 'wenshu_ms_dataset',
                 corpus_chunk_size: int = 100,
                 show_progress_bar: bool = True,
                 mrr_at_k = [10, 20 ,30 ,100],
                 ndcg_at_k = [10, 20 ,30 ,100],
                 accuracy_at_k = [1, 3, 5, 10, 50, 100],
                 precision_recall_at_k = [1, 3, 5, 10, 50, 100],
                 map_at_k = [10, 100],
                 oracle_at_k = [100],
                 ):
        self.name = name
        self.mrr_at_k = mrr_at_k
        self.show_progress_bar = show_progress_bar
        self.corpus_chunk_size = corpus_chunk_size

        self.ndcg_at_k = ndcg_at_k
        self.accuracy_at_k = accuracy_at_k
        self.precision_recall_at_k = precision_recall_at_k
        self.map_at_k = map_at_k
        self.oracle_at_k = oracle_at_k

    def compute_metrices_from_json(self, rank_lists_path, gt_path=r'wenshu_ms_dataset_all_gt.json'):
        """
        Embeds every (query, positive, negative) tuple individually.
        Is slower than the batched version, but saves memory as only the
        embeddings for one tuple are needed. Useful when you have
        a really large test set
        """
        try:
            with open(gt_path, 'r', encoding='utf-8') as f:
                gt_dict = json.load(f)
                f.close()
        except IOError:
            raise Exception('The gt file is illegal')
        try:
            with open(rank_lists_path, 'r', encoding='utf-8') as f:
                rank_dict = json.load(f)
                f.close()
        except IOError:
            raise Exception('The json file is illegal')

        print("LegalRerankingEvaluator: Evaluating result on " + self.name + "")
        #Compute scores
        scores = self.compute_metrics_from_json(rank_dict, gt_dict)

        #Output
        print("Dataset: {}".format(self.name))
        self.output_scores(scores)

        return scores

    def output_scores(self, scores):
        for k in scores['accuracy@k']:
            print("Accuracy@{}: {:.2f}%".format(k, scores['accuracy@k'][k]*100))

        for k in scores['precision@k']:
            print("Precision@{}: {:.2f}%".format(k, scores['precision@k'][k]*100))

        for k in scores['recall@k']:
            print("Recall@{}: {:.2f}%".format(k, scores['recall@k'][k]*100))

        for k in scores['mrr@k']:
            print("MRR@{}: {:.4f}".format(k, scores['mrr@k'][k]))

        for k in scores['ndcg@k']:
            print("NDCG@{}: {:.4f}".format(k, scores['ndcg@k'][k]))

        for k in scores['map@k']:
            print("MAP@{}: {:.4f}".format(k, scores['map@k'][k]))

        for k in scores['oracle@f1']:
            print("ORACLE@F1: {:.4f}".format(scores['oracle@f1'][k]))

        for k in scores['oracle@dcg']:
            print("ORACLE@DCG: {:.4f}".format(scores['oracle@dcg'][k]))

    def compute_metrics_from_json(self, rank_dict, gt_dict):
        # Init score computation values
        num_hits_at_k = {k: 0 for k in self.accuracy_at_k}
        precisions_at_k = {k: [] for k in self.precision_recall_at_k}
        recall_at_k = {k: [] for k in self.precision_recall_at_k}
        MRR = {k: 0 for k in self.mrr_at_k}
        ndcg = {k: [] for k in self.ndcg_at_k}
        AveP_at_k = {k: [] for k in self.map_at_k}
        Oracle_F1 = {k: [] for k in self.oracle_at_k}
        Oracle_DCG = {k: [] for k in self.oracle_at_k}

        # Compute scores on results
        for query_key in rank_dict.keys():
            # Get rank
            cur_rank = rank_dict[query_key]
            top_hits = cur_rank
            query_relevant_docs = gt_dict[query_key]['gt_idx']

            # Accuracy@k - We count the result correct, if at least one relevant doc is accross the top-k documents
            for k_val in self.accuracy_at_k:
                for hit in top_hits[0:k_val]:
                    if hit in query_relevant_docs:
                        num_hits_at_k[k_val] += 1
                        break

            # Precision and Recall@k
            for k_val in self.precision_recall_at_k:
                num_correct = 0
                for hit in top_hits[0:k_val]:
                    if hit in query_relevant_docs:
                        num_correct += 1

                precisions_at_k[k_val].append(num_correct / k_val)
                recall_at_k[k_val].append(num_correct / len(query_relevant_docs))

            # MRR@k
            for k_val in self.mrr_at_k:
                for rank, hit in enumerate(top_hits[0:k_val]):
                    if hit in query_relevant_docs:
                        MRR[k_val] += 1.0 / (rank + 1)
                        break

            # NDCG@k
            for k_val in self.ndcg_at_k:
                predicted_relevance = [1 if top_hit in query_relevant_docs else 0 for top_hit in top_hits[0:k_val]]
                true_relevances = [1] * len(query_relevant_docs)

                ndcg_value = self.compute_dcg_at_k(predicted_relevance, k_val) / self.compute_dcg_at_k(true_relevances, k_val)
                ndcg[k_val].append(ndcg_value)

            # MAP@k
            for k_val in self.map_at_k:
                num_correct = 0
                sum_precisions = 0

                for rank, hit in enumerate(top_hits[0:k_val]):
                    if hit in query_relevant_docs:
                        num_correct += 1
                        sum_precisions += num_correct / (rank + 1)

                avg_precision = sum_precisions / min(k_val, len(query_relevant_docs))
                AveP_at_k[k_val].append(avg_precision)
            # Oracle@k
            for k_val in self.oracle_at_k:
                rel = [1 if hit in query_relevant_docs else 0 for hit in top_hits[0:k_val]]
                per_k_F1, per_k_DCG = [0], [0]
                for i in range(1, self.corpus_chunk_size + 1):  # 100
                    per_k_F1.append(cal_F1_k(rel, i))
                    per_k_DCG.append(cal_DCG_k(rel, i))
                # F1_best, DCG_best = np.max(np.array(F1_k), axis=1), np.max(np.array(DCG_k), axis=1)
                Oracle_F1[k_val].append(per_k_F1)
                Oracle_DCG[k_val].append(per_k_DCG)

        # Compute averages
        for k in num_hits_at_k:
            num_hits_at_k[k] /= len(rank_dict)

        for k in precisions_at_k:
            precisions_at_k[k] = np.mean(precisions_at_k[k])

        for k in recall_at_k:
            recall_at_k[k] = np.mean(recall_at_k[k])

        for k in ndcg:
            ndcg[k] = np.mean(ndcg[k])

        for k in MRR:
            MRR[k] /= len(rank_dict)

        for k in AveP_at_k:
            AveP_at_k[k] = np.mean(AveP_at_k[k])

        for k in Oracle_F1:
            F1_best = np.max(np.array(Oracle_F1[k]), axis=1)
            Oracle_F1[k] = np.mean(F1_best)
        for k in Oracle_DCG:
            DCG_best = np.max(np.array(Oracle_DCG[k]), axis=1)
            Oracle_DCG[k] = np.mean(DCG_best)

        return {'accuracy@k': num_hits_at_k, 'precision@k': precisions_at_k, 'recall@k': recall_at_k, 'ndcg@k': ndcg, 'mrr@k': MRR, 'map@k': AveP_at_k, 'oracle@f1': Oracle_F1, 'oracle@dcg': Oracle_DCG}

    @staticmethod
    def compute_dcg_at_k(relevances, k):
        dcg = 0
        for i in range(min(len(relevances), k)):
            dcg += relevances[i] / np.log2(i + 2)  #+2 as we start our idx at 0
        return dcg


# gt_idx for train

In [41]:
with open('./wenshu_ms_dataset/wenshu_ms_dataset_train_test_gt.json', 'r') as f:
    js_dt = json.load(f)
    print(js_dt['0'])

{'q_id': 0, 'gt_idx': [4, 28, 32, 37, 51, 57, 63, 98]}


# cat_match

In [90]:
from sklearn.preprocessing import MinMaxScaler
df_cat = pd.read_excel('./result/train_cat1_pred.xlsx')
Scaler = MinMaxScaler(feature_range=(0, 1))
dt = Scaler.fit_transform(df_cat.iloc[:, 2:-1].values.T).T
df_cat.iloc[:, 2:-1] = dt
df_cat

Unnamed: 0,q_id,cat_1_pred,交通事故,公司事务,劳动人事,合同事务,婚姻家庭,建筑工程,房地产纠纷,民事纠纷,知识产权,金融证券保险,query
0,0,公司事务,0.146412,1.000000,0.495448,0.869694,0.526065,0.482255,0.667435,0.640126,0.000000,0.597047,原告彭正坤诉称，2018年1月29日，原告彭正坤与三名被告及本案第三人签订了《股权转让协议书...
1,1,民事纠纷,0.760179,0.656066,0.828115,0.893410,0.809173,0.535864,0.657041,1.000000,0.000000,0.847644,白淑坤向本院提出诉讼请求：1．判令魏金保赔偿各项损失21298.32元；2．诉讼费由魏金保负...
2,2,合同事务,0.382616,0.961299,0.665290,1.000000,0.481345,0.819117,0.704936,0.698347,0.000000,0.526249,原告吴学志向本院提出诉讼请求：1.被告退还原告欠款55000元；2.本案诉讼费用由被告承担。...
3,3,民事纠纷,0.507596,0.089327,0.349861,0.794022,0.767793,0.355061,0.849590,1.000000,0.000000,0.348868,原告诉称，2009年10月10日，原告经青岛市市南区民政局办理了收养邱程程（曾用名邱程仪，身...
4,4,民事纠纷,0.372916,0.312180,0.459128,0.690531,0.364521,0.414519,0.704397,1.000000,0.000000,0.365501,原告高兵向本院提出如下诉讼请求：1、判令被告撤离占用的原、被告之间的公用面积；2、本案诉讼费...
...,...,...,...,...,...,...,...,...,...,...,...,...,...
10676,10676,合同事务,0.588956,0.513407,0.729457,1.000000,0.505690,0.702048,0.730994,0.847926,0.000000,0.540738,原告邱前诉称，自2008年7月18日起，原告陆续给被告送小麦，由被告给原告加工面粉，截止20...
10677,10677,建筑工程,0.398588,0.692087,0.771277,0.936509,0.521309,1.000000,0.704466,0.748660,0.000000,0.666905,腾飞公司向本院提出诉讼请求：1.依法判令被告给付原告工程款暂计943351.81元，并支付利...
10678,10678,婚姻家庭,0.636904,0.552388,0.689808,0.743712,1.000000,0.676500,0.882997,0.848519,0.000000,0.548753,原告赵某某诉称：原告赵某某系被告陈某甲的母亲，原告因年迈、体弱多病，不能照顾自己，且无其他生...
10679,10679,合同事务,0.459772,0.132413,0.054331,1.000000,0.000000,0.243508,0.097008,0.529157,0.061518,0.218855,原告上海龙沿控股集团有限公司向本院提出诉讼请求：判令被告归还借款人民币（币种下同）300万元...


In [91]:
df_cat_dt = df_cat.iloc[:, 2:-1]
df_cat_dt

Unnamed: 0,交通事故,公司事务,劳动人事,合同事务,婚姻家庭,建筑工程,房地产纠纷,民事纠纷,知识产权,金融证券保险
0,0.146412,1.000000,0.495448,0.869694,0.526065,0.482255,0.667435,0.640126,0.000000,0.597047
1,0.760179,0.656066,0.828115,0.893410,0.809173,0.535864,0.657041,1.000000,0.000000,0.847644
2,0.382616,0.961299,0.665290,1.000000,0.481345,0.819117,0.704936,0.698347,0.000000,0.526249
3,0.507596,0.089327,0.349861,0.794022,0.767793,0.355061,0.849590,1.000000,0.000000,0.348868
4,0.372916,0.312180,0.459128,0.690531,0.364521,0.414519,0.704397,1.000000,0.000000,0.365501
...,...,...,...,...,...,...,...,...,...,...
10676,0.588956,0.513407,0.729457,1.000000,0.505690,0.702048,0.730994,0.847926,0.000000,0.540738
10677,0.398588,0.692087,0.771277,0.936509,0.521309,1.000000,0.704466,0.748660,0.000000,0.666905
10678,0.636904,0.552388,0.689808,0.743712,1.000000,0.676500,0.882997,0.848519,0.000000,0.548753
10679,0.459772,0.132413,0.054331,1.000000,0.000000,0.243508,0.097008,0.529157,0.061518,0.218855


In [116]:
data = []
# train_data
for k, file in enumerate(TRAIN_FILES):
    with open(TRAIN_DIR + file, 'r', encoding='utf-8') as f:
        js_dt = json.load(f)
        for idx, case in js_dt['ctxs'].items():
            cat1 = case['Category']['cat_1']
            data.append([cat1])

# test_data
for k, file in enumerate(TEST_FILES):
    with open(TEST_DIR + file, 'r', encoding='utf-8') as f:
        js_dt = json.load(f)
        for idx, case in js_dt['ctxs'].items():
            cat1 = case['Category']['cat_1']
            data.append([cat1])
data

[['婚姻家庭'],
 ['建筑工程'],
 ['合同事务'],
 ['合同事务'],
 ['公司事务'],
 ['婚姻家庭'],
 ['合同事务'],
 ['知识产权'],
 ['合同事务'],
 ['合同事务'],
 ['金融证券保险'],
 ['婚姻家庭'],
 ['民事纠纷'],
 ['劳动人事'],
 ['房地产纠纷'],
 ['合同事务'],
 ['劳动人事'],
 ['劳动人事'],
 ['合同事务'],
 ['建筑工程'],
 ['合同事务'],
 ['民事纠纷'],
 ['民事纠纷'],
 ['金融证券保险'],
 ['婚姻家庭'],
 ['建筑工程'],
 ['房地产纠纷'],
 ['劳动人事'],
 ['公司事务'],
 ['民事纠纷'],
 ['知识产权'],
 ['合同事务'],
 ['公司事务'],
 ['公司事务'],
 ['合同事务'],
 ['合同事务'],
 ['合同事务'],
 ['公司事务'],
 ['合同事务'],
 ['婚姻家庭'],
 ['房地产纠纷'],
 ['合同事务'],
 ['民事纠纷'],
 ['合同事务'],
 ['合同事务'],
 ['合同事务'],
 ['房地产纠纷'],
 ['合同事务'],
 ['合同事务'],
 ['合同事务'],
 ['合同事务'],
 ['公司事务'],
 ['合同事务'],
 ['建筑工程'],
 ['合同事务'],
 ['民事纠纷'],
 ['民事纠纷'],
 ['公司事务'],
 ['婚姻家庭'],
 ['合同事务'],
 ['民事纠纷'],
 ['婚姻家庭'],
 ['婚姻家庭'],
 ['公司事务'],
 ['房地产纠纷'],
 ['合同事务'],
 ['合同事务'],
 ['合同事务'],
 ['房地产纠纷'],
 ['合同事务'],
 ['民事纠纷'],
 ['金融证券保险'],
 ['婚姻家庭'],
 ['合同事务'],
 ['金融证券保险'],
 ['婚姻家庭'],
 ['婚姻家庭'],
 ['公司事务'],
 ['合同事务'],
 ['建筑工程'],
 ['合同事务'],
 ['合同事务'],
 ['房地产纠纷'],
 ['知识产权'],
 ['婚姻家庭'],
 ['合同事务'],
 ['房地产纠纷'],
 ['合同事务'],
 ['合同事务'],
 ['婚姻

In [117]:
ctxs_cat = np.array(data).reshape((-1, 100))
ctxs_cat[2]

array(['婚姻家庭', '公司事务', '婚姻家庭', '建筑工程', '民事纠纷', '房地产纠纷', '金融证券保险', '公司事务',
       '民事纠纷', '合同事务', '房地产纠纷', '建筑工程', '公司事务', '房地产纠纷', '婚姻家庭', '合同事务',
       '合同事务', '合同事务', '合同事务', '房地产纠纷', '民事纠纷', '婚姻家庭', '知识产权', '婚姻家庭',
       '合同事务', '交通事故', '民事纠纷', '合同事务', '婚姻家庭', '房地产纠纷', '交通事故', '建筑工程',
       '民事纠纷', '劳动人事', '婚姻家庭', '民事纠纷', '劳动人事', '合同事务', '劳动人事', '合同事务',
       '公司事务', '房地产纠纷', '合同事务', '金融证券保险', '合同事务', '合同事务', '合同事务', '合同事务',
       '合同事务', '合同事务', '婚姻家庭', '民事纠纷', '房地产纠纷', '公司事务', '民事纠纷', '知识产权',
       '合同事务', '金融证券保险', '合同事务', '合同事务', '民事纠纷', '婚姻家庭', '合同事务', '合同事务',
       '合同事务', '合同事务', '合同事务', '婚姻家庭', '合同事务', '合同事务', '民事纠纷', '金融证券保险',
       '合同事务', '房地产纠纷', '合同事务', '合同事务', '合同事务', '金融证券保险', '婚姻家庭', '房地产纠纷',
       '民事纠纷', '合同事务', '合同事务', '合同事务', '合同事务', '房地产纠纷', '劳动人事', '房地产纠纷',
       '劳动人事', '婚姻家庭', '合同事务', '合同事务', '合同事务', '合同事务', '合同事务', '建筑工程',
       '建筑工程', '公司事务', '合同事务', '民事纠纷'], dtype='<U6')

In [118]:
ctxs_cat_dt = np.zeros(ctxs_cat.shape)
for i in range(ctxs_cat.shape[0]):
    for j in range(ctxs_cat.shape[1]):
        ctxs_cat_dt[i, j] = df_cat_dt[ctxs_cat[i, j]][i]
ctxs_cat_dt

array([[0.5260652 , 0.48225466, 0.86969392, ..., 0.64012592, 1.        ,
        0.66743489],
       [0.89341039, 0.53586439, 0.        , ..., 1.        , 1.        ,
        0.82811489],
       [0.48134541, 0.96129865, 0.48134541, ..., 0.96129865, 1.        ,
        0.69834666],
       ...,
       [0.74371154, 0.74371154, 0.68980804, ..., 1.        , 0.63690366,
        0.68980804],
       [0.24350776, 0.        , 1.        , ..., 0.13241285, 0.09700764,
        1.        ],
       [1.        , 0.87435225, 0.67542077, ..., 0.87435225, 0.58948802,
        0.87435225]])

In [119]:
df_res = pd.DataFrame(ctxs_cat_dt)
df_res.insert(0, 'q_id', df_cat['q_id'])
df_res.to_csv(RESULT_DIR+'score_train_cat1.csv', encoding='utf_8_sig', index=None)

In [120]:
df_cat = pd.read_excel('./result/dev_cat1_pred.xlsx')
from sklearn.preprocessing import MinMaxScaler
Scaler = MinMaxScaler(feature_range=(0, 1))
dt = Scaler.fit_transform(df_cat.iloc[:, 2:-1].values.T).T
df_cat.iloc[:, 2:-1] = dt
df_cat_dt = df_cat.iloc[:, 2:-1]
df_cat_dt

Unnamed: 0,交通事故,公司事务,劳动人事,合同事务,婚姻家庭,建筑工程,房地产纠纷,民事纠纷,知识产权,金融证券保险
0,0.392102,0.221442,0.453189,0.618936,0.331105,0.513014,0.547495,1.000000,0.000000,0.534202
1,0.438685,0.545559,0.561261,1.000000,0.393816,0.693116,0.603534,0.615420,0.000000,0.568044
2,0.427748,0.000000,0.617740,1.000000,0.384591,0.559591,0.891498,0.840656,0.262989,0.700733
3,0.515172,0.660417,1.000000,0.906252,0.517865,0.773846,0.528225,0.746158,0.000000,0.448921
4,0.394968,0.553886,1.000000,0.771587,0.378905,0.516321,0.511793,0.600955,0.000000,0.493949
...,...,...,...,...,...,...,...,...,...,...
4495,0.553007,0.586377,0.628944,0.696582,1.000000,0.651553,0.731502,0.693903,0.000000,0.597567
4496,0.397919,0.172637,0.303019,0.752364,0.000000,0.430884,0.299387,1.000000,0.698639,0.214531
4497,0.270436,0.440258,0.506187,1.000000,0.375704,0.501445,0.521392,0.619299,0.000000,0.451974
4498,0.468736,0.579557,0.706967,1.000000,0.453236,0.713956,0.608611,0.676820,0.000000,0.597242


In [121]:
data = []
# test_data
for k, file in enumerate(DEV_FILES):
    with open(DEV_DIR + file, 'r', encoding='utf-8') as f:
        js_dt = json.load(f)
        for idx, case in js_dt['ctxs'].items():
            cat1 = case['Category']['cat_1']
            data.append([cat1])

In [122]:
ctxs_cat = np.array(data).reshape((-1, 100))
ctxs_cat

array([['民事纠纷', '民事纠纷', '民事纠纷', ..., '合同事务', '合同事务', '合同事务'],
       ['婚姻家庭', '建筑工程', '婚姻家庭', ..., '劳动人事', '民事纠纷', '合同事务'],
       ['合同事务', '婚姻家庭', '房地产纠纷', ..., '劳动人事', '合同事务', '金融证券保险'],
       ...,
       ['婚姻家庭', '民事纠纷', '民事纠纷', ..., '建筑工程', '房地产纠纷', '金融证券保险'],
       ['合同事务', '合同事务', '合同事务', ..., '知识产权', '房地产纠纷', '民事纠纷'],
       ['合同事务', '合同事务', '合同事务', ..., '合同事务', '合同事务', '公司事务']], dtype='<U6')

In [124]:
ctxs_cat_dt = np.zeros(ctxs_cat.shape)
for i in range(ctxs_cat.shape[0]):
    for j in range(ctxs_cat.shape[1]):
        ctxs_cat_dt[i, j] = df_cat_dt[ctxs_cat[i, j]][i]
ctxs_cat_dt

array([[1.        , 1.        , 1.        , ..., 0.61893573, 0.61893573,
        0.61893573],
       [0.39381558, 0.69311604, 0.39381558, ..., 0.56126139, 0.61542001,
        1.        ],
       [1.        , 0.38459118, 0.89149835, ..., 0.61774038, 1.        ,
        0.70073348],
       ...,
       [0.37570364, 0.61929882, 0.61929882, ..., 0.50144459, 0.52139227,
        0.45197415],
       [1.        , 1.        , 1.        , ..., 0.        , 0.60861122,
        0.67682031],
       [1.        , 1.        , 1.        , ..., 1.        , 1.        ,
        0.60230238]])

In [125]:
df_res = pd.DataFrame(ctxs_cat_dt)
df_res.insert(0, 'q_id', df_cat['q_id'])
df_res.to_csv(RESULT_DIR+'score_dev_cat1.csv', encoding='utf_8_sig', index=None)

# v1

In [6]:
CLS_DATA_DIR = './cls_dataset/'

In [7]:
def get_resdf(k):
    df_idx = pd.read_csv(CLS_DATA_DIR+'g{}/dev_pred_ori.csv'.format(k))
    res_lst = []
    if k < 2:
        for i in range(df_idx.shape[0]):
            res_lst.append([1, 1, 1])
        df_pred = pd.DataFrame(res_lst, columns = ['0', '1', 'label'])
        df_res = pd.concat([df_idx, df_pred], axis=1)
        return df_res
    if k != 3:
        with open(CLS_DATA_DIR+'g{}/results.txt'.format(k), 'r') as f:
            for line in f.readlines():
                res_lst.append(line.strip().split(', '))
        df_pred = pd.DataFrame(res_lst, columns = ['0', '1', 'label'])
        df_res = pd.concat([df_idx, df_pred], axis=1)
    else:
        for j in range(1, 4):
            with open(CLS_DATA_DIR+'g{}/results{}.txt'.format(k, j), 'r') as f:
                for line in f.readlines():
                    res_lst.append(line.strip().split(', '))
        df_pred = pd.DataFrame(res_lst, columns = ['0', '1', 'label'])
        df_res = pd.concat([df_idx, df_pred], axis=1)
    return df_res

In [8]:
df_pred = [get_resdf(k) for k in range(0, 10)]

In [9]:
df_res = pd.concat([df_pred[k] for k in range(10)], axis=0)
df_res

Unnamed: 0,q_id,c_id,0,1,label
0,10681,46,1,1,1
1,10682,28,1,1,1
2,10683,67,1,1,1
3,10684,76,1,1,1
4,10684,87,1,1,1
...,...,...,...,...,...
21554,15180,20,-0.4356097877025604,0.5379728674888611,1
21555,15180,24,-0.2748143970966339,0.4656957983970642,1
21556,15180,28,-0.07979962974786758,0.06648369878530502,1
21557,15180,39,0.20516079664230347,-0.3047630488872528,0


In [10]:
df_res = df_res.sort_values(by=['q_id', 'c_id']).reset_index(drop=True)
df_res

Unnamed: 0,q_id,c_id,0,1,label
0,10681,0,-0.6427979469299316,0.8245172500610352,1
1,10681,1,-0.7951300144195557,0.9218880534172058,1
2,10681,2,-2.0496106147766113,2.3105077743530273,1
3,10681,3,-2.0802292823791504,2.3780717849731445,1
4,10681,4,0.1740451157093048,-0.29611217975616455,0
...,...,...,...,...,...
449995,15180,95,0.6444199085235596,-0.5813272595405579,0
449996,15180,96,-0.7608732581138611,1.0080045461654663,1
449997,15180,97,-2.744319200515747,3.0168559551239014,1
449998,15180,98,1.9477359056472778,-2.068024158477783,0


In [12]:
df_res.to_csv('./result/cls/dev_pred_cls.csv', index=None, encoding='utf_8_sig')

# v2

In [24]:
df_idx = [pd.read_csv(CLS_DATA_DIR+'g{}/dev_pred_ori.csv'.format(k)) for k in range(10)]
df_idx_merge1 = pd.concat([df_idx[k] for k in range(2)], axis=0).reset_index(drop=True)
df_idx_merge2 = pd.concat([df_idx[k] for k in range(2, 10)], axis=0).reset_index(drop=True)

In [25]:
res_lst1 = []
for k in range(2):
    for i in range(df_idx_merge1.shape[0]):
        res_lst1.append([1, 1, 1])
    df_pred = pd.DataFrame(res_lst1, columns = ['0', '1', 'label'])
    df_res1 = pd.concat([df_idx_merge1, df_pred], axis=1)
    
res_lst2 = []
with open('./cls_dataset/results-v2.txt', 'r') as f:
    for line in f.readlines():
        res_lst2.append(line.strip().split(', '))
    df_pred = pd.DataFrame(res_lst2, columns = ['0', '1', 'label'])
    df_res2 = pd.concat([df_idx_merge2, df_pred], axis=1)

df_res = pd.concat([df_res1, df_res2], axis=0).reset_index(drop=True)

In [29]:
df_res = df_res.sort_values(by=['q_id', 'c_id']).reset_index(drop=True)
df_res.to_csv('./result/cls/dev_pred_cls-v2.csv', index=None, encoding='utf_8_sig')

# v3

In [39]:
df_idx = [pd.read_csv(CLS_DATA_DIR+'g{}/dev_pred_ori.csv'.format(k)) for k in range(10)]
df_idx_merge1 = pd.concat([df_idx[k] for k in range(2)], axis=0).reset_index(drop=True)
df_idx_merge2 = pd.concat([df_idx[k] for k in range(2, 10)], axis=0).reset_index(drop=True)

In [40]:
res_lst1 = []
with open('./cls_dataset/results01-v2.txt', 'r') as f:
    for line in f.readlines():
        res_lst1.append(line.strip().split(', '))
    df_pred = pd.DataFrame(res_lst1, columns = ['0', '1', 'label'])
    df_res1 = pd.concat([df_idx_merge1, df_pred], axis=1)
    
res_lst2 = []
with open('./cls_dataset/results-v2.txt', 'r') as f:
    for line in f.readlines():
        res_lst2.append(line.strip().split(', '))
    df_pred = pd.DataFrame(res_lst2, columns = ['0', '1', 'label'])
    df_res2 = pd.concat([df_idx_merge2, df_pred], axis=1)

df_res = pd.concat([df_res1, df_res2], axis=0).reset_index(drop=True)

In [41]:
df_res = df_res.sort_values(by=['q_id', 'c_id']).reset_index(drop=True)
df_res.to_csv('./result/cls/dev_pred_cls-v3.csv', index=None, encoding='utf_8_sig')