### 当前实验模型内容

1. use mentions
2. 63.3

3. add passage score

model_name |  param | dev_acc|
---| --- | ---
use mentions | lr=1e-3,hidden=50 | 63.3
um-ps | 5e-4, 50 dropout:0.2| 64.94

In [None]:
import os
import torch
import torch.nn as nn
import torchtext
from tensorboardX import SummaryWriter
import random
import numpy as np

from torchtext.data import NestedField, Field, RawField
from model import *
from dataset import DataHandler
%load_ext autoreload

%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

In [None]:
class Config:
    def __init__(self):
        self.hidden = 50
        self.embedding_dim = 300 + 100
        self.lr = 5e-4
        self.epochs = 50
        self.fix_length = None
        
        self.log_dir = './logs'
        self.model_name = 'CFC_um_ps_50'
        self.batch_size = 4
        self.train_data = './data/train_filter.pt'
        self.dev_data = './data/dev_filter.pt'
        
        self.word_vocab = './data/glove_vocab.pt'
        self.charNGram_vocab = './data/charNGram_vocab.pt'
        
        self.dropout = 0.2
        self.seed = 1023
        
config = Config()
device = torch.device("cuda:0")


In [None]:
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)

In [None]:
save_path = config.model_name  + '_lr_'+ str(config.lr)+ '__hidden__' + str(config.hidden) \
            + '_batchsize_' + str(config.batch_size) +  '_p'+ str(config.dropout)
save_path = os.path.join(config.log_dir, save_path)   
print(save_path)
config.save_path = save_path

### Define Fileds

In [None]:
word_field = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True) # query
multi_word_field = NestedField(word_field) 

word_field_sup = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True, fix_length=config.fix_length)
multi_word_field_sup = NestedField(word_field_sup) 

charNGram_field = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True) # query
multi_charNGram_field = NestedField(charNGram_field) 

charNGram_field_sup = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True, fix_length=config.fix_length)
multi_charNGram_field_sup = NestedField(charNGram_field_sup) 

raw = RawField()
raw.is_target = False

label_field = Field(sequential=False, is_target=True, use_vocab=False)

dict_field = {
    'id': ('id', raw),
    'supports': [('s_glove', multi_word_field_sup), ('s_charNGram', multi_charNGram_field_sup)],
    'query': [('q_glove', word_field), ('q_charNGram', charNGram_field)],
    'candidates': [('c_glove', multi_word_field), ('c_charNGram', multi_charNGram_field)],
    'label': ('label', label_field),
    'mentions': ('mentions', raw),
    'para_label': ('para_label', raw)
}

In [None]:
data_handler = DataHandler(config.train_data, config.dev_data, dict_field)

# torch.save(data_handler.trainset.examples, './data/train_example.pt')
# torch.save(data_handler.valset.examples, './data/dev_example.pt')

In [None]:
from tqdm import tqdm

def add_mentions(examples):

    for example in tqdm(examples):
        candidates = example.c_glove
        supports = example.s_glove


        all_mentions = []

        for candidate in candidates:
            mentions = []
            c = ' '.join(candidate)
            for idx, support in enumerate(supports):

                for i in range(len(support)):
                    token = support[i]
                    if token == candidate[0]:
                        s = ' '.join(support[i:i+len(candidate)])
                        if s == c:
                            mentions.append([idx, i, i+len(candidate)])
            all_mentions.append(mentions)
            
        example.mentions = all_mentions
        
def add_para_label(examples):
    filter_examples=  []
    for example in tqdm(examples):
        candidates = example.c_glove
        supports = example.s_glove    

        label = example.label
        mentions = example.mentions
        answser_mentions = mentions[label]
        if len(answser_mentions) != 0:
            para_label = [0]*len(supports)
            for mentions in answser_mentions:
                para_label[mentions[0]] = 1
            example.para_label = para_label
            filter_examples.append(example)
    print(f'before filter: {len(examples)}, after:{len(filter_examples)}')
    return filter_examples

#add_mentions(data_handler.valset.examples)
#add_mentions(data_handler.trainset.examples)

#train_filter_examples  = add_para_label(data_handler.trainset.examples)
#dev_filter_examples  = add_para_label(data_handler.valset.examples)

#torch.save(train_filter_examples, './data/train_filter.pt')
#torch.save(dev_filter_examples, './data/dev_filter.pt')        

In [None]:
#add_mentions(data_handler.valset.examples)
#add_mentions(data_handler.trainset.examples)

### Build Vocab

In [None]:
if config.charNGram_vocab is not None:
    charNGram_vocab = torch.load(config.charNGram_vocab)
    charNGram_field_sup.vocab = charNGram_vocab
else:
    charNGram_field_sup.build_vocab(data_handler.trainset, data_handler.valset, 
                                          vectors=torchtext.vocab.CharNGram())

if config.word_vocab is not None:
    word_vocab = torch.load(config.word_vocab)
    word_field_sup.vocab = word_vocab
else:
    word_field_sup.build_vocab(data_handler.trainset, data_handler.valset, 
                                 vectors=torchtext.vocab.GloVe(dim=300,name='6B') )

word_field.vocab = word_field_sup.vocab
charNGram_field.vocab = charNGram_field_sup.vocab

# torch.save(word_field.vocab, './data/glove_vocab.pt')
# torch.save(charNGram_field.vocab, './data/charNGram_vocab.pt')

### Get data_iter

In [None]:
train_iter = data_handler.get_train_iter(batch_size=config.batch_size)
val_iter = data_handler.get_val_iter(batch_size=config.batch_size)

In [None]:
for idx, batch in enumerate(val_iter):
    break
batch

In [None]:
def get_para_label(batch):
    para_size = batch.s_glove.size(1)
    results = []
    for label in batch.para_label:
        padding = [0]*(para_size - len(label))
        n_label = label[:]
        n_label += padding
        results.append(n_label)
    results = torch.tensor(results, dtype=torch.long)
    return results

### Define Model

In [None]:
class SimpleQANet(nn.Module):
    
    def __init__(self, config, word_vectors, charNGram_vectors, device):
        super(SimpleQANet, self).__init__()
        self.config = config
        self.device = device
        
        self.embedding_layer = EmbeddingLayer(word_vectors, charNGram_vectors)
        

        self.rnn = EncoderRNN(config.embedding_dim, config.hidden, 1, True, True, config.dropout, False)
                
            
        self.co_att = CoAttention(config.hidden*2, att_type=2, dropout=config.dropout)
        
        self.linear_1 = nn.Sequential(
                        nn.Linear(config.hidden*4, config.hidden),
                        nn.ReLU()
                    )        
        self.rnn2 =  EncoderRNN(config.hidden, config.hidden, 1, True, True, config.dropout, False)
        
        self.word_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        self.word_att_q = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        self.p_score = BilinearSeqAttn(config.hidden*2, config.hidden*2, identity=False, dropout=config.dropout)
        
        self.pass_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        self.c_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        self.mention_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        
        #self.fusion = FusionLayer(config.hidden*2, dropout=config.dropout)
        self.max_pooling = PoolingLayer()     
        
        self.fc = nn.Linear(config.hidden*2, config.hidden*4)
        
        self.to(device)
        
    def forward(self, batch, return_label = True):
        if type(batch.q_glove) is tuple:
            q_glove, _ = batch.q_glove
            q_charNGram, _ = batch.q_charNGram
        else:
            q_glove = batch.q_glove
            q_charNGram = batch.q_charNGram            
        
        s_glove = batch.s_glove
        s_charNGram = batch.s_charNGram
        
        c_glove = batch.c_glove
        c_charNGram = batch.c_charNGram
        
        q_glove = q_glove.to(self.device)
        q_charNGram = q_charNGram.to(self.device)

        s_glove = s_glove.to(self.device)
        s_charNGram = s_charNGram.to(self.device)

        c_glove = c_glove.to(self.device)
        c_charNGram = c_charNGram.to(self.device)
        
        ### Embedding and Encoder
        
        q_out = self.embedding_layer(q_glove, q_charNGram)
        s_out = self.embedding_layer(s_glove, s_charNGram,)
        c_out = self.embedding_layer(c_glove, c_charNGram)
        
        batch_size=  s_out.size(0)
        
        s_len = s_out.size(1)
        c_len = c_out.size(1)
        
        s_word_len = s_out.size(2)
        c_word_len = c_out.size(2)
        
        hidden = s_out.size(-1)
        
        s_out = s_out.view(batch_size*s_len, s_word_len, hidden).contiguous()
        c_out = c_out.view(batch_size*c_len, c_word_len, hidden).contiguous()
        
        q_out = self.rnn(q_out)
        c_out = self.rnn(c_out)
        s_out = self.rnn(s_out)
        
        # Attention
        
        q_word_len = q_out.size(1)
        q_out_expand = q_out.unsqueeze(1).expand(batch_size, s_len, q_word_len, q_out.size(-1)).contiguous()
        q_out_expand = q_out_expand.view(batch_size*s_len, q_word_len, q_out.size(-1)).contiguous()
        
        s_out_att, q_out_att = self.co_att(s_out, q_out_expand)
        #S_s = self.fusion(s_out, s_out_att)
        #S_q = self.fusion(q_out, q_out_att)
        
        S_s = self.linear_1(s_out_att)
        S_s = self.rnn2(S_s)
        
        
        
        batch_c_m = []
        for i in range(batch_size):
            # get mention embedding
            mentions = batch.mentions[i]
            c_ms = torch.zeros(c_len, s_len, s_out.size(-1))
            for idx, c_mention in enumerate(mentions):
                c_m_dict = {}
                for mention in c_mention:
                    m = s_out[i*s_len + mention[0]][mention[1]:mention[2]]
                    m = self.max_pooling(m.unsqueeze(0)).squeeze()
                    if mention[0] not in c_m_dict:
                        c_m_dict[mention[0]] = []
                    c_m_dict[mention[0]].append(m)
                c_m = torch.zeros(s_len, s_out.size(-1))
                for key in c_m_dict:
                    for m in c_m_dict[key]:
                        c_m[key] += m.cpu()
                    c_m[key] /= len(c_m_dict[key])
                c_ms[idx] = c_m
            batch_c_m.append(c_ms)
        batch_c_m = torch.stack(batch_c_m)
        batch_c_m = batch_c_m.to(self.device)
        batch_c_m = batch_c_m.view(batch_size*c_len, s_len, -1)
        batch_c_m = self.mention_att(batch_c_m)
        batch_c_m = batch_c_m.view(batch_size, c_len, -1)
        

        
        C_s = self.word_att(S_s)
        C_q = self.word_att_q(q_out)
        
        C_s = C_s.view(batch_size, s_len, -1)
        
        P_scores = self.p_score(C_s, C_q)
        
        C_s = self.pass_att(C_s) 

        
        C_c = self.c_att(c_out)        
        C_c = C_c.view(batch_size, c_len, -1)
        
        C_c = torch.cat([C_c, batch_c_m],-1)
        
        C_s = torch.tanh(self.fc(C_s))
        
        score = torch.bmm(C_c, C_s.unsqueeze(-1))
        score = score.squeeze(-1)
        if return_label:
            label = batch.label.to(self.device)
            P_label = get_para_label(batch)
            P_label = P_label.to(device)
            return score, P_scores, label, P_label
        return score, P_scores

#### test model

In [None]:
model = SimpleQANet(config, word_field.vocab.vectors, charNGram_field.vocab.vectors, device)
#score,P_score, label, P_label = model(batch)
#print(score, label, P_score, P_label)
#print(score.shape, label.shape, P_score.shape, P_label.shape)

In [None]:
from utils import AverageMeter

def train(epoch, data_iter, model, criterion, criterion_bce, optimizer, batch_size=1, joint_begin=-1):
    losses = AverageMeter()
    acces = AverageMeter()
    model.train()
    #model.embedding_layer.eval()
    for idx, batch in enumerate(data_iter):
        score, P_score, label, P_label = model(batch)

        loss1 = criterion(score, label)
        loss_p = criterion_bce(P_score, P_label.float())
        if epoch > joint_begin:
            loss = loss1 + loss_p     
        else:
            loss = loss_p

        loss = loss / batch_size
        loss.backward()
        if (idx+1)%batch_size == 0 :
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)            
            optimizer.step()
            optimizer.zero_grad()        

        losses.update(loss.item()*batch_size)
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()  / pred.size(0)
        acces.update(acc)
        if (idx+1) % (batch_size*100) == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

def val(epoch, data_iter, model, criterion, criterion_bce, joint_begin=-1):
    losses = AverageMeter()
    acces = AverageMeter()
    model.eval()
    for idx, batch in enumerate(data_iter):
        with torch.no_grad():
            score, P_score, label, P_label = model(batch)
                    
        loss1 = criterion(score, label)
        loss_p = criterion_bce(P_score, P_label.float())
        if epoch > joint_begin:
            loss = loss1 + loss_p     
        else:
            loss = loss_p
        
        losses.update(loss.item())
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()  / pred.size(0)
        acces.update(acc)
        if idx % 100 == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

In [None]:
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                             lr=config.lr)

criterion = nn.CrossEntropyLoss()
criterion_bce = nn.BCEWithLogitsLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
#train(0, train_iter, model, criterion, optimizer, batch_size=config.batch_size)
# val(0, val_iter, model,criterion)

In [None]:
if not os.path.exists(config.save_path):
    os.makedirs(config.save_path)
writer = SummaryWriter(config.save_path)

best_acc = 0.0
for epoch in range(config.epochs):
    scheduler.step()
    train_loss, train_acc = train(epoch, train_iter, model, criterion, criterion_bce, optimizer, 
                                     1)
    val_loss, val_acc = val(epoch, val_iter, model, criterion, criterion_bce)
    writer.add_scalar('train_loss', train_loss, epoch+1)
    writer.add_scalar('val_loss', val_loss, epoch+1)
    writer.add_scalar('train_acc', train_acc, epoch+1)
    writer.add_scalar('val_acc', val_acc, epoch+1)
    
    state = {
        'val_acc': val_acc,
        'train_acc': train_acc,
        'epoch': epoch,
        'model': model.state_dict()
    }
    #torch.save(state, os.path.join(config.save_path,'lastest.pth'))
    if val_acc > best_acc:
        best_acc = val_acc
        #torch.save(state, os.path.join(save_path, 'best.pth'))

In [None]:
score, P_score, label, P_label = model(batch)
        
loss1 = criterion(score, label)
loss_p = criterion_bce(P_score, P_label.float())

loss = loss1 + loss_p

loss.backward()
#optimizer.step()
#optimizer.zero_grad()        

In [None]:
best_acc

In [None]:
model.p_score.linear.weight.grad

In [None]:
model.mention_att.ws1.weight.grad

In [None]:
outputs = []
def hook(module, input, output):
    outputs.append(output)

In [None]:
hook1 = model.fc.register_forward_hook()
hook12 = model.