In [18]:
import gc
import sys
import os
from tqdm import tqdm

import re
import nltk

import numpy as np
import pandas as pd #!#
import matplotlib.pyplot as plt
import random
from einops import rearrange, repeat, reduce

from joblib import Parallel, delayed # https://www.notion.so/joblib-da8f5ee8dbd44da19b36da04bd657bb1
from torch.utils.data import Dataset, DataLoader

import tez # https://www.notion.so/Tez-093f1f31cba646e3963108294563ddd1
from tez.callbacks import Callback
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
from sklearn import metrics

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import warnings
warnings.simplefilter("ignore", UserWarning)

In [13]:
NUM_JOBS = 12 #1# hyper parameter setting
args = set_args() #1# hyper parameter setting
seed_everything(args.seed) #2# fix seed for reproducability
os.makedirs(args.output_path, exist_ok = True)

df = pd.read_csv(os.path.join(os.getcwd(), f'input/train_{5}folds.csv')) #3# read kfold data

args.model = './model' #!# big bird. 이후에 long-former 로 바꿔주세요.
tokenizer = AutoTokenizer.from_pretrained(args.model)
samples = prepare_training_data(df.iloc[:100], tokenizer, args, num_jobs=NUM_JOBS)
collate = Collate(tokenizer)

dataset = FeedbackDataset(samples, args.max_len, tokenizer)
dataloader = DataLoader(dataset, shuffle=True, batch_size = args.batch_size, collate_fn = collate)

In [14]:
CONFIG = AutoConfig.from_pretrained(args.model, output_hidden_states = True) # ref) https://github.com/huggingface/transformers/issues/1827
LM = AutoModel.from_pretrained(args.model, config = CONFIG)

Some weights of the model checkpoint at ./model were not used when initializing BigBirdModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BigBirdModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BigBirdModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
num_labels = len(target_id_map)
d_model = LM.config.hidden_size

HEAD = Head(d_model, num_labels)
FEED_BACK_MODEL = FeedbackModel(LM, HEAD)

In [None]:
from torch.optim import AdamW

model = FEED_BACK_MODEL.to(device)

EPOCH = 5
LR = 2e-5
optimizer = AdamW(model.parameters(), lr = LR)

loss = 0
loss_traj = []

model.train()
for _ in tqdm(range(EPOCH)):
    for batch in dataloader:
        batch_device = dict()
        batch_device['ids'] = batch['ids'].to(device)
        batch_device['type'] = [type_.to(device) for type_ in batch['type']]
        batch_device['mask'] = batch['mask'].to(device)
        batch_device['targets'] = batch['targets'].to(device)

        pred_logits = model(batch_device)
        loss = calc_loss(pred_logits, batch_device['targets'], batch_device['mask'].gt(0))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_traj.append(loss)

In [None]:
plt.plot(loss_traj)
plt.show()

In [2]:
def calc_loss(logits, labels, mask, num_labels = 16): #!# test code
    '''CrossEntropy 를 계산한다.
    '''
    loss_func = nn.CrossEntropyLoss()
    mask_label = repeat(mask, 'b s -> b s c', c = num_labels)
    unfold_logits = torch.masked_select(logits, mask_label).view(-1, num_labels).softmax(dim = -1)
    unfold_labels = torch.masked_select(labels, mask)
    return loss_func(unfold_logits, unfold_labels)

In [3]:
class FeedbackModel(nn.Module):
    def __init__(self, lm, head):
        super(FeedbackModel, self).__init__()
        self.lm = lm
        self.head = head
        
    def forward(self, x):
        out = self.lm(input_ids = x['ids'], attention_mask = x['mask'])[0] # out 
        out = self.head(x['type'], out)
        return out

# s : sentence, t : token, b : batch, d : d_model
class Head(nn.Module):
    def __init__(self, d_model, num_labels):
        super(Head, self).__init__()
        self.num_labels = num_labels
        self.num_labels_sent = (num_labels - 2) / 2
        
        self.fc_layer_start = nn.Linear(d_model, self.num_labels)
        self.fc_layer_sent  = nn.Linear(d_model, self.num_labels) # I-label, O, PAD #!# 
        
    def forward(self, type_list:list, out:torch.tensor) -> torch.tensor:
        global device
        batch_size, seq_len, d_model = out.shape
        pred_class = torch.zeros(batch_size, seq_len, self.num_labels).to(device)
        
        start_list, sent_list = self._get_sent_idx(type_list, seq_len)
        for i, (start, sent) in enumerate(zip(start_list, sent_list)):
            start_token = self._pool_by_mask(out[i], start.type(torch.float))
            sent_token = self._pool_by_mask(out[i], sent.type(torch.float))
            
            pred_class[i] = pred_class[i] + self.fc_layer_start(start_token)
            pred_class[i] = pred_class[i] + self.fc_layer_sent(sent_token) #!# sent 문장 성분은 num_label 말고 num_labels_sent만 활용해보기
            
        return pred_class
    
    def _get_sent_idx(self, type_list:list, seq_len):
        start_list = []
        sent_list = []
        
        for type_ in type_list:
            _, i = torch.max(type_, dim = 1)
            start = F.one_hot(i, num_classes = type_.size(-1))
            sent = type_ - start
            start_list.append(start)
            sent_list.append(sent)
        
        return start_list, sent_list
    
    def _pool_by_mask(self, x, mask): #!# todo : mean pooling 적용해보기 #!# torch.masked_select 사용해보기
        '''Parameters
            x : BERT 를 지나온 encoded vector
            mask : 유효한 token 을 1 로 표기한 mask.
                ex) [0,0,0,1,1,1,0,0,0]
        '''
        assert mask.dtype == torch.float # for einsum operation
        x = torch.einsum('st,td->sd', mask, x) # extract encoded vector : [sentence, seq_len] * [seq_len, d_model] -> [sentence, d_model]
        x = torch.einsum('si,sj->sij', mask, x).sum(dim = 0) # combine encoded vector : [sentence, d_model]
        return x

In [4]:
class set_args: #1#
    seed: int = 42
    fold: int = 0
    kfold: int = 5
    model = 'allenai/longformer-base-4096' #!# 모델 이름
    lr: float = 1e-5
    output_path = os.path.join(os.getcwd(), 'model') # '../model'
    input_path = os.path.join(os.getcwd(), 'input') # '../input'
    max_len: int = 1024
    batch_size: int = 8
    valid_batch_size: int = 8
    epochs: int = 20
    accumulation_steps = 1 #!# 이게 뭐지?

In [5]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [6]:
from torch.nn.utils.rnn import pad_sequence
class Collate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        output = dict()
        output["ids"] = [sample["ids"] for sample in batch]
        output["type"] = [sample["type"] for sample in batch] # sample["type"] = [num_sent, seq_len]
        output["mask"] = [sample["mask"] for sample in batch]
        output['targets'] = [sample["targets"] for sample in batch]

        # calculate max token length of this batch
        batch_max = max([len(ids) for ids in output["ids"]])

        output['ids'] = pad_sequence(output['ids'], batch_first = True)
        output['mask'] = pad_sequence(output['mask'], batch_first = True)
        output['targets'] = pad_sequence(output['targets'], batch_first = True)
        
        # add padding... #!# readability
        if self.tokenizer.padding_side == "right":
            output["type"] = [torch.cat((s, torch.full((s.size(-2), batch_max - s.size(-1)), 0)), dim = -1) for s in output["type"]]
        else:
            output["type"] = [torch.cat((torch.full((s.size(-2), batch_max - s.size(-1)), 0), s), dim = -1) for s in output["type"]]
        
        # add padding... #!# readability
#         if self.tokenizer.padding_side == "right":
#             output["ids"] = [torch.cat((s, torch.full((batch_max - s.size(-1),), self.tokenizer.pad_token_id)), dim = -1) for s in output["ids"]]
#             output["type"] = [torch.cat((s, torch.full((s.size(-2), batch_max - s.size(-1)), 0)), dim = -1) for s in output["type"]]
#             output["mask"] = [torch.cat((s, torch.full((batch_max - s.size(-1),), 0)), dim = -1) for s in output["mask"]]
#         else:
#             output["ids"] = [torch.cat((torch.full((batch_max - s.size(-1),), self.tokenizer.pad_token_id), s), dim = -1) for s in output["ids"]]
#             output["type"] = [torch.cat((torch.full((s.size(-2), batch_max - s.size(-1)), 0), s), dim = -1) for s in output["type"]]
#             output["mask"] = [torch.cat((torch.full((batch_max - s.size(-1),), 0), s), dim = -1) for s in output["mask"]]

        # convert to tensors
        output["ids"] = output["ids"].type(torch.long)
        output["mask"] = output["mask"].type(torch.long)
        output["targets"] = output["targets"].type(torch.long)
        output["type"] = [token_type.type(torch.long) for token_type in output["type"]]
        return output

In [7]:
class FeedbackDataset(Dataset):
    def __init__(self, samples, max_len, tokenizer):
        super(FeedbackDataset, self).__init__() #!#
        self.samples = samples
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.length = len(samples)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        input_ids = self.samples[idx]["input_ids"]
        input_labels = [target_id_map[x] for x in self.samples[idx]["input_labels"]]
        input_type_ids = self.samples[idx]["token_type_ids"]
        other_label_id = target_id_map["O"]
        padding_label_id = target_id_map["PAD"]

        # add start token id to the input_ids
        input_ids = [self.tokenizer.cls_token_id] + input_ids
        input_labels = [other_label_id] + input_labels
        input_type_ids = [[0] + type_ids for type_ids in input_type_ids]
        
        # truncate the input if the text is longer than max_len
        if len(input_ids) > self.max_len - 1:
            input_ids = input_ids[: self.max_len - 1]
            input_labels = input_labels[: self.max_len - 1]
            input_type_ids = [type_ids[: self.max_len - 1] for type_ids in input_type_ids]

        # add end token id to the input_ids
        input_ids = input_ids + [self.tokenizer.sep_token_id]
        input_labels = input_labels + [other_label_id]
        input_type_ids = [type_ids + [0] for type_ids in input_type_ids]

        attention_mask = [1] * len(input_ids)
        
        # padding
        #!# 굳이 padding 을 max_len 까지 전부 다 해야하나..?
#         padding_length = self.max_len - len(input_ids)
#         if padding_length > 0:
#             if self.tokenizer.padding_side == "right":
#                 input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
#                 input_labels = input_labels + [padding_label_id] * padding_length
#                 input_type_ids = input_type_ids + [0] * padding_length
#                 attention_mask = attention_mask + [0] * padding_length
                
#             else:
#                 input_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
#                 input_labels = [padding_label_id] * padding_length + input_labels
#                 input_type_ids = [0] * padding_length + input_type_ids
#                 attention_mask = [0] * padding_length + attention_mask

        return {
            "ids": torch.tensor(input_ids, dtype=torch.long),
            "type": torch.tensor(input_type_ids, dtype=torch.long),
            "mask": torch.tensor(attention_mask, dtype=torch.long),
            "targets": torch.tensor(input_labels, dtype=torch.long),
        }

In [8]:
def prepare_training_data(df: pd.DataFrame, tokenizer, args, num_jobs):
    '''_prepare_training_data_helper 를 병렬처리
        Parameters
            tokenizer : 전처리에 활용되는 tokenizer
            num_jobs : number of 병렬 처리 workers
        
        Returns
            training_samples(list) : 
    '''
    training_samples = []
    train_ids = df["id"].unique()

    train_ids_splits = np.array_split(train_ids, num_jobs)

    results = Parallel(n_jobs=num_jobs, backend="multiprocessing")(
        delayed(_prepare_training_data_helper)(df, tokenizer, args, idx) for idx in train_ids_splits
    )
    for result in results:
        training_samples.extend(result)

    return training_samples

def _prepare_training_data_helper(df, tokenizer, args, train_ids = None):
    training_samples = []
    for idx in df['id'].unique(): #!# replace df['id'].unique() -> train_ids
        filename = os.path.join('./input', "train", idx + ".txt")
        with open(filename, "r") as f:
            text = f.read()

        encoded_text = tokenizer.encode_plus(
            text,
            add_special_tokens=False,
            return_offsets_mapping=True,
        )

        input_ids = encoded_text["input_ids"]
        offset_mapping = encoded_text["offset_mapping"]

        sample = {
            "id": idx,
            "input_ids": input_ids,
            "text": text,
            "offset_mapping": offset_mapping,
        }

        # token_type_ids 만들기
        #!# test code
        token_type_ids_list = []
        processed_text, processed_idx_list = _replace_awkend(text) # nltk 를 위한 preprocessing
        start_idx_list, end_idx_list = _extract_sentence_idx(processed_text) # 문장 index 추출하기
        start_idx_list = _postprocess_sent_idx(start_idx_list, processed_idx_list) # preprocessing 이전의 index 로 되돌리기
        end_idx_list = _postprocess_sent_idx(end_idx_list, processed_idx_list) # preprocessing 이전의 index 로 되돌리기
        
        for start_idx, end_idx in zip(start_idx_list, end_idx_list):
            text_type_ids = [0] * len(text)
            text_type_ids[start_idx:end_idx] = [1] * (end_idx - start_idx) #!# start_idx:end_idx 로 잘 추출되는지 확인

            token_type_ids = []
            for i, (offset1, offset2) in enumerate(encoded_text["offset_mapping"]):
                if any(text_type_ids[offset1:offset2]): # 1개의 text 라도 0 이외의 값
                    if len(text[offset1:offset2].split()) > 0: #!# 1개의 token 은 include
                        token_type_ids.append(1)
                    else:
                        token_type_ids.append(0)
                else:
                    token_type_ids.append(0)
                        
            assert len(token_type_ids) == len(encoded_text["offset_mapping"])
            token_type_ids_list.append(token_type_ids)
        sample["token_type_ids"] = token_type_ids_list
        
        # input_labels
        temp_df = df[df['id'] == idx]
        input_labels = copy.deepcopy(input_ids)
        for k in range(len(input_labels)):
            input_labels[k] = "O"
            
        for _, row in temp_df.iterrows():
            text_labels = [0] * len(text)
            discourse_start = int(row["discourse_start"])
            discourse_end = int(row["discourse_end"])
            prediction_label = row["discourse_type"]
            text_labels[discourse_start:discourse_end] = [1] * (discourse_end - discourse_start)
            target_idx = []
            for map_idx, (offset1, offset2) in enumerate(encoded_text["offset_mapping"]):
                if sum(text_labels[offset1:offset2]) > 0:
                    if len(text[offset1:offset2].split()) > 0:
                        target_idx.append(map_idx)

            targets_start = target_idx[0]
            targets_end = target_idx[-1]
            pred_start = "B-" + prediction_label
            pred_end = "I-" + prediction_label
            input_labels[targets_start] = pred_start
            input_labels[targets_start + 1 : targets_end + 1] = [pred_end] * (targets_end - targets_start)
        sample["input_ids"] = input_ids
        sample["input_labels"] = input_labels

        training_samples.append(sample)
    return training_samples

In [9]:
def _replace_awkend(text):
    '''"문장.문장", "문장 .문장" 을 "문장. 문장" 으로 바꿔준다.
        Parameters
            - text (str) : "문장 .문장", "문장.문장"
        Return
            - text (str) : "문장. 문장"
        
    nltk 의 nltk.sent_tokenize() 는 문장    
    cf) "U.S. gov" 를 "U. S. gov" 로 바꾸지만, nltk 는 다행히 후자를 하나의 문장으로 취급한다.
    '''
    # "문장 .문장"
    text = re.sub(r' \.', r'. ', text) 
    
    # "문장.문장"
    replace_token = re.findall(r'\w\.\w', text) 
    replace_idx = [text.index(token) + i + 1 for i, token in enumerate(replace_token)]
    for idx in replace_idx:
        text = text[:idx] + '. ' + text[idx+1:]        
    
    return text, replace_idx

In [10]:
# extract sentence index
def _extract_sentence_idx(text):
    '''nltk 를 활용해서 문장을 추출한다.
        Parameters
            - text : nltk 가 오작동하지 않도록 전처리된 자료
        Returns
            - start_idx_list (list) : 해당 문장이 시작하는 index
            - end_idx_list (list) : 해당 문장이 끝나는 index
        
        Assert
            - 각 i 에 대해서 text[start_idx_list[i]:end_idx_list[i]] 는 
                하나의 문장에 대응한다.
    '''
    sent_list = nltk.sent_tokenize(text)
    
    start_idx_list = []
    end_idx_list = []
    for i, sent in enumerate(sent_list):
        start_idx_list.append(text.index(sent))
        end_idx_list.append(start_idx_list[-1] + len(sent))

    for i, _ in enumerate(sent_list):
        assert text[start_idx_list[i]:end_idx_list[i]] == sent_list[i]
        
    return start_idx_list, end_idx_list

In [11]:
import copy
import warnings

def _postprocess_sent_idx(sent_idx_list, processed_idx_list):
    postprocess_sent_idx_list = copy.deepcopy(sent_idx_list)
    for i, sent_idx in enumerate(sent_idx_list):
        for processed_idx in processed_idx_list:
            if sent_idx > processed_idx:
                postprocess_sent_idx_list[i] -= 1
        
    return postprocess_sent_idx_list

# original_start_idx_list = _postprocess_sent_idx(start_idx_list, processed_idx_list)
# original_end_idx_list = _postprocess_sent_idx(end_idx_list, processed_idx_list)

# # test code : 추출 목표가 추출 결과와 일치하는가?
# diff_list = []
# for i in range(len(start_idx_list)):
#     start_idx = start_idx_list[i]
#     end_idx = end_idx_list[i]
#     original_start_idx = original_start_idx_list[i]
#     original_end_idx = original_end_idx_list[i]
    
#     if processed_text[start_idx:end_idx] != text[original_start_idx:original_end_idx]:
#         diff_list.append([processed_text[start_idx:end_idx], text[original_start_idx:original_end_idx]])
    
# if len(diff_list) != 0:
#     for diff in diff_list:
#         print(f"{diff[0]}\n{diff[1]}")
#     warnings.warn("exist different sentence (processed, original)"+f"{diff_list}")

In [12]:
target_id_map = {
    "B-Lead": 0,
    "I-Lead": 1,
    "B-Position": 2,
    "I-Position": 3,
    "B-Evidence": 4,
    "I-Evidence": 5,
    "B-Claim": 6,
    "I-Claim": 7,
    "B-Concluding Statement": 8,
    "I-Concluding Statement": 9,
    "B-Counterclaim": 10,
    "I-Counterclaim": 11,
    "B-Rebuttal": 12,
    "I-Rebuttal": 13,
    "O": 14,
    "PAD": -100,
}

id_target_map = {v: k for k, v in target_id_map.items()}

# Testcode

In [113]:
def testcode_prepare_training_data_helper(samples:list):
    '''문장이 제대로 추출됐는지 눈으로 확인한다. 
        Parameters
            - samples (dict)
                > input_ids : 각 token 들의 index. 
                > token_type_ids : 문장에 대응하는 token 위치를 1 로 저장해둔 list.
                    ex) [0,0,0,1,1,1,1,0,0,0,...]
        Returns
            - decoded_samples (list) : 자연어로 변환된 결과를 담은 list.
    '''
    decoded_samples = []
    for sample in samples:
        extracted_input_ids = torch.tensor(sample['input_ids']) * torch.tensor(sample['token_type_ids'])
        for ext_ids in extracted_input_ids:
            decoded_samples.append(tokenizer.decode(ext_ids))
    
    return decoded_samples