### 当前实验模型内容

1. use mentions
2. 63.3

3. add passage score

model_name |  param | dev_acc|
---| --- | ---
use mentions | lr=1e-3,hidden=50 | 63.3

In [1]:
import os
import torch
import torch.nn as nn
import torchtext
from tensorboardX import SummaryWriter
import random
import numpy as np

from torchtext.data import NestedField, Field, RawField
from model import *
from dataset import DataHandler
%load_ext autoreload

%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

03/05/2019 21:03:42 - INFO - summarizer.preprocessing.cleaner -   'pattern' package not found; tag filters are not available for English


In [2]:
class Config:
    def __init__(self):
        self.hidden = 50
        self.embedding_dim = 300 + 100
        self.lr = 1e-4
        self.epochs = 30
        self.fix_length = None
        
        self.log_dir = './logs'
        self.model_name = 'CFC_um_ps'
        self.batch_size = 4
        self.train_data = './data/train_filter.pt'
        self.dev_data = './data/dev_filter.pt'
        
        self.word_vocab = './data/glove_vocab.pt'
        self.charNGram_vocab = './data/charNGram_vocab.pt'
        
        self.dropout = 0.2
        self.seed = 1023
        
config = Config()
device = torch.device("cuda:0")


In [3]:
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)

In [4]:
save_path = config.model_name  + '_lr_'+ str(config.lr)+ '__hidden__' + str(config.hidden) \
            + '_batchsize_' + str(config.batch_size) +  '_p'+ str(config.dropout)
save_path = os.path.join(config.log_dir, save_path)   
print(save_path)
config.save_path = save_path

./logs/CFC_um_ps_lr_0.0001__hidden__50_batchsize_4_p0.2


### Define Fileds

In [5]:
word_field = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True) # query
multi_word_field = NestedField(word_field) 

word_field_sup = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True, fix_length=config.fix_length)
multi_word_field_sup = NestedField(word_field_sup) 

charNGram_field = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True) # query
multi_charNGram_field = NestedField(charNGram_field) 

charNGram_field_sup = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True, fix_length=config.fix_length)
multi_charNGram_field_sup = NestedField(charNGram_field_sup) 

raw = RawField()
raw.is_target = False

label_field = Field(sequential=False, is_target=True, use_vocab=False)

dict_field = {
    'id': ('id', raw),
    'supports': [('s_glove', multi_word_field_sup), ('s_charNGram', multi_charNGram_field_sup)],
    'query': [('q_glove', word_field), ('q_charNGram', charNGram_field)],
    'candidates': [('c_glove', multi_word_field), ('c_charNGram', multi_charNGram_field)],
    'label': ('label', label_field),
    'mentions': ('mentions', raw),
    'para_label': ('para_label', raw)
}

In [6]:
data_handler = DataHandler(config.train_data, config.dev_data, dict_field)

# torch.save(data_handler.trainset.examples, './data/train_example.pt')
# torch.save(data_handler.valset.examples, './data/dev_example.pt')

load examples.pt  :./data/train_filter.pt, ./data/dev_filter.pt


In [7]:
from tqdm import tqdm

def add_mentions(examples):

    for example in tqdm(examples):
        candidates = example.c_glove
        supports = example.s_glove


        all_mentions = []

        for candidate in candidates:
            mentions = []
            c = ' '.join(candidate)
            for idx, support in enumerate(supports):

                for i in range(len(support)):
                    token = support[i]
                    if token == candidate[0]:
                        s = ' '.join(support[i:i+len(candidate)])
                        if s == c:
                            mentions.append([idx, i, i+len(candidate)])
            all_mentions.append(mentions)
            
        example.mentions = all_mentions
        
def add_para_label(examples):
    filter_examples=  []
    for example in tqdm(examples):
        candidates = example.c_glove
        supports = example.s_glove    

        label = example.label
        mentions = example.mentions
        answser_mentions = mentions[label]
        if len(answser_mentions) != 0:
            para_label = [0]*len(supports)
            for mentions in answser_mentions:
                para_label[mentions[0]] = 1
            example.para_label = para_label
            filter_examples.append(example)
    print(f'before filter: {len(examples)}, after:{len(filter_examples)}')
    return filter_examples

#add_mentions(data_handler.valset.examples)
#add_mentions(data_handler.trainset.examples)

#train_filter_examples  = add_para_label(data_handler.trainset.examples)
#dev_filter_examples  = add_para_label(data_handler.valset.examples)

#torch.save(train_filter_examples, './data/train_filter.pt')
#torch.save(dev_filter_examples, './data/dev_filter.pt')        

In [8]:
#add_mentions(data_handler.valset.examples)
#add_mentions(data_handler.trainset.examples)

### Build Vocab

In [9]:
if config.charNGram_vocab is not None:
    charNGram_vocab = torch.load(config.charNGram_vocab)
    charNGram_field_sup.vocab = charNGram_vocab
else:
    charNGram_field_sup.build_vocab(data_handler.trainset, data_handler.valset, 
                                          vectors=torchtext.vocab.CharNGram())

if config.word_vocab is not None:
    word_vocab = torch.load(config.word_vocab)
    word_field_sup.vocab = word_vocab
else:
    word_field_sup.build_vocab(data_handler.trainset, data_handler.valset, 
                                 vectors=torchtext.vocab.GloVe(dim=300,name='6B') )

word_field.vocab = word_field_sup.vocab
charNGram_field.vocab = charNGram_field_sup.vocab

# torch.save(word_field.vocab, './data/glove_vocab.pt')
# torch.save(charNGram_field.vocab, './data/charNGram_vocab.pt')

### Get data_iter

In [10]:
train_iter = data_handler.get_train_iter(batch_size=config.batch_size)
val_iter = data_handler.get_val_iter(batch_size=config.batch_size)

In [11]:
for idx, batch in enumerate(val_iter):
    break
batch


[torchtext.data.batch.Batch of size 4]
	[.id]:['WH_dev_0', 'WH_dev_1', 'WH_dev_2', 'WH_dev_3']
	[.s_glove]:[torch.LongTensor of size 4x15x292]
	[.s_charNGram]:[torch.LongTensor of size 4x15x292]
	[.q_glove]:[torch.LongTensor of size 4x11]
	[.q_charNGram]:[torch.LongTensor of size 4x11]
	[.c_glove]:[torch.LongTensor of size 4x18x4]
	[.c_charNGram]:[torch.LongTensor of size 4x18x4]
	[.label]:[torch.LongTensor of size 4]
	[.mentions]:[[[[6, 145, 146], [6, 173, 174], [7, 78, 79]], [[3, 50, 53], [5, 28, 31], [13, 1, 4]], [[6, 135, 136], [6, 218, 219], [6, 261, 262], [8, 45, 46], [12, 98, 99]], [[0, 2, 4], [7, 1, 3], [13, 64, 66], [13, 69, 71]], [[0, 36, 38], [10, 1, 3], [13, 75, 77]], [[0, 14, 15], [1, 63, 64], [1, 138, 139], [1, 186, 187], [1, 238, 239], [2, 128, 129], [9, 8, 9], [9, 43, 44], [10, 19, 20], [10, 34, 35], [11, 37, 38], [11, 79, 80], [13, 56, 57]], [[7, 37, 40]], [[6, 180, 181], [12, 101, 102]], [[12, 96, 97]], [[8, 43, 46]], [[6, 169, 172]], [[6, 179, 181]], [[7, 38, 40]], 

In [12]:
def get_para_label(batch):
    para_size = batch.s_glove.size(1)
    results = []
    for label in batch.para_label:
        padding = [0]*(para_size - len(label))
        n_label = label[:]
        n_label += padding
        results.append(n_label)
    results = torch.tensor(results, dtype=torch.long)
    return results

### Define Model

In [13]:
class SimpleQANet(nn.Module):
    
    def __init__(self, config, word_vectors, charNGram_vectors, device):
        super(SimpleQANet, self).__init__()
        self.config = config
        self.device = device
        
        self.embedding_layer = EmbeddingLayer(word_vectors, charNGram_vectors)
        

        self.rnn = EncoderRNN(config.embedding_dim, config.hidden, 1, True, True, config.dropout, False)
                
            
        self.co_att = CoAttention(config.hidden*2, att_type=2, dropout=config.dropout)
        
        self.linear_1 = nn.Sequential(
                        nn.Linear(config.hidden*4, config.hidden),
                        nn.ReLU()
                    )        
        self.rnn2 =  EncoderRNN(config.hidden, config.hidden, 1, True, True, config.dropout, False)
        
        self.word_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        self.word_att_q = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        self.p_score = BilinearSeqAttn(config.hidden*2, config.hidden*2, identity=False, dropout=config.dropout)
        
        self.pass_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        self.c_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        self.mention_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        
        #self.fusion = FusionLayer(config.hidden*2, dropout=config.dropout)
        self.max_pooling = PoolingLayer()     
        
        self.fc = nn.Linear(config.hidden*2, config.hidden*4)
        
        self.to(device)
        
    def forward(self, batch, return_label = True):
        if type(batch.q_glove) is tuple:
            q_glove, _ = batch.q_glove
            q_charNGram, _ = batch.q_charNGram
        else:
            q_glove = batch.q_glove
            q_charNGram = batch.q_charNGram            
        
        s_glove = batch.s_glove
        s_charNGram = batch.s_charNGram
        
        c_glove = batch.c_glove
        c_charNGram = batch.c_charNGram
        
        q_glove = q_glove.to(self.device)
        q_charNGram = q_charNGram.to(self.device)

        s_glove = s_glove.to(self.device)
        s_charNGram = s_charNGram.to(self.device)

        c_glove = c_glove.to(self.device)
        c_charNGram = c_charNGram.to(self.device)
        
        ### Embedding and Encoder
        
        q_out = self.embedding_layer(q_glove, q_charNGram)
        s_out = self.embedding_layer(s_glove, s_charNGram,)
        c_out = self.embedding_layer(c_glove, c_charNGram)
        
        batch_size=  s_out.size(0)
        
        s_len = s_out.size(1)
        c_len = c_out.size(1)
        
        s_word_len = s_out.size(2)
        c_word_len = c_out.size(2)
        
        hidden = s_out.size(-1)
        
        s_out = s_out.view(batch_size*s_len, s_word_len, hidden).contiguous()
        c_out = c_out.view(batch_size*c_len, c_word_len, hidden).contiguous()
        
        q_out = self.rnn(q_out)
        c_out = self.rnn(c_out)
        s_out = self.rnn(s_out)
        
        # Attention
        
        q_word_len = q_out.size(1)
        q_out_expand = q_out.unsqueeze(1).expand(batch_size, s_len, q_word_len, q_out.size(-1)).contiguous()
        q_out_expand = q_out_expand.view(batch_size*s_len, q_word_len, q_out.size(-1)).contiguous()
        
        s_out_att, q_out_att = self.co_att(s_out, q_out_expand)
        #S_s = self.fusion(s_out, s_out_att)
        #S_q = self.fusion(q_out, q_out_att)
        
        S_s = self.linear_1(s_out_att)
        S_s = self.rnn2(S_s)
        
        
        
        batch_c_m = []
        for i in range(batch_size):
            # get mention embedding
            mentions = batch.mentions[i]
            c_ms = torch.zeros(c_len, s_len, s_out.size(-1))
            for idx, c_mention in enumerate(mentions):
                c_m_dict = {}
                for mention in c_mention:
                    m = s_out[i*s_len + mention[0]][mention[1]:mention[2]]
                    m = self.max_pooling(m.unsqueeze(0)).squeeze()
                    if mention[0] not in c_m_dict:
                        c_m_dict[mention[0]] = []
                    c_m_dict[mention[0]].append(m)
                c_m = torch.zeros(s_len, s_out.size(-1))
                for key in c_m_dict:
                    for m in c_m_dict[key]:
                        c_m[key] += m.cpu()
                    c_m[key] /= len(c_m_dict[key])
                c_ms[idx] = c_m
            batch_c_m.append(c_ms)
        batch_c_m = torch.stack(batch_c_m)
        batch_c_m = batch_c_m.to(self.device)
        batch_c_m = batch_c_m.view(batch_size*c_len, s_len, -1)
        batch_c_m = self.mention_att(batch_c_m)
        batch_c_m = batch_c_m.view(batch_size, c_len, -1)
        

        
        C_s = self.word_att(S_s)
        C_q = self.word_att_q(q_out)
        
        C_s = C_s.view(batch_size, s_len, -1)
        
        P_scores = self.p_score(C_s, C_q)
        
        C_s = self.pass_att(C_s) 

        
        C_c = self.c_att(c_out)        
        C_c = C_c.view(batch_size, c_len, -1)
        
        C_c = torch.cat([C_c, batch_c_m],-1)
        
        C_s = torch.tanh(self.fc(C_s))
        
        score = torch.bmm(C_c, C_s.unsqueeze(-1))
        score = score.squeeze(-1)
        if return_label:
            label = batch.label.to(self.device)
            P_label = get_para_label(batch)
            P_label = P_label.to(device)
            return score, P_scores, label, P_label
        return score, P_scores

#### test model

In [14]:
model = SimpleQANet(config, word_field.vocab.vectors, charNGram_field.vocab.vectors, device)
#score,P_score, label, P_label = model(batch)
#print(score, label, P_score, P_label)
#print(score.shape, label.shape, P_score.shape, P_label.shape)

In [15]:
from utils import AverageMeter

def train(epoch, data_iter, model, criterion, criterion_bce, optimizer, batch_size=1, joint_begin=-1):
    losses = AverageMeter()
    acces = AverageMeter()
    model.train()
    #model.embedding_layer.eval()
    for idx, batch in enumerate(data_iter):
        score, P_score, label, P_label = model(batch)

        loss1 = criterion(score, label)
        loss_p = criterion_bce(P_score, P_label.float())
        if epoch > joint_begin:
            loss = loss1 + loss_p     
        else:
            loss = loss_p

        loss = loss / batch_size
        loss.backward()
        if (idx+1)%batch_size == 0 :
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)            
            optimizer.step()
            optimizer.zero_grad()        

        losses.update(loss.item()*batch_size)
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()  / pred.size(0)
        acces.update(acc)
        if (idx+1) % (batch_size*100) == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

def val(epoch, data_iter, model, criterion, criterion_bce, joint_begin=-1):
    losses = AverageMeter()
    acces = AverageMeter()
    model.eval()
    for idx, batch in enumerate(data_iter):
        with torch.no_grad():
            score, P_score, label, P_label = model(batch)
                    
        loss1 = criterion(score, label)
        loss_p = criterion_bce(P_score, P_label.float())
        if epoch > joint_begin:
            loss = loss1 + loss_p     
        else:
            loss = loss_p
        
        losses.update(loss.item())
        
        pred = score.argmax(1)
        acc = pred.eq(label).sumbest
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

In [16]:
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                             lr=config.lr)

criterion = nn.CrossEntropyLoss()
criterion_bce = nn.BCEWithLogitsLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
#train(0, train_iter, model, criterion, optimizer, batch_size=config.batch_size)
# val(0, val_iter, model,criterion)

In [17]:
if not os.path.exists(config.save_path):
    os.makedirs(config.save_path)
writer = SummaryWriter(config.save_path)

best_acc = 0.0
for epoch in range(config.epochs):
    scheduler.step()
    train_loss, train_acc = train(epoch, train_iter, model, criterion, criterion_bce, optimizer, 
                                     1)
    val_loss, val_acc = val(epoch, val_iter, model, criterion, criterion_bce)
    writer.add_scalar('train_loss', train_loss, epoch+1)
    writer.add_scalar('val_loss', val_loss, epoch+1)
    writer.add_scalar('train_acc', train_acc, epoch+1)
    writer.add_scalar('val_acc', val_acc, epoch+1)
    
    state = {
        'val_acc': val_acc,
        'train_acc': train_acc,
        'epoch': epoch,
        'model': model.state_dict()
    }
    #torch.save(state, os.path.join(config.save_path,'lastest.pth'))
    if val_acc > best_acc:
        best_acc = val_acc
        #torch.save(state, os.path.join(save_path, 'best.pth'))

  alphas = self.softmax(alphas)  # (bsz, sent_len)


epoch:0, idx:99/10845, loss:3.8411771512031554, acc:0.1875
epoch:0, idx:199/10845, loss:3.464788509607315, acc:0.22375
epoch:0, idx:299/10845, loss:3.2804538949330646, acc:0.24666666666666667
epoch:0, idx:399/10845, loss:3.129928788244724, acc:0.26375
epoch:0, idx:499/10845, loss:3.036080674648285, acc:0.2775
epoch:0, idx:599/10845, loss:2.951693090200424, acc:0.28708333333333336
epoch:0, idx:699/10845, loss:2.903557346037456, acc:0.2932142857142857
epoch:0, idx:799/10845, loss:2.8517998991906643, acc:0.300625
epoch:0, idx:899/10845, loss:2.799626319143507, acc:0.3080555555555556
epoch:0, idx:999/10845, loss:2.7602129591703415, acc:0.315
epoch:0, idx:1099/10845, loss:2.739209978688847, acc:0.31727272727272726
epoch:0, idx:1199/10845, loss:2.7170045927166937, acc:0.31833333333333336
epoch:0, idx:1299/10845, loss:2.691451080579024, acc:0.3219230769230769
epoch:0, idx:1399/10845, loss:2.6678093465736934, acc:0.3267857142857143
epoch:0, idx:1499/10845, loss:2.644142016331355, acc:0.3303333

epoch:0, idx:900/1275, loss:1.9913261348875726, acc:0.43784683684794673
epoch:0, idx:1000/1275, loss:2.0048809521324507, acc:0.4373126873126873
epoch:0, idx:1100/1275, loss:2.002054981271534, acc:0.4391462306993642
epoch:0, idx:1200/1275, loss:2.0064881068284466, acc:0.4377601998334721
epoch:1, idx:99/10845, loss:1.9928994929790498, acc:0.42
epoch:1, idx:199/10845, loss:1.9036356392502785, acc:0.47
epoch:1, idx:299/10845, loss:1.925157831509908, acc:0.4666666666666667
epoch:1, idx:399/10845, loss:1.9418779096007348, acc:0.46125
epoch:1, idx:499/10845, loss:1.9272917999625205, acc:0.4665
epoch:1, idx:599/10845, loss:1.9131777634720009, acc:0.4675
epoch:1, idx:699/10845, loss:1.9142278436677798, acc:0.4660714285714286
epoch:1, idx:799/10845, loss:1.9173747937381267, acc:0.4653125
epoch:1, idx:899/10845, loss:1.9155907676617305, acc:0.4666666666666667
epoch:1, idx:999/10845, loss:1.9038710365891456, acc:0.47075
epoch:1, idx:1099/10845, loss:1.9090830592133783, acc:0.46863636363636363
epoc

epoch:1, idx:400/1275, loss:1.76406916418575, acc:0.506857855361596
epoch:1, idx:500/1275, loss:1.7602203544266448, acc:0.5054890219560878
epoch:1, idx:600/1275, loss:1.7573173126444444, acc:0.5045757071547421
epoch:1, idx:700/1275, loss:1.7476601326006456, acc:0.507132667617689
epoch:1, idx:800/1275, loss:1.7648820746108684, acc:0.49812734082397003
epoch:1, idx:900/1275, loss:1.7603050737613843, acc:0.4997225305216426
epoch:1, idx:1000/1275, loss:1.7751759862030423, acc:0.4985014985014985
epoch:1, idx:1100/1275, loss:1.7698245346546173, acc:0.49909173478655766
epoch:1, idx:1200/1275, loss:1.773644283922586, acc:0.49646128226477937
epoch:2, idx:99/10845, loss:1.745439372062683, acc:0.4925
epoch:2, idx:199/10845, loss:1.7044200116395951, acc:0.51875
epoch:2, idx:299/10845, loss:1.6579422881205876, acc:0.5216666666666666
epoch:2, idx:399/10845, loss:1.6883791287243366, acc:0.518125
epoch:2, idx:499/10845, loss:1.7073645687103272, acc:0.5165
epoch:2, idx:599/10845, loss:1.704043060441812,

epoch:2, idx:0/1275, loss:1.5896270275115967, acc:0.5
epoch:2, idx:100/1275, loss:1.7734209139748376, acc:0.5024752475247525
epoch:2, idx:200/1275, loss:1.7134486166991998, acc:0.5211442786069652
epoch:2, idx:300/1275, loss:1.706498766658314, acc:0.5282392026578073
epoch:2, idx:400/1275, loss:1.6984842574209942, acc:0.529925187032419
epoch:2, idx:500/1275, loss:1.6933465927184936, acc:0.530439121756487
epoch:2, idx:600/1275, loss:1.694623831603769, acc:0.5262063227953411
epoch:2, idx:700/1275, loss:1.6859435581106603, acc:0.5278174037089871
epoch:2, idx:800/1275, loss:1.7028886073091056, acc:0.5205992509363296
epoch:2, idx:900/1275, loss:1.6948982665065127, acc:0.5208102108768036
epoch:2, idx:1000/1275, loss:1.7057320207923086, acc:0.5202297702297702
epoch:2, idx:1100/1275, loss:1.7023425898315905, acc:0.5197547683923706
epoch:2, idx:1200/1275, loss:1.7065847133170755, acc:0.5172772689425479
epoch:3, idx:99/10845, loss:1.6549401676654816, acc:0.55
epoch:3, idx:199/10845, loss:1.6347173

epoch:3, idx:10499/10845, loss:1.5988734224140644, acc:0.5525476190476191
epoch:3, idx:10599/10845, loss:1.598963531591701, acc:0.5524292452830188
epoch:3, idx:10699/10845, loss:1.5989348750930523, acc:0.5525934579439252
epoch:3, idx:10799/10845, loss:1.6002635232510942, acc:0.5522685185185185
epoch:3, idx:0/1275, loss:1.3900163173675537, acc:0.5
epoch:3, idx:100/1275, loss:1.6721746228118934, acc:0.5470297029702971
epoch:3, idx:200/1275, loss:1.6060120770290716, acc:0.5572139303482587
epoch:3, idx:300/1275, loss:1.605477677033193, acc:0.5681063122923588
epoch:3, idx:400/1275, loss:1.5955139449855633, acc:0.5716957605985037
epoch:3, idx:500/1275, loss:1.5895112309151305, acc:0.5703592814371258
epoch:3, idx:600/1275, loss:1.5909460150263275, acc:0.5694675540765392
epoch:3, idx:700/1275, loss:1.5842949428759017, acc:0.5656205420827389
epoch:3, idx:800/1275, loss:1.601356842135371, acc:0.5571161048689138
epoch:3, idx:900/1275, loss:1.5932833192152664, acc:0.5591009988901221
epoch:3, idx:1

epoch:4, idx:10099/10845, loss:1.5231029774934643, acc:0.5769554455445545
epoch:4, idx:10199/10845, loss:1.5231349433366868, acc:0.576764705882353
epoch:4, idx:10299/10845, loss:1.523532438844587, acc:0.5765291262135922
epoch:4, idx:10399/10845, loss:1.523003983155896, acc:0.5766346153846154
epoch:4, idx:10499/10845, loss:1.5228146801214844, acc:0.5768095238095238
epoch:4, idx:10599/10845, loss:1.5231482290947493, acc:0.5766509433962265
epoch:4, idx:10699/10845, loss:1.5246711483994655, acc:0.5761448598130842
epoch:4, idx:10799/10845, loss:1.5259804879963674, acc:0.5756712962962963
epoch:4, idx:0/1275, loss:1.3715734481811523, acc:0.5
epoch:4, idx:100/1275, loss:1.6279836016716343, acc:0.5544554455445545
epoch:4, idx:200/1275, loss:1.5554469217411915, acc:0.5733830845771144
epoch:4, idx:300/1275, loss:1.5432286704101437, acc:0.5888704318936877
epoch:4, idx:400/1275, loss:1.5340400534377727, acc:0.5885286783042394
epoch:4, idx:500/1275, loss:1.5324034396997708, acc:0.5858283433133733
ep

epoch:5, idx:9699/10845, loss:1.477752209167505, acc:0.5914175257731958
epoch:5, idx:9799/10845, loss:1.47702438417442, acc:0.5913775510204081
epoch:5, idx:9899/10845, loss:1.476997117574769, acc:0.5915151515151515
epoch:5, idx:9999/10845, loss:1.4767453949153424, acc:0.591725
epoch:5, idx:10099/10845, loss:1.4760881889205757, acc:0.591980198019802
epoch:5, idx:10199/10845, loss:1.4769637961408086, acc:0.5917156862745098
epoch:5, idx:10299/10845, loss:1.4768223193620593, acc:0.5916990291262136
epoch:5, idx:10399/10845, loss:1.4758797963622672, acc:0.5919711538461538
epoch:5, idx:10499/10845, loss:1.4756296801453546, acc:0.5921666666666666
epoch:5, idx:10599/10845, loss:1.475504029632177, acc:0.5921462264150943
epoch:5, idx:10699/10845, loss:1.4753825426240947, acc:0.5921261682242991
epoch:5, idx:10799/10845, loss:1.4747279924457823, acc:0.5923379629629629
epoch:5, idx:0/1275, loss:1.3215959072113037, acc:0.5
epoch:5, idx:100/1275, loss:1.603709718083391, acc:0.556930693069307
epoch:5, 

epoch:6, idx:9299/10845, loss:1.4458957571816702, acc:0.6033333333333334
epoch:6, idx:9399/10845, loss:1.4456759495842963, acc:0.6035106382978723
epoch:6, idx:9499/10845, loss:1.44554032472874, acc:0.6033421052631579
epoch:6, idx:9599/10845, loss:1.4464815134027351, acc:0.6028645833333334
epoch:6, idx:9699/10845, loss:1.446071692184382, acc:0.6028865979381444
epoch:6, idx:9799/10845, loss:1.4465116414230088, acc:0.602780612244898
epoch:6, idx:9899/10845, loss:1.445217701246341, acc:0.6030555555555556
epoch:6, idx:9999/10845, loss:1.4448990812763571, acc:0.603025
epoch:6, idx:10099/10845, loss:1.445035984126648, acc:0.6029950495049505
epoch:6, idx:10199/10845, loss:1.444517986381463, acc:0.6031372549019608
epoch:6, idx:10299/10845, loss:1.4442837201319274, acc:0.603131067961165
epoch:6, idx:10399/10845, loss:1.4436339725081164, acc:0.6032211538461538
epoch:6, idx:10499/10845, loss:1.4429726869123323, acc:0.6035238095238096
epoch:6, idx:10599/10845, loss:1.4429900392904034, acc:0.6035613

epoch:7, idx:8899/10845, loss:1.4078795691244723, acc:0.6133988764044944
epoch:7, idx:8999/10845, loss:1.4077035947433776, acc:0.6131666666666666
epoch:7, idx:9099/10845, loss:1.4073732956842735, acc:0.6131868131868132
epoch:7, idx:9199/10845, loss:1.4069999290767896, acc:0.613179347826087
epoch:7, idx:9299/10845, loss:1.4072456215874802, acc:0.6131720430107527
epoch:7, idx:9399/10845, loss:1.4069853658474822, acc:0.6132446808510639
epoch:7, idx:9499/10845, loss:1.406722707430783, acc:0.6131578947368421
epoch:7, idx:9599/10845, loss:1.4076035601606902, acc:0.613125
epoch:7, idx:9699/10845, loss:1.407519941253914, acc:0.6131185567010309
epoch:7, idx:9799/10845, loss:1.406891323979564, acc:0.6133418367346939
epoch:7, idx:9899/10845, loss:1.4062484727370919, acc:0.6136868686868687
epoch:7, idx:9999/10845, loss:1.40552318719998, acc:0.613675
epoch:7, idx:10099/10845, loss:1.4051454158060916, acc:0.6136881188118812
epoch:7, idx:10199/10845, loss:1.404762982349916, acc:0.61375
epoch:7, idx:1

epoch:8, idx:8499/10845, loss:1.3795397469489012, acc:0.6202058823529412
epoch:8, idx:8599/10845, loss:1.3797438970918572, acc:0.6199418604651162
epoch:8, idx:8699/10845, loss:1.379325658589944, acc:0.6204885057471264
epoch:8, idx:8799/10845, loss:1.3792166417142884, acc:0.6207102272727273
epoch:8, idx:8899/10845, loss:1.3781140901181805, acc:0.6208426966292134
epoch:8, idx:8999/10845, loss:1.3781528793954188, acc:0.6210833333333333
epoch:8, idx:9099/10845, loss:1.378487161314422, acc:0.6210714285714286
epoch:8, idx:9199/10845, loss:1.3788497962268151, acc:0.6210054347826087
epoch:8, idx:9299/10845, loss:1.3775588268057632, acc:0.621505376344086
epoch:8, idx:9399/10845, loss:1.3761272666302133, acc:0.621968085106383
epoch:8, idx:9499/10845, loss:1.376824243388678, acc:0.6217368421052631
epoch:8, idx:9599/10845, loss:1.3769467697261522, acc:0.6215885416666667
epoch:8, idx:9699/10845, loss:1.3773855012163674, acc:0.6216752577319588
epoch:8, idx:9799/10845, loss:1.37834007977527, acc:0.62

epoch:9, idx:7999/10845, loss:1.350485333532095, acc:0.635625
epoch:9, idx:8099/10845, loss:1.3508227980063285, acc:0.6356172839506172
epoch:9, idx:8199/10845, loss:1.3501706202528099, acc:0.6357926829268292
epoch:9, idx:8299/10845, loss:1.3496352511051908, acc:0.6355120481927711
epoch:9, idx:8399/10845, loss:1.3503306721008959, acc:0.6355654761904762
epoch:9, idx:8499/10845, loss:1.3494519766358768, acc:0.6361470588235294
epoch:9, idx:8599/10845, loss:1.3488406608149757, acc:0.6362209302325581
epoch:9, idx:8699/10845, loss:1.3487198035059305, acc:0.6363218390804598
epoch:9, idx:8799/10845, loss:1.3481943381780928, acc:0.6363068181818182
epoch:9, idx:8899/10845, loss:1.3474206496323093, acc:0.6364044943820225
epoch:9, idx:8999/10845, loss:1.3473291164769066, acc:0.6367777777777778
epoch:9, idx:9099/10845, loss:1.3475174715951248, acc:0.6367307692307692
epoch:9, idx:9199/10845, loss:1.3467656877928453, acc:0.6370652173913044
epoch:9, idx:9299/10845, loss:1.3464229386532178, acc:0.636989

epoch:10, idx:7499/10845, loss:1.3271073053290448, acc:0.6399333333333334
epoch:10, idx:7599/10845, loss:1.327545882797751, acc:0.6394736842105263
epoch:10, idx:7699/10845, loss:1.3261526543685755, acc:0.6400974025974026
epoch:10, idx:7799/10845, loss:1.3260451615554019, acc:0.6403846153846153
epoch:10, idx:7899/10845, loss:1.3249749927175574, acc:0.6408544303797469
epoch:10, idx:7999/10845, loss:1.3255185186835006, acc:0.64075
epoch:10, idx:8099/10845, loss:1.325393585363104, acc:0.640679012345679
epoch:10, idx:8199/10845, loss:1.3247786509654507, acc:0.640640243902439
epoch:10, idx:8299/10845, loss:1.3241442821138953, acc:0.640933734939759
epoch:10, idx:8399/10845, loss:1.3236239662225402, acc:0.6410714285714286
epoch:10, idx:8499/10845, loss:1.3249660317959155, acc:0.6405882352941177
epoch:10, idx:8599/10845, loss:1.3255150274993028, acc:0.6403488372093024
epoch:10, idx:8699/10845, loss:1.3243888653360907, acc:0.6410057471264368
epoch:10, idx:8799/10845, loss:1.3259308532600038, acc

epoch:11, idx:6899/10845, loss:1.3062632197012072, acc:0.645036231884058
epoch:11, idx:6999/10845, loss:1.3044251596203873, acc:0.6454642857142857
epoch:11, idx:7099/10845, loss:1.3043189216520585, acc:0.6455985915492958
epoch:11, idx:7199/10845, loss:1.3023597077839077, acc:0.6459027777777778
epoch:11, idx:7299/10845, loss:1.3033530276666765, acc:0.645513698630137
epoch:11, idx:7399/10845, loss:1.3016788985096925, acc:0.6461486486486486
epoch:11, idx:7499/10845, loss:1.3011735202034314, acc:0.6460333333333333
epoch:11, idx:7599/10845, loss:1.3003503361246302, acc:0.6460855263157895
epoch:11, idx:7699/10845, loss:1.3005827730448989, acc:0.6457467532467532
epoch:11, idx:7799/10845, loss:1.3011983821598383, acc:0.6454166666666666
epoch:11, idx:7899/10845, loss:1.3018570911903171, acc:0.6451582278481013
epoch:11, idx:7999/10845, loss:1.3029215566720813, acc:0.644875
epoch:11, idx:8099/10845, loss:1.3020521417168187, acc:0.6451851851851852
epoch:11, idx:8199/10845, loss:1.3023366956794407,

epoch:12, idx:6299/10845, loss:1.2789878089856062, acc:0.6577777777777778
epoch:12, idx:6399/10845, loss:1.2785559123021084, acc:0.6578125
epoch:12, idx:6499/10845, loss:1.2807293620671216, acc:0.6572692307692307
epoch:12, idx:6599/10845, loss:1.2798377599126913, acc:0.657689393939394
epoch:12, idx:6699/10845, loss:1.2786466997250248, acc:0.6579850746268656
epoch:12, idx:6799/10845, loss:1.2792360864020884, acc:0.6573897058823529
epoch:12, idx:6899/10845, loss:1.2806382985467064, acc:0.6569927536231884
epoch:12, idx:6999/10845, loss:1.2813736411066992, acc:0.6564285714285715
epoch:12, idx:7099/10845, loss:1.280404529379497, acc:0.6563028169014085
epoch:12, idx:7199/10845, loss:1.2812421393259945, acc:0.6560069444444444
epoch:12, idx:7299/10845, loss:1.2807098072042613, acc:0.6560273972602739
epoch:12, idx:7399/10845, loss:1.2822752249049576, acc:0.6555405405405406
epoch:12, idx:7499/10845, loss:1.2828785928795734, acc:0.6557333333333333
epoch:12, idx:7599/10845, loss:1.2846098709822094

epoch:13, idx:5699/10845, loss:1.2678324347773664, acc:0.66
epoch:13, idx:5799/10845, loss:1.267673541789168, acc:0.6595258620689655
epoch:13, idx:5899/10845, loss:1.267351473942399, acc:0.6600423728813559
epoch:13, idx:5999/10845, loss:1.2676044406977793, acc:0.6597083333333333
epoch:13, idx:6099/10845, loss:1.2672042783691746, acc:0.6597131147540983
epoch:13, idx:6199/10845, loss:1.2656326070343775, acc:0.6597983870967742
epoch:13, idx:6299/10845, loss:1.2658323594040815, acc:0.6599206349206349
epoch:13, idx:6399/10845, loss:1.2657625026779715, acc:0.6601953125
epoch:13, idx:6499/10845, loss:1.2657131924881384, acc:0.6600769230769231
epoch:13, idx:6599/10845, loss:1.2659915204131693, acc:0.6602651515151515
epoch:13, idx:6699/10845, loss:1.2646312861578233, acc:0.6607462686567164
epoch:13, idx:6799/10845, loss:1.2629549848803263, acc:0.6611397058823529
epoch:13, idx:6899/10845, loss:1.2636703415636135, acc:0.6609057971014493
epoch:13, idx:6999/10845, loss:1.2625867079613464, acc:0.660

epoch:14, idx:5099/10845, loss:1.2571497711597703, acc:0.6650980392156862
epoch:14, idx:5199/10845, loss:1.2583015524681944, acc:0.665
epoch:14, idx:5299/10845, loss:1.2583598214584701, acc:0.6652358490566038
epoch:14, idx:5399/10845, loss:1.2586170980913771, acc:0.6647685185185185
epoch:14, idx:5499/10845, loss:1.2591809240335767, acc:0.6645454545454546
epoch:14, idx:5599/10845, loss:1.2583690674629595, acc:0.6646428571428571
epoch:14, idx:5699/10845, loss:1.2583889827811927, acc:0.665219298245614
epoch:14, idx:5799/10845, loss:1.2573501869111225, acc:0.6648275862068965
epoch:14, idx:5899/10845, loss:1.2571033654614525, acc:0.664957627118644
epoch:14, idx:5999/10845, loss:1.258705599324157, acc:0.6645
epoch:14, idx:6099/10845, loss:1.2593482288412872, acc:0.6645081967213115
epoch:14, idx:6199/10845, loss:1.2594286282204332, acc:0.6641935483870968
epoch:14, idx:6299/10845, loss:1.2607627652053321, acc:0.6641666666666667
epoch:14, idx:6399/10845, loss:1.2629184308566619, acc:0.663554687

epoch:15, idx:4599/10845, loss:1.2263258465185114, acc:0.6759782608695653
epoch:15, idx:4699/10845, loss:1.2256603272798214, acc:0.6756914893617021
epoch:15, idx:4799/10845, loss:1.2290559295751153, acc:0.6751041666666666
epoch:15, idx:4899/10845, loss:1.2261484774308544, acc:0.6758163265306123
epoch:15, idx:4999/10845, loss:1.2266608113825321, acc:0.67565
epoch:15, idx:5099/10845, loss:1.227231034887772, acc:0.6748529411764705
epoch:15, idx:5199/10845, loss:1.227221727617658, acc:0.6746634615384616
epoch:15, idx:5299/10845, loss:1.2258949515335964, acc:0.6751415094339622
epoch:15, idx:5399/10845, loss:1.2267123907156012, acc:0.6747685185185185
epoch:15, idx:5499/10845, loss:1.2274264864989302, acc:0.6744545454545454
epoch:15, idx:5599/10845, loss:1.2281623742210546, acc:0.6737053571428572
epoch:15, idx:5699/10845, loss:1.2287990865594984, acc:0.6730263157894737
epoch:15, idx:5799/10845, loss:1.2269083864794208, acc:0.6737068965517241
epoch:15, idx:5899/10845, loss:1.2284116792994535, 

epoch:16, idx:3999/10845, loss:1.2061521564703435, acc:0.6788125
epoch:16, idx:4099/10845, loss:1.2065367600521664, acc:0.6792073170731707
epoch:16, idx:4199/10845, loss:1.2066343561932444, acc:0.6792261904761905
epoch:16, idx:4299/10845, loss:1.20928957995807, acc:0.6781976744186047
epoch:16, idx:4399/10845, loss:1.2096978657269342, acc:0.6785795454545455
epoch:16, idx:4499/10845, loss:1.2117233657091857, acc:0.6773333333333333
epoch:16, idx:4599/10845, loss:1.2125478917052563, acc:0.6771739130434783
epoch:16, idx:4699/10845, loss:1.212454252368275, acc:0.6767553191489362
epoch:16, idx:4799/10845, loss:1.2120721823582425, acc:0.6767708333333333
epoch:16, idx:4899/10845, loss:1.2118521212026172, acc:0.676734693877551
epoch:16, idx:4999/10845, loss:1.212629304035008, acc:0.6767
epoch:16, idx:5099/10845, loss:1.213209146765225, acc:0.6769607843137255
epoch:16, idx:5199/10845, loss:1.2143624336716647, acc:0.6759615384615385
epoch:16, idx:5299/10845, loss:1.2142178071290255, acc:0.67613207

epoch:17, idx:3399/10845, loss:1.182088267410941, acc:0.6842647058823529
epoch:17, idx:3499/10845, loss:1.185220387995243, acc:0.6835
epoch:17, idx:3599/10845, loss:1.1853675612890058, acc:0.6835416666666667
epoch:17, idx:3699/10845, loss:1.186645637583491, acc:0.6832432432432433
epoch:17, idx:3799/10845, loss:1.185893429535392, acc:0.6832894736842106
epoch:17, idx:3899/10845, loss:1.1855844827703177, acc:0.6828205128205128
epoch:17, idx:3999/10845, loss:1.1855094978194685, acc:0.6823125
epoch:17, idx:4099/10845, loss:1.1862564929865482, acc:0.6821341463414634
epoch:17, idx:4199/10845, loss:1.1870687493735126, acc:0.6819047619047619
epoch:17, idx:4299/10845, loss:1.1909845518701991, acc:0.6812790697674419
epoch:17, idx:4399/10845, loss:1.1918942724523895, acc:0.6811931818181818
epoch:17, idx:4499/10845, loss:1.1941668555455076, acc:0.6810555555555555
epoch:17, idx:4599/10845, loss:1.1932696207499374, acc:0.6817934782608696
epoch:17, idx:4699/10845, loss:1.1951767377634632, acc:0.681329

epoch:18, idx:2799/10845, loss:1.1869279956631362, acc:0.6842857142857143
epoch:18, idx:2899/10845, loss:1.1883803178909524, acc:0.6835344827586207
epoch:18, idx:2999/10845, loss:1.1915387023265163, acc:0.6820833333333334
epoch:18, idx:3099/10845, loss:1.1940344254792699, acc:0.6814516129032258
epoch:18, idx:3199/10845, loss:1.1945395242073573, acc:0.68171875
epoch:18, idx:3299/10845, loss:1.193741367479617, acc:0.681969696969697
epoch:18, idx:3399/10845, loss:1.1917834347946679, acc:0.6819117647058823
epoch:18, idx:3499/10845, loss:1.1892807477776492, acc:0.6827857142857143
epoch:18, idx:3599/10845, loss:1.1885249407651524, acc:0.6825694444444445
epoch:18, idx:3699/10845, loss:1.1892556823206109, acc:0.6831756756756757
epoch:18, idx:3799/10845, loss:1.191482368760595, acc:0.682828947368421
epoch:18, idx:3899/10845, loss:1.1942029483024126, acc:0.6810897435897436
epoch:18, idx:3999/10845, loss:1.194158197524026, acc:0.6808125
epoch:18, idx:4099/10845, loss:1.192047379563858, acc:0.6813

epoch:19, idx:2199/10845, loss:1.1679177326234904, acc:0.6921590909090909
epoch:19, idx:2299/10845, loss:1.1706681520135507, acc:0.6919565217391305
epoch:19, idx:2399/10845, loss:1.167605852348109, acc:0.6926041666666667
epoch:19, idx:2499/10845, loss:1.1716417469263076, acc:0.6917
epoch:19, idx:2599/10845, loss:1.168913730296951, acc:0.6918269230769231
epoch:19, idx:2699/10845, loss:1.175239055990069, acc:0.6903703703703704
epoch:19, idx:2799/10845, loss:1.1784653752456817, acc:0.6897321428571429
epoch:19, idx:2899/10845, loss:1.177721985655612, acc:0.6894827586206896
epoch:19, idx:2999/10845, loss:1.1791279492527247, acc:0.6901666666666667
epoch:19, idx:3099/10845, loss:1.1834303458275333, acc:0.6883870967741935
epoch:19, idx:3199/10845, loss:1.1894779432285576, acc:0.685703125
epoch:19, idx:3299/10845, loss:1.1895515421123215, acc:0.686060606060606
epoch:19, idx:3399/10845, loss:1.1867767359360175, acc:0.6866176470588236
epoch:19, idx:3499/10845, loss:1.18752359689985, acc:0.6868571

epoch:20, idx:1599/10845, loss:1.159332407056354, acc:0.695
epoch:20, idx:1699/10845, loss:1.1623345029310268, acc:0.6939705882352941
epoch:20, idx:1799/10845, loss:1.1654497937651145, acc:0.695
epoch:20, idx:1899/10845, loss:1.1608600910674585, acc:0.6957894736842105
epoch:20, idx:1999/10845, loss:1.1616123813688755, acc:0.694875
epoch:20, idx:2099/10845, loss:1.1602588339078994, acc:0.6945238095238095
epoch:20, idx:2199/10845, loss:1.1612163249674168, acc:0.694659090909091
epoch:20, idx:2299/10845, loss:1.1653922176846991, acc:0.6929347826086957
epoch:20, idx:2399/10845, loss:1.167838666435952, acc:0.6917708333333333
epoch:20, idx:2499/10845, loss:1.167991325172782, acc:0.6921
epoch:20, idx:2599/10845, loss:1.1720495669027933, acc:0.6909615384615385
epoch:20, idx:2699/10845, loss:1.1718093069835946, acc:0.6909259259259259
epoch:20, idx:2799/10845, loss:1.1718478139649544, acc:0.69125
epoch:20, idx:2899/10845, loss:1.172991514056921, acc:0.6911206896551724
epoch:20, idx:2999/10845, lo

epoch:21, idx:999/10845, loss:1.163855849944055, acc:0.696
epoch:21, idx:1099/10845, loss:1.164229684126648, acc:0.6943181818181818
epoch:21, idx:1199/10845, loss:1.1581857216296096, acc:0.695
epoch:21, idx:1299/10845, loss:1.162484106507439, acc:0.6932692307692307
epoch:21, idx:1399/10845, loss:1.1598323742992112, acc:0.6941071428571428
epoch:21, idx:1499/10845, loss:1.158835562552015, acc:0.6941666666666667
epoch:21, idx:1599/10845, loss:1.155082787219435, acc:0.69421875
epoch:21, idx:1699/10845, loss:1.1531407040620552, acc:0.6963235294117647
epoch:21, idx:1799/10845, loss:1.1520853259993924, acc:0.6963888888888888
epoch:21, idx:1899/10845, loss:1.1611423995934034, acc:0.6939473684210526
epoch:21, idx:1999/10845, loss:1.1618394646272063, acc:0.694
epoch:21, idx:2099/10845, loss:1.1687625195440792, acc:0.6922619047619047
epoch:21, idx:2199/10845, loss:1.1740814842927185, acc:0.6915909090909091
epoch:21, idx:2299/10845, loss:1.174546288119062, acc:0.6919565217391305
epoch:21, idx:2399

epoch:22, idx:499/10845, loss:1.1674491286575794, acc:0.703
epoch:22, idx:599/10845, loss:1.1440167765070994, acc:0.7075
epoch:22, idx:699/10845, loss:1.1616954161652497, acc:0.7035714285714286
epoch:22, idx:799/10845, loss:1.1489841560646892, acc:0.706875
epoch:22, idx:899/10845, loss:1.1474698954655065, acc:0.7066666666666667
epoch:22, idx:999/10845, loss:1.156591222986579, acc:0.7
epoch:22, idx:1099/10845, loss:1.1558896437558261, acc:0.6984090909090909
epoch:22, idx:1199/10845, loss:1.1497411292543014, acc:0.700625
epoch:22, idx:1299/10845, loss:1.149131754224117, acc:0.7005769230769231
epoch:22, idx:1399/10845, loss:1.1467509163171052, acc:0.7007142857142857
epoch:22, idx:1499/10845, loss:1.149140830208858, acc:0.6995
epoch:22, idx:1599/10845, loss:1.154793304540217, acc:0.7
epoch:22, idx:1699/10845, loss:1.151716375438606, acc:0.7008823529411765
epoch:22, idx:1799/10845, loss:1.1475751069270903, acc:0.7016666666666667
epoch:22, idx:1899/10845, loss:1.1505997799689833, acc:0.70013

epoch:22, idx:1200/1275, loss:1.3870049575522083, acc:0.6286427976686095
epoch:23, idx:99/10845, loss:1.1733797411620617, acc:0.685
epoch:23, idx:199/10845, loss:1.1150830575078725, acc:0.695
epoch:23, idx:299/10845, loss:1.1432192431390285, acc:0.6966666666666667
epoch:23, idx:399/10845, loss:1.139673539698124, acc:0.701875
epoch:23, idx:499/10845, loss:1.1465847598314285, acc:0.6985
epoch:23, idx:599/10845, loss:1.150439391732216, acc:0.6983333333333334
epoch:23, idx:699/10845, loss:1.1767486351941312, acc:0.6939285714285715
epoch:23, idx:799/10845, loss:1.1725833097100258, acc:0.695
epoch:23, idx:899/10845, loss:1.1603897979358833, acc:0.6961111111111111
epoch:23, idx:999/10845, loss:1.1555705578923225, acc:0.69875
epoch:23, idx:1099/10845, loss:1.1559207412803716, acc:0.6977272727272728
epoch:23, idx:1199/10845, loss:1.154215052879105, acc:0.698125
epoch:23, idx:1299/10845, loss:1.153599958425531, acc:0.6984615384615385
epoch:23, idx:1399/10845, loss:1.1532227967626283, acc:0.69946

epoch:23, idx:600/1275, loss:1.3965740370100825, acc:0.6360232945091514
epoch:23, idx:700/1275, loss:1.3954266717128678, acc:0.6348074179743224
epoch:23, idx:800/1275, loss:1.416639473619458, acc:0.6257802746566792
epoch:23, idx:900/1275, loss:1.4009758357979878, acc:0.6287458379578247
epoch:23, idx:1000/1275, loss:1.4053593303907763, acc:0.6301198801198801
epoch:23, idx:1100/1275, loss:1.3988573042359491, acc:0.631244323342416
epoch:23, idx:1200/1275, loss:1.397591065959718, acc:0.6298917568692756
epoch:24, idx:99/10845, loss:1.2175662138313055, acc:0.675
epoch:24, idx:199/10845, loss:1.170062529668212, acc:0.685
epoch:24, idx:299/10845, loss:1.1886169937749704, acc:0.6766666666666666
epoch:24, idx:399/10845, loss:1.2077795685827732, acc:0.675
epoch:24, idx:499/10845, loss:1.182964009642601, acc:0.6815
epoch:24, idx:599/10845, loss:1.188925650715828, acc:0.6841666666666667
epoch:24, idx:699/10845, loss:1.1696806092560292, acc:0.6896428571428571
epoch:24, idx:799/10845, loss:1.16494560

epoch:24, idx:0/1275, loss:1.5138928890228271, acc:0.5
epoch:24, idx:100/1275, loss:1.4928649327542522, acc:0.594059405940594
epoch:24, idx:200/1275, loss:1.406036307229035, acc:0.6268656716417911
epoch:24, idx:300/1275, loss:1.4081328859014368, acc:0.6395348837209303
epoch:24, idx:400/1275, loss:1.3888727423741931, acc:0.6440149625935162
epoch:24, idx:500/1275, loss:1.3809438671567245, acc:0.6462075848303394
epoch:24, idx:600/1275, loss:1.3921178522924218, acc:0.6389351081530782
epoch:24, idx:700/1275, loss:1.3910695077498016, acc:0.6383737517831669
epoch:24, idx:800/1275, loss:1.4128627307089825, acc:0.6295255930087391
epoch:24, idx:900/1275, loss:1.3976990930264719, acc:0.6320754716981132
epoch:24, idx:1000/1275, loss:1.4012471383342615, acc:0.6331168831168831
epoch:24, idx:1100/1275, loss:1.3940145729903088, acc:0.6344232515894641
epoch:24, idx:1200/1275, loss:1.3923828267150378, acc:0.6330141548709409
epoch:25, idx:99/10845, loss:1.157794730812311, acc:0.7125
epoch:25, idx:199/108

epoch:25, idx:10399/10845, loss:1.1657176706782326, acc:0.6934615384615385
epoch:25, idx:10499/10845, loss:1.1656196262460379, acc:0.6933095238095238
epoch:25, idx:10599/10845, loss:1.1663070571148453, acc:0.6928301886792453
epoch:25, idx:10699/10845, loss:1.1668392329663038, acc:0.6928971962616822
epoch:25, idx:10799/10845, loss:1.1668866398592515, acc:0.6929861111111111
epoch:25, idx:0/1275, loss:1.4720144271850586, acc:0.5
epoch:25, idx:100/1275, loss:1.4879711115419274, acc:0.5891089108910891
epoch:25, idx:200/1275, loss:1.4021311114676556, acc:0.6231343283582089
epoch:25, idx:300/1275, loss:1.4058146127236641, acc:0.6320598006644518
epoch:25, idx:400/1275, loss:1.3859416462313803, acc:0.6390274314214464
epoch:25, idx:500/1275, loss:1.3793396831094147, acc:0.6417165668662674
epoch:25, idx:600/1275, loss:1.3900528666183676, acc:0.6356073211314476
epoch:25, idx:700/1275, loss:1.3893456785978162, acc:0.6355206847360912
epoch:25, idx:800/1275, loss:1.4114160044734994, acc:0.62702871410

epoch:26, idx:9799/10845, loss:1.1577055938108538, acc:0.695076530612245
epoch:26, idx:9899/10845, loss:1.1578515647901129, acc:0.6952272727272727
epoch:26, idx:9999/10845, loss:1.1585161117486655, acc:0.695275
epoch:26, idx:10099/10845, loss:1.1578396062678336, acc:0.6956435643564356
epoch:26, idx:10199/10845, loss:1.1576449860493634, acc:0.6956617647058824
epoch:26, idx:10299/10845, loss:1.1576211756272512, acc:0.6955582524271845
epoch:26, idx:10399/10845, loss:1.156689567430518, acc:0.6956490384615385
epoch:26, idx:10499/10845, loss:1.1567295097942862, acc:0.6957619047619048
epoch:26, idx:10599/10845, loss:1.1571058403046908, acc:0.6954716981132075
epoch:26, idx:10699/10845, loss:1.1571640799111016, acc:0.695677570093458
epoch:26, idx:10799/10845, loss:1.1574124034832198, acc:0.6956712962962963
epoch:26, idx:0/1275, loss:1.4651556015014648, acc:0.5
epoch:26, idx:100/1275, loss:1.4890092564101267, acc:0.594059405940594
epoch:26, idx:200/1275, loss:1.4021500835371254, acc:0.6243781094

epoch:27, idx:9199/10845, loss:1.152163189880388, acc:0.6972826086956522
epoch:27, idx:9299/10845, loss:1.151842332226775, acc:0.6974193548387096
epoch:27, idx:9399/10845, loss:1.151303015627918, acc:0.6974468085106383
epoch:27, idx:9499/10845, loss:1.1516081302346368, acc:0.6976052631578947
epoch:27, idx:9599/10845, loss:1.1524071505176834, acc:0.69734375
epoch:27, idx:9699/10845, loss:1.152386342541305, acc:0.6976030927835052
epoch:27, idx:9799/10845, loss:1.1527861019535637, acc:0.6975765306122449
epoch:27, idx:9899/10845, loss:1.1526855787070411, acc:0.6976262626262626
epoch:27, idx:9999/10845, loss:1.1526515168763698, acc:0.697425
epoch:27, idx:10099/10845, loss:1.1532154527210658, acc:0.697029702970297
epoch:27, idx:10199/10845, loss:1.1541683128658755, acc:0.6966421568627451
epoch:27, idx:10299/10845, loss:1.1550441324645744, acc:0.6964563106796117
epoch:27, idx:10399/10845, loss:1.1548021717434032, acc:0.6965625
epoch:27, idx:10499/10845, loss:1.155598604562027, acc:0.696619047

epoch:28, idx:8599/10845, loss:1.144960397368427, acc:0.6980523255813953
epoch:28, idx:8699/10845, loss:1.1440716171478746, acc:0.6983045977011494
epoch:28, idx:8799/10845, loss:1.1443865636270494, acc:0.6984659090909091
epoch:28, idx:8899/10845, loss:1.14554932339389, acc:0.6982022471910112
epoch:28, idx:8999/10845, loss:1.1462508687153459, acc:0.6978888888888889
epoch:28, idx:9099/10845, loss:1.147279826360908, acc:0.6975274725274725
epoch:28, idx:9199/10845, loss:1.1477866055225225, acc:0.6972554347826087
epoch:28, idx:9299/10845, loss:1.1480240923238378, acc:0.6970967741935484
epoch:28, idx:9399/10845, loss:1.1488929893765996, acc:0.6971010638297872
epoch:28, idx:9499/10845, loss:1.1492168651748644, acc:0.6969473684210526
epoch:28, idx:9599/10845, loss:1.1479441863182, acc:0.697109375
epoch:28, idx:9699/10845, loss:1.1476779639528891, acc:0.6970618556701031
epoch:28, idx:9799/10845, loss:1.148281012450402, acc:0.6968112244897959
epoch:28, idx:9899/10845, loss:1.1495936377414249, ac

epoch:29, idx:8099/10845, loss:1.1546135198407703, acc:0.6957407407407408
epoch:29, idx:8199/10845, loss:1.1544966800983358, acc:0.6957621951219513
epoch:29, idx:8299/10845, loss:1.1534244306995927, acc:0.6961144578313253
epoch:29, idx:8399/10845, loss:1.1539023181317107, acc:0.6960416666666667
epoch:29, idx:8499/10845, loss:1.1542593085537938, acc:0.696
epoch:29, idx:8599/10845, loss:1.1544232179657665, acc:0.6960755813953489
epoch:29, idx:8699/10845, loss:1.154519503906198, acc:0.6961494252873563
epoch:29, idx:8799/10845, loss:1.1531330983950334, acc:0.6964204545454545
epoch:29, idx:8899/10845, loss:1.1537459639551935, acc:0.6964044943820225
epoch:29, idx:8999/10845, loss:1.1533500792698728, acc:0.6965
epoch:29, idx:9099/10845, loss:1.1541304354508828, acc:0.6964010989010989
epoch:29, idx:9199/10845, loss:1.1546784833494736, acc:0.696304347826087
epoch:29, idx:9299/10845, loss:1.1536994060697734, acc:0.6966666666666667
epoch:29, idx:9399/10845, loss:1.1550267982926774, acc:0.69630319

In [None]:
score, P_score, label, P_label = model(batch)
        
loss1 = criterion(score, label)
loss_p = criterion_bce(P_score, P_label.float())

loss = loss1 + loss_p

loss.backward()
#optimizer.step()
#optimizer.zero_grad()        

In [None]:
model.p_score.linear.weight.grad

In [None]:
model.mention_att.ws1.weight.grad

In [None]:
outputs = []
def hook(module, input, output):
    outputs.append(output)

In [None]:
hook1 = model.fc.register_forward_hook()
hook12 = model.