In [1]:
from train_yh import *
from utils_yh import *

import argparse
import os
import gc
import time
import json
import random
import warnings
from pprint import pprint

import numpy as np
import pandas as pd
from einops import rearrange, reduce, repeat
import tez
import torch
import torch.nn as nn
from sklearn import metrics
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, AutoConfig, AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup

import copy
import warnings
import os

import numpy as np
import pandas as pd
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AdamW, AutoConfig, AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
from joblib import Parallel, delayed 
from tez import enums
from tez.callbacks import Callback
from tqdm import tqdm

import re
import nltk

In [2]:
# setting
class set_args_submission:
    def __init__(self):
        self.model: str = 'longformer-base-4096'
        self.sbert: bool = True
        self.output: str = '../model'
        self.input: str = '../input'
        self.max_len: int = 4096
        self.valid_batch_size: int = 4
            
args = set_args_submission()

In [40]:
# 모델 불러오기
model = FeedbackModel(
            model_name=args.model,
            num_train_steps=0,
            learning_rate=0,
            num_labels=len(target_id_map) - 1,
            steps_per_epoch=0,
            args=args,
        )

weight_path = os.path.join(args.output, 'longformer-base-4096-sbert.bin')
model.load_state_dict(torch.load(weight_path))
model.eval()

# 데이터 불러오기
df = pd.read_csv(os.path.join("../input/", "sample_submission.csv"))
df_ids = df["id"].unique()

tokenizer = AutoTokenizer.from_pretrained(args.model)
test_samples = prepare_test_data(df, tokenizer, args)
collate = ValidCollate(tokenizer, args)

test_dataset = FeedbackDatasetValid(test_samples, args.max_len, tokenizer, args = args)

# preds_iter = model.predict(test_dataset, batch_size=args.valid_batch_size, n_jobs=-1, collate_fn=collate)

import psutil
n_jobs = psutil.cpu_count()

data_loader = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=args.valid_batch_size, 
    num_workers=n_jobs, 
    collate_fn=collate, 
    pin_memory=True
)

def predict(data_loader, model):
    tk0 = tqdm(data_loader, total = len(data_loader))
    for _, data in enumerate(tk0):
        with torch.no_grad():
            data['ids'] = data['ids'].to(model.device)
            data['mask'] = data['mask'].to(model.device)
            data['input_type_list'] = [s.to(model.device) for s in data['input_type_list']]

            output, _, _ = model(**data)
            output = output.cpu().detach().numpy()
            yield output
    tk0.close()


preds_iter = predict(data_loader, model)

raw_preds = []
current_idx = 0

for preds in preds_iter:
    preds = preds.astype(np.float16)
    raw_preds.append(preds)
    current_idx += 1
    
torch.cuda.empty_cache() 

final_preds = []
final_scores = []

for rp in raw_preds:
    pred_class = np.argmax(rp, axis=2)
    pred_scrs = np.max(rp, axis=2)
    for pred, pred_scr in zip(pred_class, pred_scrs):
        pred = pred.tolist()
        pred_scr = pred_scr.tolist()
        final_preds.append(pred)
        final_scores.append(pred_scr)

for j in range(len(test_samples)):
    tt = [id_target_map[p] for p in final_preds[j][1:]]
    tt_score = final_scores[j][1:]
    test_samples[j]["preds"] = tt
    test_samples[j]["pred_scores"] = tt_score

100%|██████████| 2/2 [00:10<00:00,  5.38s/it]


In [41]:
def jn(pst, start, end):
    return " ".join([str(x) for x in pst[start:end]])

def link_evidence(oof):
    thresh = 1
    idu = oof['id'].unique()
    idc = idu[1]
    eoof = oof[oof['class'] == "Evidence"]
    neoof = oof[oof['class'] != "Evidence"]
    for thresh2 in range(26,27, 1):
        retval = []
        for idv in idu:
            for c in  ['Lead', 'Position', 'Evidence', 'Claim', 'Concluding Statement',
                   'Counterclaim', 'Rebuttal']:
                q = eoof[(eoof['id'] == idv) & (eoof['class'] == c)]
                if len(q) == 0:
                    continue
                pst = []
                for i,r in q.iterrows():
                    pst = pst +[-1] + [int(x) for x in r['predictionstring'].split()]
                start = 1
                end = 1
                for i in range(2,len(pst)):
                    cur = pst[i]
                    end = i
                    #if pst[start] == 205:
                    #   print(cur, pst[start], cur - pst[start])
                    if (cur == -1 and c != 'Evidence') or ((cur == -1) and ((pst[i+1] > pst[end-1] + thresh) or (pst[i+1] - pst[start] > thresh2))):
                        retval.append((idv, c, jn(pst, start, end)))
                        start = i + 1
                v = (idv, c, jn(pst, start, end+1))
                #print(v)
                retval.append(v)
        roof = pd.DataFrame(retval, columns = ['id', 'class', 'predictionstring']) 
        roof = roof.merge(neoof, how='outer')
        return roof

proba_thresh = {
    "Lead": 0.7,
    "Position": 0.55,
    "Evidence": 0.65,
    "Claim": 0.55,
    "Concluding Statement": 0.7,
    "Counterclaim": 0.5,
    "Rebuttal": 0.55,
}

min_thresh = {
    "Lead": 9,
    "Position": 5,
    "Evidence": 14,
    "Claim": 3,
    "Concluding Statement": 11,
    "Counterclaim": 6,
    "Rebuttal": 4,
}

submission = []
for sample_idx, sample in enumerate(test_samples):
    preds = sample["preds"]
    offset_mapping = sample["offset_mapping"]
    sample_id = sample["id"]
    sample_text = sample["text"]
    sample_input_ids = sample["input_ids"]
    sample_pred_scores = sample["pred_scores"]
    sample_preds = []

    if len(preds) < len(offset_mapping):
        preds = preds + ["O"] * (len(offset_mapping) - len(preds))
        sample_pred_scores = sample_pred_scores + [0] * (len(offset_mapping) - len(sample_pred_scores))
    
    idx = 0
    phrase_preds = []
    while idx < len(offset_mapping):
        start, _ = offset_mapping[idx]
        if preds[idx] != "O":
            label = preds[idx][2:]
        else:
            label = "O"
        phrase_scores = []
        phrase_scores.append(sample_pred_scores[idx])
        idx += 1
        while idx < len(offset_mapping):
            if label == "O":
                matching_label = "O"
            else:
                matching_label = f"I-{label}"
            if preds[idx] == matching_label:
                _, end = offset_mapping[idx]
                phrase_scores.append(sample_pred_scores[idx])
                idx += 1
            else:
                break
        if "end" in locals():
            phrase = sample_text[start:end]
            phrase_preds.append((phrase, start, end, label, phrase_scores))

    temp_df = []
    for phrase_idx, (phrase, start, end, label, phrase_scores) in enumerate(phrase_preds):
        word_start = len(sample_text[:start].split())
        word_end = word_start + len(sample_text[start:end].split())
        word_end = min(word_end, len(sample_text.split()))
        ps = " ".join([str(x) for x in range(word_start, word_end)])
        if label != "O":
            if sum(phrase_scores) / len(phrase_scores) >= proba_thresh[label]:
                if len(ps.split()) >= min_thresh[label]:
                    temp_df.append((sample_id, label, ps))
    
    temp_df = pd.DataFrame(temp_df, columns=["id", "class", "predictionstring"])
    submission.append(temp_df)

In [42]:
submission = pd.concat(submission).reset_index(drop=True)
submission = link_evidence(submission)
submission.to_csv("submission.csv", index=False)