数据导入

In [1]:
import pandas as pd
df=pd.read_csv(r"/content/drive/My Drive/KAGGLE/nlp/twitter/data/train.csv")
raw_x=df["text"].values.tolist()
raw_y=df["target"].values.tolist()
raw_data=list(zip(raw_x,raw_y))#x为推特文本 y为是否是灾难的标签

训练测试集划分 保存

In [None]:
from sklearn.model_selection import train_test_split
training_data,test_data=train_test_split(raw_data,test_size=0.2, random_state=16)

In [None]:
train=pd.DataFrame(training_data,columns=["text","target"])

In [None]:
train.to_csv("/content/drive/My Drive/KAGGLE/nlp/twitter/train.csv")

模型及预处理部分

In [4]:
! pip install pytorch_pretrained_bert



In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_pretrained_bert import BertModel
from torch.autograd import Variable
#下游任务模型
class bigru_attention(nn.Module):
    def __init__(self, bert_config, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda):
        super(bigru_attention, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.rnn_layers = rnn_layers
        self.word_embeds = BertModel.from_pretrained(bert_config)
        # 双向GRU，//操作为了与后面的Attention操作维度匹配，hidden_dim要取偶数！
        self.bigru = nn.GRU(embedding_dim, hidden_dim, num_layers=rnn_layers, bidirectional=True,dropout=dropout_ratio, batch_first=True)
        self.dropout1 = nn.Dropout(p=dropout1)
        # 由nn.Parameter定义的变量都为requires_grad=True状态
        self.weight_W = nn.Parameter(torch.Tensor(hidden_dim*2, hidden_dim*2))
        self.weight_proj = nn.Parameter(torch.Tensor(hidden_dim*2, 1))
        self.fc = nn.Linear(hidden_dim*2,tagset_size)
        nn.init.uniform_(self.weight_W, -0.1, 0.1)
        nn.init.uniform_(self.weight_proj, -0.1, 0.1)
        self.use_cuda =  use_cuda

    def rand_init_hidden(self, batch_size):
        if self.use_cuda:
            return Variable(
                torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)).cuda(), Variable(
                torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)).cuda()
        else:
            return Variable(
                torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)), Variable(
                torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim))
 
    def forward(self, sentence, attention_mask=None):
        batch_size = sentence.size(0)
        seq_length = sentence.size(1)
        embeds, _ = self.word_embeds(sentence, attention_mask=attention_mask, output_all_encoded_layers=False)
        hidden = self.rand_init_hidden(batch_size)
        gru_out, hiden = self.bigru(embeds) # [seq_len, bs, hid_dim]
        d_gru_out = self.dropout1(gru_out)
        x = d_gru_out
        # # # Attention过程，与上图中三个公式对应
        u = torch.tanh(torch.matmul(x, self.weight_W))
        att = torch.matmul(u, self.weight_proj)
        att_score = F.softmax(att, dim=1)
        scored_x = x * att_score
        # # # Attention过程结束
        feat = torch.sum(scored_x, dim=1)
        y = self.fc(feat)
        return y


备用

In [17]:
import pandas as pd
class InputFeatures(object):
    def __init__(self, text, label, input_id, input_mask):
        self.text = text
        self.label = label
        self.input_id = input_id
        self.input_mask = input_mask
#读取bert的词汇表 之后对单词进行编码
def load_vocab(vocab_file):
    vocab = {}
    index = 0
    with open(vocab_file, "r", encoding="utf-8") as reader:
        while True:
            token = reader.readline()
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1
    return vocab
#读取训练 测试文件 分为输入texts 和标签labels
def load_file(file_path):
    df = pd.read_csv(file_path)
    raw_x=df["text"].values.tolist()
    raw_y=df["target"].values.tolist()
    raw_data=list(zip(raw_x,raw_y))
    texts = [[w for w in sample[0].split()] for sample in raw_data]
    labels = [sample[1] for sample in raw_data]
    return texts, labels
#进行对texts进行编码操作 补齐至最大长度
def load_data(file_path, max_length, vocab):
    texts, labels = load_file(file_path)
    assert len(texts) == len(labels)
    result = []
    for i in range(len(texts)):
        token = texts[i]
        label = int(labels[i])
        if len(token) > max_length-2:
            token = token[0:(max_length-2)]
        tokens_f =['[CLS]'] + token + ['[SEP]']
        input_ids = [int(vocab[i]) if i in vocab else int(vocab['[UNK]']) for i in tokens_f]
        mask_bool=1
        input_mask = [mask_bool] * len(input_ids)
        while len(input_ids) < max_length:
            input_ids.append(0)
            input_mask.append(0)
        assert len(input_ids) == max_length
        assert len(input_mask) == max_length
        #assert len(label_ids) == max_length 实体识别标签序列用
        feature = InputFeatures(text=tokens_f, label=label, input_id=input_ids, input_mask=input_mask)
        result.append(feature)
    return result

In [18]:
#文件位置以及超参数设置
train_file='/content/drive/My Drive/KAGGLE/nlp/twitter/train.csv'
dev_file='/content/drive/My Drive/KAGGLE/nlp/twitter/test.csv'
max_length=100
vocab_file='/content/drive/My Drive/KAGGLE/nlp/twitter/bert/albert_base_v2.zip_files/vocab.txt'
batch_size=16
tagset_size = 2
epochs=20

模型训练及评估

In [19]:
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report
if torch.cuda.is_available():
    device = torch.device("cuda", 0)
    print('device',device)
    use_cuda = True
else:
    device = torch.device("cpu")
    use_cuda = False
vocab = load_vocab(vocab_file)
#load data
train_data = load_data(train_file, max_length=max_length, vocab=vocab)
train_ids = torch.LongTensor([temp.input_id for temp in train_data])
train_masks = torch.LongTensor([temp.input_mask for temp in train_data])
train_tags = torch.LongTensor([temp.label for temp in train_data])
train_dataset = TensorDataset(train_ids, train_masks, train_tags)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)

dev_data = load_data(dev_file, max_length=max_length, vocab=vocab)
dev_ids = torch.LongTensor([temp.input_id for temp in dev_data])
dev_masks = torch.LongTensor([temp.input_mask for temp in dev_data])
dev_tags = torch.LongTensor([temp.label for temp in dev_data])
dev_dataset = TensorDataset(dev_ids, dev_masks, dev_tags)
dev_loader = DataLoader(dev_dataset, shuffle=True, batch_size=batch_size)

#模型及超参数设置        
model = bigru_attention('bert-base-cased', tagset_size, 768, 200, 1,
                      dropout_ratio=0.4, dropout1=0.4, use_cuda = use_cuda)
if use_cuda:
    model.cuda()
model.train()
losser=torch.nn.CrossEntropyLoss()
optimizer = getattr(optim, 'Adam')#优化器Adam
optimizer = optimizer(model.parameters(), lr=0.00003, weight_decay=0.00005)
best_f = -100
for epoch in range(epochs):
    print('epoch: {}trrain'.format(epoch))
    for i, train_batch in enumerate(tqdm(train_loader)):
        model.train()
        model.zero_grad()
        sentence, masks, tags = train_batch
        sentence, masks, tags = Variable(sentence), Variable(masks), Variable(tags)
        if use_cuda:
            sentence = sentence.cuda()
            masks = masks.cuda()
            tags = tags.cuda()
        loss = losser(model(sentence),tags)
        loss.backward()
        optimizer.step()
    print('epoch: {}loss: {}'.format(epoch, loss.item()))
    model.eval()
    pred = []
    true=[]
    for i, dev_batch in enumerate(dev_loader):
        model.zero_grad()
        sentence, masks, tags = dev_batch
        sentence, masks, tags = Variable(sentence), Variable(masks), Variable(tags)
        if use_cuda:
            sentence = sentence.cuda()
            asks = masks.cuda()
            tags = tags.cuda()
        predict_tags = F.softmax(model(sentence)).tolist()
        pred_tag=[i.index(max(i)) for i in predict_tags]
        pred.extend(pred_tag)
        true.extend(tags.tolist())
    print(classification_report(true, pred))

device cuda:0




  0%|          | 0/381 [00:00<?, ?it/s][A[A

epoch: 0trrain




  0%|          | 1/381 [00:00<02:26,  2.59it/s][A[A

  1%|          | 2/381 [00:00<02:21,  2.67it/s][A[A

  1%|          | 3/381 [00:01<02:18,  2.72it/s][A[A

  1%|          | 4/381 [00:01<02:18,  2.72it/s][A[A

  1%|▏         | 5/381 [00:01<02:15,  2.77it/s][A[A

  2%|▏         | 6/381 [00:02<02:14,  2.78it/s][A[A

  2%|▏         | 7/381 [00:02<02:14,  2.78it/s][A[A

  2%|▏         | 8/381 [00:02<02:13,  2.78it/s][A[A

  2%|▏         | 9/381 [00:03<02:14,  2.77it/s][A[A

  3%|▎         | 10/381 [00:03<02:12,  2.80it/s][A[A

  3%|▎         | 11/381 [00:03<02:12,  2.79it/s][A[A

  3%|▎         | 12/381 [00:04<02:11,  2.80it/s][A[A

  3%|▎         | 13/381 [00:04<02:11,  2.80it/s][A[A

  4%|▎         | 14/381 [00:05<02:11,  2.79it/s][A[A

  4%|▍         | 15/381 [00:05<02:11,  2.78it/s][A[A

  4%|▍         | 16/381 [00:05<02:11,  2.78it/s][A[A

  4%|▍         | 17/381 [00:06<02:11,  2.78it/s][A[A

  5%|▍         | 18/381 [00:06<02:10,  2.77it/s][A[A


epoch: 0loss: 0.21777376532554626




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.76      0.95      0.84       867
           1       0.90      0.60      0.72       656

    accuracy                           0.80      1523
   macro avg       0.83      0.77      0.78      1523
weighted avg       0.82      0.80      0.79      1523

epoch: 1trrain




  0%|          | 1/381 [00:00<02:12,  2.87it/s][A[A

  1%|          | 2/381 [00:00<02:14,  2.81it/s][A[A

  1%|          | 3/381 [00:01<02:16,  2.77it/s][A[A

  1%|          | 4/381 [00:01<02:16,  2.76it/s][A[A

  1%|▏         | 5/381 [00:01<02:17,  2.74it/s][A[A

  2%|▏         | 6/381 [00:02<02:17,  2.72it/s][A[A

  2%|▏         | 7/381 [00:02<02:17,  2.73it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.72it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.71it/s][A[A

  3%|▎         | 10/381 [00:03<02:16,  2.71it/s][A[A

  3%|▎         | 11/381 [00:04<02:16,  2.71it/s][A[A

  3%|▎         | 12/381 [00:04<02:16,  2.70it/s][A[A

  3%|▎         | 13/381 [00:04<02:15,  2.71it/s][A[A

  4%|▎         | 14/381 [00:05<02:15,  2.71it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.71it/s][A[A

  4%|▍         | 16/381 [00:05<02:14,  2.71it/s][A[A

  4%|▍         | 17/381 [00:06<02:14,  2.71it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.70it/s][A[A


epoch: 1loss: 0.8442856073379517




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.81      0.83      0.82       867
           1       0.77      0.74      0.76       656

    accuracy                           0.79      1523
   macro avg       0.79      0.79      0.79      1523
weighted avg       0.79      0.79      0.79      1523

epoch: 2trrain




  0%|          | 1/381 [00:00<02:13,  2.84it/s][A[A

  1%|          | 2/381 [00:00<02:15,  2.79it/s][A[A

  1%|          | 3/381 [00:01<02:17,  2.75it/s][A[A

  1%|          | 4/381 [00:01<02:17,  2.74it/s][A[A

  1%|▏         | 5/381 [00:01<02:17,  2.73it/s][A[A

  2%|▏         | 6/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 7/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.69it/s][A[A

  3%|▎         | 12/381 [00:04<02:16,  2.70it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.70it/s][A[A

  4%|▎         | 14/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 16/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 17/381 [00:06<02:14,  2.70it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.70it/s][A[A


epoch: 2loss: 0.18934759497642517




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.80      0.85      0.82       867
           1       0.78      0.71      0.75       656

    accuracy                           0.79      1523
   macro avg       0.79      0.78      0.79      1523
weighted avg       0.79      0.79      0.79      1523

epoch: 3trrain




  0%|          | 1/381 [00:00<02:14,  2.82it/s][A[A

  1%|          | 2/381 [00:00<02:16,  2.78it/s][A[A

  1%|          | 3/381 [00:01<02:17,  2.75it/s][A[A

  1%|          | 4/381 [00:01<02:17,  2.73it/s][A[A

  1%|▏         | 5/381 [00:01<02:18,  2.72it/s][A[A

  2%|▏         | 6/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 7/381 [00:02<02:18,  2.69it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.70it/s][A[A

  3%|▎         | 12/381 [00:04<02:16,  2.70it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.69it/s][A[A

  4%|▎         | 14/381 [00:05<02:16,  2.69it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 16/381 [00:05<02:15,  2.69it/s][A[A

  4%|▍         | 17/381 [00:06<02:15,  2.70it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.70it/s][A[A


epoch: 3loss: 0.03269965574145317




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.80      0.81      0.81       867
           1       0.75      0.74      0.74       656

    accuracy                           0.78      1523
   macro avg       0.78      0.78      0.78      1523
weighted avg       0.78      0.78      0.78      1523

epoch: 4trrain




  0%|          | 1/381 [00:00<02:16,  2.79it/s][A[A

  1%|          | 2/381 [00:00<02:16,  2.77it/s][A[A

  1%|          | 3/381 [00:01<02:18,  2.74it/s][A[A

  1%|          | 4/381 [00:01<02:19,  2.71it/s][A[A

  1%|▏         | 5/381 [00:01<02:18,  2.72it/s][A[A

  2%|▏         | 6/381 [00:02<02:18,  2.70it/s][A[A

  2%|▏         | 7/381 [00:02<02:18,  2.70it/s][A[A

  2%|▏         | 8/381 [00:02<02:18,  2.70it/s][A[A

  2%|▏         | 9/381 [00:03<02:18,  2.69it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.69it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.68it/s][A[A

  3%|▎         | 12/381 [00:04<02:17,  2.68it/s][A[A

  3%|▎         | 13/381 [00:04<02:17,  2.69it/s][A[A

  4%|▎         | 14/381 [00:05<02:16,  2.68it/s][A[A

  4%|▍         | 15/381 [00:05<02:16,  2.67it/s][A[A

  4%|▍         | 16/381 [00:05<02:16,  2.68it/s][A[A

  4%|▍         | 17/381 [00:06<02:15,  2.69it/s][A[A

  5%|▍         | 18/381 [00:06<02:15,  2.68it/s][A[A


epoch: 4loss: 0.3751208484172821




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.80      0.83      0.81       867
           1       0.76      0.73      0.74       656

    accuracy                           0.78      1523
   macro avg       0.78      0.78      0.78      1523
weighted avg       0.78      0.78      0.78      1523

epoch: 5trrain




  0%|          | 1/381 [00:00<02:13,  2.85it/s][A[A

  1%|          | 2/381 [00:00<02:15,  2.81it/s][A[A

  1%|          | 3/381 [00:01<02:16,  2.76it/s][A[A

  1%|          | 4/381 [00:01<02:17,  2.75it/s][A[A

  1%|▏         | 5/381 [00:01<02:17,  2.74it/s][A[A

  2%|▏         | 6/381 [00:02<02:17,  2.72it/s][A[A

  2%|▏         | 7/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.70it/s][A[A

  3%|▎         | 12/381 [00:04<02:16,  2.70it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.70it/s][A[A

  4%|▎         | 14/381 [00:05<02:16,  2.70it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 16/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 17/381 [00:06<02:15,  2.69it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.70it/s][A[A


epoch: 5loss: 0.013402575626969337




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.80      0.84      0.82       867
           1       0.77      0.71      0.74       656

    accuracy                           0.79      1523
   macro avg       0.79      0.78      0.78      1523
weighted avg       0.79      0.79      0.79      1523

epoch: 6trrain




  0%|          | 1/381 [00:00<02:13,  2.84it/s][A[A

  1%|          | 2/381 [00:00<02:15,  2.80it/s][A[A

  1%|          | 3/381 [00:01<02:16,  2.76it/s][A[A

  1%|          | 4/381 [00:01<02:17,  2.74it/s][A[A

  1%|▏         | 5/381 [00:01<02:17,  2.73it/s][A[A

  2%|▏         | 6/381 [00:02<02:17,  2.72it/s][A[A

  2%|▏         | 7/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.71it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.69it/s][A[A

  3%|▎         | 12/381 [00:04<02:16,  2.70it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.70it/s][A[A

  4%|▎         | 14/381 [00:05<02:16,  2.70it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 16/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 17/381 [00:06<02:14,  2.70it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.70it/s][A[A


epoch: 6loss: 0.18471141159534454




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.77      0.91      0.83       867
           1       0.84      0.65      0.73       656

    accuracy                           0.79      1523
   macro avg       0.80      0.78      0.78      1523
weighted avg       0.80      0.79      0.79      1523

epoch: 7trrain




  0%|          | 1/381 [00:00<02:14,  2.83it/s][A[A

  1%|          | 2/381 [00:00<02:15,  2.80it/s][A[A

  1%|          | 3/381 [00:01<02:17,  2.75it/s][A[A

  1%|          | 4/381 [00:01<02:17,  2.74it/s][A[A

  1%|▏         | 5/381 [00:01<02:17,  2.73it/s][A[A

  2%|▏         | 6/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 7/381 [00:02<02:17,  2.72it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 11/381 [00:04<02:16,  2.70it/s][A[A

  3%|▎         | 12/381 [00:04<02:16,  2.70it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.70it/s][A[A

  4%|▎         | 14/381 [00:05<02:17,  2.68it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.71it/s][A[A

  4%|▍         | 16/381 [00:05<02:14,  2.71it/s][A[A

  4%|▍         | 17/381 [00:06<02:14,  2.70it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.71it/s][A[A


epoch: 7loss: 0.24139034748077393




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.80      0.83      0.81       867
           1       0.76      0.73      0.74       656

    accuracy                           0.78      1523
   macro avg       0.78      0.78      0.78      1523
weighted avg       0.78      0.78      0.78      1523

epoch: 8trrain




  0%|          | 1/381 [00:00<02:15,  2.81it/s][A[A

  1%|          | 2/381 [00:00<02:16,  2.78it/s][A[A

  1%|          | 3/381 [00:01<02:17,  2.74it/s][A[A

  1%|          | 4/381 [00:01<02:18,  2.73it/s][A[A

  1%|▏         | 5/381 [00:01<02:18,  2.72it/s][A[A

  2%|▏         | 6/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 7/381 [00:02<02:18,  2.70it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.70it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.69it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.70it/s][A[A

  3%|▎         | 12/381 [00:04<02:17,  2.69it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.69it/s][A[A

  4%|▎         | 14/381 [00:05<02:16,  2.70it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 16/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 17/381 [00:06<02:14,  2.70it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.70it/s][A[A


epoch: 8loss: 0.0009091615793295205




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.81      0.79      0.80       867
           1       0.73      0.76      0.74       656

    accuracy                           0.78      1523
   macro avg       0.77      0.77      0.77      1523
weighted avg       0.78      0.78      0.78      1523

epoch: 9trrain




  0%|          | 1/381 [00:00<02:14,  2.83it/s][A[A

  1%|          | 2/381 [00:00<02:16,  2.79it/s][A[A

  1%|          | 3/381 [00:01<02:17,  2.75it/s][A[A

  1%|          | 4/381 [00:01<02:17,  2.73it/s][A[A

  1%|▏         | 5/381 [00:01<02:18,  2.72it/s][A[A

  2%|▏         | 6/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 7/381 [00:02<02:18,  2.70it/s][A[A

  2%|▏         | 8/381 [00:02<02:18,  2.70it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.70it/s][A[A

  3%|▎         | 12/381 [00:04<02:16,  2.70it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.69it/s][A[A

  4%|▎         | 14/381 [00:05<02:16,  2.69it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.69it/s][A[A

  4%|▍         | 16/381 [00:05<02:15,  2.69it/s][A[A

  4%|▍         | 17/381 [00:06<02:15,  2.68it/s][A[A

  5%|▍         | 18/381 [00:06<02:15,  2.69it/s][A[A


epoch: 9loss: 0.0014937877422198653




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.80      0.83      0.81       867
           1       0.77      0.72      0.74       656

    accuracy                           0.78      1523
   macro avg       0.78      0.78      0.78      1523
weighted avg       0.78      0.78      0.78      1523

epoch: 10trrain




  0%|          | 1/381 [00:00<02:14,  2.82it/s][A[A

  1%|          | 2/381 [00:00<02:15,  2.79it/s][A[A

  1%|          | 3/381 [00:01<02:17,  2.75it/s][A[A

  1%|          | 4/381 [00:01<02:17,  2.74it/s][A[A

  1%|▏         | 5/381 [00:01<02:17,  2.73it/s][A[A

  2%|▏         | 6/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 7/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 9/381 [00:03<02:18,  2.69it/s][A[A

  3%|▎         | 10/381 [00:03<02:16,  2.71it/s][A[A

  3%|▎         | 11/381 [00:04<02:16,  2.70it/s][A[A

  3%|▎         | 12/381 [00:04<02:17,  2.69it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.70it/s][A[A

  4%|▎         | 14/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 16/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 17/381 [00:06<02:14,  2.70it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.70it/s][A[A


epoch: 10loss: 0.01809924840927124




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.77      0.91      0.84       867
           1       0.84      0.65      0.73       656

    accuracy                           0.80      1523
   macro avg       0.81      0.78      0.78      1523
weighted avg       0.80      0.80      0.79      1523

epoch: 11trrain




  0%|          | 1/381 [00:00<02:14,  2.82it/s][A[A

  1%|          | 2/381 [00:00<02:15,  2.79it/s][A[A

  1%|          | 3/381 [00:01<02:17,  2.75it/s][A[A

  1%|          | 4/381 [00:01<02:17,  2.74it/s][A[A

  1%|▏         | 5/381 [00:01<02:18,  2.72it/s][A[A

  2%|▏         | 6/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 7/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.70it/s][A[A

  3%|▎         | 12/381 [00:04<02:17,  2.69it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.70it/s][A[A

  4%|▎         | 14/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 15/381 [00:05<02:16,  2.69it/s][A[A

  4%|▍         | 16/381 [00:05<02:16,  2.68it/s][A[A

  4%|▍         | 17/381 [00:06<02:15,  2.70it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.69it/s][A[A


epoch: 11loss: 0.024718236178159714




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.80      0.79      0.79       867
           1       0.72      0.73      0.73       656

    accuracy                           0.76      1523
   macro avg       0.76      0.76      0.76      1523
weighted avg       0.77      0.76      0.77      1523

epoch: 12trrain




  0%|          | 1/381 [00:00<02:14,  2.83it/s][A[A

  1%|          | 2/381 [00:00<02:16,  2.78it/s][A[A

  1%|          | 3/381 [00:01<02:16,  2.76it/s][A[A

  1%|          | 4/381 [00:01<02:17,  2.75it/s][A[A

  1%|▏         | 5/381 [00:01<02:17,  2.73it/s][A[A

  2%|▏         | 6/381 [00:02<02:18,  2.72it/s][A[A

  2%|▏         | 7/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 8/381 [00:02<02:18,  2.70it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.71it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.71it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.70it/s][A[A

  3%|▎         | 12/381 [00:04<02:16,  2.71it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.70it/s][A[A

  4%|▎         | 14/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 16/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 17/381 [00:06<02:15,  2.69it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.70it/s][A[A


epoch: 12loss: 0.04814837500452995




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.81      0.76      0.78       867
           1       0.71      0.76      0.73       656

    accuracy                           0.76      1523
   macro avg       0.76      0.76      0.76      1523
weighted avg       0.76      0.76      0.76      1523

epoch: 13trrain




  0%|          | 1/381 [00:00<02:11,  2.89it/s][A[A

  1%|          | 2/381 [00:00<02:13,  2.84it/s][A[A

  1%|          | 3/381 [00:01<02:17,  2.75it/s][A[A

  1%|          | 4/381 [00:01<02:16,  2.76it/s][A[A

  1%|▏         | 5/381 [00:01<02:16,  2.74it/s][A[A

  2%|▏         | 6/381 [00:02<02:18,  2.71it/s][A[A

  2%|▏         | 7/381 [00:02<02:17,  2.72it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.70it/s][A[A

  3%|▎         | 10/381 [00:03<02:17,  2.71it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.70it/s][A[A

  3%|▎         | 12/381 [00:04<02:16,  2.69it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.70it/s][A[A

  4%|▎         | 14/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.69it/s][A[A

  4%|▍         | 16/381 [00:05<02:15,  2.69it/s][A[A

  4%|▍         | 17/381 [00:06<02:15,  2.69it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.69it/s][A[A


epoch: 13loss: 0.0053279162384569645




  0%|          | 0/381 [00:00<?, ?it/s][A[A

              precision    recall  f1-score   support

           0       0.80      0.82      0.81       867
           1       0.76      0.73      0.74       656

    accuracy                           0.78      1523
   macro avg       0.78      0.78      0.78      1523
weighted avg       0.78      0.78      0.78      1523

epoch: 14trrain




  0%|          | 1/381 [00:00<02:12,  2.86it/s][A[A

  1%|          | 2/381 [00:00<02:14,  2.82it/s][A[A

  1%|          | 3/381 [00:01<02:16,  2.77it/s][A[A

  1%|          | 4/381 [00:01<02:16,  2.76it/s][A[A

  1%|▏         | 5/381 [00:01<02:17,  2.74it/s][A[A

  2%|▏         | 6/381 [00:02<02:17,  2.72it/s][A[A

  2%|▏         | 7/381 [00:02<02:17,  2.72it/s][A[A

  2%|▏         | 8/381 [00:02<02:17,  2.71it/s][A[A

  2%|▏         | 9/381 [00:03<02:17,  2.71it/s][A[A

  3%|▎         | 10/381 [00:03<02:16,  2.71it/s][A[A

  3%|▎         | 11/381 [00:04<02:17,  2.70it/s][A[A

  3%|▎         | 12/381 [00:04<02:16,  2.71it/s][A[A

  3%|▎         | 13/381 [00:04<02:16,  2.70it/s][A[A

  4%|▎         | 14/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 15/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 16/381 [00:05<02:15,  2.70it/s][A[A

  4%|▍         | 17/381 [00:06<02:14,  2.70it/s][A[A

  5%|▍         | 18/381 [00:06<02:14,  2.71it/s][A[A


KeyboardInterrupt: ignored

In [6]:
for i, dev_batch in enumerate(dev_loader):
    sentence, masks, tags = dev_batch

In [8]:
masks

tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [20]:
print(classification_report(true, pred))

              precision    recall  f1-score   support

           0       0.80      0.82      0.81       867
           1       0.76      0.73      0.74       656

    accuracy                           0.78      1523
   macro avg       0.78      0.78      0.78      1523
weighted avg       0.78      0.78      0.78      1523



In [21]:
model_name = '/content/drive/My Drive/KAGGLE/nlp/twitter' + str(epochs) + ".pkl"
torch.save(model.state_dict(), model_name)

In [None]:
import torch
model =bigru_attention('/content/drive/My Drive/chinese_roberta_wwm_ext_pytorch.zip_files', tagset_size, 768, 200, 1,
                      dropout_ratio=0.4, dropout1=0.4, use_cuda = True)
model.load_state_dict(torch.load('/content/drive/My Drive/Colab Notebooks/kashgari/ka re/50.pkl'))
model.cuda()

  "num_layers={}".format(dropout, num_layers))


bigru_attention(
  (word_embeds): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
          