In [1]:
from torchtext import data
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from dataset import DataHandler, BertField
import torch.nn as nn
import torch
from model import BiAttention, EncoderRNN, SelfAttention
import os
import torchtext
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
train_examples_path = './train_examples.pt'
val_examples_path = './val_examples.pt'

In [3]:
tokenizer = BertTokenizer.from_pretrained('./bert-base-uncased-vocab.txt', do_lower_case=True)

bert_field = BertField(tokenizer)
multi_bert_field = data.NestedField(bert_field)



word_field = data.Field(batch_first=True, sequential=True, tokenize=tokenizer.tokenize, lower=True) # query
multi_word_field = data.NestedField(word_field) 

word_field_sup = data.Field(batch_first=True, sequential=True, tokenize=tokenizer.tokenize, lower=True, fix_length=320)
multi_word_field_sup = data.NestedField(word_field_sup) 

bert_field_sup = BertField(tokenizer, fix_length=320)
multi_bert_field_sup = data.NestedField(bert_field_sup)

raw = data.RawField()
raw.is_target = False

label_field = data.Field(sequential=False, is_target=True, use_vocab=False)

dict_field = {
    'id': ('id', raw),
    'supports': [('s_glove', multi_word_field_sup), ('s_bert', multi_bert_field_sup)],
    'query': [('q_glove', word_field), ('q_bert', bert_field)],
    'answer': [('a_glove', word_field), ('a_bert', bert_field)],
    'candidates': [('c_glove', multi_word_field), ('c_bert', multi_bert_field)],
    'label': ('label', label_field)

}

In [4]:
data_handler = DataHandler(train_examples_path, val_examples_path, dict_field)

load examples.pt  :./train_examples.pt, ./val_examples.pt


In [5]:
multi_word_field_sup.build_vocab(data_handler.trainset, data_handler.valset, vectors=torchtext.vocab.GloVe(dim=300,name='6B') )
word_field.vocab = multi_word_field_sup.vocab
word_field.include_lengths = True

In [6]:
train_iter = data_handler.get_train_iter(batch_size=1)
val_iter = data_handler.get_val_iter(batch_size=1)

### Embedding

这一层需要频繁的改动，所以暂时不放在py文件中

In [7]:
class EmbeddingLayer(nn.Module):
    
    def __init__(self, word_field, bert_model_path='./bert-base-uncased/', use_all=False):
        super(EmbeddingLayer, self).__init__()
        self.word_embedding_layer = nn.Embedding.from_pretrained(embeddings=word_field.vocab.vectors)
        
        model = BertModel.from_pretrained(bert_model_path)   
        self.bert_model = model
        
        self.use_all = use_all
        self.freeze()
        
    def freeze(self):
        for param in self.bert_model.parameters():
            param.requires_grad = False
        self.word_embedding_layer.weight.requires_grad = False
        
    def forward(self, word_tokens, bert_tokens, input_mask=None):
        '''
        input:
            x: [batch_size, seg_len]
        
        return embeddings: [batch_size, seq_len, glove_dim + bert_dim]    
        '''
        word_embeddings = self.word_embedding_layer(word_tokens)
        
        # encoded_layers: [batch_size, seq_len, bert_embedding_dim] * num_of_layers
        encoded_layers, _ = self.bert_model(bert_tokens, attention_mask=input_mask)
        
        bert_embeddings = torch.zeros_like(encoded_layers[-1])
        if self.use_all:
            for layer in encoded_layers:
                bert_embeddings += layer
            bert_embeddings /= len(encoded_layers)
        else:
            bert_embeddings += encoded_layers[-1]
        
        out = torch.cat([word_embeddings, bert_embeddings], dim=-1)
        return out
        

In [8]:
class SimpleQANet(nn.Module):
    
    def __init__(self, config, word_field):
        super(SimpleQANet, self).__init__()
        self.config = config
        self.use_cuda = config.use_cuda
        
        self.embedding_layer = EmbeddingLayer(word_field, config.bert_path, config.use_all)
        self.rnn = EncoderRNN(config.word_dim + config.bert_dim, config.hidden, 1, True, True, 0.2, False)
        
        self.qc_att = BiAttention(config.hidden*2, 0.2)
        self.linear_1 = nn.Sequential(
                nn.Linear(config.hidden*8, config.hidden),
                nn.ReLU()
        )    
        
        self.rnn_2 = EncoderRNN(config.hidden, config.hidden, 1, False, True, 0.2, False)
        
        self.self_att = SelfAttention(config.hidden*2, config.hidden*2, 0.2)       
        self.self_att_2 = SelfAttention(config.hidden*2, config.hidden*2, 0.2)        
        
        self.self_att_c = SelfAttention(config.hidden*2, config.hidden*2, 0.2)        
        
        
    def forward(self, batch):
        q_glove, _ = batch.q_glove
        q_bert = batch.q_bert
        s_glove = batch.s_glove
        s_bert = batch.s_bert
        c_glove = batch.c_glove
        c_bert = batch.c_bert
        
        if self.use_cuda:
            q_glove = q_glove.cuda()
            q_bert = q_bert.cuda()
            s_glove = s_glove.cuda().squeeze(0)
            s_bert = s_bert.cuda().squeeze(0)
            c_glove = c_glove.cuda().squeeze(0)
            c_bert = c_bert.cuda().squeeze(0)
            
        # Embedding 
        context_mask = (s_bert > 0).float()
        ques_mask = (q_bert > 0).float()
        
        q_out = self.embedding_layer(q_glove, q_bert)
        s_out = self.embedding_layer(s_glove, s_bert, context_mask)
        c_out = self.embedding_layer(c_glove, c_bert)

        q_out = self.rnn(q_out)
        c_out = self.rnn(c_out)
        
        s_out = self.rnn(s_out)

        # bi-attention on supports and  question
        context_mask = (c_bert.squeeze() > 0).float()
        ques_mask = (q_bert > 0).float()
        
        support_len = s_out.size(0)
        q_out = q_out.expand(support_len, q_out.size(1), q_out.size(2))
        ques_mask = ques_mask.expand(support_len, q_out.size(1))        
        
        # s_out:[supports_len, seq_len, hidden*2], q_out: [support_len, seq_len, hidden*2]
        output = self.qc_att(s_out, q_out, ques_mask)
        output = self.linear_1(output)
        output = self.rnn_2(output)
        
        # self-attention pooling 
        # [support_len, hidden*2]
        output = self.self_att(output)
        # [1, hidden*2]
        output = self.self_att_2(output.unsqueeze(0))

        # [candidate_len, hidden*2]
        c_out = self.self_att_c(c_out)
        
        # score layer
        score = torch.mm(c_out, torch.tanh(output.transpose(0, 1)))
        return score

In [9]:
class Config:
    
    def __init__(self):
        self.hidden = 100
        self.word_dim = 300
        self.bert_dim = 768
        self.use_cuda = True
        self.bert_path = './bert-base-uncased/'
        self.use_all = False
        self.lr = 5e-4
        self.epochs = 30
        self.log_dir = './logs'
        self.model_name = 'simpleQANet'

In [10]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [11]:
def train(epoch, data_iter, model, criterion, optimizer, cuda):
    losses = AverageMeter()
    acces = AverageMeter()
    model.train()
    #model.embedding_layer.eval()
    for idx, batch in enumerate(data_iter):
        score = model(batch)
        label = batch.label
        if cuda:
            label = label.cuda()
        score = score.transpose(0,1)      
        
        loss = criterion(score, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item())
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()   
        acces.update(acc)
        if idx % 100 == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

def val(epoch, data_iter, model, criterion, cuda):
    losses = AverageMeter()
    acces = AverageMeter()
    model.eval()
    for idx, batch in enumerate(data_iter):
        with torch.no_grad():
            score = model(batch)
            
        label = batch.label
        if cuda:
            label = label.cuda()
        score = score.transpose(0,1)      
        
        loss = criterion(score, label)
        losses.update(loss.item())
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()   
        acces.update(acc)
        if idx % 100 == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

In [12]:
config = Config()
model = SimpleQANet(config, word_field)
if config.use_cuda:
    model = model.cuda()

In [13]:
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                             lr=config.lr)

criterion = nn.CrossEntropyLoss()

In [15]:
save_path = config.model_name + '_epoch'+str(config.epochs) + '_lr'+ str(config.lr)+ '_useall'+ \
                str(config.use_all)

save_path = os.path.join(config.log_dir, save_path)
if not os.path.exists(save_path):
    os.makedirs(save_path)
    
print(save_path)

./logs/simpleQANet_epoch30_lr0.0005_useallFalse


In [16]:
from tensorboardX import SummaryWriter

writer = SummaryWriter(save_path)

In [None]:
best_acc = 0.0
for epoch in range(config.epochs):
    train_loss, train_acc = train(epoch, train_iter, model, criterion, optimizer, config.use_cuda)
    val_loss, val_acc = val(epoch, val_iter, model, criterion, config.use_cuda)
    
    writer.add_scalar('train_loss', train_loss, epoch+1)
    writer.add_scalar('val_loss', val_loss, epoch+1)
    writer.add_scalar('train_acc', train_acc, epoch+1)
    writer.add_scalar('val_acc', val_acc, epoch+1)
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), os.path.join(save_path, 'best.pth'))

  alphas = self.softmax(alphas)  # (bsz, sent_len)


epoch:0, idx:0/43738, loss:3.268662929534912, acc:0.0
epoch:0, idx:100/43738, loss:2.4790367336556463, acc:0.21782178217821782
epoch:0, idx:200/43738, loss:2.4029624050492373, acc:0.21393034825870647
epoch:0, idx:300/43738, loss:2.425575031827752, acc:0.23255813953488372
epoch:0, idx:400/43738, loss:2.418265098376093, acc:0.24688279301745636


In [None]:
best_acc