In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
from transformers import BertModel, BertTokenizer, BertConfig
import torch as t
import sys
sys.path.append('/home/zhongjc/DataFound/Agricul/Bert_BiLSTM_CRF')
from sklearn.model_selection import train_test_split
import torch.optim as optim
from model import *
import torch.utils.data as tud

device = t.device('cuda:7' if t.cuda.is_available() else 'cpu')
pretrained_path = 'bert-base-chinese'
bert_embedding = 768

In [2]:
df_train = pd.read_csv('./data/train_clean.csv')
df_test = pd.read_csv('./data/test.csv')
tokenizer = BertTokenizer.from_pretrained(pretrained_path)
idx2tag = ['<pad>', '<sos>', '<eos>', 'e_medicine', 'O', 'e_crop', 'b_crop', 'm_crop', 'm_disease', 'b_medicine', 'e_disease', 'b_disease', 'm_medicine']
tag2idx = {v:k for k, v in enumerate(idx2tag)}

In [3]:
BATCH_SIZE = 10
N_EPOCHS = 60
MAX_LEN = 400
rnn_hidden = 500
dropout1 = 0.5
dropout_ratio = 0.5
rnn_layer = 1
lr = 0.0001
lr_decay = 0.00001
weight_decay = 0.00005

## data set

In [4]:
class AgriDataset(tud.Dataset):
    def __init__(self, datas):
        super(AgriDataset, self).__init__()
        self.datas = datas
        self.lens = len(datas)
        
    def __len__(self):
        return self.lens
    
    def __getitem__(self, idx):
        data = self.datas[idx][-1]
        words, tags = [], []
        for i in data.split(' '):
            if '/' not in i: continue
            words.append(i.split('/')[0])
            tags.append(i.split('/')[1])
        seq_dict = tokenizer.encode_plus(''.join(words), max_length=MAX_LEN, add_special_tokens=True, pad_to_max_length=True)
        input_ids = seq_dict['input_ids']
        atten_mask = seq_dict['attention_mask']
        tags = ['<sos>'] + tags + ['<eos>']
        tags_ids = [tag2idx[i] for i in tags] + [tag2idx['<pad>']]*(MAX_LEN-len(tags))
        
        input_ids = t.LongTensor(input_ids)
        atten_mask = t.LongTensor(atten_mask)
        tags_ids = t.LongTensor(tags_ids)
        return input_ids, atten_mask, tags_ids

In [5]:
train_data = AgriDataset(df_train.values)
train_data, val_data = train_test_split(train_data, test_size=0.005, random_state=2020)
train_iter = tud.DataLoader(train_data, batch_size=BATCH_SIZE)
val_iter = tud.DataLoader(val_data, batch_size=BATCH_SIZE)

In [6]:
model = BERT_LSTM_CRF(pretrained_path, len(idx2tag), bert_embedding, rnn_hidden, rnn_layer, dropout_ratio, dropout1)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [7]:
def train(model, data_iter, optimizer):
    model.train()
    total_loss, total_cnt = 0., 0.
    for idx, (input_ids, atten_mask, tag_ids) in enumerate(data_iter):
        input_ids, atten_mask, tag_ids = input_ids.to(device), atten_mask.to(device), tag_ids.to(device)
        feats = model(input_ids, atten_mask)
        loss = model.loss(feats, atten_mask, tag_ids)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_cnt += len(input_ids)
        total_loss += loss.item()
    return total_loss / total_cnt

def evaluate(model, data_iter):
    model.eval()
    total_loss, total_cnt = 0., 0.
    with t.no_grad():
        for idx, (input_ids, atten_mask, tag_ids) in enumerate(data_iter):
            input_ids, atten_mask, tag_ids = input_ids.to(device), atten_mask.to(device), tag_ids.to(device)
            feats = model(input_ids, atten_mask)
            loss = model.loss(feats, atten_mask, ag_ids)
            
            total_cnt += len(input_ids)
            total_loss += loss.item()
    return total_loss / total_cnt

In [8]:
val_history = []
for e in range(N_EPOCHS):
    train_loss = train(model, train_iter, optimizer)
    val_loss = evaluate(model, val_iter)
    
    if len(val_history)==0 or min(val_history)>val_loss:
        t.save('./save_models/1.pt', model.state_dict())
    val_history.append(val_loss)
    print("Epoch: {}".format(e))
    print("Train Loss:{:.4f} Val Loss:{:.4f}".format(train_loss, val_loss))

NameError: name 'ag_ids' is not defined