### 多级推理模块

0. 由于没有self-attention pooling了，所以再加一层self-attention层
1. 每次更新段落的Summary vectors 
    input: [batch_sise, para_num, para_len, dim]
    query: [batch_size, dim]
    
2. expand -> view -> biSeqAtt -> sum
3. ori san

In [1]:
import os
import torch
import torch.nn as nn
import torchtext
from tensorboardX import SummaryWriter
import random
import numpy as np

from torchtext.data import NestedField, Field, RawField
from model import *
from dataset import DataHandler
%load_ext autoreload

%autoreload 2

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
class Config:
    def __init__(self):
        self.hidden = 50
        self.embedding_dim = 300
        self.lr = 5e-4
        self.epochs = 50
        self.fix_length = 256
        
        self.log_dir = './logs'
        self.model_name = 'gan'
        self.batch_size = 4
        #self.train_data = './data/train_filter.pt'
        #self.dev_data = './data/dev_filter.pt'
        self.train_data = './data/train_graph.pt'
        self.dev_data = './data/dev_graph.pt'        
        self.word_vocab = './data/glove_vocab.pt'
        #self.word_vocab = None
        #self.charNGram_vocab = None
        
        self.dropout = 0.2
        self.seed = 1023
        self.steps = 3
        self.memory_type = 1
        
config = Config()
device = torch.device("cuda:0")


In [3]:
torch.cuda.is_available()

True

In [4]:
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)

In [5]:
save_path = config.model_name  + '_lr_'+ str(config.lr)+ '__hidden__' + str(config.hidden) \
            + '_batchsize_' + str(config.batch_size) +  '_p'+ str(config.dropout)+'_steps_'+str(config.steps)+'memory_type_' \
            + str(config.memory_type)
save_path = os.path.join(config.log_dir, save_path)   
print(save_path)
config.save_path = save_path

./logs/gan_lr_0.0005__hidden__50_batchsize_4_p0.2_steps_3memory_type_1


### Define Fileds

In [6]:
word_field = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True) # query
multi_word_field = NestedField(word_field) 

word_field_sup = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True, fix_length=config.fix_length)
multi_word_field_sup = NestedField(word_field_sup) 

charNGram_field = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True) # query
multi_charNGram_field = NestedField(charNGram_field) 

charNGram_field_sup = Field(batch_first=True, sequential=True, tokenize="spacy", lower=True, fix_length=config.fix_length)
multi_charNGram_field_sup = NestedField(charNGram_field_sup) 

raw = RawField()
raw.is_target = False

label_field = Field(sequential=False, is_target=True, use_vocab=False)

dict_field = {
    'id': ('id', raw),
    'supports': ('s_glove', multi_word_field_sup), 
    'query': ('q_glove', word_field), 
    'candidates': ('c_glove', multi_word_field),
    'label': ('label', label_field),
    'mentions': ('mentions', raw),
    'para_label': ('para_label', raw),
    'graph': ('graph', raw)

}

In [7]:
data_handler = DataHandler(config.train_data, config.dev_data, dict_field)

# torch.save(data_handler.trainset.examples, './data/train_example.pt')
# torch.save(data_handler.valset.examples, './data/dev_example.pt')

load examples.pt  :./data/train_graph.pt, ./data/dev_graph.pt


def add_graph(examples):
    for example in tqdm(examples):
        batch_graph = []
        candidate_num = len(example.c_glove)
        support_num = len(example.s_glove)

        mask = torch.zeros(candidate_num, support_num)
        for i in range(len(example.mentions)):
            candidate_mention = example.mentions[i]
            for mention in candidate_mention:
                mask[i][mention[0]] = 1
        graph = torch.zeros(candidate_num, candidate_num)
        for i in range(candidate_num):
            for j in range(i+1,candidate_num):
                graph[i][j] = (mask[j] * mask[i]).sum() > 0
                graph[j][i] = (mask[j] * mask[i]).sum() > 0

        example.graph = graph
    

add_graph(data_handler.valset.examples)
add_graph(data_handler.trainset.examples)

In [8]:
#torch.save(data_handler.valset.examples, './data/dev_graph.pt')
#torch.save(data_handler.trainset.examples, './data/train_graph.pt')  

### Build Vocab

In [9]:
if config.word_vocab is not None:
    word_vocab = torch.load(config.word_vocab)
    multi_word_field_sup.vocab = word_vocab
    word_field_sup.vocab = word_vocab
else:
    multi_word_field_sup.build_vocab(data_handler.trainset, data_handler.valset, 
                                 vectors=torchtext.vocab.GloVe(dim=300,name='6B') )

word_field.vocab = multi_word_field_sup.vocab

In [10]:
print(multi_word_field_sup.vocab.vectors.shape)

torch.Size([312667, 300])


multi_word_field_sup.build_vocab(data_handler.trainset, data_handler.valset, 
                         vectors=torchtext.vocab.GloVe(dim=300,name='840B') )
torch.save(multi_word_field_sup.vocab, './data/glove_vocab.pt')

### Get data_iter

In [11]:
train_iter = data_handler.get_train_iter(batch_size=config.batch_size)
val_iter = data_handler.get_val_iter(batch_size=config.batch_size)

In [12]:
for idx, batch in enumerate(val_iter):
    break
batch


[torchtext.data.batch.Batch of size 4]
	[.id]:['WH_dev_0', 'WH_dev_1', 'WH_dev_2', 'WH_dev_3']
	[.s_glove]:[torch.LongTensor of size 4x15x256]
	[.q_glove]:[torch.LongTensor of size 4x11]
	[.c_glove]:[torch.LongTensor of size 4x18x4]
	[.label]:[torch.LongTensor of size 4]
	[.mentions]:[[[[6, 145, 146], [6, 173, 174], [7, 78, 79]], [[3, 50, 53], [5, 28, 31], [13, 1, 4]], [[6, 135, 136], [6, 218, 219], [6, 261, 262], [8, 45, 46], [12, 98, 99]], [[0, 2, 4], [7, 1, 3], [13, 64, 66], [13, 69, 71]], [[0, 36, 38], [10, 1, 3], [13, 75, 77]], [[0, 14, 15], [1, 63, 64], [1, 138, 139], [1, 186, 187], [1, 238, 239], [2, 128, 129], [9, 8, 9], [9, 43, 44], [10, 19, 20], [10, 34, 35], [11, 37, 38], [11, 79, 80], [13, 56, 57]], [[7, 37, 40]], [[6, 180, 181], [12, 101, 102]], [[12, 96, 97]], [[8, 43, 46]], [[6, 169, 172]], [[6, 179, 181]], [[7, 38, 40]], [[6, 147, 148], [6, 182, 183]], [[6, 171, 172], [9, 6, 7], [11, 35, 36], [11, 121, 122]], [[2, 125, 127], [8, 8, 10], [12, 89, 91]], [[1, 0, 2], [1, 1

### GNN

In [13]:
class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        h = torch.mm(input, self.W)
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

class GAN(nn.Module):
    
    def __init__(self, in_features, out_features, dropout):
        super(GAN, self).__init__()
        self.gan = GraphAttentionLayer(in_features, out_features, dropout,0.1)
          
        
    def forward(self, input, graph):
        
        batch_size = input.shape[0]
        outs = []
        for i in range(batch_size):
            out = self.gan(input[i], graph[i])
            outs.append(out)
        return torch.stack(outs)




In [14]:
def generate_graph(batch):
    candidate_num = batch.c_glove.shape[1]
    support_num = batch.s_glove.shape[1]
    batch_size = batch.c_glove.shape[0]
    batch_graph = torch.zeros(batch_size,candidate_num,candidate_num)
    for i in range(batch_size):
        graph = batch.graph[i]
        n = graph.shape[0]
        batch_graph[i,:n,:n] = graph
    return batch_graph      

class GNN(nn.Module):
    def __init__(self, hidden, dropout=0.2):
        super(GNN, self).__init__()
        self.W_1 = nn.Linear(hidden, hidden)
        self.W_2 = nn.Linear(hidden, hidden)
        self.W_g = nn.Linear(hidden*2, hidden)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, adj):
        '''
        input:
        x :[batch, candidate_num, hidden]
           adj: [batch, candidate_num, candidate_num]
           
        '''
        x1 = self.dropout(self.W_1(x))
        x2 = self.dropout(self.W_2(x))
        
        x2 = torch.bmm(adj, x)
        
        u = x1 + x2
        g = self.W_g(torch.cat([x,u],dim=-1))
        g = torch.sigmoid(g)
        g = self.dropout(g)
        
        x = torch.tanh(u)*g + x * (1-g)
        return x

### Define Model

In [15]:
def generate_mask(x_size, num_turn, dropout_p=0.0, is_training=False):
    if not is_training: dropout_p = 0.0
    new_data = torch.zeros(x_size, num_turn)
    new_data = (1-dropout_p) * (new_data.zero_() + 1)
    for i in range(new_data.size(0)):
        one = random.randint(0, new_data.size(1)-1)
        new_data[i][one] = 1
    mask = 1.0/(1 - dropout_p) * torch.bernoulli(new_data)
    mask.requires_grad = False
    return mask

class SAN(nn.Module):
    def __init__(self, question_dim, support_dim, candidate_dim, num_turn=5, dropout=0.2, memo_dropout=0.4, memory_type=0, gan_dropout=0.5,
                 device=None):
        super(SAN,self).__init__()
        self.qp_bilinear_attention_word = BilinearSeqAttn(support_dim, question_dim, dropout=dropout)
        self.qp_bilinear_attention_para = BilinearSeqAttn(support_dim, question_dim, dropout=dropout)

        self.candidates_scorer = BilinearSeqAttn(candidate_dim, question_dim, dropout=dropout)        
        self.gru = nn.GRUCell(support_dim, question_dim)
        self.gnn = GNN(candidate_dim,gan_dropout)
        
        self.num_turn = num_turn
        
        self.dropout = nn.Dropout(p=dropout)
        self.memo_dropout=memo_dropout
        self.device = device
        self.memory_type = memory_type
        
    def forward(self, question_embedding, para_embedding, candidates_embedding, para_length, graph=None):
        '''
        input:
            question_embedding: [batch_size, hidden_dim]
            para_embedding: [batch_size*para_num, para_length, hidden_dim]
            candidates_embedding: [batch_size, candidates_num, hidden_dim]

        '''
        score_list = []
        batch_size = question_embedding.size(0)
        hidden = question_embedding.size(1)        
        for turn in range(self.num_turn):
            question_embedding_expand = question_embedding.unsqueeze(1).expand(batch_size, para_length, hidden).contiguous()
            question_embedding_expand = question_embedding_expand.view(-1,hidden)    
            
            # update paragraph embedding
            qp_score_word = self.qp_bilinear_attention_word(para_embedding, question_embedding_expand)
            qp_score_word = F.softmax(qp_score_word, 1)
            para_embedding_summary = torch.bmm(qp_score_word.unsqueeze(1), para_embedding).squeeze(1)
            para_embedding_summary = para_embedding_summary.contiguous().view(batch_size, para_length, hidden)
            
            # update question embedding
            qp_score_para = self.qp_bilinear_attention_para(para_embedding_summary, question_embedding)
            qp_score_para = F.softmax(qp_score_para, 1)
            S = torch.bmm(qp_score_para.unsqueeze(1), para_embedding_summary).squeeze(1)
            
            S = self.dropout(S)
            question_embedding = self.gru(S, question_embedding)
            
            # Graph update
            if graph is not None:
                candidates_embedding = self.gnn(candidates_embedding, graph)
            
            # compute candidates score            
            candidates_score = self.candidates_scorer(candidates_embedding, question_embedding)

            score_list.append(candidates_score)
        if self.memory_type == 0:
            mask = generate_mask(batch_size,self.num_turn, self.memo_dropout, self.training)
            mask = mask.to(self.device)
            mask = [m.contiguous() for m in torch.unbind(mask, 1)]

            score_list = [mask[idx].view(batch_size, 1).expand_as(inp) * inp for idx, inp in enumerate(score_list)]
            scores = torch.stack(score_list, 2)
            scores = torch.mean(scores, 2)
        elif self.memory_type == 1:
            scores = torch.stack(score_list, 2)
            scores = torch.mean(scores, 2)
        elif self.memory_type == 2:
            scores = score_list[-1]            
            
        return scores
    
    

In [16]:
class SimpleQANet(nn.Module):
    
    def __init__(self, config, word_vectors, device):
        super(SimpleQANet, self).__init__()
        self.config = config
        self.device = device
        
        self.embedding_layer = EmbeddingLayer(word_vectors)
        

        self.rnn = EncoderRNN(config.embedding_dim, config.hidden, 1, True, True, config.dropout, False)
                
            
        self.co_att = CoAttention(config.hidden*2, att_type=2, dropout=config.dropout)
        
        self.linear_1 = nn.Sequential(
                        nn.Linear(config.hidden*4, config.hidden),
                        nn.ReLU()
                    )        
        self.rnn2 =  EncoderRNN(config.hidden, config.hidden, 1, True, True, config.dropout, False)
        
        self.word_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        self.word_att_q = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        self.pass_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
        
        self.c_att = SelfAttention(config.hidden*2, config.hidden*2, config.dropout)
                
        
        #self.fusion = FusionLayer(config.hidden*2, dropout=config.dropout)
        self.max_pooling = PoolingLayer()     
        
        self.fc = nn.Linear(config.hidden*2, config.hidden*4)
        self.san = SAN(config.hidden*2,config.hidden*2,config.hidden*6, num_turn=config.steps, memory_type=config.memory_type, device=device)
        
        self.to(device)
        
    def get_candidate_vectors(self, batch, support_vectors, device):
        batch_size, candidate_num,_ = batch.c_glove.shape
        _,support_num, support_length = batch.s_glove.shape
        hidden = support_vectors.shape[-1]

        masks = []
        for idx, candidate_mentions in enumerate(batch.mentions):
            mask = torch.zeros(candidate_num, support_num, support_length)
            for i in range(len(candidate_mentions)):
                candidate_mention = candidate_mentions[i]
                for mention in candidate_mention:
                    mask[i][mention[0]][mention[1]:mention[2]] = 1
            masks.append(mask)
        masks = torch.stack(masks).to(device)

        support_vectors = support_vectors.view(batch_size,-1,hidden).unsqueeze(1)

        masks = masks.view(batch_size,candidate_num,-1)
        masks_expand = masks.unsqueeze(-1).expand(batch_size, candidate_num, support_length*support_num, hidden)
        
        candidates = support_vectors * masks_expand
        
        candidates_max = candidates.max(-2)[0]
        candidates_mean = torch.mean(candidates,-2)
        candidates_vectors = torch.cat([candidates_max, candidates_mean],-1)    

        return candidates_vectors        
        
    def forward(self, batch, return_label = True):
        if type(batch.q_glove) is tuple:
            q_glove, _ = batch.q_glove
        else:
            q_glove = batch.q_glove
        s_glove = batch.s_glove
        c_glove = batch.c_glove
        
        q_glove = q_glove.to(self.device)
        s_glove = s_glove.to(self.device)
        c_glove = c_glove.to(self.device)        
        
        q_out = self.embedding_layer(q_glove) # [batch_size,qeustion_length, hidden_dim]
        s_out = self.embedding_layer(s_glove) # [batch_szie, support_num, support_length, hidden_dim]
        c_out = self.embedding_layer(c_glove) # [batch_size, candidates_num, candidates_length, hidden_dim]        
        
        batch_size=  s_out.size(0)
        
        s_len = s_out.size(1)
        c_len = c_out.size(1)
        
        s_word_len = s_out.size(2)
        c_word_len = c_out.size(2)
        
        hidden = s_out.size(-1)
        
        s_out = s_out.view(batch_size*s_len, s_word_len, hidden).contiguous()
        c_out = c_out.view(batch_size*c_len, c_word_len, hidden).contiguous()
        
        q_out = self.rnn(q_out) # [batch_size,qeustion_length, hidden_dim]
        c_out = self.rnn(c_out) # [batch_szie * support_num, support_length, hidden_dim]
        s_out = self.rnn(s_out) # [batch_size * candidates_num, candidates_length, hidden_dim] 
        
        # Attention
        
        q_word_len = q_out.size(1)
        q_out_expand = q_out.unsqueeze(1).expand(batch_size, s_len, q_word_len, q_out.size(-1)).contiguous()
        q_out_expand = q_out_expand.view(batch_size*s_len, q_word_len, q_out.size(-1)).contiguous()
        
        s_out_att, q_out_att = self.co_att(s_out, q_out_expand)
        #S_s = self.fusion(s_out, s_out_att)
        #S_q = self.fusion(q_out, q_out_att)
        
        S_s = self.linear_1(s_out_att)
        S_s = self.rnn2(S_s) # [batch_size * para_num, para_length, hidden*2]
        
        
        candidates_vectors = self.get_candidate_vectors(batch, S_s, self.device)
        question_summary = self.word_att_q(q_out)
        
        
        candidates_summary = self.c_att(c_out)        
        candidates_summary = candidates_summary.view(batch_size, c_len, -1)
        
        candidates_summary = torch.cat([candidates_summary, candidates_vectors],-1)
        
        
        #graph = generate_graph(batch).to(device)
        graph = None
        score = self.san(question_summary, S_s, candidates_summary, s_len, graph)
        
        candidates_mask = ((c_glove > 1).sum(-1) > 0).float().to(device)
        score = score * candidates_mask + (-1e15)*(1-candidates_mask)      
        
        if return_label:
            label = batch.label.to(self.device)
            return score, label
        return score

#### test model

In [17]:
device = torch.device('cuda')
model = SimpleQANet(config, word_field.vocab.vectors, device)
#score, label= model(batch)
#print(score.shape, label.shape)

In [18]:
from tqdm import tqdm, trange

In [19]:
from utils import AverageMeter

def train(epoch, data_iter, model, criterion, optimizer, batch_size=1):
    losses = AverageMeter()
    acces = AverageMeter()
    model.train()
    #model.embedding_layer.eval()
    with trange(len(data_iter)) as t:
        for idx, batch in enumerate(data_iter):
            score, label, = model(batch)

            loss = criterion(score, label)

            loss = loss / batch_size
            loss.backward()
            if (idx+1)%batch_size == 0 :
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)            
                optimizer.step()
                optimizer.zero_grad()        

            losses.update(loss.item()*batch_size)

            pred = score.argmax(1)
            acc = pred.eq(label).sum().item()  / pred.size(0)
            acces.update(acc)
            
            matrix = {
                'acc':acces.avg,
                'epoch':epoch,
                'loss': losses.avg
            }
            t.set_postfix(matrix)
            t.update()
            if (idx+1) % (batch_size*100) == 0:
                print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

def val(epoch, data_iter, model, criterion):
    losses = AverageMeter()
    acces = AverageMeter()
    model.eval()
    for idx, batch in enumerate(data_iter):
        with torch.no_grad():
            score, label = model(batch)
                    
        loss = criterion(score, label)

        losses.update(loss.item())
        
        pred = score.argmax(1)
        acc = pred.eq(label).sum().item()  / pred.size(0)
        acces.update(acc)
        if idx % 100 == 0:
            print(f'epoch:{epoch}, idx:{idx}/{len(data_iter)}, loss:{losses.avg}, acc:{acces.avg}')
    return losses.avg, acces.avg

In [None]:
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                             lr=config.lr)

criterion = nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
#train(0, train_iter, model, criterion, optimizer, batch_size=config.batch_size)
# val(0, val_iter, model,criterion)

In [None]:
cycle_len = 1
cycle_iter = 50

In [None]:
if not os.path.exists(config.save_path):
    os.makedirs(config.save_path)
writer = SummaryWriter(config.save_path)

best_acc = 0.0
for i in range(cycle_len):
    optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                             lr=config.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cycle_iter)
    for epoch in range(cycle_iter):
        scheduler.step()
        train_loss, train_acc = train(epoch, train_iter, model, criterion, optimizer, 1)
        val_loss, val_acc = val(epoch, val_iter, model, criterion)
        global_epoch = cycle_iter * i + epoch + 1
        writer.add_scalar('train_loss', train_loss, global_epoch)
        writer.add_scalar('val_loss', val_loss, global_epoch)
        writer.add_scalar('train_acc', train_acc, global_epoch)
        writer.add_scalar('val_acc', val_acc, global_epoch)

        state = {
            'val_acc': val_acc,
            'train_acc': train_acc,
            'epoch': epoch
            ,
            'model': model.state_dict()
        }
        torch.save(state, os.path.join(config.save_path,'lastest.pth'))
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(state, os.path.join(save_path, f'best_epoch{epoch}.pth'))

  alphas = self.softmax(alphas)  # (bsz, sent_len)
  1%|          | 100/10845 [00:24<42:42,  4.19it/s, acc=0.188, epoch=0, loss=2.44]

epoch:0, idx:99/10845, loss:2.440446113348007, acc:0.1875


  2%|▏         | 200/10845 [00:48<42:47,  4.15it/s, acc=0.251, epoch=0, loss=2.35]

epoch:0, idx:199/10845, loss:2.3479348433017733, acc:0.25125


  3%|▎         | 300/10845 [01:12<43:39,  4.03it/s, acc=0.277, epoch=0, loss=2.29]

epoch:0, idx:299/10845, loss:2.2941722665230433, acc:0.27666666666666667


  4%|▎         | 400/10845 [01:36<41:13,  4.22it/s, acc=0.294, epoch=0, loss=2.24]

epoch:0, idx:399/10845, loss:2.2385188657045365, acc:0.29375


  5%|▍         | 500/10845 [02:00<40:32,  4.25it/s, acc=0.3, epoch=0, loss=2.22]  

epoch:0, idx:499/10845, loss:2.223029584169388, acc:0.3


  6%|▌         | 600/10845 [02:24<39:55,  4.28it/s, acc=0.306, epoch=0, loss=2.19]

epoch:0, idx:599/10845, loss:2.192085688014825, acc:0.30583333333333335


  6%|▋         | 700/10845 [02:49<42:00,  4.03it/s, acc=0.312, epoch=0, loss=2.18]

epoch:0, idx:699/10845, loss:2.179809968812125, acc:0.31214285714285717


  7%|▋         | 800/10845 [03:13<42:10,  3.97it/s, acc=0.317, epoch=0, loss=2.17]

epoch:0, idx:799/10845, loss:2.1671188152208924, acc:0.316875


  8%|▊         | 900/10845 [03:37<40:44,  4.07it/s, acc=0.326, epoch=0, loss=2.14]

epoch:0, idx:899/10845, loss:2.138180994656351, acc:0.3258333333333333


  9%|▉         | 1000/10845 [04:01<37:34,  4.37it/s, acc=0.333, epoch=0, loss=2.12]

epoch:0, idx:999/10845, loss:2.1191528463363647, acc:0.333


 10%|█         | 1100/10845 [04:25<39:54,  4.07it/s, acc=0.337, epoch=0, loss=2.1] 

epoch:0, idx:1099/10845, loss:2.1001409003815867, acc:0.33704545454545454


 11%|█         | 1200/10845 [04:50<38:47,  4.14it/s, acc=0.339, epoch=0, loss=2.1]

epoch:0, idx:1199/10845, loss:2.0968695268407465, acc:0.3385416666666667


 12%|█▏        | 1300/10845 [05:14<36:48,  4.32it/s, acc=0.344, epoch=0, loss=2.08]

epoch:0, idx:1299/10845, loss:2.082605100996219, acc:0.34423076923076923


 13%|█▎        | 1400/10845 [05:38<41:20,  3.81it/s, acc=0.349, epoch=0, loss=2.07]

epoch:0, idx:1399/10845, loss:2.069813695624471, acc:0.34875


 14%|█▍        | 1500/10845 [06:02<37:24,  4.16it/s, acc=0.354, epoch=0, loss=2.06]

epoch:0, idx:1499/10845, loss:2.0574700970550377, acc:0.354


 15%|█▍        | 1600/10845 [06:27<37:28,  4.11it/s, acc=0.356, epoch=0, loss=2.05]

epoch:0, idx:1599/10845, loss:2.045000130040571, acc:0.35625


 16%|█▌        | 1700/10845 [06:52<38:41,  3.94it/s, acc=0.358, epoch=0, loss=2.04]

epoch:0, idx:1699/10845, loss:2.044044786788085, acc:0.35823529411764704


 17%|█▋        | 1800/10845 [07:16<35:18,  4.27it/s, acc=0.361, epoch=0, loss=2.03]

epoch:0, idx:1799/10845, loss:2.032346332594752, acc:0.36125


 18%|█▊        | 1900/10845 [07:40<37:57,  3.93it/s, acc=0.362, epoch=0, loss=2.02]

epoch:0, idx:1899/10845, loss:2.0236138432198447, acc:0.36236842105263156


 18%|█▊        | 2000/10845 [08:05<37:18,  3.95it/s, acc=0.365, epoch=0, loss=2.01]

epoch:0, idx:1999/10845, loss:2.0144051858708263, acc:0.3655


 19%|█▉        | 2100/10845 [08:29<33:37,  4.33it/s, acc=0.367, epoch=0, loss=2.01]

epoch:0, idx:2099/10845, loss:2.0069683878691422, acc:0.36714285714285716


 20%|██        | 2200/10845 [08:53<36:17,  3.97it/s, acc=0.37, epoch=0, loss=2]    

epoch:0, idx:2199/10845, loss:1.9970509629290212, acc:0.37


 21%|██        | 2300/10845 [09:17<33:17,  4.28it/s, acc=0.371, epoch=0, loss=1.99]

epoch:0, idx:2299/10845, loss:1.9878806694111097, acc:0.3714130434782609


 22%|██▏       | 2341/10845 [09:27<33:04,  4.29it/s, acc=0.372, epoch=0, loss=1.98]

In [None]:
score, label= model(batch)
print(score.shape, label.shape)

In [None]:
batch.c_glove

In [None]:
score