In [None]:
import pandas as pd
import numpy as np
import random as rn
import seaborn as sns
import matplotlib.pyplot as plt
import copy
from tqdm.notebook import tqdm
import re
import gc
import sys,os

from scipy.stats import spearmanr
from math import floor, ceil
from scipy import stats

# np.set_printoptions(suppress=True)
# pd.set_option('colwidth',50)
# pd.set_option('max_rows',50)

from sklearn.model_selection import GroupKFold,KFold,StratifiedKFold

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset
from torch.utils.data import DataLoader, Dataset,RandomSampler, SequentialSampler

from transformers import AutoTokenizer
from transformers import BertModel, BertTokenizer,BertPreTrainedModel,BertConfig,BertTokenizerFast
from transformers import RobertaTokenizer,RobertaTokenizerFast,RobertaModel,RobertaConfig
from transformers import get_linear_schedule_with_warmup,get_cosine_with_hard_restarts_schedule_with_warmup
import tokenizers
import transformers
print(transformers.__version__)

for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
def seed_everything(seed: int):
    rn.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(8421)

In [None]:
train_df = pd.read_csv('/kaggle/input/tweet-sentiment-extraction/train.csv')
test_df = pd.read_csv('/kaggle/input/tweet-sentiment-extraction/test.csv')
sub_df = pd.read_csv('/kaggle/input/tweet-sentiment-extraction/sample_submission.csv')

train_df.dropna(inplace=True)
print(train_df.loc[8729])
print(train_df.loc[21376])
train_df.drop(8729,inplace=True) # selecttext 不对

# train_df.drop(21376,inplace=True)

# train_df = pd.read_csv('/kaggle/input/selactect-extraction-get-new/train.csv')
train_df.reset_index(drop=True,inplace=True)
train_df = train_df.sample(frac=1).reset_index(drop=True)
train_df = train_df.sample(frac=1).reset_index(drop=True)
# train_df = train_df.sample(frac=1).reset_index(drop=True)
# train_df = train_df.sample(frac=1).reset_index(drop=True)
# train_df = train_df.sample(frac=1).reset_index(drop=True)
test_df['selected_text'] = test_df['text']
train_df.shape,test_df.shape,sub_df.shape

In [None]:
# train_df[train_df['selected_text'].apply(lambda x:len(x.split())==0)]

In [None]:
train_df

In [None]:
train_df.columns,test_df.columns,sub_df.columns

# data clean

In [None]:
def process_data(tweet, selected_text, sentiment, textID, tokenizer, max_len):

    # 处理selected_text 的首字符
#     ori_selected_text = " " + " ".join(str(selected_text).split())
#     x1 = str(selected_text).split()
#     selected_text = " " + " ".join(x1)
    
    ori_selected_text = str(selected_text)
    selected_text = ori_selected_text
    
    ori_tweet = str(tweet)
    tweet = ori_tweet
    # --------
#     ori_tweet = " " + " ".join(str(tweet).split())
#     tweet = ori_tweet

    
    # 清洗http
    def clean_text(text,clean_offset):
        text_list = list(text)
        while re.search('(\shttps?:?//\S+)|(\swww\.\S+)',''.join(text_list)):
            old_s,old_e = re.search('(\shttps?:?//\S+)|(\swww\.\S+)',''.join(text_list)).span()
            new_len = len(' http')
            clean_offset[old_s+new_len-1] = clean_offset[old_e-1]
            for i in range(old_e-1,old_s+new_len-1,-1):
                clean_offset.pop(i)
                text_list.pop(i)
            text_list[old_s:old_s+new_len] = list(' http')
        
        assert len(clean_offset)==len(text_list)
        return ''.join(text_list),clean_offset

    clean_offset = list(range(len(tweet)))
#     tweet, clean_offset = clean_text(tweet,clean_offset)
#     tweet = re.sub("'",'`',tweet)
    
    # 这样的分割在空格上是模糊地带，输出可能带有移位，需要靠另一个函数来矫正输出
    def separate_alphanum(text,offset):
        assert len(text)==len(offset)
        outstr = text[0]
        for i,char in enumerate(text[1:],start=1):
            if text[i-1].isspace() or char.isspace():
                outstr += char
                continue
            if text[i-1].isalpha() and char.isalpha():
                outstr += char
                continue
            if text[i-1].isdigit() and char.isdigit():
                outstr += char
                continue
            if (not text[i-1].isalnum()) and (not char.isalnum()):
                outstr += char
                continue
            outstr += ' '
            outstr += char
        i = len(outstr)-1
        j = len(text)-1
        while i>=0:
            if outstr[i]!=text[j]:
                assert outstr[i]==' '
                offset.insert(j+1,offset[j+1])
                i-=1
            i-=1
            j-=1
        return outstr,offset

#     tweet, clean_offset = separate_alphanum(tweet,
#                                             clean_offset)

#     restra = ''.join([ori_tweet[i] for i in clean_offset])

#     print(ori_tweet)
#     print(tweet)
#     print(ori_selected_text)
#     print(selected_text)
#     print(clean_offset)
    
    # offset 是要相对于原始字符串
    # idx0 是用于原始位置对应的 新字符串的位置
    # 原始的selected_text 和原始text 找标记，然后offset重定位,获取新selected_text
    def get_new_selectext(selectext, text, new_text, offset):
        len_st = len(selectext) - 1
        idx0 = None
        idx1 = None
        for ind in (i for i, e in enumerate(text) if e == selectext[1]):
            if " " + text[ind: ind+len_st] == selectext:
                idx0 = ind
                idx1 = ind + len_st - 1
                break
        for i,v in enumerate(offset):
            if idx0<=v:
                idx0 = i
                break
        for i,v in enumerate(offset):
            if idx1<=v:
                idx1 = i
                break  
        return " " + new_text[idx0:idx1+1]

#     selected_text = get_new_selectext(selected_text,
#                                                 ori_tweet,
#                                                 tweet,
#                                                 clean_offset)

    #然后可以开始矫正selected_text
    def align_selectext(selectext,text):
        t = text.split()
        st = selectext.split()
        out_str = []
        get = False
        for i,vt in enumerate(t):
            if (st[0] in vt) and (len(st)+i-1)<len(t):
                for j,vst in enumerate(st):
                    if vst not in t[i+j]:
                        get = False
                        break
                    get =True
                if get:
                    
                    for j,vst in enumerate(st):
                        if vst!=t[i+j] :
                            if len(vst)>=len(t[i+j])/2 or len(st)<2:
                                out_str.append(t[i+j])
                            else:
                                continue
                        elif vst==t[i+j]:
                            out_str.append(t[i+j])
                    break
        if not get:
            raise
        else:
            return " " + ' '.join(out_str)

#     selected_text = align_selectext(selected_text, 
#                                     tweet)
#     print(repr(selected_text))
#     print(repr(tweet))
    def get_sted(selectext, text):
        len_st = len(selectext) 
        idx0 = None
        idx1 = None
        for ind in (i for i, e in enumerate(text) if e == selectext[0]):
            if text[ind: ind+len_st] == selectext:
                idx0 = ind
                idx1 = ind + len_st-1
                break
        if not idx1:
            raise
        return idx0,idx1
    
    idx0,idx1 = get_sted(selected_text,
                         tweet)
#     print(tweet)
#     print(selected_text)
    # 根据tokenizer的offset，计算sted
    char_targets = [0] * len(tweet)   # 维护一个select标记等长序列
    if idx0 != None and idx1 != None:
        for ct in range(idx0, idx1 + 1):
            char_targets[ct] = 1
    else:
        raise
#     print(char_targets)
    tok_tweet = tokenizer.encode(tweet)
    input_ids_orig = tok_tweet.ids
    tweet_offsets = tok_tweet.offsets
    
#     ss = ' '.join([tokenizer.id_to_token(i) for i in tok_tweet.ids])
#     print(repr(ss))
#     print(repr(tokenizer.decode(tok_tweet.ids)))
    
    target_idx = []
    for j, (offset1, offset2) in enumerate(tweet_offsets):
        if sum(char_targets[offset1: offset2]) > 0:
            target_idx.append(j)
            
    targets_start = target_idx[0]
    targets_end = target_idx[-1]
#     print(tokenizer.id_to_token(tok_tweet.ids[targets_start]))
#     print(tokenizer.id_to_token(tok_tweet.ids[targets_end]))
    
    sentiment_id = {
        'positive': 1313,
        'negative': 2430,
        'neutral': 7974
    }
    
    input_ids = [0] + [sentiment_id[sentiment]] + [2] + [2] + input_ids_orig + [2]
    token_type_ids = [0, 0, 0, 0] + [0] * (len(input_ids_orig) + 1)
    mask = [1] * len(token_type_ids)
    tweet_offsets = [(0, 0)] * 4 + tweet_offsets + [(0, 0)]
    targets_start += 4
    targets_end += 4
 
    padding_length = max_len - len(input_ids)
    if padding_length > 0:
        input_ids = input_ids + ([1] * padding_length)
        mask = mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        tweet_offsets = tweet_offsets + ([(0, 0)] * padding_length)

#     print(tweet[tweet_offsets[targets_start][0]:tweet_offsets[targets_end][1]])
    tweet_offsets = [(clean_offset[s],clean_offset[e-1]+1) if e!=0 else (0,0) for s,e in tweet_offsets]
#     print(tweet_offsets)
#     print(tweet_offsets[targets_start][0],tweet_offsets[targets_end][1])
    if tweet_offsets[targets_start] ==(0,0):
        print('offset error atart')
    if tweet_offsets[targets_end] ==(0,0):
        print('offset error atart')
#     print(tweet_offsets[targets_start])
#     print(repr(ori_tweet[tweet_offsets[targets_start][0]:tweet_offsets[targets_end][1]]))
#     print(repr(ori_selected_text))
    encoded_dict = {'input_ids':input_ids,
                   'token_type_ids':token_type_ids,
                    'attention_mask':mask,
                    'sentiment': sentiment,  # Sentiment_to_Num
                    'offset_mapping':tweet_offsets,
#                     'clean_offset':clean_offset,
                    'textID':textID,
                    'text':ori_tweet,
                    'start_position':targets_start,
                    'end_position':targets_end,
                    'selected_text':ori_selected_text,
                   }
    return encoded_dict

# dataset

In [None]:
class SpanDataset(Dataset):
    def __init__(self, tokenizer, data_df, max_seq_length=256, is_raw=True):
        self.data_df = data_df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.is_raw = is_raw
        
        print('dataset len:',self.__len__())

        
    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, index):
        try:
            if self.is_raw:
                data = process_data(
                    self.data_df.loc[index,'text'], 
                    self.data_df.loc[index,'selected_text'],
                    self.data_df.loc[index,'sentiment'],
                    self.data_df.loc[index,'textID'],
                    self.tokenizer, 
                    self.max_seq_length,
                )
            else:
                data = process_data(
                    self.data_df.loc[index,'text'], 
                    self.data_df.loc[index,'new_selectext'],
                    self.data_df.loc[index,'sentiment'],
                    self.data_df.loc[index,'textID'],
                    self.tokenizer, 
                    self.max_seq_length,
                )
        except:
            print('data error',index)
            raise
            return

        encoded_dict = {'input_ids': torch.tensor(data["input_ids"], dtype=torch.long),
                       'token_type_ids': torch.tensor(data["token_type_ids"], dtype=torch.long),
                        'attention_mask': torch.tensor(data["attention_mask"], dtype=torch.long),

                        'offset_mapping': torch.tensor(data["offset_mapping"], dtype=torch.int),
                        'textID': data["textID"],
                        'text':data['text'],
                        
                        'sentiment':data['sentiment'],
                        'start_position':torch.tensor(data["start_position"], dtype=torch.long),
                        'end_position':torch.tensor(data["end_position"], dtype=torch.long),
                        'selected_text':data['selected_text'],
                       }

        return encoded_dict
# tokenizer = tokenizers.ByteLevelBPETokenizer(
#         vocab_file="/kaggle/input/roberta-base/vocab.json", 
#         merges_file="/kaggle/input/roberta-base/merges.txt", 
#         lowercase=True,
#         add_prefix_space=True
#     )
# test_dataset = SpanDataset(tokenizer,
#                             train_df,
#                             96)
# def get_length(text):
#     tok_tweet = tokenizer.encode(text)
#     input_ids_orig = tok_tweet.ids
#     tweet_offsets = tok_tweet.offsets
#     return len(input_ids_orig)
# train_df['enc_len'] = train_df['text'].apply(get_length)
# train_df['enc_len'].describe()
# for i in range(len(test_dataset)):
#     test_dataset[i]
# 8728\26004
# test_dataset[258]
# test_dataset[5]
# 5696\6112
# test_dataset[5696]
# test_dataset[6112]
# test_dataset[18]
# test_dataset.error_num
# test_dataset[26]
# test_dataset[1900]  # token 分割句子，那末尾字符也是分配在一个token内
# 3621/5188/15205/  # 21374被删除
# test_dataset[8267]
# test_dataset[3754]

# model

In [None]:
class SpanBert(BertPreTrainedModel):  # 重写
    def __init__(self, config, model, PTM_path):
        config.output_hidden_states = True
        super(SpanBert, self).__init__(config)
        
        self.bert = model.from_pretrained(PTM_path, config=config)
        
        self.dropout = nn.Dropout2d(0.1)
#         self.liner_to_num_labels = nn.Linear(config.hidden_size*2, 2)  # start/end
        
        self.liner_to_start = nn.Linear(config.hidden_size*2, 1)
        self.liner_to_end = nn.Linear(config.hidden_size*4, 1)
#         n_weights = config.num_hidden_layers + 1
#         weights_init = torch.zeros(n_weights).float()
#         weights_init[:-1] = -3   # 咋想的
#         self.layer_weights = torch.nn.Parameter(weights_init)
        
#         self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
    ):
        # one for the output of the embeddings + one for the output of each layer
        _,_,out = self.bert(   # batch size, seq_size, hid_dim
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        
        out = torch.cat((out[-1], out[-2]), dim=-1)
        out = self.dropout(out)
        
#         sequence_output = torch.stack(   # torch.Size([1, 148, 1024, 25])
#             [self.dropout(layer) for layer in out_],
#             dim=3
#         )
#         sequence_output = (      # torch.Size([1, 148, 1024])
#             torch.softmax(self.layer_weights, dim=0) * sequence_output
#         ).sum(-1)

#         logits = self.liner_to_num_labels(out)
#         start_logits, end_logits = logits.split(1, dim=-1)
#         start_logits = start_logits.squeeze(-1)
#         end_logits = end_logits.squeeze(-1)

        start_logits = self.liner_to_start(out)
        start_token = torch.gather(out,1,start_logits.argmax(dim=1, keepdim=True).repeat(1,1,out.size(2))) # 在某一轴上自由index
        out2 = torch.cat([out,start_token.repeat(1,out.size(1),1)], dim=2)
        end_logits = self.liner_to_end(out2)
        
        return start_logits.squeeze(-1), end_logits.squeeze(-1)
    
# model = SpanBert(config = Model_Class[args.model_name][0].from_pretrained('/kaggle/input/roberta-base'),
#                  model = Model_Class['roberta'][2],
#                  PTM_path = '/kaggle/input/roberta-base')

# model(input_ids=torch.tensor([258,369,456,156,896,845,812,123]).view(2,4))


# optim

In [None]:
def get_model_optimizer(model,args):
    
    params = list(model.named_parameters())
    no_decay = ["bias","LayerNorm.bias","LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 
         'weight_decay': 0.001},
        {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 
         'weight_decay': 0.0},
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.lr,weight_decay=0)

    return optimizer
# model = CustomBert.from_pretrained(args.bert_model)
# opti = get_model_optimizer(model)
# opti.state_dict()['param_groups'][0]['lr'] #state

In [None]:
def loss_fn(start_preds, end_preds, start_labels, end_labels):
    start_loss = nn.CrossEntropyLoss()(start_preds, start_labels)
    end_loss = nn.CrossEntropyLoss()(end_preds, end_labels)
    return start_loss + end_loss

# Metric

In [None]:
def get_output_string(text, offset, pred_st, pred_ed):

    if pred_st>pred_ed:
        return text
    pred_str = text[offset[pred_st][0] : offset[pred_ed][1]]
    # 不晓得哪里报的错
    if len(pred_str.split())==0:
        return text
    
#     pred_str = xiuzhen_str(pred_str,text)
    return pp(pred_str,text)

def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    score = float(len(c)) / (len(a) + len(b) - len(c))
    return score

In [None]:
import re
def pp(filtered_output, real_tweet):
    filtered_output = ' '.join(filtered_output.split())
    if len(real_tweet.split()) < 2:
        filtered_output = real_tweet
    else:
#         print(filtered_output)
        if len(filtered_output.split()) == 1:
            if filtered_output.endswith(".."):
                if real_tweet.startswith(" "):
                    st = real_tweet.find(filtered_output)
                    fl = real_tweet.find("  ")
                    if fl != -1 and fl < st:
                        filtered_output = re.sub(r'(\.)\1{2,}', '', filtered_output)
                    else:
                        filtered_output = re.sub(r'(\.)\1{2,}', '.', filtered_output)
                else:
                    st = real_tweet.find(filtered_output)
                    fl = real_tweet.find("  ")
                    if fl != -1 and fl < st:
                        filtered_output = re.sub(r'(\.)\1{2,}', '.', filtered_output)
                    else:
                        filtered_output = re.sub(r'(\.)\1{2,}', '..', filtered_output)
                return filtered_output
            if filtered_output.endswith('!!'):
                if real_tweet.startswith(" "):
                    st = real_tweet.find(filtered_output)
                    fl = real_tweet.find("  ")
                    if fl != -1 and fl < st:
                        filtered_output = re.sub(r'(\!)\1{2,}', '', filtered_output)
                    else:
                        filtered_output = re.sub(r'(\!)\1{2,}', '!', filtered_output)
                else:
                    st = real_tweet.find(filtered_output)
                    fl = real_tweet.find("  ")
                    if fl != -1 and fl < st:
                        filtered_output = re.sub(r'(\!)\1{2,}', '!', filtered_output)
                    else:
                        filtered_output = re.sub(r'(\!)\1{2,}', '!!', filtered_output)
                return filtered_output

        if real_tweet.startswith(" "):
            filtered_output = filtered_output.strip()
            text_annotetor = ' '.join(real_tweet.split())
            start = text_annotetor.find(filtered_output)
            end = start + len(filtered_output)
            start -= 0
            end += 2
            flag = real_tweet.find("  ")
            if flag < start:
                filtered_output = real_tweet[start:end]

        if "  " in real_tweet and not real_tweet.startswith(" "):
            filtered_output = filtered_output.strip()
            text_annotetor = re.sub(" {2,}", " ", real_tweet)
            start = text_annotetor.find(filtered_output)
            end = start + len(filtered_output)
            start -= 0
            end += 2
            flag = real_tweet.find("  ")
            if flag < start:
                filtered_output = real_tweet[start:end]
    return filtered_output

In [None]:
def xiuzhen_str(pred,text):
    t = text.split()
    st = pred.split()
    out_str = []
    get = False
    for i,vt in enumerate(t):
        if (st[0] in vt) and (len(st)+i-1)<len(t):
            for j,vst in enumerate(st):
                if vst not in t[i+j]:
                    get = False
                    break
                get =True
            if get:
                for j,vst in enumerate(st):
                    out_str.append(t[i+j])
                break
    if not get:
        raise
    else:
        return ' '.join(out_str)
def get_wrong_str(pred,text,selected_text):
    t = text.split()
    st = pred.split()
    wrong_str = []
    get = False
    for i,vt in enumerate(t):
        if (st[0] in vt) and (len(st)+i-1)<len(t):
            for j,vst in enumerate(st):
                if vst not in t[i+j]:
                    get = False
                    break
                get =True
            if get:
                for j,vst in enumerate(st):
                    if vst in t[i+j] and vst!=t[i+j]:
                        wrong_str.append((repr(vst),repr(t[i+j])))
                break
    if wrong_str :
        print('get_wrong_str...')
        print(repr(text))
        print(repr(pred))
        print(repr(selected_text))
        print(wrong_str)

# eval 

In [None]:
def evaluate(model, data_loader, criterion, args):

    model.eval()

    val_avg_loss = []
    val_acc_score = []
    valid_preds = []
#     valid_preds_df = pd.DataFrame(columns=['text', 'selected_text',
#                                            'pred_text', 'offset',
#                                            'start_logits','end_logits',
#                                            'pred_start', 'pred_end',
#                                            'read_start','read_end'])
    
    with torch.no_grad():
        for idx, batch in enumerate(tqdm(data_loader, desc="Evaluating")):
            
            if args.is_cuda:
                for key,value in batch.items():
                    if isinstance(value,torch.Tensor):
                        batch[key] = batch[key].to(device)
#                         batch[key] = batch[key].cuda()
            input_ids, input_masks, input_segments, ori_start, ori_end= (
                batch["input_ids"],
                batch["attention_mask"],
                batch["token_type_ids"],
                batch["start_position"],
                batch["end_position"],
            )

    
            start_logits, end_logits = model(
                input_ids=input_ids, attention_mask=input_masks, token_type_ids=input_segments,
            )
            
            # loss
            loss = criterion(start_logits, end_logits, ori_start, ori_end)
            val_avg_loss.append(loss.item())
            
            # score
            pred_start = F.softmax(start_logits,dim=1).argmax(dim=-1).cpu().data.numpy()
            pred_end = F.softmax(end_logits,dim=1).argmax(dim=-1).cpu().data.numpy()

            ori_start = ori_start.cpu().data.numpy()
            ori_end = ori_end.cpu().data.numpy()
            offset_mapping = batch['offset_mapping'].cpu().data.numpy()
            
            for exam_idx in range(ori_start.shape[0]):
                pred_str = get_output_string(batch['text'][exam_idx],
                                              offset_mapping[exam_idx],
                                              pred_start[exam_idx],
                                              pred_end[exam_idx])
                score = jaccard(pred_str,batch['selected_text'][exam_idx])
#                 if rn.random()>0.95 and score<0.7:
#                     print('-----------------preds sample--------------------')
#                     print(score)
#                     print('Text:',batch['text'][exam_idx])
#                     print('Preds:',pred_str)
#                     print('Corec:',batch['selected_text'][exam_idx])
                    
                off_set = [batch['text'][exam_idx][offset[0] : offset[1]] for offset in offset_mapping[exam_idx]]
                start_logit = [max(0,round(i,2)) for i in start_logits[exam_idx].cpu().data.numpy()] 
                c = [' /'.join([str(string),str(score)]) for score,string in zip(start_logit,off_set)]
                read_start = '   '.join(c)
                end_logit = [max(0,round(i,2)) for i in end_logits[exam_idx].cpu().data.numpy()]
                c = [' /'.join([str(string),str(score)]) for score,string in zip(end_logit,off_set)]
                read_end = '   '.join(c)
                
                valid_preds.append({'text':batch['text'][exam_idx],
                                    'selected_text':batch['selected_text'][exam_idx],
                                    'pred_text':pred_str,
                                    'offset':offset_mapping[exam_idx].tolist(),
                                    'start_logits':start_logit,
                                    'end_logits':end_logit,
                                    'pred_start':pred_start[exam_idx], 
                                    'pred_end':pred_end[exam_idx],
                                    'read_start':read_start,
                                    'read_end':read_end,
                                   })

#                 valid_preds_df.loc[valid_preds_df.shape[0]] = [batch['text'][exam_idx],
#                                                                batch['selected_text'][exam_idx], 
#                                                                pred_str, 
#                                                                offset_mapping[exam_idx].tolist(), 
#                                                                start_logit,
#                                                                end_logit,
#                                                                pred_start[exam_idx], 
#                                                                pred_end[exam_idx],
#                                                                read_start,
#                                                                read_end,]
                val_acc_score.append(score)

    val_avg_loss = round(sum(val_avg_loss)/len(val_avg_loss),4)
    val_acc_score = round(sum(val_acc_score)/len(val_acc_score),4)
    valid_preds_df = pd.DataFrame(valid_preds)
    if args.is_cuda:
        torch.cuda.empty_cache()
    gc.collect()
    return val_avg_loss, val_acc_score, valid_preds_df

# infer

In [None]:
def infer(model, data_loader, args):
    model.eval()

    test_preds = []
    
    test_preds_df = pd.DataFrame(columns=['textID','selected_text'])
    
    with torch.no_grad():
        for idx, batch in enumerate(tqdm(data_loader, desc="Infering")):
            
            if args.is_cuda:
                for key,value in batch.items():
                    if isinstance(value,torch.Tensor):
                        batch[key] = batch[key].to(device)

            input_ids, input_masks, input_segments = (
                batch["input_ids"],
                batch["attention_mask"],
                batch["token_type_ids"],
            )
    
            start_logits, end_logits = model(
                input_ids=input_ids, attention_mask=input_masks, token_type_ids=input_segments,
            )
        
            pred_start = F.softmax(start_logits,dim=1).argmax(dim=-1).cpu().data.numpy()
            pred_end = F.softmax(end_logits,dim=1).argmax(dim=-1).cpu().data.numpy()
            
            offset_mapping = batch['offset_mapping'].cpu().data.numpy()
            
            for exam_idx in range(pred_start.shape[0]):
                pred_str = get_output_string(batch['text'][exam_idx],
                                              offset_mapping[exam_idx],
                                              pred_start[exam_idx],
                                              pred_end[exam_idx])
                
                test_preds_df.loc[test_preds_df.shape[0]] = [batch['textID'][exam_idx], pred_str]
                
    if args.is_cuda:
        torch.cuda.empty_cache()
    gc.collect()
    
    return test_preds_df

# cv

In [None]:
def cross_validation_split(train_df, args):
    
    gkf = StratifiedKFold(n_splits=args.n_splits).split(X=train_df, y=train_df.sentiment)
             
    for fold, (train_index, val_index) in enumerate(gkf):
        print('fold: ',fold)
        print(train_index[:5])
        print(val_index[:5])
        
        train_dataset = SpanDataset(args.TOKENIZER,
                                    train_df.iloc[train_index],
                                    args.max_seq_length, 
                                    is_raw=True, )
        train_loader = DataLoader(train_dataset, 
                                  shuffle=True,
                                  batch_size=args.batch, 
                                  num_workers=2,)   
        valid_dataset = SpanDataset(args.TOKENIZER,
                                    train_df.iloc[val_index],
                                    args.max_seq_length, 
                                    is_raw=True, )
        valid_loader = DataLoader(valid_dataset, 
                                  shuffle=False,
                                  batch_size=args.batch, 
                                  num_workers=2,)
        
        yield fold, train_loader, valid_loader, train_index, val_index 

# train 

In [None]:
def train_loop(model, data_loader, optimizer, criterion, scheduler, iteration, args):
    
    model.train()

    avg_loss = []
    acc_score = []

    optimizer.zero_grad()
    
    for idx, batch in enumerate(tqdm(data_loader, desc="Training")):
        if args.is_cuda:
            for key,value in batch.items():
                if isinstance(value,torch.Tensor):
                    batch[key] = batch[key].to(device)

        input_ids, input_masks, input_segments, ori_start, ori_end= (
            batch["input_ids"],
            batch["attention_mask"],
            batch["token_type_ids"],
            batch["start_position"],
            batch["end_position"],
        )

        start_logits, end_logits = model(
            input_ids=input_ids, attention_mask=input_masks, token_type_ids=input_segments,
        )
#         if idx==9:
#             print([round(i,2) for i in start_logits[0].cpu().data.numpy()])
#             print([round(i,2) for i in end_logits[0].cpu().data.numpy()])
        # loss
        loss = criterion(start_logits, end_logits, ori_start, ori_end)
        loss.backward()
        
        avg_loss.append(loss.item())
        
        # optim
        if (iteration + 1) % args.batch_accumulation == 0:  # 延迟更新参数，增加batch_size
            optimizer.step()
            if scheduler is not None:
#                 if iteration==6:print('schedule step')
                scheduler.step()
            optimizer.zero_grad()
        iteration += 1
    
        # score
#         pred_start = F.softmax(start_logits,dim=1).argmax(dim=-1).cpu().data.numpy()
#         pred_end = F.softmax(end_logits,dim=1).argmax(dim=-1).cpu().data.numpy()

#         ori_start = ori_start.cpu().data.numpy()
#         ori_end = ori_end.cpu().data.numpy()
#         offset_mapping = batch['offset_mapping'].cpu().data.numpy()

#         for exam_idx in range(ori_start.shape[0]):
#             pred_str = get_output_string(batch['text'][exam_idx],
#                                           offset_mapping[exam_idx],
#                                           pred_start[exam_idx],
#                                           pred_end[exam_idx])
#             acc_score.append(
#                 jaccard(pred_str,batch['selected_text'][exam_idx]))
    
    avg_loss = round(sum(avg_loss)/len(avg_loss),4)
#     acc_score = round(sum(acc_score)/len(acc_score),4)
    acc_score = 0.0
    # 清理
    if args.is_cuda:
        torch.cuda.empty_cache()
    gc.collect()

    return avg_loss, acc_score, iteration

# main

In [None]:

def main(args,train_df,test_df):
    
    seed_everything(8421)
    print('model path ... ',args.bert_model)
    chart_df = pd.DataFrame(columns=['fold', 'epoch', 'avg_loss', 'acc_score','val_avg_loss', 'val_acc_score'])

    for fold, train_loader, valid_loader, train_index, val_index in cross_validation_split(train_df, args):
        if fold not in args.fold:
            continue

        model = SpanBert(config = Model_Class[args.model_name][0].from_pretrained(args.bert_model),
                         model = Model_Class['roberta'][2],
                         PTM_path = args.bert_model)
        
        if torch.cuda.device_count() > 1: 
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            model = nn.DataParallel(model)
        if args.is_cuda:
            model.to(device)
        
        
        optimizer = get_model_optimizer(model,args)
        criterion = loss_fn
        
        fold_checkpoints = os.path.join(args.checkpoints_path, "model_{}_{}_{}_{}".format(args.model_name,fold,args.lr,args.batch))
        fold_predictions = os.path.join(args.predictions_path, "model_{}_{}_{}_{}".format(args.model_name,fold,args.lr,args.batch))
        os.makedirs(fold_checkpoints, exist_ok=True)
        os.makedirs(fold_predictions, exist_ok=True)

        iteration=0
        best_score = 0.0
        results = []
        
        if not args.is_scheduler:
            scheduler=None
            print('not schedule')
        else:
            print('has schedule')
            sche_step = args.epochs * ceil(len(train_index)/args.batch) / args.batch_accumulation
            warmup_steps = sche_step//3
            print('schedule step:',sche_step)
            print('schedule warm up :',warmup_steps)
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=sche_step,
            )
        
        for epoch in range(args.epochs):
            print(np.around(model.liner_to_start.weight.var(dim=1).cpu().data.numpy(),6))
            print(np.around(model.liner_to_end.weight.var(dim=1).cpu().data.numpy(),6))
            
            avg_loss, acc_score, iteration = train_loop(
                model, train_loader, optimizer, criterion, scheduler, iteration, args)

            val_avg_loss, val_acc_score, valid_preds_df = evaluate(
                model, valid_loader, criterion, args)

            print("Epoch {}/{}:  loss={:.3f} score={:.3f} val_loss={:.3f} val_score={:.3f} ".format(
                    epoch + 1, args.epochs, avg_loss, acc_score, val_avg_loss, val_acc_score))  
      
            chart_df.loc[chart_df.shape[0]] = [fold, epoch, avg_loss, acc_score, val_avg_loss, val_acc_score]
            
            if val_acc_score > best_score and args.is_save:
                best_score = val_acc_score
                torch.save( model.state_dict(), os.path.join(fold_checkpoints, "best_model.pth"))
                valid_preds_df.to_csv(os.path.join(fold_predictions, "best_preds.csv"),index=False)
                
    del model, optimizer, criterion, scheduler
    del valid_loader, train_loader, #test_loader
    if args.is_cuda:
        torch.cuda.empty_cache()
    gc.collect()
    return chart_df

# 配置

In [None]:
class args:
#     multi_task_balance = 0.5
    fold=[0,1,2,3,4,] # 
    lr=3e-5
    epochs=6
     
    model_name = 'roberta'
    batch = 72
    batch_accumulation = 1
    
    max_seq_length = 112
#     warmup_steps = 30
    n_splits = 5
#     bert_model='/kaggle/input/robertalargehugging-face'
    bert_model='/kaggle/input/roberta-base'
#     bert_model = '/kaggle/input/bert-base-uncased'
    is_cuda=torch.cuda.is_available()
    predictions_path="prediction_dir"
    checkpoints_path="model_dir"
    
    is_scheduler = True
    is_finetune_code = False
    is_save=True
    TOKENIZER = tokenizers.ByteLevelBPETokenizer(
        vocab_file=f"{bert_model}/vocab.json", 
        merges_file=f"{bert_model}/merges.txt", 
        lowercase=True,
        add_prefix_space=True
    )
# argsa()   
# args.lr
device = torch.device("cuda")
Model_Class = {'bert':[BertConfig, BertTokenizerFast, BertModel],
              'roberta':[RobertaConfig, RobertaTokenizerFast, RobertaModel]}


# CheckPoint = [
#                 '/kaggle/input/sentiment-extraction-roberta-0/model_dir/model_roberta_7/best_model.pth',
#                 '/kaggle/input/sentiment-extraction-roberta-0/model_dir/model_roberta_4/best_model.pth',
#                 '/kaggle/input/sentiment-extraction-roberta-0/model_dir/model_roberta_1/best_model.pth',
#              ]

# batch 96

In [None]:
%%time
if args.is_cuda:
    torch.cuda.empty_cache()
gc.collect()

if args.is_finetune_code:
    chart_df = main(args,train_df.iloc[:3000],test_df)
else:
    chart_df= main(args,train_df,test_df)

# 结果图

In [None]:
sns.set_style("whitegrid")
fig, axes = plt.subplots(ncols=2,figsize=(16, 8))
sns.lineplot(data=chart_df[['avg_loss', 'val_avg_loss']], hue =chart_df['fold'], ax=axes[0]);
sns.lineplot(data=chart_df[['acc_score', 'val_acc_score']], hue =chart_df['fold'], ax=axes[1]);
chart_df.to_csv('verbose_chart.csv',index=False)
chart_df

In [None]:
chart_df.groupby('fold')['avg_loss'].min(),\
chart_df.groupby('fold')['avg_loss'].min().mean(),\
chart_df.groupby('fold')['val_avg_loss'].min(),\
chart_df.groupby('fold')['val_avg_loss'].min().mean(),\
chart_df.groupby('fold')['val_acc_score'].max(),\
chart_df.groupby('fold')['val_acc_score'].max().mean()