In [None]:
import warnings 
warnings.simplefilter('ignore')

import os
import gc
import json
import copy
import random

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn.functional as F

%matplotlib inline

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from transformers import  BertTokenizer, WEIGHTS_NAME
from model.modeling_nezha import NeZhaForSequenceClassification, NeZhaModel
from model.configuration_nezha import NeZhaConfig

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup, AutoConfig
import transformers

In [None]:
train = pd.read_csv('./data/train.csv')
print(train.shape)
train.head()

In [None]:
class Config:
    def __init__(self):
        super(Config, self).__init__()
        
        self.SEED = SEED
        self.tokenizer_path = 'bert_model/nezha-cn-base'
        self.MODEL_PATH = './pretrain/pretrained_0.15_dual/checkpoint-20000'
        self.NUM_LABELS = NUM_CLASSES
        
        # data
        self.tokenizer = BertTokenizer.from_pretrained(self.tokenizer_path) #加载分词模型
        self.max_length = 256 #句子最大长度
        self.batch_size = 32
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.full_finetuning = True
        self.lr = 3e-5
        self.optimizer = 'AdamW'
        self.n_warmup = 0
        self.save_best_only = True
        
        self.multi_gpu = False
        self.attack = 'pgd'

        #self.ema = True
        self.flooding = False
        self.loss_func = 'ce' # ce dice fl
        self.epochs = 20
config = Config()

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


np.random.seed(config.SEED)
seed_everything(seed=config.SEED)

In [None]:
train_df = train[['description', 'diagnosis']]
print(train_df.shape)
train_df.head()

In [None]:
class GAIIC_Dataset(Dataset):
    def __init__(self, data_file, input_len, output_len, sos_id=1, eos_id=2, pad_id=0):
        super(GAIIC_Dataset, self).__init__()
        
        with open(data_file, 'r') as fp:
            reader = csv.reader(fp)
            self.samples = [row for row in reader]
            
        self.input_len   = input_len
        self.output_len  = output_len
        self.sos_id      = sos_id
        self.eos_id      = eos_id
        self.pad_id      = pad_id
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        description = [int(x) for x in self.samples[index][1].split()]
        
        if len(description) < self.input_len:
            description.extend([self.pad_id] * (self.input_len - len(description)))
        
        # for test
        if self.samples[index] < 3:
            return np.array(description)[:self.input_len]
        
        target = [self.sos_id] + [int(x) for x in self.samples[index][2].split()] + [self.eos_id]
        if len(target) < self.output_len:
            target.extend([self.pad_id] * (self.output_len - len(target)))
            
        return  np.array(description)[:self.input_len], np.array(target)[:self.output_len]

In [None]:
class Nezha_pool_last2emb(nn.Module):
    def __init__(self, NeZhaConfig):
        super(BERTClassifier, self).__init__()
        NeZhaConfig.output_hidden_states=True
        self.bert = NeZhaModel.from_pretrained(config.MODEL_PATH, config=NeZhaConfig)
        self.drop = nn.Dropout(p=0.1)
        self.dense1 = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size*3, 128),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Dropout(p=0.2),
            nn.Linear(128, NUM_CLASSES)
        )
        self.dense = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size*3, NUM_CLASSES)
        )
    def forward(self, input_ids, attention_mask):
        
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        hidden_output = outputs[2]
        classification_input = torch.cat((pooled_output, hidden_output[-1][:, 0], hidden_output[-2][:, 0]), 1)
        output = self.drop(classification_input)
        logits = self.dense(output)
        
        return logits

In [None]:
def val_fn(model, valid_dataloader, criterion):
    val_loss = 0

    model.eval()
    for step, (input_ids, target_ids) in enumerate(tqdm(valid_dataloader)):
        
        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)
        
        pred_ids = model(input_ids=input_ids, taregt_ids=taregt_ids)
        
        loss = criterion(pred_ids, target_ids)
        
        train_loss += loss.item()
    model.train()
    avg_val_loss = val_loss / len(valid_dataloader)

    return avg_val_loss

In [None]:
def train_fn(model, train_dataloader, criterion, optimizer, scheduler=None, epoch=0):
    
    train_loss = 0
    model.train()
    for step, (input_ids, target_ids) in enumerate(tqdm(train_dataloader, desc='Epoch ' + str(epoch))):
        
        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)
        
        optimizer.zero_grad()
        
        pred_ids = model(input_ids=input_ids, taregt_ids=taregt_ids)
        
        loss = criterion(pred_ids, target_ids)
        
        train_loss += loss.item()
        
        loss.backward()
        
        optimizer.step()
        if scheduler:
            scheduler.step()
            
    avg_train_loss = train_loss / len(train_dataloader)
    
    print('Training loss:', avg_train_loss)

In [None]:
import time
def run(config):
    t = time.time()
    n_splits = 5
    with open(f'folds.json') as f:
        kfolds = json.load(f)
    
    torch.manual_seed(config.SEED)
    
    for FOLD in range(n_splits):
        if config.loss_func == 'dice':
            criterion = SelfAdjDiceLoss()
        elif config.loss_func == 'ce':
            criterion = nn.CrossEntropyLoss()
        elif  config.loss_func == 'fl':
            criterion = FocalLoss(num_class=25)
        model = Nezha_pool_last3emb(NeZhaConfig)
        
        model.to(device)
        
        if config.multi_gpu:
            model = torch.nn.DataParallel(model)
            
        if config.attack == 'fgm':
            fgm = FGM(model, epsilon=1, emb_name='word_embeddings.')
        elif config.attack == 'pgd':
            pgd = PGD(model)
        elif config.attack == 'freelb':
            freelb = FreeLB(model)
        
        train_indices = kfolds[f'fold_{FOLD}']['train']
        valid_indices = kfolds[f'fold_{FOLD}']['valid']

        train_data = TransformerDataset(train_df, train_indices)
        valid_data = TransformerDataset(train_df, valid_indices)

        train_dataloader = DataLoader(train_data, batch_size=config.batch_size, num_workers=4, shuffle=True)
        valid_dataloader = DataLoader(valid_data, batch_size=config.batch_size, num_workers=4, shuffle=False)
        
        if config.full_finetuning:
            param_optimizer = list(model.named_parameters())
            no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

            optimizer_parameters = [
                {
                    "params": [
                        p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
                    ],
                    "weight_decay": 0.01,
                },
                {
                    "params": [
                        p for n, p in param_optimizer if any(nd in n for nd in no_decay)
                    ],
                    "weight_decay": 0.0,
                },
            ]
            optimizer = optim.AdamW(optimizer_parameters, lr=config.lr,  weight_decay=0.01, eps=1e-8)

        num_training_steps = len(train_dataloader) * config.epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=num_training_steps
        )

        min_avg_val_loss = float('inf')
        best_result_epoch = 0
        best_val_loss = 0
        best_acc = 0
        best_acc_epoch = 0
        for epoch in range(config.epochs):
            train_loss = 0
            corrects = 0
            for step, batch in enumerate(tqdm(train_dataloader, desc='Epoch ' + str(epoch))):
                model.train()

                b_input_ids = batch['input_ids'].to(device)
                b_attention_mask = batch['attention_mask'].to(device)
                b_labels = batch['labels'].to(device)

                optimizer.zero_grad()

                logits = model(input_ids=b_input_ids, attention_mask=b_attention_mask)
                loss = criterion(logits, b_labels)
                if config.flooding:
                    loss = abs(loss - 0.4) + 0.4
                train_loss += loss.item()
                _, preds = torch.max(logits, dim=1)
                corrects += torch.sum(preds == b_labels)

                loss.backward()

                if config.attack == 'fgm':
                    fgm.attack()
                    logits_fgm = model(input_ids=b_input_ids, attention_mask=b_attention_mask)
                    loss_adv = criterion(logits_fgm, b_labels)
                    if config.flooding:
                        loss_adv = abs(loss_adv - 0.4) + 0.4
                    loss_adv.backward()
                    fgm.restore()
                elif config.attack == 'pgd':
                    pgd.backup_grad()
                    # 对抗训练
                    K = 3
                    for t in range(K):
                        pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.processor
                        if t != K-1:
                            model.zero_grad()
                        else:
                            pgd.restore_grad()
                        logits_adv = model(input_ids=b_input_ids, attention_mask=b_attention_mask)
                        loss_adv = criterion(logits_adv, b_labels)
                        loss_adv.backward() # 反向传播，并在正常的grad基础上，累加对抗训练的梯度
                    pgd.restore() # 恢复embedding参数

                optimizer.step()
                scheduler.step()

            avg_train_loss = train_loss / len(train_dataloader)
            avg_train_acc = corrects.cpu().numpy() / len(train_dataloader) / config.batch_size

            print('Training loss:', avg_train_loss, 'Training acc:', avg_train_acc)
            #train_fn(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, epoch, fgm)
            avg_val_loss, avg_val_acc = val_fn(model, valid_dataloader, criterion)

            if config.save_best_only:
                if best_acc < avg_val_acc:
                    best_acc_epoch = epoch
                    best_model = copy.deepcopy(model)
                    best_acc = avg_val_acc

                    model_name = f'{model_type}_{FOLD}'
                    state_dict = {k: v for k, v in model.state_dict().items() if 'relative_positions' not in k}
                    
                    torch.save(state_dict, 'models/' + model_name + '.pt')
                    min_avg_val_loss = avg_val_loss

            
            if epoch - best_acc_epoch > 2:
                break
    print('Cost time:{}'.format(time.time() - t))

In [None]:
model_type = 'nezha_mlm_last3cls_0.15_dual_grad_clip'
NeZhaConfig = NeZhaConfig.from_pretrained(config.MODEL_PATH)
device = config.device

In [None]:
run(config) # 4:19