# BERT
微调将最后一层的第一个token即[CLS]的隐藏向量作为句子的表示，然后输入到softmax层进行分类。

In [39]:
!pip install transformers



In [40]:
import logging
import random

import numpy as np
import torch
import pandas as pd
from collections import Counter
from transformers import BasicTokenizer
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import f1_score, precision_score, recall_score
import time
from sklearn.metrics import classification_report


logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(levelname)s: %(message)s')

# set seed
seed = 666
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

# set cuda
gpu = 0
use_cuda = gpu >= 0 and torch.cuda.is_available()
if use_cuda:
    torch.cuda.set_device(gpu)
    device = torch.device("cuda", gpu)
else:
    device = torch.device("cpu")
logging.info("Use cuda: %s, gpu id: %d.", use_cuda, gpu)

2020-08-08 08:35:37,329 INFO: Use cuda: True, gpu id: 0.


In [41]:
# split data to 10 fold

def all_data2fold(fold_num, num=10000):
    fold_data = []
    f = pd.read_csv(data_file, sep='\t', encoding='UTF-8')
    texts = f['text'].tolist()[:num]
    labels = f['label'].tolist()[:num]

    total = len(labels)

    index = list(range(total))
    np.random.shuffle(index)

    all_texts = []
    all_labels = []
    for i in index:
        all_texts.append(texts[i])
        all_labels.append(labels[i])

    label2id = {}
    for i in range(total):
        label = str(all_labels[i])
        if label not in label2id:
            label2id[label] = [i]
        else:
            label2id[label].append(i)

    all_index = [[] for _ in range(fold_num)]
    for label, data in label2id.items():
        # print(label, len(data))
        batch_size = int(len(data) / fold_num)
        other = len(data) - batch_size * fold_num
        for i in range(fold_num):
            cur_batch_size = batch_size + 1 if i < other else batch_size
            # print(cur_batch_size)
            batch_data = [data[i * batch_size + b] for b in range(cur_batch_size)]
            all_index[i].extend(batch_data)

    batch_size = int(total / fold_num)
    other_texts = []
    other_labels = []
    other_num = 0
    start = 0
    for fold in range(fold_num):
        num = len(all_index[fold])
        texts = [all_texts[i] for i in all_index[fold]]
        labels = [all_labels[i] for i in all_index[fold]]

        if num > batch_size:
            fold_texts = texts[:batch_size]
            other_texts.extend(texts[batch_size:])
            fold_labels = labels[:batch_size]
            other_labels.extend(labels[batch_size:])
            other_num += num - batch_size
        elif num < batch_size:
            end = start + batch_size - num
            fold_texts = texts + other_texts[start: end]
            fold_labels = labels + other_labels[start: end]
            start = end
        else:
            fold_texts = texts
            fold_labels = labels

        assert batch_size == len(fold_labels)

        # shuffle
        index = list(range(batch_size))
        np.random.shuffle(index)

        shuffle_fold_texts = []
        shuffle_fold_labels = []
        for i in index:
            shuffle_fold_texts.append(fold_texts[i])
            shuffle_fold_labels.append(fold_labels[i])

        data = {'label': shuffle_fold_labels, 'text': shuffle_fold_texts}
        fold_data.append(data)

    logging.info("Fold lens %s", str([len(data['label']) for data in fold_data]))

    return fold_data

In [42]:
# build vocab
basic_tokenizer = BasicTokenizer()

class Vocab():
    def __init__(self, train_data):
        self.min_count = 5
        self.pad = 0
        self.unk = 1
        self._id2word = ['[PAD]', '[UNK]']
        self._id2extword = ['[PAD]', '[UNK]']

        self._id2label = []
        self.target_names = []

        self.build_vocab(train_data)

        reverse = lambda x: dict(zip(x, range(len(x))))
        self._word2id = reverse(self._id2word)
        self._label2id = reverse(self._id2label)

        logging.info("Build vocab: words %d, labels %d." % (self.word_size, self.label_size))

    def build_vocab(self, data):
        self.word_counter = Counter()

        for text in data['text']:
            words = text.split()
            for word in words:
                self.word_counter[word] += 1

        for word, count in self.word_counter.most_common():
            if count >= self.min_count:
                self._id2word.append(word)

        label2name = {0: '科技', 1: '股票', 2: '体育', 3: '娱乐', 4: '时政', 5: '社会', 6: '教育', 7: '财经',
                      8: '家居', 9: '游戏', 10: '房产', 11: '时尚', 12: '彩票', 13: '星座'}

        self.label_counter = Counter(data['label'])

        for label in range(len(self.label_counter)):
            count = self.label_counter[label]
            self._id2label.append(label)
            self.target_names.append(label2name[label])

    def load_pretrained_embs(self, embfile):
        with open(embfile, encoding='utf-8') as f:
            lines = f.readlines()
            items = lines[0].split()
            word_count, embedding_dim = int(items[0]), int(items[1])

        index = len(self._id2extword)
        embeddings = np.zeros((word_count + index, embedding_dim))
        for line in lines[1:]:
            values = line.split()
            self._id2extword.append(values[0])
            vector = np.array(values[1:], dtype='float64')
            embeddings[self.unk] += vector
            embeddings[index] = vector
            index += 1

        embeddings[self.unk] = embeddings[self.unk] / word_count
        embeddings = embeddings / np.std(embeddings)

        reverse = lambda x: dict(zip(x, range(len(x))))
        self._extword2id = reverse(self._id2extword)

        assert len(set(self._id2extword)) == len(self._id2extword)

        return embeddings

    def word2id(self, xs):
        if isinstance(xs, list):
            return [self._word2id.get(x, self.unk) for x in xs]
        return self._word2id.get(xs, self.unk)

    def extword2id(self, xs):
        if isinstance(xs, list):
            return [self._extword2id.get(x, self.unk) for x in xs]
        return self._extword2id.get(xs, self.unk)

    def label2id(self, xs):
        if isinstance(xs, list):
            return [self._label2id.get(x, self.unk) for x in xs]
        return self._label2id.get(xs, self.unk)

    @property
    def word_size(self):
        return len(self._id2word)

    @property
    def extword_size(self):
        return len(self._id2extword)

    @property
    def label_size(self):
        return len(self._id2label)

In [43]:
# build module

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.weight.data.normal_(mean=0.0, std=0.05)

        self.bias = nn.Parameter(torch.Tensor(hidden_size))
        b = np.zeros(hidden_size, dtype=np.float32)
        self.bias.data.copy_(torch.from_numpy(b))

        self.query = nn.Parameter(torch.Tensor(hidden_size))
        self.query.data.normal_(mean=0.0, std=0.05)

    def forward(self, batch_hidden, batch_masks):
        # batch_hidden: b x len x hidden_size (2 * hidden_size of lstm)
        # batch_masks:  b x len

        # linear
        key = torch.matmul(batch_hidden, self.weight) + self.bias  # b x len x hidden

        # compute attention
        outputs = torch.matmul(key, self.query)  # b x len

        masked_outputs = outputs.masked_fill((1 - batch_masks).bool(), float(-1e32))

        attn_scores = F.softmax(masked_outputs, dim=1)  # b x len

        # 对于全零向量，-1e32的结果为 1/len, -inf为nan, 额外补0
        masked_attn_scores = attn_scores.masked_fill((1 - batch_masks).bool(), 0.0)

        # sum weighted sources
        batch_outputs = torch.bmm(masked_attn_scores.unsqueeze(1), key).squeeze(1)  # b x hidden

        return batch_outputs, attn_scores


class WordBertEncoder(nn.Module):
    def __init__(self):
        super(WordBertEncoder, self).__init__()
        self.dropout = nn.Dropout(dropout)

        self.tokenizer = WhitespaceTokenizer()
        self.bert = BertModel.from_pretrained(bert_path)

        self.pooled = False
        logging.info('Build Bert encoder with pooled {}.'.format(self.pooled))

    def encode(self, tokens):
        tokens = self.tokenizer.tokenize(tokens)
        return tokens

    def get_bert_parameters(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in self.bert.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01},
            {'params': [p for n, p in self.bert.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]
        return optimizer_parameters
    #微调将最后一层的第一个token即[CLS]的隐藏向量作为句子的表示，然后输入到softmax层进行分类。
    def forward(self, input_ids, token_type_ids):
        # input_ids: sen_num x bert_len
        # token_type_ids: sen_num  x bert_len

        # sen_num x bert_len x 256, sen_num x 256
        sequence_output, pooled_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids)

        if self.pooled:
            reps = pooled_output
        else:
            reps = sequence_output[:, 0, :]  # sen_num x 256

        if self.training:
            reps = self.dropout(reps)

        return reps


class WhitespaceTokenizer():
    """WhitespaceTokenizer with vocab."""

    def __init__(self):
        vocab_file = bert_path + 'vocab.txt'
        self._token2id = self.load_vocab(vocab_file)
        self._id2token = {v: k for k, v in self._token2id.items()}
        self.max_len = 256
        self.unk = 1

        logging.info("Build Bert vocab with size %d." % (self.vocab_size))

    def load_vocab(self, vocab_file):
        f = open(vocab_file, 'r')
        lines = f.readlines()
        lines = list(map(lambda x: x.strip(), lines))
        vocab = dict(zip(lines, range(len(lines))))
        return vocab

    def tokenize(self, tokens):
        assert len(tokens) <= self.max_len - 2
        tokens = ["[CLS]"] + tokens + ["[SEP]"]
        output_tokens = self.token2id(tokens)
        return output_tokens

    def token2id(self, xs):
        if isinstance(xs, list):
            return [self._token2id.get(x, self.unk) for x in xs]
        return self._token2id.get(xs, self.unk)

    @property
    def vocab_size(self):
        return len(self._id2token)


class SentEncoder(nn.Module):
    def __init__(self, sent_rep_size):
        super(SentEncoder, self).__init__()
        self.dropout = nn.Dropout(dropout)

        self.sent_lstm = nn.LSTM(
            input_size=sent_rep_size,
            hidden_size=sent_hidden_size,
            num_layers=sent_num_layers,
            batch_first=True,
            bidirectional=True
        )

    def forward(self, sent_reps, sent_masks):
        # sent_reps:  b x doc_len x sent_rep_size
        # sent_masks: b x doc_len

        sent_hiddens, _ = self.sent_lstm(sent_reps)  # b x doc_len x hidden*2
        sent_hiddens = sent_hiddens * sent_masks.unsqueeze(2)

        if self.training:
            sent_hiddens = self.dropout(sent_hiddens)

        return sent_hiddens

In [44]:
# build model
class Model(nn.Module):
    def __init__(self, vocab):
        super(Model, self).__init__()
        self.sent_rep_size = 256
        self.doc_rep_size = sent_hidden_size * 2
        self.all_parameters = {}
        parameters = []
        self.word_encoder = WordBertEncoder()
        bert_parameters = self.word_encoder.get_bert_parameters()

        self.sent_encoder = SentEncoder(self.sent_rep_size)
        self.sent_attention = Attention(self.doc_rep_size)
        parameters.extend(list(filter(lambda p: p.requires_grad, self.sent_encoder.parameters())))
        parameters.extend(list(filter(lambda p: p.requires_grad, self.sent_attention.parameters())))

        self.out = nn.Linear(self.doc_rep_size, vocab.label_size, bias=True)
        parameters.extend(list(filter(lambda p: p.requires_grad, self.out.parameters())))

        if use_cuda:
            self.to(device)

        if len(parameters) > 0:
            self.all_parameters["basic_parameters"] = parameters
        self.all_parameters["bert_parameters"] = bert_parameters

        logging.info('Build model with bert word encoder, lstm sent encoder.')

        para_num = sum([np.prod(list(p.size())) for p in self.parameters()])
        logging.info('Model param num: %.2f M.' % (para_num / 1e6))

    def forward(self, batch_inputs):
        # batch_inputs(batch_inputs1, batch_inputs2): b x doc_len x sent_len
        # batch_masks : b x doc_len x sent_len
        batch_inputs1, batch_inputs2, batch_masks = batch_inputs
        batch_size, max_doc_len, max_sent_len = batch_inputs1.shape[0], batch_inputs1.shape[1], batch_inputs1.shape[2]
        batch_inputs1 = batch_inputs1.view(batch_size * max_doc_len, max_sent_len)  # sen_num x sent_len
        batch_inputs2 = batch_inputs2.view(batch_size * max_doc_len, max_sent_len)  # sen_num x sent_len
        batch_masks = batch_masks.view(batch_size * max_doc_len, max_sent_len)  # sen_num x sent_len

        sent_reps = self.word_encoder(batch_inputs1, batch_inputs2)  # sen_num x sent_rep_size

        sent_reps = sent_reps.view(batch_size, max_doc_len, self.sent_rep_size)  # b x doc_len x sent_rep_size
        batch_masks = batch_masks.view(batch_size, max_doc_len, max_sent_len)  # b x doc_len x max_sent_len
        sent_masks = batch_masks.bool().any(2).float()  # b x doc_len

        sent_hiddens = self.sent_encoder(sent_reps, sent_masks)  # b x doc_len x doc_rep_size
        doc_reps, atten_scores = self.sent_attention(sent_hiddens, sent_masks)  # b x doc_rep_size

        batch_outputs = self.out(doc_reps)  # b x num_labels

        return batch_outputs

In [45]:

class Optimizer:
    def __init__(self, model_parameters, steps):
        self.all_params = []
        self.optims = []
        self.schedulers = []

        for name, parameters in model_parameters.items():
            if name.startswith("basic"):
                optim = torch.optim.Adam(parameters, lr=learning_rate)
                self.optims.append(optim)

                l = lambda step: decay ** (step // decay_step)
                scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=l)
                self.schedulers.append(scheduler)
                self.all_params.extend(parameters)
            elif name.startswith("bert"):
                optim_bert = AdamW(parameters, bert_lr, eps=1e-8)
                self.optims.append(optim_bert)

                scheduler_bert = get_linear_schedule_with_warmup(optim_bert, 0, steps)
                self.schedulers.append(scheduler_bert)

                for group in parameters:
                    for p in group['params']:
                        self.all_params.append(p)
            else:
                Exception("no nameed parameters.")

        self.num = len(self.optims)

    def step(self):
        for optim, scheduler in zip(self.optims, self.schedulers):
            optim.step()
            scheduler.step()
            optim.zero_grad()

    def zero_grad(self):
        for optim in self.optims:
            optim.zero_grad()

    def get_lr(self):
        lrs = tuple(map(lambda x: x.get_lr()[-1], self.schedulers))
        lr = ' %.5f' * self.num
        res = lr % lrs
        return res

In [46]:
# build dataset
def sentence_split(text, vocab, max_sent_len=256, max_segment=16):
    words = text.strip().split()
    document_len = len(words)

    index = list(range(0, document_len, max_sent_len))
    index.append(document_len)

    segments = []
    for i in range(len(index) - 1):
        segment = words[index[i]: index[i + 1]]
        assert len(segment) > 0
        segment = [word if word in vocab._id2word else '<UNK>' for word in segment]
        segments.append([len(segment), segment])

    assert len(segments) > 0
    if len(segments) > max_segment:
        segment_ = int(max_segment / 2)
        return segments[:segment_] + segments[-segment_:]
    else:
        return segments


def get_examples(data, word_encoder, vocab, max_sent_len=256, max_segment=8):
    label2id = vocab.label2id
    examples = []

    for text, label in zip(data['text'], data['label']):
        # label
        id = label2id(label)

        # words
        sents_words = sentence_split(text, vocab, max_sent_len-2, max_segment)
        doc = []
        for sent_len, sent_words in sents_words:
            token_ids = word_encoder.encode(sent_words)
            sent_len = len(token_ids)
            token_type_ids = [0] * sent_len
            doc.append([sent_len, token_ids, token_type_ids])
        examples.append([id, len(doc), doc])

    logging.info('Total %d docs.' % len(examples))
    return examples

In [47]:
# build loader

def batch_slice(data, batch_size):
    batch_num = int(np.ceil(len(data) / float(batch_size)))
    for i in range(batch_num):
        cur_batch_size = batch_size if i < batch_num - 1 else len(data) - batch_size * i
        docs = [data[i * batch_size + b] for b in range(cur_batch_size)]

        yield docs


def data_iter(data, batch_size, shuffle=True, noise=1.0):
    """
    randomly permute data, then sort by source length, and partition into batches
    ensure that the length of  sentences in each batch
    """

    batched_data = []
    if shuffle:
        np.random.shuffle(data)

        lengths = [example[1] for example in data]
        noisy_lengths = [- (l + np.random.uniform(- noise, noise)) for l in lengths]
        sorted_indices = np.argsort(noisy_lengths).tolist()
        sorted_data = [data[i] for i in sorted_indices]
    else:
        sorted_data =data
        
    batched_data.extend(list(batch_slice(sorted_data, batch_size)))

    if shuffle:
        np.random.shuffle(batched_data)

    for batch in batched_data:
        yield batch

In [48]:
# some function

def get_score(y_ture, y_pred):
    y_ture = np.array(y_ture)
    y_pred = np.array(y_pred)
    f1 = f1_score(y_ture, y_pred, average='macro') * 100
    p = precision_score(y_ture, y_pred, average='macro') * 100
    r = recall_score(y_ture, y_pred, average='macro') * 100

    return str((reformat(p, 2), reformat(r, 2), reformat(f1, 2))), reformat(f1, 2)


def reformat(num, n):
    return float(format(num, '0.' + str(n) + 'f'))

In [49]:
# build trainer

class Trainer():
    def __init__(self, model, vocab):
        self.model = model
        self.report = True
        
        self.train_data = get_examples(train_data, model.word_encoder, vocab)
        self.batch_num = int(np.ceil(len(self.train_data) / float(train_batch_size)))
        self.dev_data = get_examples(dev_data, model.word_encoder, vocab)
        self.test_data = get_examples(test_data, model.word_encoder, vocab)

        # criterion
        self.criterion = nn.CrossEntropyLoss()

        # label name
        self.target_names = vocab.target_names

        # optimizer
        self.optimizer = Optimizer(model.all_parameters, steps=self.batch_num * epochs)

        # count
        self.step = 0
        self.early_stop = -1
        self.best_train_f1, self.best_dev_f1 = 0, 0
        self.last_epoch = epochs

    def train(self):
        logging.info('Start training...')
        for epoch in range(1, epochs + 1):
            train_f1 = self._train(epoch)

            dev_f1 = self._eval(epoch)

            if self.best_dev_f1 <= dev_f1:
                logging.info(
                    "Exceed history dev = %.2f, current dev = %.2f" % (self.best_dev_f1, dev_f1))
                torch.save(self.model.state_dict(), save_model)

                self.best_train_f1 = train_f1
                self.best_dev_f1 = dev_f1
                self.early_stop = 0
            else:
                self.early_stop += 1
                if self.early_stop == early_stops:
                    logging.info(
                        "Eearly stop in epoch %d, best train: %.2f, dev: %.2f" % (
                            epoch - early_stops, self.best_train_f1, self.best_dev_f1))
                    self.last_epoch = epoch
                    break
    def test(self):
        self.model.load_state_dict(torch.load(save_model))
        self._eval(self.last_epoch + 1, test=True)

    def _train(self, epoch):
        self.optimizer.zero_grad()
        self.model.train()

        start_time = time.time()
        epoch_start_time = time.time()
        overall_losses = 0
        losses = 0
        batch_idx = 1
        y_pred = []
        y_true = []
        for batch_data in data_iter(self.train_data, train_batch_size, shuffle=True):
            torch.cuda.empty_cache()
            batch_inputs, batch_labels = self.batch2tensor(batch_data)
            batch_outputs = self.model(batch_inputs)
            loss = self.criterion(batch_outputs, batch_labels)
            loss.backward()

            loss_value = loss.detach().cpu().item()
            losses += loss_value
            overall_losses += loss_value

            y_pred.extend(torch.max(batch_outputs, dim=1)[1].cpu().numpy().tolist())
            y_true.extend(batch_labels.cpu().numpy().tolist())

            nn.utils.clip_grad_norm_(self.optimizer.all_params, max_norm=clip)
            for optimizer, scheduler in zip(self.optimizer.optims, self.optimizer.schedulers):
                optimizer.step()
                scheduler.step()
            self.optimizer.zero_grad()

            self.step += 1

            if batch_idx % log_interval == 0:
                elapsed = time.time() - start_time

                lrs = self.optimizer.get_lr()
                logging.info(
                    '| epoch {:3d} | step {:3d} | batch {:3d}/{:3d} | lr{} | loss {:.4f} | s/batch {:.2f}'.format(
                        epoch, self.step, batch_idx, self.batch_num, lrs,
                        losses / log_interval,
                        elapsed / log_interval))

                losses = 0
                start_time = time.time()

            batch_idx += 1

        overall_losses /= self.batch_num
        during_time = time.time() - epoch_start_time

        # reformat
        overall_losses = reformat(overall_losses, 4)
        score, f1 = get_score(y_true, y_pred)

        logging.info(
            '| epoch {:3d} | score {} | f1 {} | loss {:.4f} | time {:.2f}'.format(epoch, score, f1,
                                                                                  overall_losses,
                                                                                  during_time))
        if set(y_true) == set(y_pred) and self.report:
            report = classification_report(y_true, y_pred, digits=4, target_names=self.target_names)
            logging.info('\n' + report)

        return f1

    def _eval(self, epoch, test=False):
        self.model.eval()
        start_time = time.time()
        data = self.test_data if test else self.dev_data
        y_pred = []
        y_true = []
        with torch.no_grad():
            for batch_data in data_iter(data, test_batch_size, shuffle=False):
                torch.cuda.empty_cache()
                batch_inputs, batch_labels = self.batch2tensor(batch_data)
                batch_outputs = self.model(batch_inputs)
                y_pred.extend(torch.max(batch_outputs, dim=1)[1].cpu().numpy().tolist())
                y_true.extend(batch_labels.cpu().numpy().tolist())

            score, f1 = get_score(y_true, y_pred)

            during_time = time.time() - start_time
            
            if test:
                df = pd.DataFrame({'label': y_pred})
                df.to_csv(save_test, index=False, sep=',')
            else:
                logging.info(
                    '| epoch {:3d} | dev | score {} | f1 {} | time {:.2f}'.format(epoch, score, f1,
                                                                              during_time))
                if set(y_true) == set(y_pred) and self.report:
                    report = classification_report(y_true, y_pred, digits=4, target_names=self.target_names)
                    logging.info('\n' + report)

        return f1

    def batch2tensor(self, batch_data):
        '''
            [[label, doc_len, [[sent_len, [sent_id0, ...], [sent_id1, ...]], ...]]
        '''
        batch_size = len(batch_data)
        doc_labels = []
        doc_lens = []
        doc_max_sent_len = []
        for doc_data in batch_data:
            doc_labels.append(doc_data[0])
            doc_lens.append(doc_data[1])
            sent_lens = [sent_data[0] for sent_data in doc_data[2]]
            max_sent_len = max(sent_lens)
            doc_max_sent_len.append(max_sent_len)

        max_doc_len = max(doc_lens)
        max_sent_len = max(doc_max_sent_len)

        batch_inputs1 = torch.zeros((batch_size, max_doc_len, max_sent_len), dtype=torch.int64)
        batch_inputs2 = torch.zeros((batch_size, max_doc_len, max_sent_len), dtype=torch.int64)
        batch_masks = torch.zeros((batch_size, max_doc_len, max_sent_len), dtype=torch.float32)
        batch_labels = torch.LongTensor(doc_labels)

        for b in range(batch_size):
            for sent_idx in range(doc_lens[b]):
                sent_data = batch_data[b][2][sent_idx]
                for word_idx in range(sent_data[0]):
                    batch_inputs1[b, sent_idx, word_idx] = sent_data[1][word_idx]
                    batch_inputs2[b, sent_idx, word_idx] = sent_data[2][word_idx]
                    batch_masks[b, sent_idx, word_idx] = 1

        if use_cuda:
            batch_inputs1 = batch_inputs1.to(device)
            batch_inputs2 = batch_inputs2.to(device)
            batch_masks = batch_masks.to(device)
            batch_labels = batch_labels.to(device)

        return (batch_inputs1, batch_inputs2, batch_masks), batch_labels

## 初始参数设定

In [52]:
fold_id = 9
fold_num = 10
data_file = '../input/train_set.csv'
test_data_file = '../input/test_a.csv'

save_model = '../output/bert20200806.bin'
save_test = '../output/bertNext.csv'

# build word encoder
bert_path = '../emb/bert-mini/'

# build sent encoder
sent_hidden_size = 256
sent_num_layers = 2

In [53]:
fold_data = all_data2fold(10,200000)

2020-08-08 08:36:24,352 INFO: Fold lens [20000, 20000, 20000, 20000, 20000, 20000, 20000, 20000, 20000, 20000]


# build train, dev, test data

In [54]:
# dev
dev_data = fold_data[fold_id]

# train
train_texts = []
train_labels = []
for i in range(0, fold_id):
    data = fold_data[i]
    train_texts.extend(data['text'])
    train_labels.extend(data['label'])

train_data = {'label': train_labels, 'text': train_texts}

# test
f = pd.read_csv(test_data_file, sep='\t', encoding='UTF-8')
texts = f['text'].tolist()
test_data = {'label': [0] * len(texts), 'text': texts}

In [55]:
vocab = Vocab(train_data)

2020-08-08 08:37:17,457 INFO: Build vocab: words 5983, labels 14.


In [56]:
model = Model(vocab)

2020-08-08 08:37:17,473 INFO: Build Bert vocab with size 5981.
2020-08-08 08:37:17,474 INFO: loading configuration file ../emb/bert-mini/config.json
2020-08-08 08:37:17,481 INFO: Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 256,
  "model_type": "bert",
  "num_attention_heads": 4,
  "num_hidden_layers": 4,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 5981
}

2020-08-08 08:37:17,482 INFO: loading weights file ../emb/bert-mini/pytorch_model.bin
2020-08-08 08:37:17,570 INFO: All model checkpoint weights were used when initializing BertModel.

2020-08-08 08:37:17,571 INFO: All the weights of BertModel were initialized from the model checkpoint at ../emb/bert-mini/.
If your task is similar to the task the model of the ckeckpoint was tr

## 设置超参

In [57]:
# build optimizer
learning_rate = 2e-4
dropout = 0.15
bert_lr = 5e-5
decay = .75
decay_step = 1000
clip = 5.0
epochs = 4
early_stops = 3
log_interval = 50
test_batch_size = 16
train_batch_size = 16

In [58]:
# train
trainer = Trainer(model, vocab)

2020-08-08 08:53:56,859 INFO: Total 180000 docs.
2020-08-08 08:55:41,583 INFO: Total 20000 docs.
2020-08-08 09:00:04,540 INFO: Total 50000 docs.


## 目标使loss尽量最小之后，再进行预测

In [59]:
trainer.train()

2020-08-08 09:19:00,651 INFO: Start training...
2020-08-08 09:19:36,874 INFO: | epoch   1 | step  50 | batch  50/11250 | lr 0.00020 0.00005 | loss 2.2175 | s/batch 0.72
2020-08-08 09:20:11,199 INFO: | epoch   1 | step 100 | batch 100/11250 | lr 0.00020 0.00005 | loss 1.4541 | s/batch 0.69
2020-08-08 09:20:48,323 INFO: | epoch   1 | step 150 | batch 150/11250 | lr 0.00020 0.00005 | loss 0.9852 | s/batch 0.74
2020-08-08 09:21:19,674 INFO: | epoch   1 | step 200 | batch 200/11250 | lr 0.00020 0.00005 | loss 0.8033 | s/batch 0.63
2020-08-08 09:21:54,750 INFO: | epoch   1 | step 250 | batch 250/11250 | lr 0.00020 0.00005 | loss 0.7683 | s/batch 0.70
2020-08-08 09:22:29,364 INFO: | epoch   1 | step 300 | batch 300/11250 | lr 0.00020 0.00005 | loss 0.6919 | s/batch 0.69
2020-08-08 09:23:01,724 INFO: | epoch   1 | step 350 | batch 350/11250 | lr 0.00020 0.00005 | loss 0.6007 | s/batch 0.65
2020-08-08 09:23:34,797 INFO: | epoch   1 | step 400 | batch 400/11250 | lr 0.00020 0.00005 | loss 0.5645

2020-08-08 09:56:51,800 INFO: | epoch   1 | step 3400 | batch 3400/11250 | lr 0.00008 0.00005 | loss 0.3035 | s/batch 0.74
2020-08-08 09:57:22,568 INFO: | epoch   1 | step 3450 | batch 3450/11250 | lr 0.00008 0.00005 | loss 0.3224 | s/batch 0.62
2020-08-08 09:57:56,169 INFO: | epoch   1 | step 3500 | batch 3500/11250 | lr 0.00008 0.00005 | loss 0.3223 | s/batch 0.67
2020-08-08 09:58:29,362 INFO: | epoch   1 | step 3550 | batch 3550/11250 | lr 0.00008 0.00005 | loss 0.2747 | s/batch 0.66
2020-08-08 09:59:02,281 INFO: | epoch   1 | step 3600 | batch 3600/11250 | lr 0.00008 0.00005 | loss 0.2617 | s/batch 0.66
2020-08-08 09:59:39,095 INFO: | epoch   1 | step 3650 | batch 3650/11250 | lr 0.00008 0.00005 | loss 0.3114 | s/batch 0.74
2020-08-08 10:00:11,406 INFO: | epoch   1 | step 3700 | batch 3700/11250 | lr 0.00008 0.00005 | loss 0.3060 | s/batch 0.65
2020-08-08 10:00:48,248 INFO: | epoch   1 | step 3750 | batch 3750/11250 | lr 0.00008 0.00005 | loss 0.2061 | s/batch 0.74
2020-08-08 10:01

2020-08-08 10:35:00,479 INFO: | epoch   1 | step 6750 | batch 6750/11250 | lr 0.00004 0.00004 | loss 0.2656 | s/batch 0.73
2020-08-08 10:35:36,138 INFO: | epoch   1 | step 6800 | batch 6800/11250 | lr 0.00004 0.00004 | loss 0.2303 | s/batch 0.71
2020-08-08 10:36:06,908 INFO: | epoch   1 | step 6850 | batch 6850/11250 | lr 0.00004 0.00004 | loss 0.2332 | s/batch 0.62
2020-08-08 10:36:41,873 INFO: | epoch   1 | step 6900 | batch 6900/11250 | lr 0.00004 0.00004 | loss 0.2591 | s/batch 0.70
2020-08-08 10:37:16,577 INFO: | epoch   1 | step 6950 | batch 6950/11250 | lr 0.00004 0.00004 | loss 0.2011 | s/batch 0.69
2020-08-08 10:37:52,017 INFO: | epoch   1 | step 7000 | batch 7000/11250 | lr 0.00003 0.00004 | loss 0.2738 | s/batch 0.71
2020-08-08 10:38:29,650 INFO: | epoch   1 | step 7050 | batch 7050/11250 | lr 0.00003 0.00004 | loss 0.3265 | s/batch 0.75
2020-08-08 10:39:04,229 INFO: | epoch   1 | step 7100 | batch 7100/11250 | lr 0.00003 0.00004 | loss 0.2068 | s/batch 0.69
2020-08-08 10:39

2020-08-08 11:14:18,045 INFO: | epoch   1 | step 10100 | batch 10100/11250 | lr 0.00001 0.00004 | loss 0.1716 | s/batch 0.71
2020-08-08 11:15:00,032 INFO: | epoch   1 | step 10150 | batch 10150/11250 | lr 0.00001 0.00004 | loss 0.1905 | s/batch 0.84
2020-08-08 11:15:37,163 INFO: | epoch   1 | step 10200 | batch 10200/11250 | lr 0.00001 0.00004 | loss 0.2151 | s/batch 0.74
2020-08-08 11:16:11,075 INFO: | epoch   1 | step 10250 | batch 10250/11250 | lr 0.00001 0.00004 | loss 0.2177 | s/batch 0.68
2020-08-08 11:16:42,211 INFO: | epoch   1 | step 10300 | batch 10300/11250 | lr 0.00001 0.00004 | loss 0.1829 | s/batch 0.62
2020-08-08 11:17:17,519 INFO: | epoch   1 | step 10350 | batch 10350/11250 | lr 0.00001 0.00004 | loss 0.1948 | s/batch 0.71
2020-08-08 11:17:53,327 INFO: | epoch   1 | step 10400 | batch 10400/11250 | lr 0.00001 0.00004 | loss 0.1886 | s/batch 0.72
2020-08-08 11:18:30,465 INFO: | epoch   1 | step 10450 | batch 10450/11250 | lr 0.00001 0.00004 | loss 0.1998 | s/batch 0.74


2020-08-08 11:53:41,052 INFO: | epoch   2 | step 12500 | batch 1250/11250 | lr 0.00001 0.00004 | loss 0.1173 | s/batch 0.71
2020-08-08 11:54:12,962 INFO: | epoch   2 | step 12550 | batch 1300/11250 | lr 0.00001 0.00004 | loss 0.1595 | s/batch 0.64
2020-08-08 11:54:43,536 INFO: | epoch   2 | step 12600 | batch 1350/11250 | lr 0.00001 0.00004 | loss 0.1838 | s/batch 0.61
2020-08-08 11:55:16,738 INFO: | epoch   2 | step 12650 | batch 1400/11250 | lr 0.00001 0.00004 | loss 0.1821 | s/batch 0.66
2020-08-08 11:55:49,980 INFO: | epoch   2 | step 12700 | batch 1450/11250 | lr 0.00001 0.00004 | loss 0.1403 | s/batch 0.66
2020-08-08 11:56:23,231 INFO: | epoch   2 | step 12750 | batch 1500/11250 | lr 0.00001 0.00004 | loss 0.2048 | s/batch 0.66
2020-08-08 11:56:55,639 INFO: | epoch   2 | step 12800 | batch 1550/11250 | lr 0.00001 0.00004 | loss 0.1490 | s/batch 0.65
2020-08-08 11:57:28,270 INFO: | epoch   2 | step 12850 | batch 1600/11250 | lr 0.00001 0.00004 | loss 0.1677 | s/batch 0.65
2020-08-

2020-08-08 12:32:05,499 INFO: | epoch   2 | step 15850 | batch 4600/11250 | lr 0.00000 0.00003 | loss 0.1531 | s/batch 0.56
2020-08-08 12:32:41,147 INFO: | epoch   2 | step 15900 | batch 4650/11250 | lr 0.00000 0.00003 | loss 0.1365 | s/batch 0.71
2020-08-08 12:33:14,478 INFO: | epoch   2 | step 15950 | batch 4700/11250 | lr 0.00000 0.00003 | loss 0.1458 | s/batch 0.67
2020-08-08 12:33:47,215 INFO: | epoch   2 | step 16000 | batch 4750/11250 | lr 0.00000 0.00003 | loss 0.1730 | s/batch 0.65
2020-08-08 12:34:22,477 INFO: | epoch   2 | step 16050 | batch 4800/11250 | lr 0.00000 0.00003 | loss 0.2174 | s/batch 0.71
2020-08-08 12:34:59,335 INFO: | epoch   2 | step 16100 | batch 4850/11250 | lr 0.00000 0.00003 | loss 0.1621 | s/batch 0.74
2020-08-08 12:35:35,361 INFO: | epoch   2 | step 16150 | batch 4900/11250 | lr 0.00000 0.00003 | loss 0.1914 | s/batch 0.72
2020-08-08 12:36:08,901 INFO: | epoch   2 | step 16200 | batch 4950/11250 | lr 0.00000 0.00003 | loss 0.1322 | s/batch 0.67
2020-08-

2020-08-08 13:09:44,508 INFO: | epoch   2 | step 19200 | batch 7950/11250 | lr 0.00000 0.00003 | loss 0.1627 | s/batch 0.62
2020-08-08 13:10:16,253 INFO: | epoch   2 | step 19250 | batch 8000/11250 | lr 0.00000 0.00003 | loss 0.1674 | s/batch 0.63
2020-08-08 13:10:46,054 INFO: | epoch   2 | step 19300 | batch 8050/11250 | lr 0.00000 0.00003 | loss 0.2009 | s/batch 0.60
2020-08-08 13:11:16,993 INFO: | epoch   2 | step 19350 | batch 8100/11250 | lr 0.00000 0.00003 | loss 0.1391 | s/batch 0.62
2020-08-08 13:11:48,850 INFO: | epoch   2 | step 19400 | batch 8150/11250 | lr 0.00000 0.00003 | loss 0.1686 | s/batch 0.64
2020-08-08 13:12:21,127 INFO: | epoch   2 | step 19450 | batch 8200/11250 | lr 0.00000 0.00003 | loss 0.1661 | s/batch 0.65
2020-08-08 13:12:53,007 INFO: | epoch   2 | step 19500 | batch 8250/11250 | lr 0.00000 0.00003 | loss 0.1213 | s/batch 0.64
2020-08-08 13:13:27,165 INFO: | epoch   2 | step 19550 | batch 8300/11250 | lr 0.00000 0.00003 | loss 0.1853 | s/batch 0.68
2020-08-

2020-08-08 13:47:36,257 INFO: | epoch   2 | step 22500 | batch 11250/11250 | lr 0.00000 0.00003 | loss 0.1674 | s/batch 0.77
2020-08-08 13:47:36,542 INFO: | epoch   2 | score (94.13, 93.72, 93.92) | f1 93.92 | loss 0.1658 | time 7745.99
2020-08-08 13:47:36,924 INFO: 
              precision    recall  f1-score   support

          科技     0.9500    0.9501    0.9501     35027
          股票     0.9528    0.9559    0.9544     33251
          体育     0.9879    0.9882    0.9880     28283
          娱乐     0.9624    0.9720    0.9672     19920
          时政     0.9132    0.9290    0.9210     13515
          社会     0.9118    0.9079    0.9098     11009
          教育     0.9530    0.9500    0.9515      8987
          财经     0.9015    0.8727    0.8868      7957
          家居     0.9392    0.9378    0.9385      7063
          游戏     0.9334    0.9170    0.9252      5291
          房产     0.9712    0.9591    0.9651      4428
          时尚     0.9339    0.9223    0.9280      2818
          彩票     0.9480    0.

2020-08-08 14:25:09,888 INFO: | epoch   3 | step 24900 | batch 2400/11250 | lr 0.00000 0.00002 | loss 0.1549 | s/batch 0.57
2020-08-08 14:25:43,021 INFO: | epoch   3 | step 24950 | batch 2450/11250 | lr 0.00000 0.00002 | loss 0.1185 | s/batch 0.66
2020-08-08 14:26:16,789 INFO: | epoch   3 | step 25000 | batch 2500/11250 | lr 0.00000 0.00002 | loss 0.1916 | s/batch 0.68
2020-08-08 14:26:52,456 INFO: | epoch   3 | step 25050 | batch 2550/11250 | lr 0.00000 0.00002 | loss 0.1315 | s/batch 0.71
2020-08-08 14:27:26,952 INFO: | epoch   3 | step 25100 | batch 2600/11250 | lr 0.00000 0.00002 | loss 0.1208 | s/batch 0.69
2020-08-08 14:28:01,346 INFO: | epoch   3 | step 25150 | batch 2650/11250 | lr 0.00000 0.00002 | loss 0.1504 | s/batch 0.69
2020-08-08 14:28:38,035 INFO: | epoch   3 | step 25200 | batch 2700/11250 | lr 0.00000 0.00002 | loss 0.1240 | s/batch 0.73
2020-08-08 14:29:14,250 INFO: | epoch   3 | step 25250 | batch 2750/11250 | lr 0.00000 0.00002 | loss 0.1421 | s/batch 0.72
2020-08-

2020-08-08 15:03:03,401 INFO: | epoch   3 | step 28250 | batch 5750/11250 | lr 0.00000 0.00002 | loss 0.0946 | s/batch 0.62
2020-08-08 15:03:40,931 INFO: | epoch   3 | step 28300 | batch 5800/11250 | lr 0.00000 0.00002 | loss 0.1564 | s/batch 0.75
2020-08-08 15:04:16,555 INFO: | epoch   3 | step 28350 | batch 5850/11250 | lr 0.00000 0.00002 | loss 0.1019 | s/batch 0.71
2020-08-08 15:04:48,075 INFO: | epoch   3 | step 28400 | batch 5900/11250 | lr 0.00000 0.00002 | loss 0.1353 | s/batch 0.63
2020-08-08 15:05:23,221 INFO: | epoch   3 | step 28450 | batch 5950/11250 | lr 0.00000 0.00002 | loss 0.1081 | s/batch 0.70
2020-08-08 15:05:57,864 INFO: | epoch   3 | step 28500 | batch 6000/11250 | lr 0.00000 0.00002 | loss 0.1346 | s/batch 0.69
2020-08-08 15:06:32,380 INFO: | epoch   3 | step 28550 | batch 6050/11250 | lr 0.00000 0.00002 | loss 0.1664 | s/batch 0.69
2020-08-08 15:07:05,112 INFO: | epoch   3 | step 28600 | batch 6100/11250 | lr 0.00000 0.00002 | loss 0.1484 | s/batch 0.65
2020-08-

2020-08-08 15:41:20,955 INFO: | epoch   3 | step 31600 | batch 9100/11250 | lr 0.00000 0.00001 | loss 0.1692 | s/batch 0.73
2020-08-08 15:41:53,075 INFO: | epoch   3 | step 31650 | batch 9150/11250 | lr 0.00000 0.00001 | loss 0.1400 | s/batch 0.64
2020-08-08 15:42:27,456 INFO: | epoch   3 | step 31700 | batch 9200/11250 | lr 0.00000 0.00001 | loss 0.1644 | s/batch 0.69
2020-08-08 15:42:58,999 INFO: | epoch   3 | step 31750 | batch 9250/11250 | lr 0.00000 0.00001 | loss 0.1653 | s/batch 0.63
2020-08-08 15:43:30,598 INFO: | epoch   3 | step 31800 | batch 9300/11250 | lr 0.00000 0.00001 | loss 0.1194 | s/batch 0.63
2020-08-08 15:44:03,072 INFO: | epoch   3 | step 31850 | batch 9350/11250 | lr 0.00000 0.00001 | loss 0.1500 | s/batch 0.65
2020-08-08 15:44:32,103 INFO: | epoch   3 | step 31900 | batch 9400/11250 | lr 0.00000 0.00001 | loss 0.1190 | s/batch 0.58
2020-08-08 15:45:10,009 INFO: | epoch   3 | step 31950 | batch 9450/11250 | lr 0.00000 0.00001 | loss 0.1252 | s/batch 0.76
2020-08-

2020-08-08 16:18:36,044 INFO: | epoch   4 | step 34000 | batch 250/11250 | lr 0.00000 0.00001 | loss 0.1212 | s/batch 0.66
2020-08-08 16:19:12,744 INFO: | epoch   4 | step 34050 | batch 300/11250 | lr 0.00000 0.00001 | loss 0.1152 | s/batch 0.73
2020-08-08 16:19:43,011 INFO: | epoch   4 | step 34100 | batch 350/11250 | lr 0.00000 0.00001 | loss 0.0964 | s/batch 0.61
2020-08-08 16:20:15,632 INFO: | epoch   4 | step 34150 | batch 400/11250 | lr 0.00000 0.00001 | loss 0.1491 | s/batch 0.65
2020-08-08 16:20:47,582 INFO: | epoch   4 | step 34200 | batch 450/11250 | lr 0.00000 0.00001 | loss 0.1119 | s/batch 0.64
2020-08-08 16:21:25,002 INFO: | epoch   4 | step 34250 | batch 500/11250 | lr 0.00000 0.00001 | loss 0.1338 | s/batch 0.75
2020-08-08 16:22:01,597 INFO: | epoch   4 | step 34300 | batch 550/11250 | lr 0.00000 0.00001 | loss 0.1169 | s/batch 0.73
2020-08-08 16:22:34,220 INFO: | epoch   4 | step 34350 | batch 600/11250 | lr 0.00000 0.00001 | loss 0.1161 | s/batch 0.65
2020-08-08 16:23

2020-08-08 16:56:20,182 INFO: | epoch   4 | step 37350 | batch 3600/11250 | lr 0.00000 0.00001 | loss 0.1325 | s/batch 0.64
2020-08-08 16:56:54,940 INFO: | epoch   4 | step 37400 | batch 3650/11250 | lr 0.00000 0.00001 | loss 0.0826 | s/batch 0.70
2020-08-08 16:57:27,457 INFO: | epoch   4 | step 37450 | batch 3700/11250 | lr 0.00000 0.00001 | loss 0.1097 | s/batch 0.65
2020-08-08 16:57:58,171 INFO: | epoch   4 | step 37500 | batch 3750/11250 | lr 0.00000 0.00001 | loss 0.0918 | s/batch 0.61
2020-08-08 16:58:30,199 INFO: | epoch   4 | step 37550 | batch 3800/11250 | lr 0.00000 0.00001 | loss 0.1080 | s/batch 0.64
2020-08-08 16:59:04,938 INFO: | epoch   4 | step 37600 | batch 3850/11250 | lr 0.00000 0.00001 | loss 0.1153 | s/batch 0.69
2020-08-08 16:59:39,082 INFO: | epoch   4 | step 37650 | batch 3900/11250 | lr 0.00000 0.00001 | loss 0.0836 | s/batch 0.68
2020-08-08 17:00:14,531 INFO: | epoch   4 | step 37700 | batch 3950/11250 | lr 0.00000 0.00001 | loss 0.1135 | s/batch 0.71
2020-08-

2020-08-08 17:34:24,174 INFO: | epoch   4 | step 40700 | batch 6950/11250 | lr 0.00000 0.00000 | loss 0.1124 | s/batch 0.70
2020-08-08 17:34:57,915 INFO: | epoch   4 | step 40750 | batch 7000/11250 | lr 0.00000 0.00000 | loss 0.1056 | s/batch 0.67
2020-08-08 17:35:32,513 INFO: | epoch   4 | step 40800 | batch 7050/11250 | lr 0.00000 0.00000 | loss 0.0942 | s/batch 0.69
2020-08-08 17:36:06,749 INFO: | epoch   4 | step 40850 | batch 7100/11250 | lr 0.00000 0.00000 | loss 0.1138 | s/batch 0.68
2020-08-08 17:36:42,215 INFO: | epoch   4 | step 40900 | batch 7150/11250 | lr 0.00000 0.00000 | loss 0.0614 | s/batch 0.71
2020-08-08 17:37:13,645 INFO: | epoch   4 | step 40950 | batch 7200/11250 | lr 0.00000 0.00000 | loss 0.1184 | s/batch 0.63
2020-08-08 17:37:50,473 INFO: | epoch   4 | step 41000 | batch 7250/11250 | lr 0.00000 0.00000 | loss 0.0764 | s/batch 0.74
2020-08-08 17:38:23,771 INFO: | epoch   4 | step 41050 | batch 7300/11250 | lr 0.00000 0.00000 | loss 0.1426 | s/batch 0.67
2020-08-

2020-08-08 18:12:00,256 INFO: | epoch   4 | step 44050 | batch 10300/11250 | lr 0.00000 0.00000 | loss 0.0704 | s/batch 0.65
2020-08-08 18:12:31,792 INFO: | epoch   4 | step 44100 | batch 10350/11250 | lr 0.00000 0.00000 | loss 0.0962 | s/batch 0.63
2020-08-08 18:13:02,422 INFO: | epoch   4 | step 44150 | batch 10400/11250 | lr 0.00000 0.00000 | loss 0.1400 | s/batch 0.61
2020-08-08 18:13:34,094 INFO: | epoch   4 | step 44200 | batch 10450/11250 | lr 0.00000 0.00000 | loss 0.0704 | s/batch 0.63
2020-08-08 18:14:14,294 INFO: | epoch   4 | step 44250 | batch 10500/11250 | lr 0.00000 0.00000 | loss 0.1224 | s/batch 0.80
2020-08-08 18:14:50,567 INFO: | epoch   4 | step 44300 | batch 10550/11250 | lr 0.00000 0.00000 | loss 0.1398 | s/batch 0.73
2020-08-08 18:15:25,603 INFO: | epoch   4 | step 44350 | batch 10600/11250 | lr 0.00000 0.00000 | loss 0.1227 | s/batch 0.70
2020-08-08 18:15:59,841 INFO: | epoch   4 | step 44400 | batch 10650/11250 | lr 0.00000 0.00000 | loss 0.0802 | s/batch 0.68


In [60]:

# test
trainer.test()