In [2]:
from torchtext import data
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from dataset import DataHandler, BertField
import torch.nn as nn
import torch
from model import BiAttention, EncoderRNN, SelfAttention
import os
import torchtext
from tensorboardX import SummaryWriter

os.environ["CUDA_VISIBLE_DEVICES"] = '2'

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
train_examples_path = './train_examples.pt'
val_examples_path = './val_examples.pt'

In [4]:
tokenizer = BertTokenizer.from_pretrained('./bert-base-uncased-vocab.txt', do_lower_case=True)

bert_field = BertField(tokenizer)
multi_bert_field = data.NestedField(bert_field)



word_field = data.Field(batch_first=True, sequential=True, tokenize=tokenizer.tokenize, lower=True) # query
multi_word_field = data.NestedField(word_field) 

word_field_sup = data.Field(batch_first=True, sequential=True, tokenize=tokenizer.tokenize, lower=True, fix_length=320)
multi_word_field_sup = data.NestedField(word_field_sup) 

bert_field_sup = BertField(tokenizer, fix_length=320)
multi_bert_field_sup = data.NestedField(bert_field_sup)

raw = data.RawField()
raw.is_target = False

label_field = data.Field(sequential=False, is_target=True, use_vocab=False)

dict_field = {
    'id': ('id', raw),
    'supports': [('s_glove', multi_word_field_sup), ('s_bert', multi_bert_field_sup)],
    'query': [('q_glove', word_field), ('q_bert', bert_field)],
    'answer': [('a_glove', word_field), ('a_bert', bert_field)],
    'candidates': [('c_glove', multi_word_field), ('c_bert', multi_bert_field)],
    'label': ('label', label_field)

}

In [5]:
data_handler = DataHandler(train_examples_path, val_examples_path, dict_field)

load examples.pt  :./train_examples.pt, ./val_examples.pt


In [6]:
multi_word_field_sup.build_vocab(data_handler.trainset, data_handler.valset, vectors=torchtext.vocab.GloVe(dim=300,name='6B') )
word_field.vocab = multi_word_field_sup.vocab
word_field.include_lengths = True

In [7]:
train_iter = data_handler.get_train_iter(batch_size=1)
val_iter = data_handler.get_val_iter(batch_size=1)

### Embedding

这一层需要频繁的改动，所以暂时不放在py文件中

In [8]:
class EmbeddingLayer(nn.Module):
    
    def __init__(self, word_field, bert_model_path='./bert-base-uncased/', use_all=False):
        super(EmbeddingLayer, self).__init__()
        self.word_embedding_layer = nn.Embedding.from_pretrained(embeddings=word_field.vocab.vectors)
        
        model = BertModel.from_pretrained(bert_model_path)   
        self.bert_model = model
        
        self.use_all = use_all
        self.freeze()
        
    def freeze(self):
        for param in self.bert_model.parameters():
            param.requires_grad = False
        self.word_embedding_layer.weight.requires_grad = False
        
    def forward(self, word_tokens, bert_tokens, input_mask=None):
        '''
        input:
            x: [batch_size, seg_len]
        
        return embeddings: [batch_size, seq_len, glove_dim + bert_dim]    
        '''
        word_embeddings = self.word_embedding_layer(word_tokens)
        
        # encoded_layers: [batch_size, seq_len, bert_embedding_dim] * num_of_layers
        encoded_layers, _ = self.bert_model(bert_tokens, attention_mask=input_mask)
        
        bert_embeddings = torch.zeros_like(encoded_layers[-1])
        if self.use_all:
            for layer in encoded_layers:
                bert_embeddings += layer
            bert_embeddings /= len(encoded_layers)
        else:
            bert_embeddings += encoded_layers[-1]
        
        out = torch.cat([word_embeddings, bert_embeddings], dim=-1)
        return out
        

In [9]:
class SimpleQANet(nn.Module):
    
    def __init__(self, config, word_field):
        super(SimpleQANet, self).__init__()
        self.config = config
        self.use_cuda = config.use_cuda
        
        self.embedding_layer = EmbeddingLayer(word_field, config.bert_path, config.use_all)
        self.rnn = EncoderRNN(config.word_dim + config.bert_dim, config.hidden, 1, True, True, 0.2, False)
        
        self.qc_att = BiAttention(config.hidden*2, 0.2)
        self.linear_1 = nn.Sequential(
                nn.Linear(config.hidden*8, config.hidden),
                nn.ReLU()
        )    
        
        self.rnn_2 = EncoderRNN(config.hidden, config.hidden, 1, False, True, 0.2, False)
        
        self.self_att = SelfAttention(config.hidden*2, config.hidden*2, 0.2)       
        self.self_att_2 = SelfAttention(config.hidden*2, config.hidden*2, 0.2)        
        
        self.self_att_c = SelfAttention(config.hidden*2, config.hidden*2, 0.2)        
        
        
    def forward(self, batch):
        q_glove, _ = batch.q_glove
        q_bert = batch.q_bert
        s_glove = batch.s_glove
        s_bert = batch.s_bert
        c_glove = batch.c_glove
        c_bert = batch.c_bert
        
        if self.use_cuda:
            q_glove = q_glove.cuda()
            q_bert = q_bert.cuda()
            s_glove = s_glove.cuda().squeeze(0)
            s_bert = s_bert.cuda().squeeze(0)
            c_glove = c_glove.cuda().squeeze(0)
            c_bert = c_bert.cuda().squeeze(0)
            
        # Embedding 
        context_mask = (s_bert > 0).float()
        ques_mask = (q_bert > 0).float()
        
        q_out = self.embedding_layer(q_glove, q_bert)
        s_out = self.embedding_layer(s_glove, s_bert, context_mask)
        c_out = self.embedding_layer(c_glove, c_bert)

        q_out = self.rnn(q_out)
        c_out = self.rnn(c_out)
        
        s_out = self.rnn(s_out)

        # bi-attention on supports and  question
        context_mask = (c_bert.squeeze() > 0).float()
        ques_mask = (q_bert > 0).float()
        
        support_len = s_out.size(0)
        q_out = q_out.expand(support_len, q_out.size(1), q_out.size(2))
        ques_mask = ques_mask.expand(support_len, q_out.size(1))        
        
        # s_out:[supports_len, seq_len, hidden*2], q_out: [support_len, seq_len, hidden*2]
        output = self.qc_att(s_out, q_out, ques_mask)
        output = self.linear_1(output)
        output = self.rnn_2(output)
        
        # self-attention pooling 
        # [support_len, hidden*2]
        output = self.self_att(output)
        # [1, hidden*2]
        output = self.self_att_2(output.unsqueeze(0))

        # [candidate_len, hidden*2]
        c_out = self.self_att_c(c_out)
        
        # score layer
        score = torch.mm(c_out, torch.tanh(output.transpose(0, 1)))
        return score

In [10]:
class Config:
    
    def __init__(self):
        self.hidden = 100
        self.word_dim = 300
        self.bert_dim = 768
        self.use_cuda = True
        self.bert_path = './bert-base-uncased/'
        self.use_all = True
        self.lr = 1e-4
        self.epochs = 30
        self.log_dir = './logs'
        self.model_name = 'simpleQANet'
        self.batch_size = 8

In [11]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [12]:
def train(epoch, data_iter, model, criterion, optimizer, cuda, batch_size=1):
    losses = AverageMeter()
    acces = AverageMeter()
    model.train()
    #model.embedding_layer.eval()
    for idx, batch in enumerate(data_iter):
        score = model(batch)
        label = batch.label
        if cuda:
            label = label.cuda()
        score = score.transpose(0,1)      
        
        loss = criterion(score, label)

        loss = loss / batch_size
        loss.backward()
        if (idx+1)%batch_size == 0 :
            optimizer.step()
            optimizer.zero_grad()        

        losses.update(loss.item())
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()   
        acces.update(acc)
        if (idx+1) % (batch_size*100) == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

def val(epoch, data_iter, model, criterion, cuda):
    losses = AverageMeter()
    acces = AverageMeter()
    model.eval()
    for idx, batch in enumerate(data_iter):
        with torch.no_grad():
            score = model(batch)
            
        label = batch.label
        if cuda:
            label = label.cuda()
        score = score.transpose(0,1)      
        
        loss = criterion(score, label)
        losses.update(loss.item())
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()   
        acces.update(acc)
        if idx % 100 == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

In [13]:
config = Config()
model = SimpleQANet(config, word_field)
if config.use_cuda:
    model = model.cuda()

In [14]:
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                             lr=config.lr)

criterion = nn.CrossEntropyLoss()

In [15]:
save_path = config.model_name + '_epoch'+str(config.epochs) + '_lr'+ str(config.lr)+ '_useall'+ \
                str(config.use_all) + '_batchsize' + str(config.batch_size)

save_path = os.path.join(config.log_dir, save_path)
if not os.path.exists(save_path):
    os.makedirs(save_path)
    
print(save_path)

writer = SummaryWriter(save_path)

./logs/simpleQANet_epoch30_lr0.0001_useallTrue_batchsize8


In [None]:
best_acc = 0.0
for epoch in range(config.epochs):
    train_loss, train_acc = train(epoch, train_iter, model, criterion, optimizer, 
                                  config.use_cuda, config.batch_size)
    val_loss, val_acc = val(epoch, val_iter, model, criterion, config.use_cuda)
    
    writer.add_scalar('train_loss', train_loss, epoch+1)
    writer.add_scalar('val_loss', val_loss, epoch+1)
    writer.add_scalar('train_acc', train_acc, epoch+1)
    writer.add_scalar('val_acc', val_acc, epoch+1)
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), os.path.join(save_path, 'best.pth'))

  alphas = self.softmax(alphas)  # (bsz, sent_len)


epoch:0, idx:799/43738, loss:0.28659709547646345, acc:0.265
epoch:0, idx:1599/43738, loss:0.2743978534586495, acc:0.296875
epoch:0, idx:2399/43738, loss:0.27228552452870647, acc:0.3070833333333333
epoch:0, idx:3199/43738, loss:0.268721807145339, acc:0.3184375
epoch:0, idx:3999/43738, loss:0.26731599562172775, acc:0.32175
epoch:0, idx:4799/43738, loss:0.26574684386413233, acc:0.3258333333333333
epoch:0, idx:5599/43738, loss:0.2635271064802406, acc:0.3291071428571429
epoch:0, idx:6399/43738, loss:0.261661778416601, acc:0.3309375
epoch:0, idx:7199/43738, loss:0.25964686856229996, acc:0.3347222222222222
epoch:0, idx:7999/43738, loss:0.25868207126075865, acc:0.337125
epoch:0, idx:8799/43738, loss:0.25706368385930545, acc:0.3427272727272727
epoch:0, idx:9599/43738, loss:0.2563256980104294, acc:0.34510416666666666
epoch:0, idx:10399/43738, loss:0.2555852147268328, acc:0.34615384615384615
epoch:0, idx:11199/43738, loss:0.2546444391916988, acc:0.34794642857142855
epoch:0, idx:11999/43738, loss:

epoch:1, idx:7999/43738, loss:0.2055365326437168, acc:0.456125
epoch:1, idx:8799/43738, loss:0.2057263868378306, acc:0.4553409090909091
epoch:1, idx:9599/43738, loss:0.20574462676692443, acc:0.45489583333333333
epoch:1, idx:10399/43738, loss:0.20619887379571222, acc:0.4530769230769231
epoch:1, idx:11199/43738, loss:0.20683590275767658, acc:0.45026785714285716
epoch:1, idx:11999/43738, loss:0.2065541827729903, acc:0.4513333333333333
epoch:1, idx:12799/43738, loss:0.20572538802487542, acc:0.455078125
epoch:1, idx:13599/43738, loss:0.2055854808707612, acc:0.45544117647058824
epoch:1, idx:14399/43738, loss:0.20571503492777185, acc:0.4542361111111111
epoch:1, idx:15199/43738, loss:0.2056441273185498, acc:0.45480263157894735
epoch:1, idx:15999/43738, loss:0.205246537637664, acc:0.4558125
epoch:1, idx:16799/43738, loss:0.20491843341478325, acc:0.45613095238095236
epoch:1, idx:17599/43738, loss:0.2045202373347605, acc:0.4575568181818182
epoch:1, idx:18399/43738, loss:0.20404567944619845, acc:0

epoch:2, idx:15199/43738, loss:0.18893969198186442, acc:0.4869078947368421
epoch:2, idx:15999/43738, loss:0.18883445119240788, acc:0.4870625
epoch:2, idx:16799/43738, loss:0.18826402264863384, acc:0.4892261904761905
epoch:2, idx:17599/43738, loss:0.18816463076028, acc:0.4900568181818182
epoch:2, idx:18399/43738, loss:0.18777470174528982, acc:0.49130434782608695
epoch:2, idx:19199/43738, loss:0.18751743714334831, acc:0.4925
epoch:2, idx:19999/43738, loss:0.18747297888682224, acc:0.4933
epoch:2, idx:20799/43738, loss:0.18708908152139675, acc:0.49514423076923075
epoch:2, idx:21599/43738, loss:0.18715216019432301, acc:0.49527777777777776
epoch:2, idx:22399/43738, loss:0.1872454044004969, acc:0.4959375
epoch:2, idx:23199/43738, loss:0.18725957362992882, acc:0.49530172413793105
epoch:2, idx:23999/43738, loss:0.18718895005361022, acc:0.49525
epoch:2, idx:24799/43738, loss:0.1869849894380182, acc:0.4958467741935484
epoch:2, idx:25599/43738, loss:0.18661253569352992, acc:0.497109375
epoch:2, id

epoch:3, idx:22399/43738, loss:0.17387903640537844, acc:0.5340178571428571
epoch:3, idx:23199/43738, loss:0.17414201459101228, acc:0.5335344827586207
epoch:3, idx:23999/43738, loss:0.17390148590622023, acc:0.5343333333333333
epoch:3, idx:24799/43738, loss:0.17353586397134238, acc:0.535
epoch:3, idx:25599/43738, loss:0.1733046379653797, acc:0.535546875
epoch:3, idx:26399/43738, loss:0.17329811796362982, acc:0.5358333333333334
epoch:3, idx:27199/43738, loss:0.1733474590874539, acc:0.5358455882352942
epoch:3, idx:27999/43738, loss:0.17305401336648252, acc:0.5367142857142857
epoch:3, idx:28799/43738, loss:0.1731525816052494, acc:0.536875
epoch:3, idx:29599/43738, loss:0.17339660238733, acc:0.5362837837837838
epoch:3, idx:30399/43738, loss:0.1732394272435225, acc:0.53625
epoch:3, idx:31199/43738, loss:0.17277638978766985, acc:0.5371794871794872
epoch:3, idx:31999/43738, loss:0.17291489277527217, acc:0.53703125
epoch:3, idx:32799/43738, loss:0.1727415617793073, acc:0.5373780487804878
epoch:3

epoch:4, idx:29599/43738, loss:0.16371314631998413, acc:0.5579054054054055
epoch:4, idx:30399/43738, loss:0.1635928927448664, acc:0.5579605263157895
epoch:4, idx:31199/43738, loss:0.16361875235786638, acc:0.5580448717948718
epoch:4, idx:31999/43738, loss:0.163355323785363, acc:0.55853125
epoch:4, idx:32799/43738, loss:0.162941727275114, acc:0.5593902439024391
epoch:4, idx:33599/43738, loss:0.16293588416027238, acc:0.5594642857142857
epoch:4, idx:34399/43738, loss:0.16298279079605912, acc:0.5591569767441861
epoch:4, idx:35199/43738, loss:0.16294060283665948, acc:0.5594602272727273
epoch:4, idx:35999/43738, loss:0.16271456345768334, acc:0.5599722222222222
epoch:4, idx:36799/43738, loss:0.16264376807955358, acc:0.5605434782608696
epoch:4, idx:37599/43738, loss:0.16264113996076357, acc:0.5603191489361702
epoch:4, idx:38399/43738, loss:0.16264504548402328, acc:0.559921875
epoch:4, idx:39199/43738, loss:0.16275095654002922, acc:0.559719387755102
epoch:4, idx:39999/43738, loss:0.1628568768330

epoch:5, idx:36799/43738, loss:0.1553879528299465, acc:0.581929347826087
epoch:5, idx:37599/43738, loss:0.1555203914304358, acc:0.5812765957446808
epoch:5, idx:38399/43738, loss:0.15548108225725932, acc:0.5812239583333333
epoch:5, idx:39199/43738, loss:0.15560503446400803, acc:0.5807908163265306
epoch:5, idx:39999/43738, loss:0.1554827303652535, acc:0.581075
epoch:5, idx:40799/43738, loss:0.1555644623934652, acc:0.5811519607843137
epoch:5, idx:41599/43738, loss:0.15548518679197654, acc:0.5814663461538462
epoch:5, idx:42399/43738, loss:0.155535976165139, acc:0.5817216981132075
epoch:5, idx:43199/43738, loss:0.15551344928895425, acc:0.581875
epoch:5, idx:0/5129, loss:2.427467107772827, acc:0.0
epoch:5, idx:100/5129, loss:1.5300386329688649, acc:0.44554455445544555
epoch:5, idx:200/5129, loss:1.5125855183719996, acc:0.46766169154228854
epoch:5, idx:300/5129, loss:1.468357071131963, acc:0.5116279069767442
epoch:5, idx:400/5129, loss:1.4564084883342658, acc:0.5211970074812967
epoch:5, idx:5

epoch:6, idx:100/5129, loss:1.5049036617326264, acc:0.46534653465346537
epoch:6, idx:200/5129, loss:1.4728945707207295, acc:0.5024875621890548
epoch:6, idx:300/5129, loss:1.4565832601037136, acc:0.5282392026578073
epoch:6, idx:400/5129, loss:1.4459782386955775, acc:0.5336658354114713
epoch:6, idx:500/5129, loss:1.3538801030008616, acc:0.562874251497006
epoch:6, idx:600/5129, loss:1.330689064377358, acc:0.5773710482529119
epoch:6, idx:700/5129, loss:1.335703962029472, acc:0.5763195435092725
epoch:6, idx:800/5129, loss:1.325396959701281, acc:0.5792759051186017
epoch:6, idx:900/5129, loss:1.323291491588927, acc:0.5749167591564928
epoch:6, idx:1000/5129, loss:1.3336367520180854, acc:0.5714285714285714
epoch:6, idx:1100/5129, loss:1.3277203310467134, acc:0.5667574931880109
epoch:6, idx:1200/5129, loss:1.339394301970138, acc:0.5645295587010825
epoch:6, idx:1300/5129, loss:1.3439612665559402, acc:0.5641813989239047
epoch:6, idx:1400/5129, loss:1.3201747920196623, acc:0.5731620271234832
epoch:

epoch:7, idx:1100/5129, loss:1.3246446199085797, acc:0.5685740236148955
epoch:7, idx:1200/5129, loss:1.338835012332089, acc:0.5653621981681932
epoch:7, idx:1300/5129, loss:1.3402689679323208, acc:0.5680245964642583
epoch:7, idx:1400/5129, loss:1.3119148703067325, acc:0.5788722341184868
epoch:7, idx:1500/5129, loss:1.3197808449543134, acc:0.5762824783477681
epoch:7, idx:1600/5129, loss:1.3097229914915405, acc:0.5777638975640225
epoch:7, idx:1700/5129, loss:1.302986242468395, acc:0.5778953556731334
epoch:7, idx:1800/5129, loss:1.2960802257425053, acc:0.5796779566907274
epoch:7, idx:1900/5129, loss:1.2975088266370673, acc:0.578642819568648
epoch:7, idx:2000/5129, loss:1.2894202236799166, acc:0.5797101449275363
epoch:7, idx:2100/5129, loss:1.2891846000013438, acc:0.5768681580199905
epoch:7, idx:2200/5129, loss:1.2848180110131757, acc:0.5751930940481599
epoch:7, idx:2300/5129, loss:1.2844975243724466, acc:0.5749674054758801
epoch:7, idx:2400/5129, loss:1.2840684380867937, acc:0.574760516451

epoch:8, idx:2000/5129, loss:1.2556607690320976, acc:0.5997001499250375
epoch:8, idx:2100/5129, loss:1.2596373079802978, acc:0.595430747263208
epoch:8, idx:2200/5129, loss:1.2564749604041119, acc:0.5938209904588824
epoch:8, idx:2300/5129, loss:1.2568385425480695, acc:0.5923511516731855
epoch:8, idx:2400/5129, loss:1.2579750244454113, acc:0.588088296543107
epoch:8, idx:2500/5129, loss:1.2589610436412155, acc:0.5881647341063575
epoch:8, idx:2600/5129, loss:1.2655678212081227, acc:0.5851595540176855
epoch:8, idx:2700/5129, loss:1.2648773762317818, acc:0.5879303961495742
epoch:8, idx:2800/5129, loss:1.2630650453516858, acc:0.5862192074259193
epoch:8, idx:2900/5129, loss:1.2619767414058582, acc:0.5863495346432265
epoch:8, idx:3000/5129, loss:1.273076828406557, acc:0.5854715094968344
epoch:8, idx:3100/5129, loss:1.2730927828174374, acc:0.5859400193485972
epoch:8, idx:3200/5129, loss:1.2766425780572581, acc:0.584817244611059
epoch:8, idx:3300/5129, loss:1.278600190312203, acc:0.58255074219933

epoch:9, idx:3000/5129, loss:1.289837769915743, acc:0.5891369543485505
epoch:9, idx:3100/5129, loss:1.2906027488802103, acc:0.5894872621734925
epoch:9, idx:3200/5129, loss:1.2922627451488913, acc:0.5895032802249297
epoch:9, idx:3300/5129, loss:1.2915434141148765, acc:0.5889124507724932
epoch:9, idx:3400/5129, loss:1.2840216336194048, acc:0.5901205527785945
epoch:9, idx:3500/5129, loss:1.2810832335310676, acc:0.5912596401028277
epoch:9, idx:3600/5129, loss:1.2758303899292116, acc:0.5923354623715634
epoch:9, idx:3700/5129, loss:1.2792953928296904, acc:0.5909213726019995
epoch:9, idx:3800/5129, loss:1.2761463906377812, acc:0.5935280189423836
epoch:9, idx:3900/5129, loss:1.2788592779336145, acc:0.5924122019994873
epoch:9, idx:4000/5129, loss:1.2800515231710616, acc:0.5931017245688578
epoch:9, idx:4100/5129, loss:1.2819286958073528, acc:0.593026091197269
epoch:9, idx:4200/5129, loss:1.276370114070678, acc:0.5936681742442276
epoch:9, idx:4300/5129, loss:1.2719902934164842, acc:0.594745408044

epoch:10, idx:3800/5129, loss:1.2687684281605915, acc:0.5953696395685346
epoch:10, idx:3900/5129, loss:1.2718437476796511, acc:0.5931812355806203
epoch:10, idx:4000/5129, loss:1.2728183479740809, acc:0.5933516620844789
epoch:10, idx:4100/5129, loss:1.2740553567225512, acc:0.593026091197269
epoch:10, idx:4200/5129, loss:1.2681641185696537, acc:0.5936681742442276
epoch:10, idx:4300/5129, loss:1.2645296730868514, acc:0.5945129039758196
epoch:10, idx:4400/5129, loss:1.2667462792424444, acc:0.5944103612815269
epoch:10, idx:4500/5129, loss:1.2605549232631597, acc:0.5954232392801599
epoch:10, idx:4600/5129, loss:1.2628146285072561, acc:0.5942186481199739
epoch:10, idx:4700/5129, loss:1.2638648841938194, acc:0.5943416294405446
epoch:10, idx:4800/5129, loss:1.2584926669680434, acc:0.5950843574255363
epoch:10, idx:4900/5129, loss:1.25946146529947, acc:0.5949806162007754
epoch:10, idx:5000/5129, loss:1.2608570052868389, acc:0.5940811837632474
epoch:10, idx:5100/5129, loss:1.257373096385941, acc:0

epoch:11, idx:4600/5129, loss:1.2546938504020537, acc:0.5913931753966529
epoch:11, idx:4700/5129, loss:1.2559093231122511, acc:0.5909380982769623
epoch:11, idx:4800/5129, loss:1.2525666363070196, acc:0.590918558633618
epoch:11, idx:4900/5129, loss:1.2517326969081006, acc:0.590695776372169
epoch:11, idx:5000/5129, loss:1.2519003593148148, acc:0.5896820635872826
epoch:11, idx:5100/5129, loss:1.2479787449542505, acc:0.5896882964124681
epoch:12, idx:799/43738, loss:0.13021203923737631, acc:0.635
epoch:12, idx:1599/43738, loss:0.12836231577326543, acc:0.64625
epoch:12, idx:2399/43738, loss:0.12782689435795572, acc:0.65
epoch:12, idx:3199/43738, loss:0.1270065000874456, acc:0.6521875
epoch:12, idx:3999/43738, loss:0.12617076353379525, acc:0.65375
epoch:12, idx:4799/43738, loss:0.12559310430757856, acc:0.6564583333333334
epoch:12, idx:5599/43738, loss:0.12381762449114052, acc:0.6614285714285715
epoch:12, idx:6399/43738, loss:0.12315032865924877, acc:0.66234375
epoch:12, idx:7199/43738, loss:0

epoch:13, idx:2399/43738, loss:0.11747758071714391, acc:0.68875
epoch:13, idx:3199/43738, loss:0.1152769587340299, acc:0.6940625
epoch:13, idx:3999/43738, loss:0.11515830614138395, acc:0.6935
epoch:13, idx:4799/43738, loss:0.11478907497172865, acc:0.6939583333333333
epoch:13, idx:5599/43738, loss:0.1148397826373444, acc:0.6923214285714285
epoch:13, idx:6399/43738, loss:0.11596359468094306, acc:0.689375
epoch:13, idx:7199/43738, loss:0.1154970343706211, acc:0.6894444444444444
epoch:13, idx:7999/43738, loss:0.11657355956002721, acc:0.68775
epoch:13, idx:8799/43738, loss:0.11718859655711143, acc:0.6859090909090909
epoch:13, idx:9599/43738, loss:0.11778687441423244, acc:0.6845833333333333
epoch:13, idx:10399/43738, loss:0.11778228643074937, acc:0.6854807692307693
epoch:13, idx:11199/43738, loss:0.11733460706386332, acc:0.6858035714285714
epoch:13, idx:11999/43738, loss:0.11767132386342079, acc:0.68475
epoch:13, idx:12799/43738, loss:0.11797605420486434, acc:0.682421875
epoch:13, idx:13599/

epoch:14, idx:8799/43738, loss:0.11410108470728367, acc:0.6927272727272727
epoch:14, idx:9599/43738, loss:0.11461364180771245, acc:0.6898958333333334
epoch:14, idx:10399/43738, loss:0.11487289185984992, acc:0.69
epoch:14, idx:11199/43738, loss:0.11465658691570363, acc:0.689375
epoch:14, idx:11999/43738, loss:0.1149478206470764, acc:0.6886666666666666
epoch:14, idx:12799/43738, loss:0.11568318519544846, acc:0.685703125
epoch:14, idx:13599/43738, loss:0.11599167751100407, acc:0.6849264705882353
epoch:14, idx:14399/43738, loss:0.11605929821477427, acc:0.6847222222222222
epoch:14, idx:15199/43738, loss:0.11637396225457594, acc:0.6840131578947368
epoch:14, idx:15999/43738, loss:0.11615818783265422, acc:0.6846875
epoch:14, idx:16799/43738, loss:0.11604217694656524, acc:0.6860714285714286
epoch:14, idx:17599/43738, loss:0.11584713875130877, acc:0.6864772727272728
epoch:14, idx:18399/43738, loss:0.11593713531964053, acc:0.6865217391304348
epoch:14, idx:19199/43738, loss:0.11569435956378583, ac

epoch:15, idx:15199/43738, loss:0.11251443849458347, acc:0.6932894736842106
epoch:15, idx:15999/43738, loss:0.11233393797084863, acc:0.694375
epoch:15, idx:16799/43738, loss:0.11209858408634318, acc:0.6950595238095238
epoch:15, idx:17599/43738, loss:0.11272932311772654, acc:0.6935227272727272
epoch:15, idx:18399/43738, loss:0.11258880846479492, acc:0.6942934782608695
epoch:15, idx:19199/43738, loss:0.11297052136567799, acc:0.6925
epoch:15, idx:19999/43738, loss:0.11327393752805656, acc:0.691
epoch:15, idx:20799/43738, loss:0.11338515772738342, acc:0.6905288461538461
epoch:15, idx:21599/43738, loss:0.11364825260863183, acc:0.6899537037037037
epoch:15, idx:22399/43738, loss:0.11393148065406422, acc:0.6889732142857142
epoch:15, idx:23199/43738, loss:0.1139213906927536, acc:0.6894827586206896
epoch:15, idx:23999/43738, loss:0.11395137404926936, acc:0.6894166666666667
epoch:15, idx:24799/43738, loss:0.11402424551611354, acc:0.6889112903225807
epoch:15, idx:25599/43738, loss:0.11369528540312

epoch:16, idx:21599/43738, loss:0.11128777706078297, acc:0.6993981481481482
epoch:16, idx:22399/43738, loss:0.1111506932078815, acc:0.7000446428571429
epoch:16, idx:23199/43738, loss:0.11086038054991715, acc:0.7003879310344827
epoch:16, idx:23999/43738, loss:0.11080017016382772, acc:0.699875
epoch:16, idx:24799/43738, loss:0.11073602253590868, acc:0.7002822580645162
epoch:16, idx:25599/43738, loss:0.11073554395949942, acc:0.70015625
epoch:16, idx:26399/43738, loss:0.11078574018684308, acc:0.7000757575757576
epoch:16, idx:27199/43738, loss:0.11058147512983209, acc:0.7005514705882353
epoch:16, idx:27999/43738, loss:0.11036811114778643, acc:0.7008928571428571
epoch:16, idx:28799/43738, loss:0.1103532761747758, acc:0.7005555555555556
epoch:16, idx:29599/43738, loss:0.11025656656721589, acc:0.7006418918918919
epoch:16, idx:30399/43738, loss:0.11008231362187776, acc:0.7011842105263157
epoch:16, idx:31199/43738, loss:0.10991511039848541, acc:0.7015064102564103
epoch:16, idx:31999/43738, loss:

epoch:17, idx:27199/43738, loss:0.10858318479587872, acc:0.7052941176470588
epoch:17, idx:27999/43738, loss:0.1086485939906644, acc:0.705
epoch:17, idx:28799/43738, loss:0.10852708270870304, acc:0.7051041666666666
epoch:17, idx:29599/43738, loss:0.10863349282465926, acc:0.7047635135135135
epoch:17, idx:30399/43738, loss:0.10855541923898272, acc:0.7052960526315789
epoch:17, idx:31199/43738, loss:0.1086237394105261, acc:0.7051923076923077
epoch:17, idx:31999/43738, loss:0.10862103512999602, acc:0.705375
epoch:17, idx:32799/43738, loss:0.10850056462580475, acc:0.7057317073170731
epoch:17, idx:33599/43738, loss:0.10865621286932202, acc:0.7051785714285714
epoch:17, idx:34399/43738, loss:0.1087969847884419, acc:0.7047383720930233
epoch:17, idx:35199/43738, loss:0.10897993303581395, acc:0.7038352272727273
epoch:17, idx:35999/43738, loss:0.10902091216523614, acc:0.7036111111111111
epoch:17, idx:36799/43738, loss:0.10897019661244248, acc:0.7035054347826087
epoch:17, idx:37599/43738, loss:0.1090

epoch:18, idx:33599/43738, loss:0.10542733354236546, acc:0.7091964285714286
epoch:18, idx:34399/43738, loss:0.10560474442413177, acc:0.7088081395348838
epoch:18, idx:35199/43738, loss:0.10554472639900433, acc:0.7094318181818182
epoch:18, idx:35999/43738, loss:0.10576882562753033, acc:0.7092777777777778
epoch:18, idx:36799/43738, loss:0.10598509332507092, acc:0.7089130434782609
epoch:18, idx:37599/43738, loss:0.1060311066688, acc:0.7092021276595745
epoch:18, idx:38399/43738, loss:0.1060850446851479, acc:0.7089583333333334
epoch:18, idx:39199/43738, loss:0.105997779756355, acc:0.7091581632653061
epoch:18, idx:39999/43738, loss:0.10585081959555682, acc:0.7094
epoch:18, idx:40799/43738, loss:0.10578026416823645, acc:0.7096078431372549
epoch:18, idx:41599/43738, loss:0.10573844774158575, acc:0.7099278846153846
epoch:18, idx:42399/43738, loss:0.10574196154653931, acc:0.7097877358490566
epoch:18, idx:43199/43738, loss:0.10571444238197862, acc:0.7098611111111112
epoch:18, idx:0/5129, loss:0.78

epoch:19, idx:39999/43738, loss:0.10273516970038327, acc:0.720275
epoch:19, idx:40799/43738, loss:0.10273977530137642, acc:0.7205637254901961
epoch:19, idx:41599/43738, loss:0.10281951642142023, acc:0.7202884615384615
epoch:19, idx:42399/43738, loss:0.10306246785283499, acc:0.7198820754716981
epoch:19, idx:43199/43738, loss:0.1032160468832754, acc:0.7198148148148148
epoch:19, idx:0/5129, loss:1.1972217559814453, acc:0.0
epoch:19, idx:100/5129, loss:1.5227117656481148, acc:0.5247524752475248
epoch:19, idx:200/5129, loss:1.4853588248885685, acc:0.5522388059701493
epoch:19, idx:300/5129, loss:1.447273274370008, acc:0.574750830564784
epoch:19, idx:400/5129, loss:1.4171223785923306, acc:0.57356608478803
epoch:19, idx:500/5129, loss:1.3183364522998442, acc:0.5968063872255489
epoch:19, idx:600/5129, loss:1.2895124238139382, acc:0.610648918469218
epoch:19, idx:700/5129, loss:1.3037171953818767, acc:0.6077032810271041
epoch:19, idx:800/5129, loss:1.2959696318531007, acc:0.602996254681648
epoch:

epoch:20, idx:300/5129, loss:1.45313947114388, acc:0.584717607973422
epoch:20, idx:400/5129, loss:1.4505474629006034, acc:0.5685785536159601
epoch:20, idx:500/5129, loss:1.3414959283452668, acc:0.5948103792415169
epoch:20, idx:600/5129, loss:1.302140328276921, acc:0.6056572379367721
epoch:20, idx:700/5129, loss:1.3104618399559431, acc:0.6062767475035663
epoch:20, idx:800/5129, loss:1.3157888723400126, acc:0.6004993757802747
epoch:20, idx:900/5129, loss:1.3025942537706183, acc:0.607103218645949
epoch:20, idx:1000/5129, loss:1.305036355055623, acc:0.6063936063936064
epoch:20, idx:1100/5129, loss:1.2968688994531032, acc:0.6049046321525886
epoch:20, idx:1200/5129, loss:1.3025989155506066, acc:0.6061615320566195
epoch:20, idx:1300/5129, loss:1.3042632687255613, acc:0.6102997694081476
epoch:20, idx:1400/5129, loss:1.2707283413017716, acc:0.6202712348322627
epoch:20, idx:1500/5129, loss:1.2879350729792973, acc:0.6169220519653564
epoch:20, idx:1600/5129, loss:1.2851859127089316, acc:0.61711430

epoch:21, idx:1100/5129, loss:1.2974309376918869, acc:0.6103542234332425
epoch:21, idx:1200/5129, loss:1.2997255252610636, acc:0.6144879267277269
epoch:21, idx:1300/5129, loss:1.299003789142679, acc:0.6179861644888547
epoch:21, idx:1400/5129, loss:1.262804915571451, acc:0.6288365453247681
epoch:21, idx:1500/5129, loss:1.2735056360509696, acc:0.6269153897401732
epoch:21, idx:1600/5129, loss:1.2709520725366996, acc:0.6277326670830731
epoch:21, idx:1700/5129, loss:1.261713146042081, acc:0.6272780717225162
epoch:21, idx:1800/5129, loss:1.2514874168000971, acc:0.6302054414214325
epoch:21, idx:1900/5129, loss:1.2522876493119366, acc:0.6275644397685429
epoch:21, idx:2000/5129, loss:1.2523123784654442, acc:0.6251874062968515
epoch:21, idx:2100/5129, loss:1.2541250923895824, acc:0.6220847215611613
epoch:21, idx:2200/5129, loss:1.2494746714234624, acc:0.6238073602907769
epoch:21, idx:2300/5129, loss:1.2496509857731557, acc:0.625380269448066
epoch:21, idx:2400/5129, loss:1.253554774250452, acc:0.

epoch:22, idx:1900/5129, loss:1.274652397743123, acc:0.6354550236717517
epoch:22, idx:2000/5129, loss:1.2689469087047913, acc:0.6356821589205397
epoch:22, idx:2100/5129, loss:1.27093912123116, acc:0.6354117087101381
epoch:22, idx:2200/5129, loss:1.267229222038804, acc:0.6369831894593366
epoch:22, idx:2300/5129, loss:1.2678497424498945, acc:0.6384180790960452
epoch:22, idx:2400/5129, loss:1.270567241821846, acc:0.6347355268638067
epoch:22, idx:2500/5129, loss:1.277948364445635, acc:0.6337465013994402
epoch:22, idx:2600/5129, loss:1.2897055246385645, acc:0.6309111880046137
epoch:22, idx:2700/5129, loss:1.2878311027430331, acc:0.6305072195483155
epoch:22, idx:2800/5129, loss:1.284063976692945, acc:0.6301320956801142
epoch:22, idx:2900/5129, loss:1.2860917394781104, acc:0.6294381247845571
epoch:22, idx:3000/5129, loss:1.2981657828565878, acc:0.6284571809396867
epoch:22, idx:3100/5129, loss:1.2984216311277976, acc:0.6281844566268946
epoch:22, idx:3200/5129, loss:1.3025790029919657, acc:0.62

epoch:23, idx:2700/5129, loss:1.2931565225554678, acc:0.6227323213624584
epoch:23, idx:2800/5129, loss:1.294333087620204, acc:0.6219207425919314
epoch:23, idx:2900/5129, loss:1.2968550335997262, acc:0.6204756980351603
epoch:23, idx:3000/5129, loss:1.3083430012090331, acc:0.6201266244585139
epoch:23, idx:3100/5129, loss:1.306771341181501, acc:0.6214124475975492
epoch:23, idx:3200/5129, loss:1.312887499780999, acc:0.6188691034051859
epoch:23, idx:3300/5129, loss:1.3122184261277892, acc:0.6192063011208725
epoch:23, idx:3400/5129, loss:1.2981731053775845, acc:0.6236401058512202
epoch:23, idx:3500/5129, loss:1.2961743509716661, acc:0.6243930305626963
epoch:23, idx:3600/5129, loss:1.2871932142006295, acc:0.6262149402943626
epoch:23, idx:3700/5129, loss:1.2908450818221269, acc:0.6263172115644421
epoch:23, idx:3800/5129, loss:1.282946305972311, acc:0.6272033675348593
epoch:23, idx:3900/5129, loss:1.2875400352855426, acc:0.6257369905152524
epoch:23, idx:4000/5129, loss:1.2882203919288369, acc:0

epoch:24, idx:3400/5129, loss:1.3032592700986574, acc:0.6218759188473978
epoch:24, idx:3500/5129, loss:1.2982227306677865, acc:0.6246786632390745
epoch:24, idx:3600/5129, loss:1.28899323077905, acc:0.627048042210497
epoch:24, idx:3700/5129, loss:1.291652356924414, acc:0.6265874088084301
epoch:24, idx:3800/5129, loss:1.2827170698917467, acc:0.6277295448566167
epoch:24, idx:3900/5129, loss:1.287300544032926, acc:0.6267623686234299
epoch:24, idx:4000/5129, loss:1.2866786926262261, acc:0.6263434141464633
epoch:24, idx:4100/5129, loss:1.2900618286183392, acc:0.6264325774201415
epoch:24, idx:4200/5129, loss:1.2831668809705392, acc:0.6269935729588193
epoch:24, idx:4300/5129, loss:1.2783541084929584, acc:0.627993489886073
epoch:24, idx:4400/5129, loss:1.278407408070548, acc:0.6284935241990457
epoch:24, idx:4500/5129, loss:1.2717913345042082, acc:0.629193512552766
epoch:24, idx:4600/5129, loss:1.2737205542123622, acc:0.6276896326885459
epoch:24, idx:4700/5129, loss:1.2768659782178908, acc:0.626

epoch:25, idx:4200/5129, loss:1.2827679892557202, acc:0.6279457272078076
epoch:25, idx:4300/5129, loss:1.2794618642027364, acc:0.6286910020925366
epoch:25, idx:4400/5129, loss:1.2801740032854987, acc:0.6291751874573961
epoch:25, idx:4500/5129, loss:1.2721651442244672, acc:0.6311930682070651
epoch:25, idx:4600/5129, loss:1.2735498618788756, acc:0.6300804173005868
epoch:25, idx:4700/5129, loss:1.2773935184567387, acc:0.62965326526271
epoch:25, idx:4800/5129, loss:1.2725886995765447, acc:0.6296604873984587
epoch:25, idx:4900/5129, loss:1.2757140334404455, acc:0.629463374821465
epoch:25, idx:5000/5129, loss:1.2787444631196074, acc:0.6286742651469706
epoch:25, idx:5100/5129, loss:1.275517024821369, acc:0.6288962948441482
epoch:26, idx:799/43738, loss:0.08557836972875521, acc:0.76125
epoch:26, idx:1599/43738, loss:0.08314809171133675, acc:0.774375
epoch:26, idx:2399/43738, loss:0.08359731910633855, acc:0.7720833333333333
epoch:26, idx:3199/43738, loss:0.08584918272827054, acc:0.766875
epoch:

epoch:26, idx:5000/5129, loss:1.2927229169192278, acc:0.6280743851229754
epoch:26, idx:5100/5129, loss:1.2884516166828002, acc:0.6277200548911978
epoch:27, idx:799/43738, loss:0.07233596364210826, acc:0.805
epoch:27, idx:1599/43738, loss:0.07774412773171208, acc:0.79125
epoch:27, idx:2399/43738, loss:0.07706241062958724, acc:0.7916666666666666
epoch:27, idx:3199/43738, loss:0.07723759949018132, acc:0.7925
epoch:27, idx:3999/43738, loss:0.07848630557616706, acc:0.78975
epoch:27, idx:4799/43738, loss:0.08096730528738894, acc:0.7795833333333333
epoch:27, idx:5599/43738, loss:0.08157123045289025, acc:0.7769642857142857
epoch:27, idx:6399/43738, loss:0.08128193423523043, acc:0.77734375
epoch:27, idx:7199/43738, loss:0.08181795872243432, acc:0.7747222222222222
epoch:27, idx:7999/43738, loss:0.08221371719695161, acc:0.77425
epoch:27, idx:8799/43738, loss:0.08237324438010216, acc:0.772159090909091
epoch:27, idx:9599/43738, loss:0.0823893969206741, acc:0.7729166666666667
epoch:27, idx:10399/437

In [None]:
layer = model.linear_1[0]

In [None]:
model.train()
cuda = True
#model.embedding_layer.eval()
for idx, batch in enumerate(train_iter):
    score = model(batch)
    label = batch.label
    if cuda:
        label = label.cuda()
    score = score.transpose(0,1)      

    loss = criterion(score, label)

    optimizer.zero_grad()
    loss.backward()
    # optimizer.step()
    break