In [1]:
import pandas as pd
import numpy as np
import collections
import re
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score, fbeta_score, precision_score, recall_score, roc_auc_score
import warnings
import torch.nn as nn
from tqdm import tqdm_notebook
from tqdm import tqdm
import random
import gensim
import os
from torch.utils import data
from torch import nn
import torch.nn.functional as F
from torch.optim import *
torch.set_printoptions(edgeitems=768)
warnings.filterwarnings("ignore")
np.set_printoptions(threshold=np.inf)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 设置基本参数
MAX_LEN = 50
BATCH_SIZE = 512
SEED=888
NAME = 'ESIM'
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE=='cuda':
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
DEVICE



device(type='cuda')

In [2]:
train_data = pd.read_csv('../data/gaiic_track3_round1_train_20210228.tsv', sep='\t', header=None)
train_data.rename(columns={0:'sen1', 1:'sen2', 2:'labels'}, inplace=True)
train_data['labels'] = train_data['labels'].astype(int)
train_data = train_data.dropna().reset_index(drop=True)
test_data = pd.read_csv('../data/gaiic_track3_round1_testA_20210228.tsv', sep='\t', header=None)
test_data.rename(columns={0:'sen1', 1:'sen2'}, inplace=True)
train_data

Unnamed: 0,sen1,sen2,labels
0,1 2 3 4 5 6 7,8 9 10 4 11,0
1,12 13 14 15,12 15 11 16,0
2,17 18 12 19 20 21 22 23 24,12 23 25 6 26 27 19,1
3,28 29 30 31 11,32 33 34 30 31,1
4,29 35 36 29,29 37 36 29,1
...,...,...,...
99995,12 19 1162 126 53 66,12 19 79 389 126 53 66,1
99996,275 552 553 433 881 338 1104 101 202 2343 14825,995 551 550 1660 2830 1075 662 935,0
99997,421 330 62 12 80 81 82 76,202 62 12 80 838 76,1
99998,177 455 456 3474 964 1364 55 1364,133 134 2246,1


In [4]:
train_dataset = []
for i in tqdm_notebook(range(len(train_data))):
    train_dict = {}
    train_dict['sen1'] = train_data.loc[i, 'sen1']
    train_dict['sen2'] = train_data.loc[i, 'sen2']
    train_dict['labels'] = train_data.loc[i, 'labels']
    train_dataset.append(train_dict)
test_dataset = []
for i in tqdm_notebook(range(len(test_data))):
    test_dict = {}
    test_dict['sen1'] = test_data.loc[i, 'sen1']
    test_dict['sen2'] = test_data.loc[i, 'sen2']
    test_dict['labels'] = None
    test_dataset.append(test_dict)

HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=25000), HTML(value='')))




In [5]:
s = set()
class DataSet(data.Dataset):
    def __init__(self, data, mode='train'):
        self.data = data
        self.mode = mode
        self.dataset = self.get_data(self.data,self.mode)
        
    def get_data(self, data, mode):
        dataset = []
        global s
        for data_li in tqdm_notebook(data):
            sen1 = data_li['sen1'].split(' ')
            sen1 = list(map(int, sen1))
            sen2 = data_li['sen2'].split(' ')
            sen2 = list(map(int, sen2))
            for i in sen1:
                s.add(i)
            for i in sen2:
                s.add(i)
            sen1_length = len(sen1)
            sen2_length = len(sen2)
            if len(sen1) < MAX_LEN:
                sen1 += [0] * (MAX_LEN - sen1_length)
            else:
                sen1 = sen1[:MAX_LEN]
            if len(sen2) < MAX_LEN:
                sen2 += [0] * (MAX_LEN - sen2_length)
            else:
                sen2 = sen2[:MAX_LEN]
            if mode!='test':
                labels = data_li['labels']
            else:
                labels = None          
            dataset_dict = {'sen1':sen1, 'sen2':sen2, 'sen1_length':sen1_length, 'sen2_length':sen2_length, 'labels':labels}
            dataset.append(dataset_dict)
        return dataset
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        data = self.dataset[idx]
        sen1 = torch.tensor(data['sen1'])
        sen2 = torch.tensor(data['sen2'])
        sen1_length = torch.tensor(data['sen1_length'])
        sen2_length = torch.tensor(data['sen2_length'])
        if self.mode == 'test':
            return sen1, sen2, sen1_length, sen2_length
        else:
            labels = torch.tensor(data['labels'])
            return sen1, sen2, sen1_length, sen2_length, labels

def get_dataloader(dataset, mode):
    torchdata = DataSet(dataset, mode=mode)
    if mode == 'train':
        dataloader = torch.utils.data.DataLoader(torchdata, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, drop_last=True)
    elif mode == 'test':
        dataloader = torch.utils.data.DataLoader(torchdata, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, drop_last=False)
    elif mode == 'valid':
        dataloader = torch.utils.data.DataLoader(torchdata, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, drop_last=False)
    return dataloader, torchdata

train_dataloader, train_torchdata = get_dataloader(train_dataset, mode='train')
test_dataloader, test_torchdata = get_dataloader(test_dataset, mode='test')
print(len(s))

HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=25000), HTML(value='')))


20600


In [7]:
class ESIM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, linear_size, embeddings=None, dropout=0.5):
        super(ESIM, self).__init__()
        self.dropout = dropout
        self.hidden_size = hidden_size
        self.embed_dim = embed_dim
        self.embeds = nn.Embedding(vocab_size, self.embed_dim,padding_idx=0)
        if embeddings:
            self.embeds.weight.data.copy_(embeddings)
            self.embeds.weight.requires_grad = False
        self.embeds.weight.requires_grad = True
        self.lstm1 = nn.LSTM(self.embed_dim, self.hidden_size, batch_first=True, bidirectional=True)
        self.lstm2 = nn.LSTM(self.hidden_size*8, self.hidden_size, batch_first=True, bidirectional=True)
        self.drop = nn.Dropout(dropout)
        
        self.fc = nn.Sequential(
            nn.Linear(self.hidden_size * 8, linear_size),
            nn.ELU(inplace=True),
            nn.Dropout(self.dropout),
            nn.Linear(linear_size, linear_size),
            nn.ELU(inplace=True),
            nn.Dropout(self.dropout),
            nn.Linear(linear_size, 1),
            nn.Sigmoid()
        )
    
    def soft_attention_align(self, x1, x2, mask1, mask2):
        '''
        x1: batch_size * seq_len * dim
        x2: batch_size * seq_len * dim
        '''
        attention = torch.matmul(x1, x2.transpose(1, 2))
        mask1 = mask1.float().masked_fill_(mask1, float('-inf'))
        mask2 = mask2.float().masked_fill_(mask2, float('-inf'))

        weight1 = F.softmax(attention + mask2.unsqueeze(1), dim=-1)
        x1_align = torch.matmul(weight1, x2)
        weight2 = F.softmax(attention.transpose(1, 2) + mask1.unsqueeze(1), dim=-1)
        x2_align = torch.matmul(weight2, x1)

        return x1_align, x2_align

    def submul(self, x1, x2):
        mul = x1 * x2
        sub = x1 - x2
        return torch.cat([sub, mul], -1)

    def apply_multiple(self, x):
        p1 = F.avg_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
        p2 = F.max_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
        return torch.cat([p1, p2], 1)

    def forward(self, sent1, sent2, labels_ids=None):
        mask1, mask2 = sent1.eq(0), sent2.eq(0)
        x1 = self.embeds(sent1)
        x2 = self.embeds(sent2)
        o1, _ = self.lstm1(x1)
        o2, _ = self.lstm1(x2)

        q1_align, q2_align = self.soft_attention_align(o1, o2, mask1, mask2)

        q1_combined = torch.cat([o1, q1_align, self.submul(o1, q1_align)], -1)
        q2_combined = torch.cat([o2, q2_align, self.submul(o2, q2_align)], -1)

        q1_compose, _ = self.lstm2(q1_combined)
        q2_compose, _ = self.lstm2(q2_combined)

        q1_rep = self.apply_multiple(q1_compose)
        q2_rep = self.apply_multiple(q2_compose)

        x = torch.cat([q1_rep, q2_rep], -1)
        similarity = self.fc(x)
        if labels_ids is not None:
            loss_fct = nn.BCELoss()
            loss = loss_fct(similarity.view(-1,1).float(), labels_ids.view(-1,1).float())
            return loss
        else:
            return similarity

In [8]:
valid_result = []
def validation_funtion(model, valid_dataloader, valid_torchdata, mode):
    model.eval()
    results = []
    true_label = []
    if valid_torchdata.dataset[0]['labels'] != None:
        for i, (sen1, sen2, sen1_length, sen2_length, label_ids) in enumerate(tqdm_notebook(valid_dataloader)):
            output = model(sen1.to(DEVICE), sen2.to(DEVICE))
            results += list(output.detach().cpu()) 
            true_label += list(label_ids)
    else:
        for i, (sen1, sen2, sen1_length, sen2_length) in enumerate(tqdm_notebook(valid_dataloader)):
            output = model(sen1.to(DEVICE), sen2.to(DEVICE))
            results += list(output.detach().cpu()) 
    if mode == 'valid':
        auc = roc_auc_score(true_label,results)
        acc = precision_score(true_label,[1 if i >= 0.5 else 0 for i in results])
        recall = recall_score(true_label, [1 if i >= 0.5 else 0 for i in results])
        f1 = f1_score(true_label, [1 if i >= 0.5 else 0 for i in results])
        return auc, acc, recall, f1
    else:
        return results
                            
def train(model, train_dataloader, valid_dataloader, valid_torchdata, epochs, early_stop=None):
    global logger
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.8},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
    optimizer = BertAdam(optimizer_grouped_parameters, lr=2e-5)
    optimizer = AdamW(model.parameters(), lr=1e-3, amsgrad=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=1e-5, last_epoch=-1)
    total_loss = []
    train_loss = []
    best_score = -np.inf
    no_improve = 0
    for epoch in range(epochs):
        model.train()
        bar = tqdm_notebook(train_dataloader)
        for i, (sen1, sen2, sen1_length, sen2_length, label_ids) in enumerate(bar):
            optimizer.zero_grad()
            output = model(sen1.to(DEVICE), sen2.to(DEVICE), labels_ids=label_ids.to(DEVICE))
            loss = output
            loss.backward()
            train_loss.append(loss.item())
            scheduler.step(epoch + i / len(train_dataloader))
            optimizer.step()
            bar.set_postfix(tloss=np.array(train_loss).mean())
        auc, accuracy, recall, f1 = validation_funtion(model, valid_dataloader, valid_torchdata, 'valid')
        print('train_loss: {:.5f}, auc: {:.5f}, accuracy: {:.5f}, recall: {:.5f}, f1_socre: {:.5f}\n'.format(train_loss[-1],auc,accuracy,recall,f1))
        logger.info('Epoch:[{}]\t auc={:.5f}\t accuracy={:.5f}\t recall={:.5f}\t f1_socre={:.5f}'.format(epoch,auc,accuracy,recall,f1))
        global model_num
        if early_stop:
            if auc > best_score:
                best_score = auc
                torch.save(model.state_dict(), '{}_model_{}.bin'.format(NAME, model_num))
            else:
                no_improve += 1
            if no_improve == early_stop:
                model_num += 1
                break
            if epoch == epochs-1:
                model_num += 1
        else:
            if epoch >= epochs-1:
                torch.save(model.state_dict(), '{}_model_{}.bin'.format(NAME, model_num))
                model_num += 1

In [9]:
import logging
def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])
    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    logger.removeHandler(sh)
    return logger

In [10]:
FOLD = 5
kf = KFold(n_splits=FOLD, shuffle=True, random_state=SEED)
model_num = 1
test_preds_total = collections.defaultdict(list)
logger = get_logger('{}.log'.format(NAME))
for i, (train_index, test_index) in enumerate(kf.split(train_dataset)):
    print(str(i+1), '-'*50)
    tra = [train_dataset[index] for index in train_index]
    val = [train_dataset[index] for index in test_index]
    train_dataloader, train_torchdata = get_dataloader(tra, mode='train')
    valid_dataloader, valid_torchdata = get_dataloader(val, mode='valid')
    model = ESIM(22000,768,512,128)
    model.to(DEVICE)
    losses = train(model,train_dataloader,
                    valid_dataloader,
                    valid_torchdata,
                    epochs=15,
                    early_stop=2)
    torch.cuda.empty_cache()

1 --------------------------------------------------


HBox(children=(IntProgress(value=0, max=80000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.38466, auc: 0.90553, accuracy: 0.75261, recall: 0.79551, f1_socre: 0.77346



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.28244, auc: 0.92115, accuracy: 0.79198, recall: 0.79427, f1_socre: 0.79312



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.13995, auc: 0.92253, accuracy: 0.79347, recall: 0.80680, f1_socre: 0.80008



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.23973, auc: 0.91840, accuracy: 0.79588, recall: 0.78174, f1_socre: 0.78875



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.12257, auc: 0.91948, accuracy: 0.80838, recall: 0.76219, f1_socre: 0.78461

2 --------------------------------------------------


HBox(children=(IntProgress(value=0, max=80000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.35931, auc: 0.90616, accuracy: 0.78637, recall: 0.74759, f1_socre: 0.76649



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.27961, auc: 0.92194, accuracy: 0.80269, recall: 0.78753, f1_socre: 0.79504



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.12498, auc: 0.92327, accuracy: 0.79769, recall: 0.80723, f1_socre: 0.80243



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.24431, auc: 0.92019, accuracy: 0.82953, recall: 0.72123, f1_socre: 0.77160



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.12683, auc: 0.91903, accuracy: 0.78017, recall: 0.80804, f1_socre: 0.79386

3 --------------------------------------------------


HBox(children=(IntProgress(value=0, max=80000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.36308, auc: 0.90546, accuracy: 0.77343, recall: 0.75246, f1_socre: 0.76280



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.28561, auc: 0.92091, accuracy: 0.77846, recall: 0.80808, f1_socre: 0.79299



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.12578, auc: 0.92205, accuracy: 0.78935, recall: 0.80006, f1_socre: 0.79467



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.20920, auc: 0.91947, accuracy: 0.80691, recall: 0.75578, f1_socre: 0.78051



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.11804, auc: 0.91711, accuracy: 0.77896, recall: 0.80850, f1_socre: 0.79345

4 --------------------------------------------------


HBox(children=(IntProgress(value=0, max=80000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.37501, auc: 0.90223, accuracy: 0.77685, recall: 0.72958, f1_socre: 0.75247



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.26963, auc: 0.91949, accuracy: 0.78002, recall: 0.80036, f1_socre: 0.79006



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.12338, auc: 0.92018, accuracy: 0.78614, recall: 0.79829, f1_socre: 0.79217



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.25039, auc: 0.91475, accuracy: 0.77131, recall: 0.79760, f1_socre: 0.78424



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.12119, auc: 0.91513, accuracy: 0.76851, recall: 0.80615, f1_socre: 0.78688

5 --------------------------------------------------


HBox(children=(IntProgress(value=0, max=80000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.35384, auc: 0.90647, accuracy: 0.73180, recall: 0.84110, f1_socre: 0.78265



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.22171, auc: 0.92234, accuracy: 0.79810, recall: 0.80049, f1_socre: 0.79929



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.15141, auc: 0.92264, accuracy: 0.80341, recall: 0.79531, f1_socre: 0.79934



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.22912, auc: 0.91661, accuracy: 0.80119, recall: 0.76887, f1_socre: 0.78470



HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


train_loss: 0.14994, auc: 0.91605, accuracy: 0.79617, recall: 0.77664, f1_socre: 0.78629



In [12]:
# model_num = 6
model = ESIM(22000,768,512,128)
model.to(DEVICE)
test_preds_total = []
test_dataloader, test_torchdata = get_dataloader(test_dataset, mode='test')
for i in range(1,model_num):
    model.load_state_dict(torch.load('{}_model_{}.bin'.format(NAME, i)))
    test_pred_results = validation_funtion(model, test_dataloader, test_torchdata, 'test')
    test_preds_total.append(test_pred_results)
test_preds_merge = np.sum(test_preds_total, axis=0) / (model_num-1)

HBox(children=(IntProgress(value=0, max=25000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=49), HTML(value='')))




HBox(children=(IntProgress(value=0, max=49), HTML(value='')))




HBox(children=(IntProgress(value=0, max=49), HTML(value='')))




HBox(children=(IntProgress(value=0, max=49), HTML(value='')))




HBox(children=(IntProgress(value=0, max=49), HTML(value='')))




In [13]:
import os
f = open('submit.txt','w')
for x in test_preds_merge:
    f.write(str(x)+'\n')
f.close()