In [1]:
import import_ipynb
from mydata import mydata
import torch
import torch.nn as nn
from torch import optim
from transformers import *
import pandas as pd
import ast
import copy
import os
from time import strftime,gmtime
from opencc import OpenCC
import pyprind

# Model
class bertwwm(nn.Module):
    def __init__(self):
        super(bertwwm,self).__init__()
        self.bert_model = BertModel.from_pretrained("hfl/chinese-bert-wwm")
        #self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-bert-wwm")
        self.bi_decoder = nn.Sequential(
            nn.Linear(768,1)
            ,nn.Dropout(0.1)
            ,nn.Sigmoid()
        )
    def forward(self,input_ids):
        #print(input_ids.shape)
        # input a string
        all_hidden_states, all_attentions = self.bert_model(input_ids)[-2:]
        binary = self.bi_decoder(all_attentions).reshape(1)
       # print(binary.shape,binary)
        return binary

importing Jupyter notebook from mydata.ipynb


In [2]:
def test(model,data,device):
    print('Testing...')
    size,loss,acc,not_acc = 0,0,0,0
    test_num = len(data.test)
    criterion = nn.BCELoss()
    tokenizer = BertTokenizer.from_pretrained("hfl/chinese-bert-wwm")
    model.eval()
    print(f'Testing data: {test_num}')
    with torch.no_grad():
        for content,name in data.test:  
            if len(name) is 0:    
                label = torch.Tensor([0]).to('cpu')
            else:
                label = torch.Tensor([1]).to('cpu')
            size +=1 
            input_ids = torch.tensor([tokenizer.encode(str(content),max_length=512,truncation=True)]).to(device)
        #print(input_ids)
            pred = model(input_ids).to('cpu')
            a_loss = criterion(pred,label) 
            loss += a_loss
            if len(name) is 0 and pred.item() < 0.4:
                #print(pred.item())
                acc +=1
            elif len(name) is not 0 and pred.item() >0.4:
                acc +=1    
    loss /= test_num
    return loss,acc

In [3]:
# lr  : learning rate
# w_d :weight_decay

def train(data,lr_rate ,w_d,device):
    model = bertwwm().to(device)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adadelta(parameters, lr=lr_rate, weight_decay=w_d)
    criterion = nn.BCELoss()  
    #cc = OpenCC('tw2sp')
    max_test_acc = 0
    tokenizer = BertTokenizer.from_pretrained("hfl/chinese-bert-wwm")
    i = 0
    model.train()
    optimizer.zero_grad()
    loss = 0
    pbar = pyprind.ProgBar(len(data.train))
    for content,name in data.train:
        i +=1 
        input_ids = torch.tensor([tokenizer.encode(str(content),max_length=512,truncation=True)]).to(device)
        #print(input_ids)
        pred = model(input_ids).to('cpu') 
        if len(name) is 0:    
            label = torch.Tensor([0]).to('cpu')
        else:
            label = torch.Tensor([1]).to('cpu')
            
        batch_loss = criterion(pred,label)
        loss += batch_loss
        batch_loss.backward()
        #print(f'Train loss {loss}') 
        if (i+1) % 8 is 0:
            optimizer.step()
            optimizer.zero_grad()
            #print(f'Batch loss is {loss/8}')
            loss = 0
        pbar.update()
        """
        if (i+1) % 1000 is 0:
            test_loss,test_acc = test(model,data,device)
            print(f'Test Loss is {test_loss} ACC is {test_acc}')
            if test_acc > max_test_acc:
                print('Better model!')
                best_model = copy.deepcopy(model.state_dict())
        model.to(device)
        model.train()
        """
    best_model = copy.deepcopy(model.state_dict())
    return best_model

In [4]:

    
device = torch.device('cuda:0')
data = mydata('./train_fix.csv')
print('Data load down')
mode = 't'

if mode is 'train':
    best_model = train(data,0.006,5e-5,device)
    
    if not os.path.exists('saved_models'):
        os.makedirs('saved_models')    
    modeltime = strftime('%H_%M_%S', gmtime()) 
    modelname = 'bertWWM_'+ modeltime
    torch.save(best_model, f'saved_models/{modelname}.pt')
    print(f'Train end, model name is {modelname}.pt')
    test_model = bertwwm().to(device)
    test_model.load_state_dict(torch.load(f'saved_models/{modelname}'))
    test_loss,test_acc = test(test_model,data,device)
    print(f'Test Loss is {test_loss} ACC is {test_acc}')
    
else:
    modelname = 'bertWWM_06_52_12.pt.pt'
    test_model = bertwwm().to(device)
    test_model.load_state_dict(torch.load(f'saved_models/{modelname}'))
    test_loss,test_acc = test(test_model,data,device)
    print(f'Test Loss is {test_loss} ACC is {test_acc}')
    
    
    

Data load down
Testing...
Testing data: 503
Test Loss is 0.02215447835624218 ACC is 500


In [5]:
#if __name__ == '__main__':
#    main()

In [6]:
# bertWWM_06_41_57.pt : 500/503
# bertWWM_06_52_12.pt.pt : 500/503