---
# 使い方
### ①３セル目のARCHITECTURE_NAME,EXP_NUMで実験環境を指定。 
### ②model名/exp/EXP_NUMフォルダ内のconfigファイル内で、ハイパラを設定。
### ③実行
---

# 実験環境指定


In [None]:
import sys

#colabかkaggleかそれ以外か
#colab
if 'google.colab' in sys.modules:
    ENV_FLAG='colab'

#kaggle
elif 'kaggle_web_client' in sys.modules:
    ENV_FLAG='kaggle'

#mac
elif 'ctypes.macholib.dyld' in sys.modules:
    ENV_FLAG='mac'

#それ以外
else:
    ENV_FLAG=None

In [None]:
ENV_FLAG

In [None]:
#初期状態ではモジュールが入っていないので、入れる
if ENV_FLAG=='colab':
    !pip install wandb transformers

    !pip install pytorch_lightning
    !pip install python-box
    !pip install sentencepiece

In [None]:
#実験環境の指定。
#フォルダ名を代入
ARCHITECTURE_NAME='external-answer-weight'
EXP_NUM=1
EXP_PATH=f'chaii/model/{ARCHITECTURE_NAME}/exp/{EXP_NUM}'

if ENV_FLAG=='colab':
    from google.colab import drive
    drive.mount('/content/drive')
    sys.path.append(f'/content/drive/MyDrive/kaggle/{EXP_PATH}')

if ENV_FLAG=='kaggle':
    sys.path.append(f'../input/{ARCHITECTURE_NAME}-exp{EXP_NUM}')

if ENV_FLAG=='mac':
    sys.path.append(f'/Volumes/GoogleDrive/マイドライブ/kaggle/{EXP_PATH}')
    

from box import Box
import config
cfg=Box(vars(config.Config()))


In [None]:
#一応、設定を目視で確認。
cfg

# 共通セットアップ

In [None]:
#gpuの確認
!nvidia-smi -L

In [None]:
#kaggle.jsonの移動
if ENV_FLAG=='colab':

    !mkdir -p ~/.kaggle
    !cp /content/drive/MyDrive/kaggle/kaggle.json ~/.kaggle/kaggle.json
    !chmod 600 /root/.kaggle/kaggle.json

In [None]:
import os
import sys
import copy
from pathlib import Path
import collections
import gc
gc.enable()
import math
import json
import time
import random
import multiprocessing
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

import wandb

import numpy as np
import pandas as pd
#notebook用tqdm。出力が綺麗になる。
from tqdm.notebook import tqdm
from sklearn import model_selection

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import torch.optim as optim
from torch.utils.data import (
    Dataset, DataLoader,
    SequentialSampler, RandomSampler
)
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
from pytorch_lightning import Trainer,seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint,EarlyStopping,LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger


try:
    #mixed presicionを使うため、ampをインポート
    from apex import amp
    APEX_INSTALLED = True
except ImportError:
    APEX_INSTALLED = False

import transformers
from transformers import (
    #学習済みモデルの重みのファイル名を返却
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    logging,
    #qaタスクにおける、コンフィグとモデルの組み合わせを返却
    MODEL_FOR_QUESTION_ANSWERING_MAPPING,
)
logging.set_verbosity_warning()
logging.set_verbosity_error()


print(f"Apex AMP Installed :: {APEX_INSTALLED}")
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

In [None]:
#model&tokenizer
if ENV_FLAG=='colab':
    MODEL_PATH="deepset/xlm-roberta-large-squad2"
    TOKENIZER_PATH="xlm-roberta-large"
#datasetは事前にDL
if ENV_FLAG=='kaggle':
    MODEL_PATH='../input/xlm-roberta-squad2/deepset/xlm-roberta-large-squad2'
    TOKENIZER_PATH='../input/xlm-roberta-squad2/deepset/xlm-roberta-large-squad2'

In [None]:
#wandbにログイン
wandb.login()

In [None]:
def set_seed(seed):
    #numpyのrandomseed
    np.random.seed(seed)
    #python　randomモジュールのseed
    random.seed(seed)
    #pythonでハッシュを生成するときの乱数を固定。kerasを使うときに効果がありそう
    #https://qiita.com/okotaku/items/8d682a11d8f2370684c9
    os.environ['PYTHONHASHSEED'] = str(seed)
    #torchで使っている乱数を固定。
    torch.manual_seed(seed)
    #単一gpuでの乱数を固定
    torch.cuda.manual_seed(seed)
    #複数gpu全ての乱数を固定
    torch.cuda.manual_seed_all(seed)

# 訓練データの用意

In [None]:
if ENV_FLAG=='colab':
    train=pd.read_csv('/content/drive/MyDrive/kaggle/chaii/data/input/train.csv.zip')
    test=pd.read_csv('/content/drive/MyDrive/kaggle/chaii/data/input/test.csv')
    submission=pd.read_csv('/content/drive/MyDrive/kaggle/chaii/data/input/sample_submission.csv')

elif ENV_FLAG=='kaggle':
    train=pd.read_csv('../input/chaii-hindi-and-tamil-question-answering/train.csv')
    test=pd.read_csv('../input/chaii-hindi-and-tamil-question-answering/test.csv')
    submission=pd.read_csv('../input/chaii-hindi-and-tamil-question-answering/sample_submission.csv')

elif ENV_FLAG=='mac':
    train=pd.read_csv('/Volumes/GoogleDrive/マイドライブ/kaggle/chaii/data/input/train.csv.zip')
    test=pd.read_csv('/Volumes/GoogleDrive/マイドライブ/kaggle/chaii/data/input/test.csv')
    submission=pd.read_csv('/Volumes/GoogleDrive/マイドライブ/kaggle/chaii/data/input/sample_submission.csv')

In [None]:
#拡張データセット
if ENV_FLAG=='colab':
    
    !kaggle datasets download -d rhtsingh/mlqa-hindi-processed 

    import zipfile
    with zipfile.ZipFile('/content/mlqa-hindi-processed.zip') as zip_ref:   
        zip_ref.extractall('/content')
    
    external_mlqa = pd.read_csv('/content/mlqa_hindi.csv')
    external_xquad = pd.read_csv('/content/xquad.csv')

    external_train = pd.concat([external_mlqa, external_xquad])

In [None]:
#clean data
#remove index
print(len(train))
mistake_id=['bc9f0d533',
            '1a2160a69',
            '632c16ba0',
            'e0090c270',
            '1b8635229',
            '8997bf894',
            '33d679522',
            'f22ab8d6b',
            '3a4db1dda']

mistake_index=train[train.id.isin(mistake_id)].index
train=train.drop(mistake_index,axis=0).reset_index(drop=True)
print(len(train))

In [None]:
def apply_stratified_kfolds(data, num_splits):
    data["kfold"] = -1
    kf = model_selection.StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=cfg.seed)
    #train_indexとvalid_indexがイテレート
    for f, (t_, v_) in enumerate(kf.split(X=data, y=data['language'])):
        data.loc[v_, 'kfold'] = f
    return data

train = apply_stratified_kfolds(train, num_splits=cfg.num_fold)
#idは便宜上付与。
if ENV_FLAG=='colab':
    external_train["kfold"] = -1
    external_train['id'] = list(np.arange(1, len(external_train)+1))
    train = pd.concat([train, external_train]).reset_index(drop=True)

def convert_answers(row):
    return {'answer_start': [row[0]], 'text': [row[1]]}

train['answers'] = train[['answer_start', 'answer_text']].apply(convert_answers, axis=1)

In [None]:
#各foldの比率を確認。
pivot=train.pivot_table(values='id',index='kfold',columns='language',aggfunc='count')
pivot2=pivot.apply(lambda x: x/sum(x),axis=1)

print(pivot,pivot2,sep='\n')

# Dataset,Dataloader関連

In [None]:
def prepare_train_features(example,tokenizer,max_length,doc_stride):

    example["question"] = example["question"].lstrip()
    tokenized_example = tokenizer(
        example["question"],
        example["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_mapping = tokenized_example.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_example.pop("offset_mapping")

    features = []
    for i, offsets in enumerate(offset_mapping):
        feature = {}

        input_ids = tokenized_example["input_ids"][i]
        attention_mask = tokenized_example["attention_mask"][i]

        feature['input_ids'] = input_ids
        feature['attention_mask'] = attention_mask
        feature['offset_mapping'] = offsets

        cls_index = input_ids.index(tokenizer.cls_token_id)
        sequence_ids = tokenized_example.sequence_ids(i)

        sample_index = sample_mapping[i]
        answers = example["answers"]

        if len(answers["answer_start"]) == 0:
            feature["start_position"] = cls_index
            feature["end_position"] = cls_index
        else:
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                feature["start_position"] = cls_index
                feature["end_position"] = cls_index
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                feature["start_position"] = token_start_index - 1
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                feature["end_position"] = token_end_index + 1

        features.append(feature)
    return features

In [None]:
def prepare_eval_features(example,tokenizer,max_length,doc_stride):

    example["question"] = example["question"].lstrip()
    tokenized_example = tokenizer(
        example["question"],
        example["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_mapping = tokenized_example.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_example.pop("offset_mapping")

    features = []
    for i, offsets in enumerate(offset_mapping):
        feature = {}

        input_ids = tokenized_example["input_ids"][i]
        attention_mask = tokenized_example["attention_mask"][i]

        feature["example_id"] = example['id']
        feature['context'] = example['context']
        feature['question'] = example['question']

        feature['input_ids'] = input_ids
        feature['attention_mask'] = attention_mask
        feature['offset_mapping'] = offsets

        cls_index = input_ids.index(tokenizer.cls_token_id)
        sequence_ids = tokenized_example.sequence_ids(i)

        sample_index = sample_mapping[i]
        answers = example["answers"]

        if len(answers["answer_start"]) == 0:
            feature["start_position"] = cls_index
            feature["end_position"] = cls_index
        else:
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                feature["start_position"] = cls_index
                feature["end_position"] = cls_index
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                feature["start_position"] = token_start_index - 1
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                feature["end_position"] = token_end_index + 1

        feature['sequence_ids'] = [0 if i is None else i for i in tokenized_example.sequence_ids(i)]      
        
        features.append(feature)
    return features

In [None]:
def prepare_test_features(example, tokenizer,max_length,doc_stride):
    example["question"] = example["question"].lstrip()
    
    tokenized_example = tokenizer(
        example["question"],
        example["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    features = []
    #各spanごとに、辞書形で情報を格納
    for i in range(len(tokenized_example["input_ids"])):
        feature = {}
        feature["example_id"] = example['id']
        feature['context'] = example['context']
        feature['question'] = example['question']
        feature['input_ids'] = tokenized_example['input_ids'][i]
        feature['attention_mask'] = tokenized_example['attention_mask'][i]
        feature['offset_mapping'] = tokenized_example['offset_mapping'][i]
        #sequence_idsがNoneはスペシャルトーケンの場合。この場合は、sequence_idsを0にする。
        feature['sequence_ids'] = [0 if i is None else i for i in tokenized_example.sequence_ids(i)]
        features.append(feature)
    return features

In [None]:
class ChaiiDataset(Dataset):
    def __init__(self, features, mode='train'):
        super(ChaiiDataset, self).__init__()
        self.features = features
        self.mode = mode
        
    def __len__(self):
        return len(self.features)
    
    #trainとtestで出力するデータを分けている（必要なものが異なるため）
    def __getitem__(self, index):   
        feature = self.features[index]
        if self.mode == 'train':
            return {
                'input_ids':torch.tensor(feature['input_ids'], dtype=torch.long),
                'attention_mask':torch.tensor(feature['attention_mask'], dtype=torch.long),
                #'offset_mapping':torch.tensor(feature['offset_mapping'], dtype=torch.long),
                'start_position':torch.tensor(feature['start_position'], dtype=torch.long),
                'end_position':torch.tensor(feature['end_position'], dtype=torch.long)
            }
        elif self.mode == 'eval':
            return {
                'input_ids':torch.tensor(feature['input_ids'], dtype=torch.long),
                'attention_mask':torch.tensor(feature['attention_mask'], dtype=torch.long),
                'offset_mapping':feature['offset_mapping'],
                'sequence_ids':feature['sequence_ids'],
                'example_id':feature['example_id'],
                'context': feature['context'],
                'question': feature['question'],
                'start_position':torch.tensor(feature['start_position'], dtype=torch.long),
                'end_position':torch.tensor(feature['end_position'], dtype=torch.long)
            }
            
        elif self.mode == 'test':
            return {
                'input_ids':torch.tensor(feature['input_ids'], dtype=torch.long),
                'attention_mask':torch.tensor(feature['attention_mask'], dtype=torch.long),
                'offset_mapping':feature['offset_mapping'],
                'sequence_ids':feature['sequence_ids'],
                'id':feature['example_id'],
                'context': feature['context'],
                'question': feature['question']
            }

In [None]:
#dataloader用　上記2つを使用して作成。jaccardをmonitorするためのdataloaderも作成。
def make_loader(
    data, 
    tokenizer,
    max_length,
    doc_stride,
    train_batch_size,
    eval_batch_size,
    fold
):
    train_df, valid_df = data[data['kfold']!=fold], data[data['kfold']==fold].reset_index(drop=True)
    
    train_features =[]
    valid_features =[]

    for i, row in tqdm(train_df.iterrows(),total=len(train_df)):
        train_features += prepare_train_features(row,tokenizer,max_length,doc_stride)

    for i, row in tqdm(valid_df.iterrows(),total=len(valid_df)):
        valid_features += prepare_eval_features(row,tokenizer,max_length,doc_stride)


    train_dataset = ChaiiDataset(train_features)
    valid_dataset = ChaiiDataset(valid_features,mode='eval')

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        drop_last=False 
    )

    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=eval_batch_size, 
        shuffle=False,
        num_workers=2,
        pin_memory=True, 
        drop_last=False
    )

    return train_dataloader, valid_dataloader,valid_df,valid_features



# 損失

In [None]:
# def loss_fn(preds, labels):
#     start_preds, end_preds = preds
#     start_labels, end_labels = labels

#     start_loss = nn.CrossEntropyLoss()(start_preds, start_labels)
#     end_loss = nn.CrossEntropyLoss()(end_preds, end_labels)
#     total_loss = (start_loss + end_loss) / 2

#     return total_loss

def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

def weight_loss_fn(preds, labels,tokenizer,alpha=2):
    start_preds, end_preds = preds
    start_labels, end_labels = labels

    start_loss = nn.CrossEntropyLoss()(start_preds, start_labels)
    end_loss = nn.CrossEntropyLoss()(end_preds, end_labels)
    total_loss = (start_loss + end_loss) / 2

    #答えを持っているとき、損失を増やすように修正
    have_answer=start_labels!=tokenizer.cls_token_id
    
    weight_=torch.tensor([1.]).repeat(len(start_labels))
    weight_=weight_.to('cuda')
    weight_[have_answer]*=alpha
    
    #答えがあるサンプルについては、重みをalpha倍し、最終のtotal_lossを算出
    total_loss=(total_loss*weight_.sum())/len(weight_)

    return total_loss

# 推論用関数。valid_dataのjaccard係数を算出するために定義

In [None]:
def postprocess_qa_predictions(examples, features1, raw_predictions,tokenizer,n_best_size = 20, max_answer_length = 30):
        #offsetのNoneへの変換を、元のfeaturesに反映させないようにする。
        features=copy.deepcopy(features1)
        all_start_logits, all_end_logits = raw_predictions
        
        #各exampleに対して、どれくらいfeatureがあるか。各exampleとfeatureのindexは０始まりに修正されている。なお、リストをvalueとした、デフォルトdict形式
        example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
        features_per_example = collections.defaultdict(list)
        for i, feature in enumerate(features):
            features_per_example[example_id_to_index[feature["example_id"]]].append(i)

        predictions = collections.OrderedDict()

        print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

        for example_index, example in examples.iterrows():
            feature_indices = features_per_example[example_index]

            min_null_score = None
            valid_answers = []
            
            context = example["context"]
            for feature_index in feature_indices:
                start_logits = all_start_logits[feature_index]
                end_logits = all_end_logits[feature_index]

                sequence_ids = features[feature_index]["sequence_ids"]
                context_index = 1

                features[feature_index]["offset_mapping"] = [
                    (o if sequence_ids[k] == context_index else None)
                    for k, o in enumerate(features[feature_index]["offset_mapping"])
                ]
                offset_mapping = features[feature_index]["offset_mapping"]

                cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
                feature_null_score = start_logits[cls_index] + end_logits[cls_index]
                if min_null_score is None or min_null_score < feature_null_score:
                    min_null_score = feature_null_score

                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
                #上位２０個数のうち、妥当なもののみを候補とする。
                for start_index in start_indexes:
                    for end_index in end_indexes:
                        if (
                            start_index >= len(offset_mapping)
                            or end_index >= len(offset_mapping)
                            or offset_mapping[start_index] is None
                            or offset_mapping[end_index] is None
                        ):
                            continue
                        # Don't consider answers with a length that is either < 0 or > max_answer_length.
                        if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                            continue

                        start_char = offset_mapping[start_index][0]
                        end_char = offset_mapping[end_index][1]
                        valid_answers.append(
                            {
                                "score": start_logits[start_index] + end_logits[end_index],
                                "text": context[start_char: end_char]
                            }
                        )
            #impossible_answerの判定はおこなっていない。（元：https://colab.research.google.com/drive/13QRHItm8dLliHUFckUwppsLK_U2vCmgS#scrollTo=rdckkKKuj48S）
            if len(valid_answers) > 0:
                best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
            else:
                best_answer = {"text": "", "score": 0.0}
            
            predictions[example["id"]] = best_answer["text"]
            
        return predictions


# tokenizer,seedの設定

In [None]:
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_PATH)
#seed固定
set_seed(cfg.seed)

# MODEL(pytorch-lightning)

In [None]:
class ChaiiModel(pl.LightningModule):
    def __init__(
        self,
        model_path,
        tokenizer,
        weight_decay,
        learning_rate,
        epoch,
        train_batch_size,
        valid_batch_size,
        warmup_ratio,
        gradient_accumulation_steps,
        t_dataloader,
        v_dataloader,
        valid_df,
        valid_features
    ):

        super().__init__()
        #gradient_accumulateのため、マニュアル
        self.automatic_optimization = False

        #config
        self.weight_decay=weight_decay
        self.learning_rate=learning_rate
        self.epoch=epoch
        self.train_batch_size=train_batch_size
        self.valid_batch_size=valid_batch_size
        self.warmup_ratio=warmup_ratio
        self.gradient_accumulation_steps=gradient_accumulation_steps

        #tokenizer
        self.tokenizer=tokenizer

        #model
        self.model_config=AutoConfig.from_pretrained(model_path)
        self.model_config.update(
            {
                "output_hidden_states": True,
                "hidden_dropout_prob": 0.1,
                #"layer_norm_eps": ,
                #"add_pooling_layer": False,
            }
        )
        self.model=AutoModel.from_pretrained(model_path,config=self.model_config)
        self.qa_outputs = nn.Linear(self.model_config.hidden_size, 2)
        self._init_weights(self.qa_outputs)
        
        self.dropout = nn.Dropout(self.model_config.hidden_dropout_prob)

        #dataloader
        self._train_dataloader=t_dataloader
        self._valid_dataloader=v_dataloader

        #valid_df
        self.valid_df=valid_df
        #valid_features
        self.valid_features=valid_features

        #save_hyperparameters
        self.save_hyperparameters('weight_decay',
                                    'learning_rate',
                                    'epoch',
                                    'train_batch_size',
                                    'valid_batch_size',
                                    'warmup_ratio',
                                    'gradient_accumulation_steps')
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.model_config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()

    def configure_optimizers(self):
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.weight_decay,
            },
            {
                "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
            ]
        optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.learning_rate,
            )
        
        num_training_steps=math.ceil(len(self._train_dataloader)/self.gradient_accumulation_steps)*self.epoch
        num_warmup_steps=num_training_steps*self.warmup_ratio

        scheduler=get_linear_schedule_with_warmup(
            optimizer,
            num_training_steps=num_training_steps,
            num_warmup_steps=num_warmup_steps,
        )

        return {'optimizer':optimizer,'lr_scheduler':scheduler}

    def forward(self):
        pass

    def _shared_step(self, batch):
        input_ids=batch['input_ids']
        attention_mask=batch['attention_mask']
        
        outputs=self.model(input_ids,attention_mask)
        last_hidden_state=outputs.last_hidden_state

        #batch*max_length*2の行列
        qa_logits=self.qa_outputs(last_hidden_state)

        #横軸方向(dim=2)に、chunk_size=1で分割→batch*max_length*1の行列が2つできる。
        start_logits, end_logits=qa_logits.split(1,dim=-1)

        #batch*max_lengthの２次元行列に変換
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        return (start_logits,end_logits)

    #gradient_accumulateを加味しているため、マニュアルbackward
    def training_step(self,batch, batch_idx):
        
        opt = self.optimizers()
        sch = self.lr_schedulers()

        #import pdb; pdb.set_trace()
        start_logits, end_logits=self._shared_step(batch)
        start_position=batch['start_position']
        end_position=batch['end_position']

        
        #loss=loss_fn((start_logits,end_logits),(start_position,end_position))
        loss=weight_loss_fn((start_logits,end_logits),(start_position,end_position),self.tokenizer,alpha=2)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        #if average 
        loss = loss / self.gradient_accumulation_steps

        #backward
        self.manual_backward(loss)

         # accumulate gradients of `n` batches
        if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
            opt.step()
            sch.step()
            opt.zero_grad()

    #def training_step_end(...)
    #def training_epoch_end(...)

    def validation_step(self, batch, batch_idx):

        start_logits, end_logits=self._shared_step(batch)

        start_position=batch['start_position']
        end_position=batch['end_position']
        
        loss=weight_loss_fn((start_logits,end_logits),(start_position,end_position),self.tokenizer,alpha=2)
        self.log_dict({"val_loss":loss},on_step=True,on_epoch=True,prog_bar=True)

        return [start_logits,end_logits]
    def validation_epoch_end(self,val_step_outputs):
        #jaccard計算。logitはndarrayに直してる。（postprocess用に）
        #from IPython.core.debugger import Pdb; Pdb().set_trace()

        all_start_logits, all_end_logits=[],[]
        for out in val_step_outputs:
            all_start_logits.append(out[0].to('cpu').detach().numpy().tolist())
            all_end_logits.append(out[1].to('cpu').detach().numpy().tolist())
        all_start_logits=np.vstack(all_start_logits)
        all_end_logits=np.vstack(all_end_logits)
        
        predictions=postprocess_qa_predictions(self.valid_df,self.valid_features,(all_start_logits, all_end_logits),self.tokenizer)
        res=self.valid_df[['id','answer_text']]
        res['prediction']=res['id'].map(predictions)
        all_jaccard=res[['answer_text', 'prediction']].apply(lambda x: jaccard(x[0],x[1]), axis=1)
        epoch_jaccard=np.mean(all_jaccard)
        self.log_dict({"jaccard":epoch_jaccard})

    def train_dataloader(self):
        return self._train_dataloader

    def val_dataloader(self):
        return self._valid_dataloader

    

# 訓練実行

In [None]:
for fold in range(cfg.num_fold):

    print(f'FOLD: {fold+1}')
    train_dataloader,val_dataloader,valid_df,valid_features=\
        make_loader(train,
                    tokenizer,
                    cfg.max_length,
                    cfg.doc_stride,
                    cfg.train_batch_size, 
                    cfg.valid_batch_size,
                    fold=fold)
    
    #setup
    model=ChaiiModel(
            MODEL_PATH,
            tokenizer,
            cfg.weight_decay,
            cfg.lr,
            cfg.epoch,
            cfg.train_batch_size,
            cfg.valid_batch_size,
            cfg.warmup_ratio,
            cfg.gradient_accumulation_steps,
            train_dataloader,
            val_dataloader,
            valid_df,
            valid_features
        )

    checkpoint_callback = ModelCheckpoint(monitor='jaccard',
                                save_top_k=1,
                                save_weights_only=True,
                                dirpath=cfg.param_dir,
                                filename=f'{fold+1}fold/{fold+1}fold',
                                verbose=False,
                                mode='max')
    
    #LRmonitor
    learning_rate_monitor=LearningRateMonitor(logging_interval='step')

    early_stopping = EarlyStopping(monitor='jaccard',mode='max',patience=6)
    
    #kfoldはgroupで一元管理。
    wandb_logger = WandbLogger(project=cfg.architecture_name,
                            name=f'{fold+1}fold',
                            group=cfg.exp_name,
                            )

    trainer = pl.Trainer(max_epochs=cfg.epoch,
                        checkpoint_callback=True, 
                        gpus=1, 
                        #val_check回数。
                        val_check_interval=0.2,
                        deterministic=True,#https://kutohonn.hatenablog.com/entry/2021/01/04/232434より。SEEDを固定するためにオプション。
                        callbacks=[checkpoint_callback,learning_rate_monitor,early_stopping],
                        #fast_dev_run=True,
                        logger=wandb_logger,
                        #推論関数を適切に実行させるために以下を設定。
                        num_sanity_val_steps=0
                        )
    
    print('training start')
    trainer.fit(model)


    #これを実行しないと、新しいものが立ち上がらない
    wandb.finish()

    #RAM確保
    del model
    gc.collect()
    torch.cuda.empty_cache()
