## 1. Load Data

In [35]:
import json

def get_preds(path):
    with open(path) as f:
        return json.load(f) # dict

path = "/opt/ml/outputs/preds/backup/no1_xlm_roberta/nbest_predictions.json"
preds = get_preds(path)

## -1. Save (ensemble result)

In [None]:
from utils import increment_path

def save_json(preds, path):
    with open(path, "w") as f:
            f.write(json.dumps(preds, indent=4, ensure_ascii=False) + "\n")

out_path = increment_path("/opt/ml/outputs/preds/", infix="ensemble", name="predictions.json")
save_json(ensemble_result, out_path)

## 2-1. Sum ensemble

In [70]:
from collections import defaultdict
from operator import itemgetter


def _get_text_scores(ans_list):
    text_scores = defaultdict(float)
    for ans in ans_list:
        text_scores[ans["text"]] += ans["probability"]
    return text_scores

def _get_best_text(text_scores):
    return max(text_scores.items(), key=itemgetter(1))[0]

def get_self_ensemble_result(preds: dict):
    result = dict() # dict keeps input sequence order
    for q_id, ans_list in preds.items():
        text_scores = _get_text_scores(ans_list)
        best_text = _get_best_text(text_scores)
        result[q_id] = best_text
        
    return result

ensemble_result = get_self_ensemble_result(preds)
#ensemble_result

## 2-2. Count ensemble

In [76]:
from collections import defaultdict
from operator import itemgetter


def _get_text_counts(ans_list):
    text_counts = defaultdict(int)
    for ans in ans_list:
        text_counts[ans["text"]] += 1
    return text_counts

def _get_best_text(text_counts):
    return max(text_counts.items(), key=itemgetter(1))[0]

def get_self_ensemble_result(preds: dict):
    result = dict() # dict keeps input sequence order
    for q_id, ans_list in preds.items():
        text_counts = _get_text_counts(ans_list)
        best_text = _get_best_text(text_counts)
        result[q_id] = best_text
        
    return result

ensemble_result = get_self_ensemble_result(preds)
#ensemble_result

## 2-3. Sum*Count ensemble

In [113]:
from collections import defaultdict
from operator import itemgetter

def _get_text_counts(ans_list):
    text_counts = defaultdict(int)
    for ans in ans_list:
        text_counts[ans["text"]] += 1
    return text_counts

def _get_text_scores(ans_list):
    text_scores = defaultdict(float)
    for ans in ans_list:
        text_scores[ans["text"]] += ans["probability"]
    return text_scores

def _get_best_text(text_scores):
    return max(text_scores.items(), key=itemgetter(1))[0]

def get_self_ensemble_result(preds: dict, exp=1.0):
    result = dict() # dict keeps input sequence order
    for q_id, ans_list in preds.items():
        text_scores = _get_text_scores(ans_list)
        text_counts = _get_text_counts(ans_list)
        text_weighted_scores = {text: score*(count**exp) for text, score, count in zip(text_scores, text_scores.values(), text_counts.values())}
        best_text = _get_best_text(text_weighted_scores)
        result[q_id] = best_text
        
    return result

ensemble_result = get_self_ensemble_result(preds, exp=1)
ensemble_result

{'mrc-1-000653': '지구',
 'mrc-1-001113': '플레이오프',
 'mrc-0-002191': '빌헬름 미클라스',
 'mrc-0-003951': '뉴질랜드',
 'mrc-1-001272': '북한군',
 'mrc-1-000993': '아래턱',
 'mrc-0-005021': '근대 자본주의',
 'mrc-1-000163': '인격성',
 'mrc-0-001283': '순조 11년(1811)',
 'mrc-0-004543': '고전도성 철',
 'mrc-0-000439': '점쟁이',
 'mrc-0-002895': '칼라치 전방 약250km 지점',
 'mrc-0-000535': '롭 포드',
 'mrc-1-001724': '〈상당(上堂)〉',
 'mrc-0-000901': '고려 현종',
 'mrc-0-001606': '그로인볼트',
 'mrc-0-000266': '공무원',
 'mrc-0-001326': '《코믹 바이어스 가이드》',
 'mrc-0-000032': '국방부 보고서',
 'mrc-0-005215': '영산사',
 'mrc-0-005407': '숙의 정씨',
 'mrc-0-003683': '국가인권위원회',
 'mrc-0-003644': '이탈리아',
 'mrc-0-002835': '바이샬리',
 'mrc-0-000049': '아이템',
 'mrc-1-001829': '화흡',
 'mrc-1-001662': '보신 전쟁',
 'mrc-0-001206': '2018년',
 'mrc-0-004007': '로스쿨',
 'mrc-1-000418': '일등바위',
 'mrc-0-003133': '잡화점',
 'mrc-0-004646': '건늠선',
 'mrc-0-001058': '카필라 성',
 'mrc-0-002361': '종가의 자존심',
 'mrc-0-004830': '13개',
 'mrc-0-002762': '‘남고사’',
 'mrc-0-000395': '오달제',
 'mrc-0-001668': '어머니',
 'mrc-0-