### 当前实验模型内容

1. Forward RNN

In [None]:
import os
import torch
import torch.nn as nn
import torchtext
from tensorboardX import SummaryWriter

from torchtext.data import NestedField, Field, RawField
from model import BiAttention, EncoderRNN, SelfAttention, EmbeddingLayer
from dataset import DataHandler

os.environ["CUDA_VISIBLE_DEVICES"] = '2'

In [None]:
class Config:
    def __init__(self):
        self.hidden = 100
        self.embedding_dim = 300 + 100
        self.lr = 1e-4
        self.epochs = 30
        self.fix_length = None
        
        self.log_dir = './logs'
        self.model_name = 'simpleQANet_forwardrnn'
        self.batch_size = 1
        self.train_data = './data/train_example.pt'
        self.dev_data = './data/dev_example.pt'
        
        self.word_vocab = './data/glove_vocab.pt'
        self.charNGram_vocab = './data/charNGram_vocab.pt'
        
config = Config()
device = torch.device("cuda:0")


In [None]:
save_path = config.model_name + '_epochs_'+str(config.epochs) + '_lr_'+ str(config.lr)+ \
            '_batchsize_' + str(config.batch_size) + '_fixlength_' + str(config.fix_length)
save_path = os.path.join(config.log_dir, save_path)   
print(save_path)
config.save_path = save_path

### Define Fileds

In [None]:
word_field = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True) # query
multi_word_field = NestedField(word_field) 

word_field_sup = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True, fix_length=config.fix_length)
multi_word_field_sup = NestedField(word_field_sup) 

charNGram_field = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True) # query
multi_charNGram_field = NestedField(charNGram_field) 

charNGram_field_sup = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True, fix_length=config.fix_length)
multi_charNGram_field_sup = NestedField(charNGram_field_sup) 

raw = RawField()
raw.is_target = False

label_field = Field(sequential=False, is_target=True, use_vocab=False)

dict_field = {
    'id': ('id', raw),
    'supports': [('s_glove', multi_word_field_sup), ('s_charNGram', multi_charNGram_field_sup)],
    'query': [('q_glove', word_field), ('q_charNGram', charNGram_field)],
    'candidates': [('c_glove', multi_word_field), ('c_charNGram', multi_charNGram_field)],
    'label': ('label', label_field)
}

In [None]:
data_handler = DataHandler(config.train_data, config.dev_data, dict_field)

# torch.save(data_handler.trainset.examples, './data/train_example.pt')
# torch.save(data_handler.valset.examples, './data/dev_example.pt')

### Build Vocab

In [None]:
if config.charNGram_vocab is not None:
    charNGram_vocab = torch.load(config.charNGram_vocab)
    charNGram_field_sup.vocab = charNGram_vocab
else:
    charNGram_field_sup.build_vocab(data_handler.trainset, data_handler.valset, 
                                          vectors=torchtext.vocab.CharNGram())

if config.word_vocab is not None:
    word_vocab = torch.load(config.word_vocab)
    word_field_sup.vocab = word_vocab
else:
    word_field_sup.build_vocab(data_handler.trainset, data_handler.valset, 
                                 vectors=torchtext.vocab.GloVe(dim=300,name='6B') )

word_field.vocab = word_field_sup.vocab
charNGram_field.vocab = charNGram_field_sup.vocab

# torch.save(word_field.vocab, './data/glove_vocab.pt')
# torch.save(charNGram_field.vocab, './data/charNGram_vocab.pt')

### Get data_iter

In [None]:
train_iter = data_handler.get_train_iter(batch_size=1)
val_iter = data_handler.get_val_iter(batch_size=1)

In [None]:
for idx, batch in enumerate(val_iter):
    break
batch

### Define Model

In [None]:
class SimpleQANet(nn.Module):
    
    def __init__(self, config, word_vectors, charNGram_vectors, device):
        super(SimpleQANet, self).__init__()
        self.config = config
        self.device = device
        
        self.embedding_layer = EmbeddingLayer(word_vectors, charNGram_vectors)
        self.rnn = EncoderRNN(config.embedding_dim, config.hidden, 1, True, True, 0.2, False)
        
        self.qc_att = BiAttention(config.hidden*2, 0.2)
        self.linear_1 = nn.Sequential(
                nn.Linear(config.hidden*8, config.hidden),
                nn.ReLU()
        )    
        
        self.rnn_2 = EncoderRNN(config.hidden, config.hidden, 1, False, True, 0.2, False)
        
        self.self_att = SelfAttention(config.hidden*2, config.hidden*2, 0.2)   
        
        self.forward_rnn = EncoderRNN(config.hidden*2, config.hidden, 1, False, True, 0.2, False)
        
        self.self_att_2 = SelfAttention(config.hidden*2, config.hidden*2, 0.2)        
        
        self.self_att_c = SelfAttention(config.hidden*2, config.hidden*2, 0.2)    
        
        self.linear_2 = nn.Linear(config.hidden*2, config.hidden*2, bias=False)
        self.to(device)
        
    def forward(self, batch, return_label = True):
        if type(batch.q_glove) is tuple:
            q_glove, _ = batch.q_glove
            q_charNGram, _ = batch.q_charNGram
        else:
            q_glove = batch.q_glove
            q_charNGram = batch.q_charNGram            
        
        s_glove = batch.s_glove.squeeze(0)
        s_charNGram = batch.s_charNGram.squeeze(0)
        
        c_glove = batch.c_glove.squeeze(0)
        c_charNGram = batch.c_charNGram.squeeze(0)
        
        q_glove = q_glove.to(self.device)
        q_charNGram = q_charNGram.to(self.device)

        s_glove = s_glove.to(self.device)
        s_charNGram = s_charNGram.to(self.device)

        c_glove = c_glove.to(self.device)
        c_charNGram = c_charNGram.to(self.device)
        
        
        q_out = self.embedding_layer(q_glove, q_charNGram)
        s_out = self.embedding_layer(s_glove, s_charNGram,)
        c_out = self.embedding_layer(c_glove, c_charNGram)
        # print(f'question, supports, candidates: {q_out.shape}, {s_out.shape}, {c_out.shape}')
        q_out = self.rnn(q_out)
        c_out = self.rnn(c_out)
        s_out = self.rnn(s_out)


        support_len = s_out.size(0)
        q_out = q_out.expand(support_len, q_out.size(1), q_out.size(2))
        
        # s_out:[supports_len, seq_len, hidden*2], q_out: [support_len, seq_len, hidden*2]
        output = self.qc_att(s_out, q_out)
        output = self.linear_1(output)
        output = self.rnn_2(output)
        
        
        
        # self-attention pooling 
        # [support_len, hidden*2]
        output = self.self_att(output)
        output = output.unsqueeze(0)
        output = self.forward_rnn(output)
        
        # [1, hidden*2]
        output = self.self_att_2(output)

        # [candidate_len, hidden*2]
        c_out = self.self_att_c(c_out)
        
        # Score [1, candidates]
        out1 = self.linear_2(output)
        score = torch.mm(out1, c_out.transpose(0,1))
        
        if return_label:
            label = batch.label.to(self.device)
            return score, label
        return score

#### test model

In [None]:
model = SimpleQANet(config, word_field.vocab.vectors, charNGram_field.vocab.vectors, device)
score, label = model(batch)
print(score.shape, label.shape)

In [None]:
from utils import AverageMeter

def train(epoch, data_iter, model, criterion, optimizer, batch_size=1):
    losses = AverageMeter()
    acces = AverageMeter()
    model.train()
    #model.embedding_layer.eval()
    for idx, batch in enumerate(data_iter):
        score, label = model(batch)
        
        loss = criterion(score, label)

        loss = loss / batch_size
        loss.backward()
        if (idx+1)%batch_size == 0 :
            optimizer.step()
            optimizer.zero_grad()        

        losses.update(loss.item()*batch_size)
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()   
        acces.update(acc)
        if (idx+1) % (batch_size*100) == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

def val(epoch, data_iter, model, criterion):
    losses = AverageMeter()
    acces = AverageMeter()
    model.eval()
    for idx, batch in enumerate(data_iter):
        with torch.no_grad():
            score, label = model(batch)
                    
        loss = criterion(score, label)
        losses.update(loss.item())
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()   
        acces.update(acc)
        if idx % 100 == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

In [None]:
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                             lr=config.lr)

criterion = nn.CrossEntropyLoss()

#train(0, train_iter, model, criterion, optimizer, batch_size=config.batch_size)
# val(0, val_iter, model,criterion)

In [None]:
if not os.path.exists(config.save_path):
    os.makedirs(config.save_path)
writer = SummaryWriter(config.save_path)

best_acc = 0.0
for epoch in range(config.epochs):
    train_loss, train_acc = train(epoch, train_iter, model, criterion, optimizer, 
                                     config.batch_size)
    val_loss, val_acc = val(epoch, val_iter, model, criterion)
    
    writer.add_scalar('train_loss', train_loss, epoch+1)
    writer.add_scalar('val_loss', val_loss, epoch+1)
    writer.add_scalar('train_acc', train_acc, epoch+1)
    writer.add_scalar('val_acc', val_acc, epoch+1)
    
    state = {
        'val_acc': val_acc,
        'train_acc': train_acc,
        'epoch': epoch,
        'model': model.state_dict()
    }
    torch.save(state, os.path.join(config.save_path,'lastest.pth'))
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(state, os.path.join(save_path, 'best.pth'))