### 当前实验模型内容

1. use mentions
2. 63.3

3. add passage score

model_name |  param | dev_acc|
---| --- | ---
use mentions | lr=1e-3,hidden=50 | 63.3
um-ps | 5e-4, 50 dropout:0.2| 64.94
um-ps | reson_p:0.2, step=3 | 65.00
um-ps | reason p:0.4, step=5| 66.12
um-ps no n-gram char| reason p:0.4, step=5| 65.50

In [1]:
import os
import torch
import torch.nn as nn
import torchtext
from tensorboardX import SummaryWriter
import random
import numpy as np

from torchtext.data import NestedField, Field, RawField
from model import *
from dataset import DataHandler
%load_ext autoreload

%autoreload 2

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
class Config:
    def __init__(self):
        self.hidden = 50
        self.embedding_dim = 300
        self.lr = 5e-4
        self.epochs = 50
        self.fix_length = None
        
        self.log_dir = './logs'
        self.model_name = 'CFC_um_ps_no_char_save'
        self.batch_size = 4
        self.train_data = './data/train_filter.pt'
        self.dev_data = './data/dev_filter.pt'
        
        self.word_vocab = './data/glove_vocab.pt'
        self.charNGram_vocab = './data/charNGram_vocab.pt'
        #self.word_vocab = None
        #self.charNGram_vocab = None
        
        self.dropout = 0.2
        self.seed = 1023
        self.steps = 5
        
config = Config()
device = torch.device("cuda:0")


In [3]:
torch.cuda.is_available()

True

In [4]:
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)

In [5]:
save_path = config.model_name  + '_lr_'+ str(config.lr)+ '__hidden__' + str(config.hidden) \
            + '_batchsize_' + str(config.batch_size) +  '_p'+ str(config.dropout)+'_steps'+str(config.steps)+'cycle_lr'
save_path = os.path.join(config.log_dir, save_path)   
print(save_path)
config.save_path = save_path

./logs/CFC_um_ps_no_char_save_lr_0.0005__hidden__50_batchsize_4_p0.2_steps5cycle_lr


### Define Fileds

In [6]:
word_field = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True) # query
multi_word_field = NestedField(word_field) 

word_field_sup = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True, fix_length=config.fix_length)
multi_word_field_sup = NestedField(word_field_sup) 

charNGram_field = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True) # query
multi_charNGram_field = NestedField(charNGram_field) 

charNGram_field_sup = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True, fix_length=config.fix_length)
multi_charNGram_field_sup = NestedField(charNGram_field_sup) 

raw = RawField()
raw.is_target = False

label_field = Field(sequential=False, is_target=True, use_vocab=False)

dict_field = {
    'id': ('id', raw),
    'supports': [('s_glove', multi_word_field_sup), ('s_charNGram', multi_charNGram_field_sup)],
    'query': [('q_glove', word_field), ('q_charNGram', charNGram_field)],
    'candidates': [('c_glove', multi_word_field), ('c_charNGram', multi_charNGram_field)],
    'label': ('label', label_field),
    'mentions': ('mentions', raw),
    'para_label': ('para_label', raw)
}

In [7]:
data_handler = DataHandler(config.train_data, config.dev_data, dict_field)

# torch.save(data_handler.trainset.examples, './data/train_example.pt')
# torch.save(data_handler.valset.examples, './data/dev_example.pt')

load examples.pt  :./data/train_filter.pt, ./data/dev_filter.pt


### Build Vocab

In [8]:
if config.charNGram_vocab is not None:
    charNGram_vocab = torch.load(config.charNGram_vocab)
    multi_charNGram_field_sup.vocab = charNGram_vocab
    charNGram_field_sup.vocab = charNGram_vocab
else:
    multi_charNGram_field_sup.build_vocab(data_handler.trainset, data_handler.valset, 
                                          vectors=torchtext.vocab.CharNGram())

if config.word_vocab is not None:
    word_vocab = torch.load(config.word_vocab)
    multi_word_field_sup.vocab = word_vocab
    word_field_sup.vocab = word_vocab
else:
    multi_word_field_sup.build_vocab(data_handler.trainset, data_handler.valset, 
                                 vectors=torchtext.vocab.GloVe(dim=300,name='6B') )

word_field.vocab = multi_word_field_sup.vocab
charNGram_field.vocab = multi_charNGram_field_sup.vocab



In [9]:
print(multi_word_field_sup.vocab.vectors.shape,multi_charNGram_field_sup.vocab.vectors.shape )

torch.Size([312667, 300]) torch.Size([312667, 100])


### Get data_iter

In [10]:
train_iter = data_handler.get_train_iter(batch_size=config.batch_size)
val_iter = data_handler.get_val_iter(batch_size=config.batch_size)

In [11]:
for idx, batch in enumerate(val_iter):
    break
batch


[torchtext.data.batch.Batch of size 4]
	[.id]:['WH_dev_0', 'WH_dev_1', 'WH_dev_2', 'WH_dev_3']
	[.s_glove]:[torch.LongTensor of size 4x15x292]
	[.s_charNGram]:[torch.LongTensor of size 4x15x292]
	[.q_glove]:[torch.LongTensor of size 4x11]
	[.q_charNGram]:[torch.LongTensor of size 4x11]
	[.c_glove]:[torch.LongTensor of size 4x18x4]
	[.c_charNGram]:[torch.LongTensor of size 4x18x4]
	[.label]:[torch.LongTensor of size 4]
	[.mentions]:[[[[6, 145, 146], [6, 173, 174], [7, 78, 79]], [[3, 50, 53], [5, 28, 31], [13, 1, 4]], [[6, 135, 136], [6, 218, 219], [6, 261, 262], [8, 45, 46], [12, 98, 99]], [[0, 2, 4], [7, 1, 3], [13, 64, 66], [13, 69, 71]], [[0, 36, 38], [10, 1, 3], [13, 75, 77]], [[0, 14, 15], [1, 63, 64], [1, 138, 139], [1, 186, 187], [1, 238, 239], [2, 128, 129], [9, 8, 9], [9, 43, 44], [10, 19, 20], [10, 34, 35], [11, 37, 38], [11, 79, 80], [13, 56, 57]], [[7, 37, 40]], [[6, 180, 181], [12, 101, 102]], [[12, 96, 97]], [[8, 43, 46]], [[6, 169, 172]], [[6, 179, 181]], [[7, 38, 40]], 

In [12]:
def get_para_label(batch):
    para_size = batch.s_glove.size(1)
    results = []
    for label in batch.para_label:
        padding = [0]*(para_size - len(label))
        n_label = label[:]
        n_label += padding
        results.append(n_label)
    results = torch.tensor(results, dtype=torch.long)
    return results

### Define Model

In [13]:
def generate_mask(x_size, num_turn, dropout_p=0.0, is_training=False):
    if not is_training: dropout_p = 0.0
    new_data = torch.zeros(x_size, num_turn)
    new_data = (1-dropout_p) * (new_data.zero_() + 1)
    for i in range(new_data.size(0)):
        one = random.randint(0, new_data.size(1)-1)
        new_data[i][one] = 1
    mask = 1.0/(1 - dropout_p) * torch.bernoulli(new_data)
    mask.requires_grad = False
    return mask

class SAN(nn.Module):
    def __init__(self, x_size, h_size, c_size, num_turn=5, dropout=0.2, memo_dropout=0.4, device=None):
        super(SAN,self).__init__()
        self.att = BilinearSeqAttn(x_size, h_size, dropout=dropout)
        self.rnn = nn.GRUCell(x_size, h_size)
        self.num_turn = num_turn
        self.fc = nn.Linear(x_size, c_size)
        self.dropout = nn.Dropout(p=dropout)
        self.memo_dropout=memo_dropout
        self.device = device
        
    def forward(self, x, h0, c):
        '''
        x: [batch, sup_len, hidden]
        h0:[batch, hidden]
        c: [batch, can_len, hidden]
        '''
        score_list = []
        
        for turn in range(self.num_turn):
            score = self.att(x, h0)
            x_att = torch.bmm(F.softmax(score,1).unsqueeze(1),x).squeeze(1) # [batch, hidden]
            
            h0 = self.dropout(h0)
            h0 = self.rnn(x_att, h0)
            
            x_att = torch.tanh(self.fc(x_att)) 

            score = torch.bmm(c, x_att.unsqueeze(-1))
            score = score.squeeze(-1) # [batch, can_len]
            score_list.append(score)
            
        mask = generate_mask(x.size(0),self.num_turn, self.memo_dropout, self.training)
        mask = mask.to(self.device)
        mask = [m.contiguous() for m in torch.unbind(mask, 1)]
        
        score_list = [mask[idx].view(x.size(0), 1).expand_as(inp) * inp for idx, inp in enumerate(score_list)]
        scores = torch.stack(score_list, 2)
        scores = torch.mean(scores, 2)
        return scores
    
    

In [14]:
class SimpleQANet(nn.Module):
    
    def __init__(self, config, word_vectors, charNGram_vectors, device):
        super(SimpleQANet, self).__init__()
        self.config = config
        self.device = device
        
        self.embedding_layer = EmbeddingLayer(word_vectors, charNGram_vectors)
        

        self.rnn = EncoderRNN(config.embedding_dim, config.hidden, 1, True, True, config.dropout, False)
                
            
        self.co_att = CoAttention(config.hidden*2, att_type=2, dropout=config.dropout)
        
        self.linear_1 = nn.Sequential(
                        nn.Linear(config.hidden*4, config.hidden),
                        nn.ReLU()
                    )        
        self.rnn2 =  EncoderRNN(config.hidden, config.hidden, 1, True, True, config.dropout, False)
        
        self.word_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        self.word_att_q = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        self.p_score = BilinearSeqAttn(config.hidden*2, config.hidden*2, identity=False, dropout=config.dropout)
        
        self.pass_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        self.c_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        self.mention_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        
        #self.fusion = FusionLayer(config.hidden*2, dropout=config.dropout)
        self.max_pooling = PoolingLayer()     
        
        self.fc = nn.Linear(config.hidden*2, config.hidden*4)
        self.san = SAN(config.hidden*2,config.hidden*2,config.hidden*4, num_turn=config.steps, device=device)
        
        self.to(device)
        
    def forward(self, batch, return_label = True):
        if type(batch.q_glove) is tuple:
            q_glove, _ = batch.q_glove
            q_charNGram, _ = batch.q_charNGram
        else:
            q_glove = batch.q_glove
            q_charNGram = batch.q_charNGram            
        
        s_glove = batch.s_glove
        s_charNGram = batch.s_charNGram
        
        c_glove = batch.c_glove
        c_charNGram = batch.c_charNGram
        
        q_glove = q_glove.to(self.device)
        q_charNGram = q_charNGram.to(self.device)

        s_glove = s_glove.to(self.device)
        s_charNGram = s_charNGram.to(self.device)

        c_glove = c_glove.to(self.device)
        c_charNGram = c_charNGram.to(self.device)
        
        ### Embedding and Encoder
        
        #q_out = self.embedding_layer(q_glove, q_charNGram)
        #s_out = self.embedding_layer(s_glove, s_charNGram,)
        #c_out = self.embedding_layer(c_glove, c_charNGram)
        
        q_out = self.embedding_layer(q_glove)
        s_out = self.embedding_layer(s_glove)
        c_out = self.embedding_layer(c_glove)        
        
        batch_size=  s_out.size(0)
        
        s_len = s_out.size(1)
        c_len = c_out.size(1)
        
        s_word_len = s_out.size(2)
        c_word_len = c_out.size(2)
        
        hidden = s_out.size(-1)
        
        s_out = s_out.view(batch_size*s_len, s_word_len, hidden).contiguous()
        c_out = c_out.view(batch_size*c_len, c_word_len, hidden).contiguous()
        
        q_out = self.rnn(q_out)
        c_out = self.rnn(c_out)
        s_out = self.rnn(s_out)
        
        # Attention
        
        q_word_len = q_out.size(1)
        q_out_expand = q_out.unsqueeze(1).expand(batch_size, s_len, q_word_len, q_out.size(-1)).contiguous()
        q_out_expand = q_out_expand.view(batch_size*s_len, q_word_len, q_out.size(-1)).contiguous()
        
        s_out_att, q_out_att = self.co_att(s_out, q_out_expand)
        #S_s = self.fusion(s_out, s_out_att)
        #S_q = self.fusion(q_out, q_out_att)
        
        S_s = self.linear_1(s_out_att)
        S_s = self.rnn2(S_s)
        
        
        
        batch_c_m = []
        for i in range(batch_size):
            # get mention embedding
            mentions = batch.mentions[i]
            c_ms = torch.zeros(c_len, s_len, s_out.size(-1))
            for idx, c_mention in enumerate(mentions):
                c_m_dict = {}
                for mention in c_mention:
                    m = s_out[i*s_len + mention[0]][mention[1]:mention[2]]
                    m = self.max_pooling(m.unsqueeze(0)).squeeze()
                    if mention[0] not in c_m_dict:
                        c_m_dict[mention[0]] = []
                    c_m_dict[mention[0]].append(m)
                c_m = torch.zeros(s_len, s_out.size(-1))
                for key in c_m_dict:
                    for m in c_m_dict[key]:
                        c_m[key] += m.cpu()
                    c_m[key] /= len(c_m_dict[key])
                c_ms[idx] = c_m
            batch_c_m.append(c_ms)
        batch_c_m = torch.stack(batch_c_m)
        batch_c_m = batch_c_m.to(self.device)
        batch_c_m = batch_c_m.view(batch_size*c_len, s_len, -1)
        batch_c_m = self.mention_att(batch_c_m)
        batch_c_m = batch_c_m.view(batch_size, c_len, -1)
        

        
        C_s = self.word_att(S_s)
        C_q = self.word_att_q(q_out)
        
        C_s = C_s.view(batch_size, s_len, -1)
        
        P_scores = self.p_score(C_s, C_q)
        

        
        C_c = self.c_att(c_out)        
        C_c = C_c.view(batch_size, c_len, -1)
        
        C_c = torch.cat([C_c, batch_c_m],-1)
        
        
        score = self.san(C_s, C_q, C_c)
        
        if return_label:
            label = batch.label.to(self.device)
            P_label = get_para_label(batch)
            P_label = P_label.to(device)
            return score, P_scores, label, P_label
        return score, P_scores

#### test model

In [15]:
model = SimpleQANet(config, word_field.vocab.vectors, charNGram_field.vocab.vectors, device)
#score,P_score, label, P_label = model(batch)
#print(score, label, P_score, P_label)
#print(score.shape, label.shape, P_score.shape, P_label.shape)

In [17]:
from tqdm import tqdm, trange

In [18]:
from utils import AverageMeter

def train(epoch, data_iter, model, criterion, criterion_bce, optimizer, batch_size=1, joint_begin=-1):
    losses = AverageMeter()
    acces = AverageMeter()
    model.train()
    #model.embedding_layer.eval()
    with trange(len(data_iter)) as t:
        for idx, batch in enumerate(data_iter):
            score, P_score, label, P_label = model(batch)

            loss1 = criterion(score, label)
            loss_p = criterion_bce(P_score, P_label.float())
            if epoch > joint_begin:
                loss = loss1 + loss_p     
            else:
                loss = loss_p

            loss = loss / batch_size
            loss.backward()
            if (idx+1)%batch_size == 0 :
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)            
                optimizer.step()
                optimizer.zero_grad()        

            losses.update(loss.item()*batch_size)

            pred = score.argmax(1)
            acc = pred.eq(label).sum().item()  / pred.size(0)
            acces.update(acc)
            
            t.update()
            if (idx+1) % (batch_size*100) == 0:
                print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

def val(epoch, data_iter, model, criterion, criterion_bce, joint_begin=-1):
    losses = AverageMeter()
    acces = AverageMeter()
    model.eval()
    for idx, batch in enumerate(data_iter):
        with torch.no_grad():
            score, P_score, label, P_label = model(batch)
                    
        loss1 = criterion(score, label)
        loss_p = criterion_bce(P_score, P_label.float())
        if epoch > joint_begin:
            loss = loss1 + loss_p     
        else:
            loss = loss_p
        
        losses.update(loss.item())
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()  / pred.size(0)
        acces.update(acc)
        if idx % 100 == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

In [19]:
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                             lr=config.lr)

criterion = nn.CrossEntropyLoss()
criterion_bce = nn.BCEWithLogitsLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
#train(0, train_iter, model, criterion, optimizer, batch_size=config.batch_size)
# val(0, val_iter, model,criterion)

In [20]:
cycle_len = 1
cycle_iter = 50

In [None]:
if not os.path.exists(config.save_path):
    os.makedirs(config.save_path)
writer = SummaryWriter(config.save_path)

best_acc = 0.0
for i in range(cycle_len):
    optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                             lr=config.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cycle_iter)
    for epoch in range(cycle_iter):
        scheduler.step()
        train_loss, train_acc = train(epoch, train_iter, model, criterion, criterion_bce, optimizer, 
                                         1)
        val_loss, val_acc = val(epoch, val_iter, model, criterion, criterion_bce)
        global_epoch = cycle_iter * i + epoch + 1
        writer.add_scalar('train_loss', train_loss, global_epoch)
        writer.add_scalar('val_loss', val_loss, global_epoch)
        writer.add_scalar('train_acc', train_acc, global_epoch)
        writer.add_scalar('val_acc', val_acc, global_epoch)

        state = {
            'val_acc': val_acc,
            'train_acc': train_acc,
            'epoch': epoch,
            'model': model.state_dict()
        }
        torch.save(state, os.path.join(config.save_path,'lastest.pth'))
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(state, os.path.join(save_path, f'best_epoch{epoch}.pth'))

  alphas = self.softmax(alphas)  # (bsz, sent_len)
  1%|          | 100/10845 [01:02<1:32:44,  1.93it/s]

epoch:0, idx:99/10845, loss:3.5322713935375214, acc:0.19


  2%|▏         | 200/10845 [02:00<1:35:07,  1.87it/s]

epoch:0, idx:199/10845, loss:3.1706510162353516, acc:0.25


  3%|▎         | 300/10845 [03:01<1:39:05,  1.77it/s]

epoch:0, idx:299/10845, loss:2.9915048694610595, acc:0.27166666666666667


  4%|▎         | 400/10845 [03:58<1:15:50,  2.30it/s]

epoch:0, idx:399/10845, loss:2.85627171933651, acc:0.295625


  5%|▍         | 500/10845 [04:58<1:46:36,  1.62it/s]

epoch:0, idx:499/10845, loss:2.783329175949097, acc:0.303


  6%|▌         | 600/10845 [05:56<1:14:00,  2.31it/s]

epoch:0, idx:599/10845, loss:2.71351801554362, acc:0.31625


  6%|▋         | 700/10845 [07:00<1:48:19,  1.56it/s]

epoch:0, idx:699/10845, loss:2.6874175867864065, acc:0.3192857142857143


  7%|▋         | 800/10845 [08:00<1:47:16,  1.56it/s]

epoch:0, idx:799/10845, loss:2.635710633546114, acc:0.32875


  8%|▊         | 900/10845 [09:03<1:51:41,  1.48it/s]

epoch:0, idx:899/10845, loss:2.5842120410336387, acc:0.34194444444444444


  9%|▉         | 1000/10845 [10:02<1:32:34,  1.77it/s]

epoch:0, idx:999/10845, loss:2.55336423265934, acc:0.34625


 10%|█         | 1100/10845 [11:00<1:46:41,  1.52it/s]

epoch:0, idx:1099/10845, loss:2.5309066755121403, acc:0.3509090909090909


 11%|█         | 1200/10845 [11:59<1:39:09,  1.62it/s]

epoch:0, idx:1199/10845, loss:2.511420675218105, acc:0.35229166666666667


 12%|█▏        | 1300/10845 [12:56<1:54:39,  1.39it/s]

epoch:0, idx:1299/10845, loss:2.4866002217622905, acc:0.3573076923076923


 13%|█▎        | 1400/10845 [13:55<1:20:32,  1.95it/s]

epoch:0, idx:1399/10845, loss:2.4656204162750925, acc:0.36142857142857143


 14%|█▍        | 1500/10845 [14:57<1:42:06,  1.53it/s]

epoch:0, idx:1499/10845, loss:2.4496890643835068, acc:0.36566666666666664


 15%|█▍        | 1600/10845 [15:59<1:40:03,  1.54it/s]

epoch:0, idx:1599/10845, loss:2.432006627395749, acc:0.36734375


 16%|█▌        | 1700/10845 [17:03<1:59:45,  1.27it/s]

epoch:0, idx:1699/10845, loss:2.4255990909127627, acc:0.3683823529411765


 17%|█▋        | 1800/10845 [18:01<1:17:31,  1.94it/s]

epoch:0, idx:1799/10845, loss:2.4088736352324487, acc:0.3726388888888889


 18%|█▊        | 1900/10845 [19:01<1:15:28,  1.98it/s]

epoch:0, idx:1899/10845, loss:2.398387525865906, acc:0.3738157894736842


 18%|█▊        | 2000/10845 [20:03<1:25:29,  1.72it/s]

epoch:0, idx:1999/10845, loss:2.387915161818266, acc:0.376375


 19%|█▉        | 2100/10845 [21:03<1:19:47,  1.83it/s]

epoch:0, idx:2099/10845, loss:2.381273531346094, acc:0.3769047619047619


 20%|██        | 2200/10845 [22:02<1:04:06,  2.25it/s]

epoch:0, idx:2199/10845, loss:2.369785328073935, acc:0.3794318181818182


 21%|██        | 2300/10845 [23:00<1:07:22,  2.11it/s]

epoch:0, idx:2299/10845, loss:2.3590276363621587, acc:0.3804347826086957


 22%|██▏       | 2400/10845 [24:02<1:04:59,  2.17it/s]

epoch:0, idx:2399/10845, loss:2.350936353156964, acc:0.3827083333333333


 23%|██▎       | 2500/10845 [25:01<1:13:29,  1.89it/s]

epoch:0, idx:2499/10845, loss:2.339938402915001, acc:0.385


 24%|██▍       | 2600/10845 [26:00<1:39:14,  1.38it/s]

epoch:0, idx:2599/10845, loss:2.3379019493323105, acc:0.3858653846153846


 25%|██▍       | 2700/10845 [27:02<1:07:06,  2.02it/s]

epoch:0, idx:2699/10845, loss:2.3263152370629485, acc:0.38851851851851854


 26%|██▌       | 2800/10845 [28:00<1:22:09,  1.63it/s]

epoch:0, idx:2799/10845, loss:2.318104272761515, acc:0.390625


 27%|██▋       | 2900/10845 [28:57<1:06:57,  1.98it/s]

epoch:0, idx:2899/10845, loss:2.307880508529729, acc:0.39232758620689656


 28%|██▊       | 3000/10845 [29:58<57:04,  2.29it/s]  

epoch:0, idx:2999/10845, loss:2.303415585269531, acc:0.39325


 29%|██▊       | 3100/10845 [30:58<1:26:54,  1.49it/s]

epoch:0, idx:3099/10845, loss:2.2966377992110867, acc:0.3945161290322581


 30%|██▉       | 3200/10845 [31:56<1:28:26,  1.44it/s]

epoch:0, idx:3199/10845, loss:2.290208999952301, acc:0.3959375


 30%|███       | 3300/10845 [32:57<1:13:02,  1.72it/s]

epoch:0, idx:3299/10845, loss:2.2820074848604923, acc:0.39765151515151514


 31%|███▏      | 3400/10845 [33:55<1:07:41,  1.83it/s]

epoch:0, idx:3399/10845, loss:2.2734897138178347, acc:0.4002941176470588


 32%|███▏      | 3500/10845 [34:57<1:02:57,  1.94it/s]

epoch:0, idx:3499/10845, loss:2.2670167533925603, acc:0.4007857142857143


 33%|███▎      | 3600/10845 [35:55<1:05:48,  1.83it/s]

epoch:0, idx:3599/10845, loss:2.260362837190429, acc:0.4025


 34%|███▍      | 3700/10845 [36:53<1:12:34,  1.64it/s]

epoch:0, idx:3699/10845, loss:2.2537598332601623, acc:0.40364864864864863


 35%|███▌      | 3800/10845 [37:53<59:55,  1.96it/s]  

epoch:0, idx:3799/10845, loss:2.246082622565721, acc:0.4061842105263158


 36%|███▌      | 3900/10845 [38:52<1:08:08,  1.70it/s]

epoch:0, idx:3899/10845, loss:2.2392650289566087, acc:0.4078846153846154


 37%|███▋      | 4000/10845 [39:53<1:08:57,  1.65it/s]

epoch:0, idx:3999/10845, loss:2.2330236473977565, acc:0.409375


 38%|███▊      | 4101/10845 [40:56<46:17,  2.43it/s]  

epoch:0, idx:4099/10845, loss:2.231180366728364, acc:0.4096341463414634


 39%|███▊      | 4200/10845 [41:59<1:23:25,  1.33it/s]

epoch:0, idx:4199/10845, loss:2.2275672911249456, acc:0.41035714285714286


 40%|███▉      | 4300/10845 [42:56<48:38,  2.24it/s]  

epoch:0, idx:4299/10845, loss:2.222125321176163, acc:0.4105813953488372


 41%|████      | 4400/10845 [43:58<1:07:21,  1.59it/s]

epoch:0, idx:4399/10845, loss:2.2184889268400996, acc:0.413125


 41%|████▏     | 4500/10845 [44:56<1:00:13,  1.76it/s]

epoch:0, idx:4499/10845, loss:2.212154425435596, acc:0.4146111111111111


 42%|████▏     | 4600/10845 [45:53<54:45,  1.90it/s]  

epoch:0, idx:4599/10845, loss:2.20696970379223, acc:0.41657608695652176


 43%|████▎     | 4700/10845 [46:51<1:14:24,  1.38it/s]

epoch:0, idx:4699/10845, loss:2.201538470989846, acc:0.4176063829787234


 44%|████▍     | 4800/10845 [47:50<56:35,  1.78it/s]  

epoch:0, idx:4799/10845, loss:2.1986355964280664, acc:0.418125


 45%|████▌     | 4900/10845 [48:52<51:01,  1.94it/s]  

epoch:0, idx:4899/10845, loss:2.1946113447997035, acc:0.4190816326530612


 46%|████▌     | 5000/10845 [49:49<1:00:22,  1.61it/s]

epoch:0, idx:4999/10845, loss:2.188883385461569, acc:0.42045


 47%|████▋     | 5100/10845 [50:45<55:23,  1.73it/s]  

epoch:0, idx:5099/10845, loss:2.183839599157081, acc:0.4218627450980392


 48%|████▊     | 5200/10845 [51:44<1:01:39,  1.53it/s]

epoch:0, idx:5199/10845, loss:2.1808771324100404, acc:0.4226923076923077


 49%|████▉     | 5300/10845 [52:44<1:01:59,  1.49it/s]

epoch:0, idx:5299/10845, loss:2.177008451262735, acc:0.4232075471698113


 50%|████▉     | 5400/10845 [53:42<49:19,  1.84it/s]  

epoch:0, idx:5399/10845, loss:2.1729828052995384, acc:0.4237037037037037


 51%|█████     | 5500/10845 [54:39<54:51,  1.62it/s]  

epoch:0, idx:5499/10845, loss:2.1675649666677823, acc:0.4240909090909091


 52%|█████▏    | 5600/10845 [55:37<41:07,  2.13it/s]  

epoch:0, idx:5599/10845, loss:2.162313947964992, acc:0.4253571428571429


 53%|█████▎    | 5700/10845 [56:38<49:21,  1.74it/s]  

epoch:0, idx:5699/10845, loss:2.1588635301171686, acc:0.42653508771929827


 53%|█████▎    | 5800/10845 [57:38<49:49,  1.69it/s]  

epoch:0, idx:5799/10845, loss:2.1545096818640315, acc:0.4279741379310345


 54%|█████▍    | 5900/10845 [58:40<50:29,  1.63it/s]  

epoch:0, idx:5899/10845, loss:2.1498081111200786, acc:0.428728813559322


 55%|█████▌    | 6000/10845 [59:39<45:45,  1.76it/s]  

epoch:0, idx:5999/10845, loss:2.1433430577466885, acc:0.42995833333333333


 56%|█████▌    | 6100/10845 [1:00:39<43:36,  1.81it/s]  

epoch:0, idx:6099/10845, loss:2.1388148945228, acc:0.4310245901639344


 57%|█████▋    | 6200/10845 [1:01:38<51:23,  1.51it/s]  

epoch:0, idx:6199/10845, loss:2.134961709711821, acc:0.43233870967741933


 58%|█████▊    | 6300/10845 [1:02:36<50:17,  1.51it/s]  

epoch:0, idx:6299/10845, loss:2.1332014785541427, acc:0.43313492063492065


 59%|█████▉    | 6400/10845 [1:03:40<34:28,  2.15it/s]  

epoch:0, idx:6399/10845, loss:2.1294459889968858, acc:0.43390625


 60%|█████▉    | 6500/10845 [1:04:40<39:59,  1.81it/s]  

epoch:0, idx:6499/10845, loss:2.1245726598501204, acc:0.43503846153846154


 61%|██████    | 6600/10845 [1:05:40<44:01,  1.61it/s]  

epoch:0, idx:6599/10845, loss:2.1193559113325495, acc:0.4363257575757576


 62%|██████▏   | 6700/10845 [1:06:42<41:00,  1.68it/s]  

epoch:0, idx:6699/10845, loss:2.115453531092672, acc:0.4373507462686567


 63%|██████▎   | 6800/10845 [1:07:40<42:41,  1.58it/s]

epoch:0, idx:6799/10845, loss:2.1100338214460543, acc:0.4386029411764706


 64%|██████▎   | 6900/10845 [1:08:40<28:03,  2.34it/s]

epoch:0, idx:6899/10845, loss:2.1072976026068564, acc:0.4393478260869565


 65%|██████▍   | 7000/10845 [1:09:42<39:19,  1.63it/s]  

epoch:0, idx:6999/10845, loss:2.103830952086619, acc:0.44014285714285717


 65%|██████▌   | 7100/10845 [1:10:38<34:54,  1.79it/s]

epoch:0, idx:7099/10845, loss:2.0996364765142053, acc:0.4408802816901408


 66%|██████▋   | 7200/10845 [1:11:34<32:42,  1.86it/s]

epoch:0, idx:7199/10845, loss:2.0974150148572193, acc:0.44125


 67%|██████▋   | 7300/10845 [1:12:34<36:17,  1.63it/s]

epoch:0, idx:7299/10845, loss:2.094158499224545, acc:0.4420890410958904


 68%|██████▊   | 7400/10845 [1:13:31<30:27,  1.88it/s]

epoch:0, idx:7399/10845, loss:2.091736463352635, acc:0.44280405405405404


 69%|██████▉   | 7500/10845 [1:14:35<37:49,  1.47it/s]  

epoch:0, idx:7499/10845, loss:2.089976219681899, acc:0.4433666666666667


 70%|███████   | 7600/10845 [1:15:35<29:15,  1.85it/s]

epoch:0, idx:7599/10845, loss:2.0870255687323054, acc:0.4440789473684211


 71%|███████   | 7700/10845 [1:16:36<27:21,  1.92it/s]

epoch:0, idx:7699/10845, loss:2.0836030472369935, acc:0.44487012987012986


 72%|███████▏  | 7800/10845 [1:17:36<27:20,  1.86it/s]

epoch:0, idx:7799/10845, loss:2.0801554384693883, acc:0.4454166666666667


 73%|███████▎  | 7900/10845 [1:18:37<23:07,  2.12it/s]

epoch:0, idx:7899/10845, loss:2.077641548844455, acc:0.44617088607594935


 74%|███████▍  | 8000/10845 [1:19:37<25:59,  1.82it/s]

epoch:0, idx:7999/10845, loss:2.076718298414722, acc:0.4463125


 75%|███████▍  | 8100/10845 [1:20:37<37:01,  1.24it/s]

epoch:0, idx:8099/10845, loss:2.074404303996283, acc:0.44688271604938273


 76%|███████▌  | 8200/10845 [1:21:34<22:07,  1.99it/s]

epoch:0, idx:8199/10845, loss:2.0709767611306615, acc:0.4473170731707317


 77%|███████▋  | 8300/10845 [1:22:32<28:19,  1.50it/s]

epoch:0, idx:8299/10845, loss:2.067500458936016, acc:0.44810240963855424


 77%|███████▋  | 8400/10845 [1:23:33<27:17,  1.49it/s]

epoch:0, idx:8399/10845, loss:2.0666592726022714, acc:0.44821428571428573


 78%|███████▊  | 8500/10845 [1:24:32<22:25,  1.74it/s]

epoch:0, idx:8499/10845, loss:2.0642658773748312, acc:0.4485588235294118


 79%|███████▉  | 8600/10845 [1:25:29<18:25,  2.03it/s]

epoch:0, idx:8599/10845, loss:2.062657406671449, acc:0.4490406976744186


 80%|████████  | 8700/10845 [1:26:32<15:53,  2.25it/s]

epoch:0, idx:8699/10845, loss:2.0604484089917836, acc:0.449683908045977


 81%|████████  | 8800/10845 [1:27:35<24:23,  1.40it/s]

epoch:0, idx:8799/10845, loss:2.058222128107128, acc:0.45034090909090907


 82%|████████▏ | 8900/10845 [1:28:33<16:59,  1.91it/s]

epoch:0, idx:8899/10845, loss:2.0559238479194346, acc:0.4510393258426966


 83%|████████▎ | 9000/10845 [1:29:29<20:16,  1.52it/s]

epoch:0, idx:8999/10845, loss:2.0535882480823333, acc:0.4518611111111111


 84%|████████▍ | 9100/10845 [1:30:28<15:14,  1.91it/s]

epoch:0, idx:9099/10845, loss:2.0504806395857544, acc:0.45274725274725275


 85%|████████▍ | 9200/10845 [1:31:24<17:17,  1.59it/s]

epoch:0, idx:9199/10845, loss:2.0484896737709644, acc:0.4534239130434783


 86%|████████▌ | 9300/10845 [1:32:28<14:58,  1.72it/s]

epoch:0, idx:9299/10845, loss:2.0457748244830998, acc:0.454005376344086


 87%|████████▋ | 9400/10845 [1:33:27<16:31,  1.46it/s]

epoch:0, idx:9399/10845, loss:2.043319661045011, acc:0.4546276595744681


 88%|████████▊ | 9500/10845 [1:34:28<14:09,  1.58it/s]

epoch:0, idx:9499/10845, loss:2.040936368052897, acc:0.4551842105263158


 89%|████████▊ | 9600/10845 [1:35:25<12:15,  1.69it/s]

epoch:0, idx:9599/10845, loss:2.038880929532461, acc:0.455546875


 89%|████████▉ | 9700/10845 [1:36:28<13:19,  1.43it/s]

epoch:0, idx:9699/10845, loss:2.0379381360697377, acc:0.4559020618556701


 90%|█████████ | 9800/10845 [1:37:30<08:57,  1.94it/s]

epoch:0, idx:9799/10845, loss:2.0360002088409908, acc:0.45622448979591834


 91%|█████████▏| 9900/10845 [1:38:29<09:23,  1.68it/s]

epoch:0, idx:9899/10845, loss:2.0344203929482685, acc:0.4569191919191919


 92%|█████████▏| 10000/10845 [1:39:30<07:55,  1.78it/s]

epoch:0, idx:9999/10845, loss:2.0327889201238754, acc:0.457625


 93%|█████████▎| 10100/10845 [1:40:31<10:22,  1.20it/s]

epoch:0, idx:10099/10845, loss:2.0300659541667687, acc:0.4583910891089109


 94%|█████████▍| 10200/10845 [1:41:30<07:42,  1.40it/s]

epoch:0, idx:10199/10845, loss:2.0292424558439093, acc:0.4586029411764706


 95%|█████████▍| 10300/10845 [1:42:31<05:22,  1.69it/s]

epoch:0, idx:10299/10845, loss:2.0274440393300313, acc:0.4587378640776699


 96%|█████████▌| 10400/10845 [1:43:30<04:40,  1.58it/s]

epoch:0, idx:10399/10845, loss:2.025812394964294, acc:0.45913461538461536


 97%|█████████▋| 10500/10845 [1:44:29<02:44,  2.10it/s]

epoch:0, idx:10499/10845, loss:2.023207047407116, acc:0.4597857142857143


 98%|█████████▊| 10600/10845 [1:45:28<02:16,  1.80it/s]

epoch:0, idx:10599/10845, loss:2.0198343247751582, acc:0.4604009433962264


 99%|█████████▊| 10700/10845 [1:46:25<01:16,  1.90it/s]

epoch:0, idx:10699/10845, loss:2.0179514120686277, acc:0.46102803738317755


100%|█████████▉| 10800/10845 [1:47:25<00:24,  1.81it/s]

epoch:0, idx:10799/10845, loss:2.0153534377490483, acc:0.4621527777777778


100%|██████████| 10845/10845 [1:47:50<00:00,  1.69it/s]


epoch:0, idx:0/1275, loss:1.5959317684173584, acc:0.25
epoch:0, idx:100/1275, loss:1.784310257080758, acc:0.5173267326732673
epoch:0, idx:200/1275, loss:1.7295144169188257, acc:0.5335820895522388
epoch:0, idx:300/1275, loss:1.720778158534801, acc:0.5307308970099668
epoch:0, idx:400/1275, loss:1.7179905489793146, acc:0.5305486284289277
epoch:0, idx:500/1275, loss:1.7194111429288716, acc:0.5349301397205589
epoch:0, idx:600/1275, loss:1.7249160044006817, acc:0.528702163061564
epoch:0, idx:700/1275, loss:1.7150068969257208, acc:0.5310271041369472
epoch:0, idx:800/1275, loss:1.726561816890588, acc:0.5227840199750312
epoch:0, idx:900/1275, loss:1.7245922780923388, acc:0.5230299667036626
epoch:0, idx:1000/1275, loss:1.735311128668018, acc:0.51998001998002
epoch:0, idx:1100/1275, loss:1.731954125438789, acc:0.5199818346957311
epoch:0, idx:1200/1275, loss:1.7339711444364003, acc:0.5181099084096586


  1%|          | 100/10845 [00:57<1:42:29,  1.75it/s]

epoch:1, idx:99/10845, loss:1.8177885353565215, acc:0.495


  2%|▏         | 200/10845 [01:55<1:44:18,  1.70it/s]

epoch:1, idx:199/10845, loss:1.730317063778639, acc:0.54


  3%|▎         | 300/10845 [02:58<1:47:21,  1.64it/s]

epoch:1, idx:299/10845, loss:1.763959134221077, acc:0.5341666666666667


  4%|▎         | 400/10845 [03:57<1:32:40,  1.88it/s]

epoch:1, idx:399/10845, loss:1.766843361929059, acc:0.53


  5%|▍         | 500/10845 [04:57<1:55:01,  1.50it/s]

epoch:1, idx:499/10845, loss:1.745472249507904, acc:0.536


  6%|▌         | 600/10845 [05:55<1:21:33,  2.09it/s]

epoch:1, idx:599/10845, loss:1.741843127856652, acc:0.5395833333333333


  6%|▋         | 700/10845 [06:57<1:48:37,  1.56it/s]

epoch:1, idx:699/10845, loss:1.7524484353831835, acc:0.5335714285714286


  7%|▋         | 800/10845 [07:58<1:30:58,  1.84it/s]

epoch:1, idx:799/10845, loss:1.7490693258494139, acc:0.533125


  8%|▊         | 900/10845 [08:58<1:56:59,  1.42it/s]

epoch:1, idx:899/10845, loss:1.7401029430826505, acc:0.5366666666666666


  9%|▉         | 1000/10845 [09:59<1:47:14,  1.53it/s]

epoch:1, idx:999/10845, loss:1.7330821735262871, acc:0.5395


 10%|█         | 1100/10845 [11:03<1:27:25,  1.86it/s]

epoch:1, idx:1099/10845, loss:1.7335585122216832, acc:0.538409090909091


 11%|█         | 1200/10845 [12:05<1:23:09,  1.93it/s]

epoch:1, idx:1199/10845, loss:1.7220050815989574, acc:0.5427083333333333


 12%|█▏        | 1300/10845 [13:08<1:39:10,  1.60it/s]

epoch:1, idx:1299/10845, loss:1.7260112381325319, acc:0.5432692307692307


 13%|█▎        | 1400/10845 [14:09<1:35:25,  1.65it/s]

epoch:1, idx:1399/10845, loss:1.72311439730227, acc:0.54375


 14%|█▍        | 1500/10845 [15:09<2:14:30,  1.16it/s]

epoch:1, idx:1499/10845, loss:1.7248426174620788, acc:0.543


 15%|█▍        | 1600/10845 [16:10<1:47:20,  1.44it/s]

epoch:1, idx:1599/10845, loss:1.7266835379693657, acc:0.54390625


 16%|█▌        | 1700/10845 [17:10<1:28:08,  1.73it/s]

epoch:1, idx:1699/10845, loss:1.7215924330932253, acc:0.545


 17%|█▋        | 1800/10845 [18:12<1:26:03,  1.75it/s]

epoch:1, idx:1799/10845, loss:1.7248831511951155, acc:0.5445833333333333


 18%|█▊        | 1900/10845 [19:10<1:19:41,  1.87it/s]

epoch:1, idx:1899/10845, loss:1.7280837834744076, acc:0.5444736842105263


 18%|█▊        | 2000/10845 [20:10<1:45:16,  1.40it/s]

epoch:1, idx:1999/10845, loss:1.7265507308468222, acc:0.54425


 19%|█▉        | 2100/10845 [21:09<1:40:42,  1.45it/s]

epoch:1, idx:2099/10845, loss:1.7236591964676267, acc:0.5448809523809524


 20%|██        | 2200/10845 [22:13<1:29:55,  1.60it/s]

epoch:1, idx:2199/10845, loss:1.7273123903233896, acc:0.5448863636363637


 21%|██        | 2300/10845 [23:13<1:33:31,  1.52it/s]

epoch:1, idx:2299/10845, loss:1.7253958266779132, acc:0.5453260869565217


 22%|██▏       | 2400/10845 [24:11<1:36:51,  1.45it/s]

epoch:1, idx:2399/10845, loss:1.7282506868305305, acc:0.5435416666666667


 23%|██▎       | 2500/10845 [25:10<1:47:18,  1.30it/s]

epoch:1, idx:2499/10845, loss:1.7308862277448178, acc:0.5426


 24%|██▍       | 2600/10845 [26:11<1:07:32,  2.03it/s]

epoch:1, idx:2599/10845, loss:1.7297656976546232, acc:0.5440384615384616


 25%|██▍       | 2700/10845 [27:09<1:12:29,  1.87it/s]

epoch:1, idx:2699/10845, loss:1.7316072344835158, acc:0.543425925925926


 26%|██▌       | 2800/10845 [28:10<1:29:41,  1.49it/s]

epoch:1, idx:2799/10845, loss:1.7324345372670462, acc:0.5435714285714286


 27%|██▋       | 2900/10845 [29:09<1:04:38,  2.05it/s]

epoch:1, idx:2899/10845, loss:1.7294947588906207, acc:0.5443103448275862


 28%|██▊       | 3000/10845 [30:12<1:23:26,  1.57it/s]

epoch:1, idx:2999/10845, loss:1.7311064419398705, acc:0.5438333333333333


 29%|██▊       | 3100/10845 [31:10<1:18:31,  1.64it/s]

epoch:1, idx:3099/10845, loss:1.7332800585175714, acc:0.5434677419354839


 30%|██▉       | 3200/10845 [32:12<1:12:52,  1.75it/s]

epoch:1, idx:3199/10845, loss:1.7329600097378717, acc:0.543203125


 30%|███       | 3300/10845 [33:15<1:19:12,  1.59it/s]

epoch:1, idx:3299/10845, loss:1.732973053125721, acc:0.5437121212121212


 31%|███▏      | 3400/10845 [34:16<58:42,  2.11it/s]  

epoch:1, idx:3399/10845, loss:1.7364405343243305, acc:0.5413235294117648


 32%|███▏      | 3500/10845 [35:13<1:10:46,  1.73it/s]

epoch:1, idx:3499/10845, loss:1.7383587622259344, acc:0.5415


 33%|███▎      | 3600/10845 [36:13<1:23:56,  1.44it/s]

epoch:1, idx:3599/10845, loss:1.7389145771621002, acc:0.5416666666666666


 34%|███▍      | 3700/10845 [37:09<1:22:18,  1.45it/s]

epoch:1, idx:3699/10845, loss:1.738304227449604, acc:0.5414189189189189


 35%|███▌      | 3800/10845 [38:06<58:52,  1.99it/s]  

epoch:1, idx:3799/10845, loss:1.7348752183153442, acc:0.5417763157894737


 36%|███▌      | 3900/10845 [39:03<57:24,  2.02it/s]  

epoch:1, idx:3899/10845, loss:1.7298047720736418, acc:0.5431410256410256


 37%|███▋      | 4000/10845 [40:02<1:16:53,  1.48it/s]

epoch:1, idx:3999/10845, loss:1.7305173481814564, acc:0.5428125


 38%|███▊      | 4100/10845 [41:03<1:14:14,  1.51it/s]

epoch:1, idx:4099/10845, loss:1.7320330925794636, acc:0.5422560975609756


 40%|███▉      | 4300/10845 [43:02<1:14:06,  1.47it/s]

epoch:1, idx:4299/10845, loss:1.7282253867699657, acc:0.5430813953488373


 41%|████      | 4400/10845 [43:59<1:01:43,  1.74it/s]

epoch:1, idx:4399/10845, loss:1.7257455086674203, acc:0.5434659090909091


 41%|████▏     | 4500/10845 [44:55<54:51,  1.93it/s]  

epoch:1, idx:4499/10845, loss:1.726587992982732, acc:0.5431111111111111


 42%|████▏     | 4600/10845 [45:53<1:10:48,  1.47it/s]

epoch:1, idx:4599/10845, loss:1.7249372605234385, acc:0.5433695652173913


 43%|████▎     | 4700/10845 [46:53<1:02:16,  1.64it/s]

epoch:1, idx:4699/10845, loss:1.7260930758651267, acc:0.543031914893617


 44%|████▍     | 4800/10845 [47:49<1:08:08,  1.48it/s]

epoch:1, idx:4799/10845, loss:1.7256438323296606, acc:0.5426041666666667


 45%|████▌     | 4900/10845 [48:47<50:29,  1.96it/s]  

epoch:1, idx:4899/10845, loss:1.7262358651173357, acc:0.5423979591836735


 46%|████▌     | 5000/10845 [49:50<1:04:44,  1.50it/s]

epoch:1, idx:4999/10845, loss:1.72781841340065, acc:0.5421


 47%|████▋     | 5100/10845 [50:47<58:13,  1.64it/s]  

epoch:1, idx:5099/10845, loss:1.7279972675851747, acc:0.5415686274509804


 48%|████▊     | 5200/10845 [51:47<51:16,  1.83it/s]  

epoch:1, idx:5199/10845, loss:1.7251745783938812, acc:0.5422115384615385


 49%|████▉     | 5300/10845 [52:48<46:50,  1.97it/s]  

epoch:1, idx:5299/10845, loss:1.7260940191880711, acc:0.5417924528301887


 50%|████▉     | 5400/10845 [53:52<57:48,  1.57it/s]  

epoch:1, idx:5399/10845, loss:1.7265049231769862, acc:0.5416666666666666


 51%|█████     | 5500/10845 [54:51<1:09:05,  1.29it/s]

epoch:1, idx:5499/10845, loss:1.724634905397892, acc:0.5421818181818182


 52%|█████▏    | 5600/10845 [55:56<1:10:09,  1.25it/s]

epoch:1, idx:5599/10845, loss:1.7260120240811792, acc:0.5415178571428572


 53%|█████▎    | 5700/10845 [56:56<57:34,  1.49it/s]  

epoch:1, idx:5699/10845, loss:1.7255042659112236, acc:0.5417543859649123


 53%|█████▎    | 5800/10845 [57:52<48:17,  1.74it/s]  

epoch:1, idx:5799/10845, loss:1.7258978182925233, acc:0.5416810344827586


 54%|█████▍    | 5900/10845 [58:51<50:41,  1.63it/s]  

epoch:1, idx:5899/10845, loss:1.7243424807791994, acc:0.5422033898305084


 55%|█████▌    | 6000/10845 [59:58<53:08,  1.52it/s]  

epoch:1, idx:5999/10845, loss:1.7265786679064234, acc:0.542125


 56%|█████▌    | 6100/10845 [1:01:03<51:52,  1.52it/s]  

epoch:1, idx:6099/10845, loss:1.7267551109561177, acc:0.5418852459016393


 57%|█████▋    | 6200/10845 [1:02:02<42:48,  1.81it/s]  

epoch:1, idx:6199/10845, loss:1.7271504300904852, acc:0.5419758064516129


 58%|█████▊    | 6300/10845 [1:03:01<53:38,  1.41it/s]

epoch:1, idx:6299/10845, loss:1.7263338939041373, acc:0.5418650793650793


 59%|█████▉    | 6400/10845 [1:04:00<39:51,  1.86it/s]  

epoch:1, idx:6399/10845, loss:1.7264993261056951, acc:0.5416796875


 60%|█████▉    | 6500/10845 [1:05:00<39:43,  1.82it/s]  

epoch:1, idx:6499/10845, loss:1.7252365375780143, acc:0.542


 61%|██████    | 6600/10845 [1:06:02<45:59,  1.54it/s]  

epoch:1, idx:6599/10845, loss:1.7227091786287951, acc:0.5427272727272727


 62%|██████▏   | 6700/10845 [1:07:01<43:01,  1.61it/s]

epoch:1, idx:6699/10845, loss:1.7231703815064323, acc:0.5426865671641791


 63%|██████▎   | 6800/10845 [1:08:03<35:17,  1.91it/s]  

epoch:1, idx:6799/10845, loss:1.7224613710602417, acc:0.5430882352941176


 64%|██████▎   | 6900/10845 [1:09:01<39:09,  1.68it/s]

epoch:1, idx:6899/10845, loss:1.722850740282, acc:0.5427536231884058


 65%|██████▍   | 7000/10845 [1:10:04<43:08,  1.49it/s]

epoch:1, idx:6999/10845, loss:1.722363219152604, acc:0.5428928571428572


 65%|██████▌   | 7100/10845 [1:11:04<35:44,  1.75it/s]  

epoch:1, idx:7099/10845, loss:1.7211455613227797, acc:0.5434859154929578


 66%|██████▋   | 7200/10845 [1:12:03<35:59,  1.69it/s]

epoch:1, idx:7199/10845, loss:1.72111616578574, acc:0.5439236111111111


 67%|██████▋   | 7243/10845 [1:12:30<37:30,  1.60it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 96%|█████████▌| 10400/10845 [1:43:50<04:59,  1.49it/s]

epoch:1, idx:10399/10845, loss:1.7032602607372862, acc:0.5488221153846153


 97%|█████████▋| 10500/10845 [1:44:50<03:52,  1.48it/s]

epoch:1, idx:10499/10845, loss:1.7031744786387397, acc:0.5488571428571428


 98%|█████████▊| 10600/10845 [1:45:52<02:59,  1.37it/s]

epoch:1, idx:10599/10845, loss:1.703335246966695, acc:0.5487971698113208


 99%|█████████▊| 10700/10845 [1:46:51<01:37,  1.49it/s]

epoch:1, idx:10699/10845, loss:1.7033808979726284, acc:0.5485981308411215


100%|█████████▉| 10800/10845 [1:47:51<00:27,  1.63it/s]

epoch:1, idx:10799/10845, loss:1.704900074871602, acc:0.5485648148148148


100%|██████████| 10845/10845 [1:48:16<00:00,  2.55it/s]


epoch:1, idx:0/1275, loss:1.2758736610412598, acc:0.5
epoch:1, idx:100/1275, loss:1.6536215602761448, acc:0.5717821782178217
epoch:1, idx:200/1275, loss:1.5987262997164655, acc:0.568407960199005
epoch:1, idx:300/1275, loss:1.5883581433383334, acc:0.5681063122923588
epoch:1, idx:400/1275, loss:1.5862848625664698, acc:0.5704488778054863
epoch:1, idx:500/1275, loss:1.5864739381148667, acc:0.5708582834331337
epoch:1, idx:600/1275, loss:1.5917361690081693, acc:0.5682196339434277
epoch:1, idx:700/1275, loss:1.5813205777833532, acc:0.572039942938659
epoch:1, idx:800/1275, loss:1.597754700576768, acc:0.5624219725343321
epoch:1, idx:900/1275, loss:1.5954834112316603, acc:0.5635405105438401
epoch:1, idx:1000/1275, loss:1.608658006260326, acc:0.560939060939061
epoch:1, idx:1100/1275, loss:1.6022305364071727, acc:0.5626702997275205
epoch:1, idx:1200/1275, loss:1.603012339211225, acc:0.5624479600333055


  1%|          | 100/10845 [00:54<1:35:18,  1.88it/s]

epoch:2, idx:99/10845, loss:1.627691715657711, acc:0.585


  2%|▏         | 200/10845 [01:56<1:28:24,  2.01it/s]

epoch:2, idx:199/10845, loss:1.5700693368911742, acc:0.60625


  3%|▎         | 300/10845 [02:59<2:07:51,  1.37it/s]

epoch:2, idx:299/10845, loss:1.554404634932677, acc:0.6025


  4%|▎         | 400/10845 [03:55<1:39:09,  1.76it/s]

epoch:2, idx:399/10845, loss:1.54267543528229, acc:0.601875


  5%|▍         | 500/10845 [04:58<2:09:37,  1.33it/s]

epoch:2, idx:499/10845, loss:1.5834038315713406, acc:0.594


  6%|▌         | 600/10845 [05:58<1:55:03,  1.48it/s]

epoch:2, idx:599/10845, loss:1.5806845682611068, acc:0.59


  6%|▋         | 700/10845 [06:55<1:31:06,  1.86it/s]

epoch:2, idx:699/10845, loss:1.5622436601136411, acc:0.595


  7%|▋         | 800/10845 [07:57<1:41:10,  1.65it/s]

epoch:2, idx:799/10845, loss:1.565002490207553, acc:0.595


  8%|▊         | 900/10845 [08:57<1:49:30,  1.51it/s]

epoch:2, idx:899/10845, loss:1.5713239905238152, acc:0.5925


  9%|▉         | 1000/10845 [10:00<1:31:02,  1.80it/s]

epoch:2, idx:999/10845, loss:1.5715386326909064, acc:0.59225


  9%|▉         | 1020/10845 [10:11<1:39:25,  1.65it/s]