<a href="https://colab.research.google.com/github/dh610/ai-intensive2/blob/main/lab5/lab6_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Lab 6 : BERT

@copyright:
    (c) 2023. iKnow Lab. Ajou Univ., All rights reserved.

M.S. Student: Wansik-Jo (jws5327@ajou.ac.kr)

# For assignment

- Python code의 주석 처리되어있는 부분을 구현하면 됩니다.
- MD 형식의 Cell의 [BLANK] 부분을 채우면 됩니다.
- MD 형식의 Cell의 [ANSWER] 부분 이후에 답을 작성하면 됩니다.
- 조교에게 퀴즈의 답과 함께 코드 실행 결과를 보여준 뒤, BB에 제출 후 가시면 됩니다.

---


## 목차

1. BERT
2. WordPiece Tokenizer
3. BERT Pretraining
    - Masked Language Model
    - Next Sentence Prediction
4. Embedding
    - Positional Embedding
    - Token Embedding
    - Segment Embedding
5. BERT Architecture
    - Multi-Head Attention
    - Feed Forward Network
    - Layer Normalization
    - Residual Connection
    - Encoder
6. Training

## 1. BERT (Introduction)

### BERT

- [BERT](https://arxiv.org/abs/1810.04805)는 2018년 10월에 Google에서 발표한 모델로, Bidirectional Encoder Representations from Transformers의 약자이다.

- 이름에서도 알 수 있듯이, BERT는 Transformer Bi-Directional Encoder를 사용하여 사전 훈련된 언어 모델이다.

- 즉 Pretaining 과정을 통해 얻은 사전 훈련된 모델을 Fine-tuning을 통해 다양한 NLP Task에 사용할 수 있다.

![BERT_pretraining](./figure/BERT_pretraining.png)

- 본 실습에서는, BERT의 구조를 이해하고, 직접 구현한다. 또한, BERT의 Pretraining 과정을 역시 직접 구현하는 것을 목표로 한다.


### Data

- 본 실습에서는, [Cornell Movie-Dialogs Corpus](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html)를 사용한다.

    - 220,000 conversational exchanges between 10,292 pairs of movie characters


In [None]:
!pip install transformers datasets tokenizers
!wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
!unzip -qq cornell_movie_dialogs_corpus.zip
!rm cornell_movie_dialogs_corpus.zip
!mkdir datasets
!mv cornell\ movie-dialogs\ corpus/movie_conversations.txt ./datasets
!mv cornell\ movie-dialogs\ corpus/movie_lines.txt ./datasets

- 다음과 같은 preprocess 과정을 거친다.

1. Data corpus는 `movie_conversations.txt` 파일과 `movie_lines.txt` 파일로 나뉘어져 있다.

2. `movie_lines.txt` 파일의 text를 특정 delimiter (‘+++ $ +++’)를 사용하여 line의 ID, character ID, movie ID, 그리고 dialogue text로 나누어 line_dic이라는 dictionary에 저장한다.

3. `movie_conversations.txt` 파일의 text를 특정 delimiter (‘+++ $ +++’)를 사용하여 character ID, character ID, movie ID, 그리고 dialogue text로 나누어 conv_dic이라는 dictionary에 저장한다.

4. 마지막으로, input sequence의 최대 길이를 64로 제한한다. 이를 위해 text를 split하고, 처음 64개의 token만을 사용한다.


In [None]:
from pathlib import Path

MAX_LEN = 64

### loading all data into memory
corpus_movie_conv = './datasets/movie_conversations.txt'
corpus_movie_lines = './datasets/movie_lines.txt'
with open(corpus_movie_conv, 'r', encoding='iso-8859-1') as c:
    conv = c.readlines()
with open(corpus_movie_lines, 'r', encoding='iso-8859-1') as l:
    lines = l.readlines()

### splitting text using special lines
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]

### generate question answer pairs
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []

        if i == len(ids) - 1:
            break

        first = lines_dic[ids[i]].strip()
        second = lines_dic[ids[i+1]].strip()

        qa_pairs.append(' '.join(first.split()[:MAX_LEN]))
        qa_pairs.append(' '.join(second.split()[:MAX_LEN]))
        pairs.append(qa_pairs)

In [None]:
print(len(pairs))
print(pairs[42])

221616
['She okay?', 'I hope so.']


## 2. WordPiece Tokenizer

- BERT는 WordPiece Tokenizer를 사용한다.

- WordPiece Tokenizer는 [Neural Machine Translation of Rare Words with Subword Units](https://arxiv.org/abs/1508.07909) 논문에서 제안된 방법으로, 단어를 subword로 나누는 방법이다.

>  “I like surfboarding!” → [‘[CLS]’, ‘i’, ‘like’, ‘surf’, ‘##board’, ‘##ing’, ‘!’, ‘[SEP]’] → [1, 48, 250, 4033, 3588, 154, 5, 2]

- WordPiece Tokenizer는 각 pair마다 score를 계산한다.

$$ score = \text{frequency of the pair} \div \text{frequency of the first word} \times \text{frequency of the second word} $$

Huggingface의 설명을 보자

> By dividing the frequency of the pair by the product of the frequencies of each of its parts, the algorithm prioritizes the merging of pairs where the individual parts are less frequent in the vocabulary. For instance, it won’t necessarily merge ("un", "##able") even if that pair occurs very frequently in the vocabulary, because the two pairs "un" and "##able" will likely each appear in a lot of other words and have a high frequency. In contrast, a pair like ("hu", "##gging") will probably be merged faster (assuming the word “hugging” appears often in the vocabulary) since "hu" and "##gging" are likely to be less frequent individually.



- 본 실습에서는, 위에서 import한 training corpus를 위해 WordPiece Tokenizer를 직접 training한다.

1. transformer library의 `BertWordPieceTokenizer`를 import한다.

2. Conversation text를 여러개의 txt file로 나눈다. (N=10000)

3. Parameter들을 설정한다.
    - `clean_text` : Remove control character
    - `handle_chinese_chars` : Handle Chinese characters
    - `strip_accents` : Remove accents
    - `lowercase` : Lowercase

4. tokenizer.train 함수를 사용하여, training을 진행한다. 함수의 training parameter는 다음과 같다.
    - `files` : training에 사용할 txt file들의 list
    - `vocab_size` : tokenizer의 vocab size
    - `min_frequency` : tokenizer의 min frequency
    - `limit_alphabet` : tokenizer의 alphabet limit
    - `wordpieces_prefix` : tokenizer의 wordpieces prefix
    - `special_tokens` : tokenizer의 special tokens

5. tokenizer.save_model 함수를 사용하여, tokenizer를 저장한다.

In [None]:
import os
from tqdm import tqdm
import torch
from tokenizers import BertWordPieceTokenizer
import transformers, datasets
from transformers import BertTokenizer

os.makedirs('./data', exist_ok=True)
text_data = []
file_count = 0

for sample in tqdm([x[0] for x in pairs]):
    text_data.append(sample)

    # Save at every 10k samples
    if len(text_data) == 10000:
        with open(f'./data/text_{file_count}.txt', 'w', encoding='utf-8') as fp:
            fp.write('\n'.join(text_data))
        text_data = []
        file_count += 1

paths = [str(x) for x in Path('./data').glob('**/*.txt')]

### training own tokenizer
tokenizer = BertWordPieceTokenizer(
    clean_text=True,
    handle_chinese_chars=False,
    strip_accents=False,
    lowercase=True
)

tokenizer.train(
    files=paths,
    vocab_size=30_000,
    min_frequency=5,
    limit_alphabet=1000,
    wordpieces_prefix='##',
    special_tokens=['[PAD]', '[CLS]', '[SEP]', '[MASK]', '[UNK]']
    )

os.makedirs('./bert', exist_ok=True)
tokenizer.save_model('./bert', 'bert')
tokenizer = BertTokenizer.from_pretrained('./bert/bert-vocab.txt', local_files_only=True)

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 221616/221616 [00:00<00:00, 1972885.23it/s]









In [None]:
target_sentence = "I was walking on the bridge. People passed my by slowly. And the weather was cloudy."

print(tokenizer.tokenize(target_sentence))
print(tokenizer.encode(target_sentence))
print(tokenizer.decode(tokenizer.encode(target_sentence)))

['i', 'was', 'walking', 'on', 'the', 'bridge', '.', 'people', 'passed', 'my', 'by', 'slowly', '.', 'and', 'the', 'weather', 'was', 'cloud', '##y', '.']
[1, 48, 215, 2643, 192, 150, 2783, 17, 464, 3259, 218, 454, 5583, 17, 179, 150, 3005, 215, 4682, 114, 17, 2]
[CLS] i was walking on the bridge. people passed my by slowly. and the weather was cloudy. [SEP]


BERT vocab의 special tokens은 다음과 같다.

- [PAD] : 0, Padding token, Training 과정에서 gradient를 계산하지 않는다.
- [UNK] : 1, Unknown token, Vocab에 없는 token, replacement용 token이다.
- [CLS] : 2, Classification token, 문장의 시작 token(SOS)이며, 전체 문장의 representation을 담고 있다.
- [SEP] : 3, Separation token, 문장의 끝 token(EOS)이며, 첫번쨰 문장과 두번째 문장을 구분하는 역할을 한다.
- [MASK] : 4, Mask token, Masked Language Model을 위해 사용되는 token이다.

In [None]:
print(tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]', '[PAD]', '[MASK]', '[UNK]']))

[1, 2, 0, 3, 4]


### 3. BERT Pretraining

LM(Language Model)에서 살펴본 방법과 같이 BERT는 Next Word Prediction을 하지 않는다.

대신 BERT는 다음과 같은 두가지 방법을 사용한다.

1. Masked Language Model
2. Next Sentence Prediction

### 3.1 Masked Language Model

- Masked Language Model은 다음과 같은 방법으로 학습한다.

Input sequence의 15%의 token을 [MASK] token으로 바꾸고, training 과정에서 이를 예측하도록 한다.

해당 방법은 실제로 적용할 떄, 다음과 같은 문제점이 있었다.

Model이 [MASK] token이 input으로 들어올 때만 예측하도록 학습되기 때문에, 실제로 fine-tuning을 할 때와 같이 [MASK] token이 input으로 들어오지 않을 때는 예측을 하지 못한다.

따라서, 15%의 token 중에서
- 80%는 [MASK] token으로,
- 10%는 random token으로,
- 10%는 원래 token으로 유지하는 방법을 사용한다.


### 3.2 Next Sentence Prediction

또한, BERT model이 두 문장 사이의 관계를 파악할 수 있도록, NSP라는 task를 추가한다.

NSP는 두 문장이 이어지는지 아닌지를 예측하는 task이다.

NSP는 다음과 같은 방법으로 학습한다.

Input sequence의 50%는 첫번째 문장, 50%는 두번째 문장으로 구성한다.

이 때, 50%의 문장 중에서 50%는 두 문장이 이어지는 문장, 50%는 두 문장이 이어지지 않는 문장으로 구성한다.

In [None]:
from torch.utils.data import Dataset, DataLoader
import itertools
import random

class BERTDataset(Dataset):
    def __init__(self, data_pair, tokenizer, seq_len=64):
        self.lines = data_pair
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, item):
        # Step 1: get random sentence pair, either negative or positive (saved as is_next_label)
        t1, t2, is_next_label = self.get_sent(item)

        # Step 2: replace random words in sentence with mask / random words
        t1_random, t1_label = self.random_word(t1)
        t2_random, t2_label = self.random_word(t2)

        # Step 3:
        # Adding CLS and SEP tokens to the start and end of sentences
        t1 = [self.tokenizer.vocab['[CLS]']] + t1_random + [self.tokenizer.vocab['[SEP]']]
        t2 = t2_random + [self.tokenizer.vocab['[SEP]']]
        # Adding PAD token for labels
        t1_label = [self.tokenizer.vocab['[PAD]']] + t1_label + [self.tokenizer.vocab['[PAD]']]
        t2_label = t2_label + [self.tokenizer.vocab['[PAD]']]

        # Step 4:
        # Combine sentence 1 and 2 as one input
        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]
        # Adding PAD tokens to make the sentence same length as seq_len
        padding = [self.tokenizer.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)

        output = {"bert_input": bert_input,
                  "bert_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}

        return {key: torch.tensor(value) for key, value in output.items()}

    def random_word(self, sentence):
        #Randomly mask words in sentence
        tokens = sentence.split()
        output = []
        output_label = []

        # 15% of the tokens would be replaced
        for i, token in enumerate(tokens):
            prob = random.random()

            # We don't want to mask the special tokens(0,CLS, -1,SEP)
            token_id = self.tokenizer(token)['input_ids'][1:-1]

            if prob < 0.15:
                prob /= 0.15

                # 80% chance change token to mask token
                if prob < 0.8:
                    for i in range(len(token_id)):
                        output.append(self.tokenizer.vocab['[MASK]'])

                # 10% chance change token to random token
                elif prob < 0.9:
                    for i in range(len(token_id)):
                        output.append(random.randrange(len(self.tokenizer.vocab)))

                # 10% chance change token to current token
                else:
                    output.append(token_id)

                output_label.append(token_id)

            else:
                output.append(token_id)
                for i in range(len(token_id)):
                    output_label.append(0)

        # flattening
        output = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output]))
        output_label = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output_label]))
        assert len(output) == len(output_label)
        return output, output_label

    def get_corpus_line(self, item):
        '''return sentence pair'''
        return self.lines[item][0], self.lines[item][1]

    def get_random_line(self):
        '''return random single sentence'''
        return self.lines[random.randrange(len(self.lines))][1]

    def get_sent(self, index):
        '''return random sentence pair'''
        t1, t2 = self.get_corpus_line(index)

        # negative or positive pair, for next sentence prediction
        if random.random() > 0.5:
            return t1, t2, 1
        else:
            return t1, self.get_random_line(), 0


위 코드는 다음과 같은 방법으로 Torch BERT용 Dataset을 구성한다.

1. Positive 또는 Negative 둘 중 임의의 문장 쌍을 선택하고, 두 문장이 이어지는지 아닌지 여부를 나타내는 label `is_next`를 저장한다.

2. 위에서 설명한 확률에 따라, 두 문장의 무작위 token을 masking한다. 이 때, masking된 token의 label `mask_label`을 저장한다. 그 다음, Token ID로 변환한다.

3. 각 문장에 `[CLS]`와 `[SEP]` token을 추가한다.

4. 두 문장을 합쳐 하나의 sequence로 만든다. 이 때, `[SEP]` token을 사용하여 두 문장을 구분한다. 또한, `[PAD]` token을 사용하여 sequence의 길이를 `seq_len`으로 맞춘다.

5. `segment_label`을 생성한다. `segment_label`은 각 token이 어느 문장에 속하는지를 나타내는 label이다. `[CLS]`와 `[SEP]` token은 0, 첫번째 문장은 0, 두번째 문장은 1로 설정한다.

---

정의한 BERTDataset으로 train data를 생성하고, 이를 DataLoader를 통하여 batch로 나눈다.

In [None]:
import random

train_data = BERTDataset(
   pairs, seq_len=MAX_LEN, tokenizer=tokenizer)
train_loader = DataLoader(
   train_data, batch_size=32, shuffle=True, pin_memory=True)
sample_data = next(iter(train_loader))
print(train_data[random.randrange(len(train_data))])

{'bert_input': tensor([   1,  162,   11,   58,   40, 1372,   17,    2,  769,   17,   48,   11,
         265,  249,  569,  269,   17,    2,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0]), 'bert_label': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'segment_label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'is_next': tensor(1)}


## 4. Embedding

- BERT는 다음과 같은 3가지의 embedding을 사용한다.
    - Token Embedding
    - Segment Embedding
    - Position Embedding

### 4.1 Positional Embedding

![BERT_positional_embedding](./figure/positional_embedding.png)

- BERT는 Transformer와 같이 Positional Embedding을 사용한다.

- Transformer와 다르게, BERT는 Positional Embedding을 학습한다.

- Positional Embedding은 다음과 같은 방법으로 학습한다.

    - Positional Embedding의 dimension은 `hidden_size`이다.

    - Positional Embedding의 index는 0부터 `max_len`까지이다.

    - Positional Embedding의 index를 `hidden_size`로 나눈 후, sin과 cos 함수를 적용한다.

    - Positional Embedding의 index가 짝수일 때는 sin 함수를, 홀수일 때는 cos 함수를 적용한다.

    - Positional Embedding의 index가 짝수일 때는 `hidden_size`의 앞 절반, 홀수일 때는 뒤 절반을 사용한다.

    $$ PE_{(pos, 2i)} = sin(pos / 10000^{2i / hidden\_size}) $$
    $$ PE_{(pos, 2i+1)} = cos(pos / 10000^{2i / hidden\_size}) $$


In [None]:
import math

class PositionalEmbedding(torch.nn.Module):
    def __init__(self, d_model, max_len=128):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        for pos in range(max_len):
            # for each dimension of the each position
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))

        # include the batch size
        self.pe = pe.unsqueeze(0).to('cuda')

    def forward(self, x):
        return self.pe

- 위에서 직접 구현한 Positional Embedding과, 기본적인 embedding인 Token Embedding, 그리고 Segment Embedding을 합쳐 하나의 embedding matrix를 구성하는 `BertEmbedding`를 구현한다.

In [None]:


class BERTEmbedding(torch.nn.Module):
    """
    BERT Embedding which is consisted with under features
        1. TokenEmbedding : normal embedding matrix
        2. PositionalEmbedding : adding positional information using sin, cos
        2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)
        sum of all these features are output of BERTEmbedding
    """
    def __init__(self, vocab_size, embed_size, seq_len=64, dropout=0.1):
        """
        vocab_size: total vocab size
        embed_size: embedding size of token embedding
        dropout: dropout rate
        """
        super().__init__()
        self.embed_size = embed_size
        # (m, seq_len) --> (m, seq_len, embed_size)
        # padding_idx is not updated during training, remains as fixed pad (0)
        self.token = torch.nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.segment = torch.nn.Embedding(3, embed_size, padding_idx=0)
        self.position = PositionalEmbedding(d_model=embed_size, max_len=seq_len)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, sequence, segment_label):
        x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
        x = self.dropout(x)

        return x

## 5. BERT Architecture

- BERT는 Transformer Encoder를 사용한다.

![BERT](./figure/BERT.png)

- Transformer Encoder는 다음과 같은 구조로 이루어져 있다.

    - Multi-Head Attention
    - Feed Forward Network
    - Layer Normalization
    - Residual Connection

### 5.1 Multi-Head Attention

Transformer에서 학습했던 것처럼, BERT는 Multi-Head Attention을 사용한다.

$$ MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O $$
$$ head_i = Attention(QW_i^Q, KW_i^K, VW_i^V) $$
$$ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$

In [None]:
import torch.nn.functional as F

### attention layers
class MultiHeadedAttention(torch.nn.Module):

    def __init__(self, heads, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()

        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = torch.nn.Dropout(dropout)

        self.query = torch.nn.Linear(d_model, d_model)
        self.key = torch.nn.Linear(d_model, d_model)
        self.value = torch.nn.Linear(d_model, d_model)
        self.output_linear = torch.nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, d_model)
        mask of shape: (batch_size, 1, 1, max_words)
        """
        # (batch_size, max_len, d_model)
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        # (batch_size, max_len, d_model) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)

        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1))

        # fill 0 mask with super small number so it wont affect the softmax weight
        # (batch_size, h, max_len, max_len)
        scores = scores.masked_fill(mask == 0, -1e9)

        # (batch_size, h, max_len, max_len)
        # softmax to put attention weight for all non-pad tokens
        # max_len X max_len matrix of attention
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)

        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)

        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, d_model)
        context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)

        # (batch_size, max_len, d_model)
        return self.output_linear(context)

### 5.2 Feed Forward Network

Transformer의 Encoder와 같이, BERT는 Feed Forward Network를 사용한다.

두 개의 FC layer로 이루어져 있다. 두 FC layer 사이에는 [GeLU](https://arxiv.org/abs/1606.08415) activation function을 사용한다.

In [None]:

class FeedForward(torch.nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, middle_dim=256, dropout=0.1):
        super(FeedForward, self).__init__()

        self.fc1 = torch.nn.Linear(d_model, middle_dim)
        self.fc2 = torch.nn.Linear(middle_dim, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        out = self.activation(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

구현한 `MultiHeadAttention`과 `FeedForward`를 사용하여, 그리고 Torch module의 Layer Normalization을 사용하여 `EncoderLayer`를 구현한다.

In [None]:
class EncoderLayer(torch.nn.Module):
    def __init__(
        self,
        d_model=768,
        heads=12,
        feed_forward_hidden=768 * 4,
        dropout=0.1
        ):
        super(EncoderLayer, self).__init__()
        self.layernorm = torch.nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadedAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, embeddings, mask):
        # embeddings: (batch_size, max_len, d_model)
        # encoder mask: (batch_size, 1, 1, max_len)
        # result: (batch_size, max_len, d_model)
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        # residual layer
        interacted = self.layernorm(interacted + embeddings)
        # bottleneck
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

최종적으로, BERT Model을 다음과 같이 구성할 수 있다.

- BERT
    - BertEmbedding
    - EncoderLayer * 12

- MLM , NSP

- BERTLM
    - BERT with MLM, NSP

In [None]:
class BERT(torch.nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """
    def __init__(self, vocab_size, d_model=768, n_layers=12, heads=12, dropout=0.1):
        """
        :vocab_size: vocab_size of total words
        :hidden: BERT model hidden size
        :n_layers: numbers of Transformer blocks(layers)
        :attn_heads: number of attention heads
        :dropout: dropout rate
        """
        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.heads = heads

        # paper noted they used 4 * hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = 256

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model)

        # multi-layers transformer blocks, deep network
        self.encoder_blocks = torch.nn.ModuleList(
            [EncoderLayer(d_model, heads, 256, dropout) for _ in range(n_layers)])

    def forward(self, x, segment_info):
        # attention masking for padded token
        # (batch_size, 1, seq_len, seq_len)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x, segment_info)

        # running over multiple transformer blocks
        for encoder in self.encoder_blocks:
            x = encoder.forward(x, mask)

        return x

class NextSentencePrediction(torch.nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, hidden):
        """
        :hidden: BERT model output size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, 2)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        # use only the first token which is the [CLS]
        return self.softmax(self.linear(x[:, 0]))

class MaskedLanguageModel(torch.nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :hidden: output size of BERT model
        :vocab_size: total vocab size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, vocab_size)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

class BERTLM(torch.nn.Module):
    """
    BERT Language Model
    Next Sentence Prediction Model + Masked Language Model
    """

    def __init__(self, bert: BERT, vocab_size):
        """
        :bert: BERT model which should be trained
        :vocab_size: total vocab size for masked_lm
        """
        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.d_model)
        self.mask_lm = MaskedLanguageModel(self.bert.d_model, vocab_size)

    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return self.next_sentence(x), self.mask_lm(x)

## 6. Training

구현한 BERTLM을 사용하여, BERT를 pretraining한다.

In [None]:
import numpy as np
from torch.optim import Adam

class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

In [None]:
class BERTTrainer:
    def __init__(
        self,
        model,
        train_dataloader,
        test_dataloader=None,
        lr= 1e-4,
        weight_decay=0.01,
        betas=(0.9, 0.999),
        warmup_steps=10000,
        log_freq=10,
        device='cuda'
        ):

        self.device = device
        self.model = model
        self.train_data = train_dataloader
        self.test_data = test_dataloader

        # Setting the Adam optimizer with hyper-param
        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.optim_schedule = ScheduledOptim(
            self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps
            )

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = torch.nn.NLLLoss(ignore_index=0)
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        mode = "train" if train else "test"

        # progress bar
        data_iter = tqdm(
            enumerate(data_loader),
            desc="EP_%s:%d" % (mode, epoch),
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}"
        )

        for i, data in data_iter:

            # 0. batch_data will be sent into the device(GPU or cpu)
            data = {key: value.to(self.device) for key, value in data.items()}

            # 1. forward the next_sentence_prediction and masked_lm model
            next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])

            # 2-1. NLL(negative log likelihood) loss of is_next classification result
            next_loss = self.criterion(next_sent_output, data["is_next"])

            # 2-2. NLLLoss of predicting masked token word
            # transpose to (m, vocab_size, seq_len) vs (m, seq_len)
            # criterion(mask_lm_output.view(-1, mask_lm_output.size(-1)), data["bert_label"].view(-1))
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])

            # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
            loss = next_loss + mask_loss

            # 3. backward and optimization only in train
            if train:
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            # next sentence prediction accuracy
            correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["is_next"].nelement()

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))
        print(
            f"EP{epoch}, {mode}: \
            avg_loss={avg_loss / len(data_iter)}, \
            total_acc={total_correct * 100.0 / total_element}"
        )

In [None]:
'''test run'''

train_data = BERTDataset(
   pairs, seq_len=MAX_LEN, tokenizer=tokenizer)

train_loader = DataLoader(
   train_data, batch_size=128, shuffle=True, pin_memory=True)

bert_model = BERT(
  vocab_size=len(tokenizer.vocab),
  d_model=128,
  n_layers=2,
  heads=2,
  dropout=0.1
).to('cuda')

bert_lm = BERTLM(bert_model, len(tokenizer.vocab)).to('cuda')
bert_trainer = BERTTrainer(bert_lm, train_loader, device='cuda')
epochs = 5

for epoch in range(epochs):
  bert_trainer.train(epoch)

Total Parameters: 5703210


EP_train:0:   0%|| 1/1732 [00:01<38:22,  1.33s/it]

{'epoch': 0, 'iter': 0, 'avg_loss': 10.8682222366333, 'avg_acc': 52.34375, 'loss': 10.8682222366333}


EP_train:0:   1%|| 11/1732 [00:03<07:31,  3.81it/s]

{'epoch': 0, 'iter': 10, 'avg_loss': 10.883026903325861, 'avg_acc': 50.14204545454546, 'loss': 10.867362976074219}


EP_train:0:   1%|| 21/1732 [00:06<06:59,  4.08it/s]

{'epoch': 0, 'iter': 20, 'avg_loss': 10.875261125110445, 'avg_acc': 50.78125, 'loss': 10.807914733886719}


EP_train:0:   2%|| 31/1732 [00:08<07:20,  3.87it/s]

{'epoch': 0, 'iter': 30, 'avg_loss': 10.869547966987856, 'avg_acc': 50.403225806451616, 'loss': 10.823481559753418}


EP_train:0:   2%|| 41/1732 [00:11<07:01,  4.01it/s]

{'epoch': 0, 'iter': 40, 'avg_loss': 10.852993104516006, 'avg_acc': 50.64786585365854, 'loss': 10.795344352722168}


EP_train:0:   3%|| 51/1732 [00:13<06:42,  4.18it/s]

{'epoch': 0, 'iter': 50, 'avg_loss': 10.836528927672143, 'avg_acc': 50.474877450980394, 'loss': 10.744404792785645}


EP_train:0:   4%|| 61/1732 [00:16<07:02,  3.95it/s]

{'epoch': 0, 'iter': 60, 'avg_loss': 10.820503047255219, 'avg_acc': 50.371413934426236, 'loss': 10.704460144042969}


EP_train:0:   4%|| 71/1732 [00:18<07:07,  3.88it/s]

{'epoch': 0, 'iter': 70, 'avg_loss': 10.798840616790342, 'avg_acc': 50.3631161971831, 'loss': 10.6514892578125}


EP_train:0:   5%|| 81/1732 [00:21<07:03,  3.90it/s]

{'epoch': 0, 'iter': 80, 'avg_loss': 10.773246576756607, 'avg_acc': 50.057870370370374, 'loss': 10.58171272277832}


EP_train:0:   5%|| 91/1732 [00:24<07:09,  3.82it/s]

{'epoch': 0, 'iter': 90, 'avg_loss': 10.748712854070979, 'avg_acc': 50.10302197802198, 'loss': 10.557393074035645}


EP_train:0:   6%|| 101/1732 [00:26<06:47,  4.00it/s]

{'epoch': 0, 'iter': 100, 'avg_loss': 10.723172962075413, 'avg_acc': 49.95358910891089, 'loss': 10.446120262145996}


EP_train:0:   6%|| 111/1732 [00:29<06:53,  3.92it/s]

{'epoch': 0, 'iter': 110, 'avg_loss': 10.698180112752828, 'avg_acc': 49.84515765765766, 'loss': 10.3860502243042}


EP_train:0:   7%|| 121/1732 [00:31<06:42,  4.00it/s]

{'epoch': 0, 'iter': 120, 'avg_loss': 10.672383198068161, 'avg_acc': 49.799845041322314, 'loss': 10.370305061340332}


EP_train:0:   8%|| 131/1732 [00:34<06:38,  4.01it/s]

{'epoch': 0, 'iter': 130, 'avg_loss': 10.645600690186479, 'avg_acc': 49.904580152671755, 'loss': 10.261902809143066}


EP_train:0:   8%|| 141/1732 [00:36<06:34,  4.03it/s]

{'epoch': 0, 'iter': 140, 'avg_loss': 10.620538995621052, 'avg_acc': 50.03324468085106, 'loss': 10.276094436645508}


EP_train:0:   9%|| 151/1732 [00:39<06:17,  4.19it/s]

{'epoch': 0, 'iter': 150, 'avg_loss': 10.595666550642607, 'avg_acc': 49.974130794701985, 'loss': 10.216792106628418}


EP_train:0:   9%|| 161/1732 [00:41<06:27,  4.06it/s]

{'epoch': 0, 'iter': 160, 'avg_loss': 10.572175760446868, 'avg_acc': 49.9902950310559, 'loss': 10.18545913696289}


EP_train:0:  10%|| 171/1732 [00:44<06:27,  4.02it/s]

{'epoch': 0, 'iter': 170, 'avg_loss': 10.548281725387127, 'avg_acc': 49.99543128654971, 'loss': 10.136054992675781}


EP_train:0:  10%|| 181/1732 [00:46<06:33,  3.94it/s]

{'epoch': 0, 'iter': 180, 'avg_loss': 10.524227347821821, 'avg_acc': 50.094958563535904, 'loss': 10.158782958984375}


EP_train:0:  11%|| 191/1732 [00:49<06:50,  3.76it/s]

{'epoch': 0, 'iter': 190, 'avg_loss': 10.50132742976643, 'avg_acc': 50.049083769633505, 'loss': 10.111011505126953}


EP_train:0:  12%|| 201/1732 [00:51<06:20,  4.03it/s]

{'epoch': 0, 'iter': 200, 'avg_loss': 10.479387667641712, 'avg_acc': 50.077736318407965, 'loss': 10.047221183776855}


EP_train:0:  12%|| 211/1732 [00:54<06:25,  3.94it/s]

{'epoch': 0, 'iter': 210, 'avg_loss': 10.457694207322541, 'avg_acc': 50.03332345971564, 'loss': 10.03204345703125}


EP_train:0:  13%|| 221/1732 [00:56<06:17,  4.00it/s]

{'epoch': 0, 'iter': 220, 'avg_loss': 10.435909465427313, 'avg_acc': 49.96111425339366, 'loss': 9.956625938415527}


EP_train:0:  13%|| 231/1732 [00:59<06:13,  4.02it/s]

{'epoch': 0, 'iter': 230, 'avg_loss': 10.413620531817019, 'avg_acc': 49.969561688311686, 'loss': 9.938676834106445}


EP_train:0:  14%|| 241/1732 [01:01<06:27,  3.85it/s]

{'epoch': 0, 'iter': 240, 'avg_loss': 10.392407088853512, 'avg_acc': 50.00972510373444, 'loss': 9.891407012939453}


EP_train:0:  14%|| 251/1732 [01:04<06:12,  3.98it/s]

{'epoch': 0, 'iter': 250, 'avg_loss': 10.370243764018632, 'avg_acc': 50.040463147410364, 'loss': 9.801039695739746}


EP_train:0:  15%|| 261/1732 [01:06<06:19,  3.87it/s]

{'epoch': 0, 'iter': 260, 'avg_loss': 10.34864777532117, 'avg_acc': 50.08680555555556, 'loss': 9.741663932800293}


EP_train:0:  16%|| 271/1732 [01:09<06:09,  3.95it/s]

{'epoch': 0, 'iter': 270, 'avg_loss': 10.326367135417417, 'avg_acc': 50.07495387453874, 'loss': 9.695479393005371}


EP_train:0:  16%|| 281/1732 [01:11<05:57,  4.06it/s]

{'epoch': 0, 'iter': 280, 'avg_loss': 10.302847278499943, 'avg_acc': 50.03336298932385, 'loss': 9.644225120544434}


EP_train:0:  17%|| 291/1732 [01:14<05:54,  4.06it/s]

{'epoch': 0, 'iter': 290, 'avg_loss': 10.279974291824393, 'avg_acc': 49.99731529209622, 'loss': 9.568742752075195}


EP_train:0:  17%|| 301/1732 [01:16<06:08,  3.89it/s]

{'epoch': 0, 'iter': 300, 'avg_loss': 10.256881751887427, 'avg_acc': 50.00519102990033, 'loss': 9.546987533569336}


EP_train:0:  18%|| 311/1732 [01:19<06:12,  3.81it/s]

{'epoch': 0, 'iter': 310, 'avg_loss': 10.233888294919128, 'avg_acc': 49.989951768488744, 'loss': 9.492486953735352}


EP_train:0:  19%|| 321/1732 [01:22<06:06,  3.85it/s]

{'epoch': 0, 'iter': 320, 'avg_loss': 10.20970635102174, 'avg_acc': 49.900214174454824, 'loss': 9.452032089233398}


EP_train:0:  19%|| 331/1732 [01:24<05:58,  3.91it/s]

{'epoch': 0, 'iter': 330, 'avg_loss': 10.185199028775772, 'avg_acc': 49.846582326283986, 'loss': 9.359602928161621}


EP_train:0:  20%|| 341/1732 [01:27<05:44,  4.04it/s]

{'epoch': 0, 'iter': 340, 'avg_loss': 10.161056096253157, 'avg_acc': 49.84879032258064, 'loss': 9.348237991333008}


EP_train:0:  20%|| 351/1732 [01:29<05:46,  3.99it/s]

{'epoch': 0, 'iter': 350, 'avg_loss': 10.13718383332603, 'avg_acc': 49.85754985754986, 'loss': 9.286320686340332}


EP_train:0:  21%|| 361/1732 [01:32<05:49,  3.92it/s]

{'epoch': 0, 'iter': 360, 'avg_loss': 10.11197671071314, 'avg_acc': 49.874480609418285, 'loss': 9.189637184143066}


EP_train:0:  21%|| 371/1732 [01:34<05:43,  3.96it/s]

{'epoch': 0, 'iter': 370, 'avg_loss': 10.086940184436397, 'avg_acc': 49.84417115902965, 'loss': 9.205728530883789}


EP_train:0:  22%|| 381/1732 [01:37<05:38,  3.99it/s]

{'epoch': 0, 'iter': 380, 'avg_loss': 10.062340628756626, 'avg_acc': 49.82365485564305, 'loss': 9.158561706542969}


EP_train:0:  23%|| 391/1732 [01:39<05:41,  3.93it/s]

{'epoch': 0, 'iter': 390, 'avg_loss': 10.03772706814739, 'avg_acc': 49.78021099744245, 'loss': 9.146259307861328}


EP_train:0:  23%|| 401/1732 [01:42<05:30,  4.03it/s]

{'epoch': 0, 'iter': 400, 'avg_loss': 10.012900925633913, 'avg_acc': 49.76036471321696, 'loss': 9.087571144104004}


EP_train:0:  24%|| 411/1732 [01:44<05:20,  4.12it/s]

{'epoch': 0, 'iter': 410, 'avg_loss': 9.98676299584753, 'avg_acc': 49.78520377128954, 'loss': 8.97130012512207}


EP_train:0:  24%|| 421/1732 [01:47<05:28,  3.99it/s]

{'epoch': 0, 'iter': 420, 'avg_loss': 9.961014949227739, 'avg_acc': 49.79030581947744, 'loss': 8.861788749694824}


EP_train:0:  25%|| 431/1732 [01:49<05:22,  4.03it/s]

{'epoch': 0, 'iter': 430, 'avg_loss': 9.93586506367836, 'avg_acc': 49.831424013921115, 'loss': 8.799809455871582}


EP_train:0:  25%|| 441/1732 [01:52<05:25,  3.97it/s]

{'epoch': 0, 'iter': 440, 'avg_loss': 9.910720989547348, 'avg_acc': 49.81753117913833, 'loss': 8.757440567016602}


EP_train:0:  26%|| 451/1732 [01:54<05:15,  4.05it/s]

{'epoch': 0, 'iter': 450, 'avg_loss': 9.885374900242706, 'avg_acc': 49.85275776053215, 'loss': 8.779256820678711}


EP_train:0:  27%|| 461/1732 [01:57<05:11,  4.09it/s]

{'epoch': 0, 'iter': 460, 'avg_loss': 9.860567860386118, 'avg_acc': 49.8949295010846, 'loss': 8.669825553894043}


EP_train:0:  27%|| 471/1732 [01:59<05:17,  3.97it/s]

{'epoch': 0, 'iter': 470, 'avg_loss': 9.835647609330033, 'avg_acc': 49.917064755838645, 'loss': 8.7075834274292}


EP_train:0:  28%|| 481/1732 [02:02<05:18,  3.92it/s]

{'epoch': 0, 'iter': 480, 'avg_loss': 9.811093711059952, 'avg_acc': 49.90579521829522, 'loss': 8.618895530700684}


EP_train:0:  28%|| 491/1732 [02:04<05:18,  3.90it/s]

{'epoch': 0, 'iter': 490, 'avg_loss': 9.786379988955868, 'avg_acc': 49.90453156822811, 'loss': 8.587145805358887}


EP_train:0:  29%|| 501/1732 [02:07<05:15,  3.91it/s]

{'epoch': 0, 'iter': 500, 'avg_loss': 9.762040735956676, 'avg_acc': 49.90799650698603, 'loss': 8.546481132507324}


EP_train:0:  30%|| 511/1732 [02:09<05:08,  3.96it/s]

{'epoch': 0, 'iter': 510, 'avg_loss': 9.73757607036374, 'avg_acc': 49.91285469667319, 'loss': 8.507669448852539}


EP_train:0:  30%|| 521/1732 [02:12<05:12,  3.87it/s]

{'epoch': 0, 'iter': 520, 'avg_loss': 9.712006740972772, 'avg_acc': 49.83805182341651, 'loss': 8.407122611999512}


EP_train:0:  31%|| 531/1732 [02:14<04:54,  4.07it/s]

{'epoch': 0, 'iter': 530, 'avg_loss': 9.686387590095821, 'avg_acc': 49.920550847457626, 'loss': 8.351553916931152}


EP_train:0:  31%|| 541/1732 [02:17<04:52,  4.07it/s]

{'epoch': 0, 'iter': 540, 'avg_loss': 9.662198800035855, 'avg_acc': 49.97400646950093, 'loss': 8.335844993591309}


EP_train:0:  32%|| 551/1732 [02:19<04:47,  4.11it/s]

{'epoch': 0, 'iter': 550, 'avg_loss': 9.63757852807019, 'avg_acc': 49.99432849364791, 'loss': 8.146554946899414}


EP_train:0:  32%|| 561/1732 [02:22<05:02,  3.87it/s]

{'epoch': 0, 'iter': 560, 'avg_loss': 9.612758988365133, 'avg_acc': 49.97632575757576, 'loss': 8.105886459350586}


EP_train:0:  33%|| 571/1732 [02:24<05:03,  3.83it/s]

{'epoch': 0, 'iter': 570, 'avg_loss': 9.588862546062302, 'avg_acc': 49.9425350262697, 'loss': 8.256758689880371}


EP_train:0:  34%|| 581/1732 [02:27<04:51,  3.94it/s]

{'epoch': 0, 'iter': 580, 'avg_loss': 9.565645219536943, 'avg_acc': 49.96234939759036, 'loss': 8.181739807128906}


EP_train:0:  34%|| 591/1732 [02:29<04:35,  4.13it/s]

{'epoch': 0, 'iter': 590, 'avg_loss': 9.541344804812203, 'avg_acc': 49.98281514382403, 'loss': 8.063505172729492}


EP_train:0:  35%|| 601/1732 [02:32<04:54,  3.84it/s]

{'epoch': 0, 'iter': 600, 'avg_loss': 9.517279214747932, 'avg_acc': 49.98570091514143, 'loss': 8.050545692443848}


EP_train:0:  35%|| 611/1732 [02:35<04:48,  3.88it/s]

{'epoch': 0, 'iter': 610, 'avg_loss': 9.492827822846008, 'avg_acc': 49.933510638297875, 'loss': 8.0579195022583}


EP_train:0:  36%|| 621/1732 [02:37<04:38,  3.98it/s]

{'epoch': 0, 'iter': 620, 'avg_loss': 9.468425483519328, 'avg_acc': 49.95345209339774, 'loss': 7.973042011260986}


EP_train:0:  36%|| 631/1732 [02:40<04:47,  3.82it/s]

{'epoch': 0, 'iter': 630, 'avg_loss': 9.443858647686556, 'avg_acc': 49.95171354992076, 'loss': 7.873818397521973}


EP_train:0:  37%|| 641/1732 [02:42<04:30,  4.03it/s]

{'epoch': 0, 'iter': 640, 'avg_loss': 9.418672029760065, 'avg_acc': 49.941497659906396, 'loss': 7.775134086608887}


EP_train:0:  38%|| 651/1732 [02:45<04:34,  3.94it/s]

{'epoch': 0, 'iter': 650, 'avg_loss': 9.394265701510756, 'avg_acc': 49.918394777265746, 'loss': 7.79007625579834}


EP_train:0:  38%|| 661/1732 [02:47<04:37,  3.86it/s]

{'epoch': 0, 'iter': 660, 'avg_loss': 9.370286963891333, 'avg_acc': 49.919629349470505, 'loss': 7.825533866882324}


EP_train:0:  39%|| 671/1732 [02:50<04:27,  3.96it/s]

{'epoch': 0, 'iter': 670, 'avg_loss': 9.345527378173237, 'avg_acc': 49.911512667660205, 'loss': 7.7185893058776855}


EP_train:0:  39%|| 681/1732 [02:52<04:23,  3.98it/s]

{'epoch': 0, 'iter': 680, 'avg_loss': 9.32065171887346, 'avg_acc': 49.87724853157122, 'loss': 7.559593677520752}


EP_train:0:  40%|| 691/1732 [02:55<04:24,  3.94it/s]

{'epoch': 0, 'iter': 690, 'avg_loss': 9.2958626885145, 'avg_acc': 49.87789435600579, 'loss': 7.699054718017578}


EP_train:0:  40%|| 701/1732 [02:57<04:22,  3.92it/s]

{'epoch': 0, 'iter': 700, 'avg_loss': 9.27086821024156, 'avg_acc': 49.86514800285307, 'loss': 7.588382244110107}


EP_train:0:  41%|| 711/1732 [03:00<04:20,  3.91it/s]

{'epoch': 0, 'iter': 710, 'avg_loss': 9.245951402371778, 'avg_acc': 49.856056610407876, 'loss': 7.705009937286377}


EP_train:0:  42%|| 721/1732 [03:02<04:19,  3.90it/s]

{'epoch': 0, 'iter': 720, 'avg_loss': 9.221375934955317, 'avg_acc': 49.85696948682386, 'loss': 7.580085754394531}


EP_train:0:  42%|| 731/1732 [03:05<04:14,  3.93it/s]

{'epoch': 0, 'iter': 730, 'avg_loss': 9.195904023996293, 'avg_acc': 49.851444938440494, 'loss': 7.371208190917969}


EP_train:0:  43%|| 741/1732 [03:07<04:03,  4.07it/s]

{'epoch': 0, 'iter': 740, 'avg_loss': 9.171564699345433, 'avg_acc': 49.863992914979754, 'loss': 7.544313907623291}


EP_train:0:  43%|| 751/1732 [03:10<04:10,  3.92it/s]

{'epoch': 0, 'iter': 750, 'avg_loss': 9.147520245629526, 'avg_acc': 49.85852197070572, 'loss': 7.307809829711914}


EP_train:0:  44%|| 761/1732 [03:12<03:59,  4.05it/s]

{'epoch': 0, 'iter': 760, 'avg_loss': 9.12247441162731, 'avg_acc': 49.85524802890933, 'loss': 7.249027729034424}


EP_train:0:  45%|| 771/1732 [03:15<03:59,  4.01it/s]

{'epoch': 0, 'iter': 770, 'avg_loss': 9.097367618798282, 'avg_acc': 49.852059014267184, 'loss': 7.308106422424316}


EP_train:0:  45%|| 781/1732 [03:17<04:02,  3.92it/s]

{'epoch': 0, 'iter': 780, 'avg_loss': 9.072371446826851, 'avg_acc': 49.81394046094751, 'loss': 7.104928016662598}


EP_train:0:  46%|| 791/1732 [03:20<03:52,  4.05it/s]

{'epoch': 0, 'iter': 790, 'avg_loss': 9.047206297535661, 'avg_acc': 49.81925568900126, 'loss': 7.0086750984191895}


EP_train:0:  46%|| 801/1732 [03:22<03:51,  4.02it/s]

{'epoch': 0, 'iter': 800, 'avg_loss': 9.022860637169503, 'avg_acc': 49.84394506866417, 'loss': 7.004143714904785}


EP_train:0:  47%|| 811/1732 [03:25<03:55,  3.91it/s]

{'epoch': 0, 'iter': 810, 'avg_loss': 8.998642069490858, 'avg_acc': 49.8497225647349, 'loss': 6.962000370025635}


EP_train:0:  47%|| 821/1732 [03:27<03:47,  4.00it/s]

{'epoch': 0, 'iter': 820, 'avg_loss': 8.97400939479089, 'avg_acc': 49.84584348355664, 'loss': 6.915220260620117}


EP_train:0:  48%|| 831/1732 [03:30<03:43,  4.03it/s]

{'epoch': 0, 'iter': 830, 'avg_loss': 8.95013524192048, 'avg_acc': 49.853339350180505, 'loss': 7.013818740844727}


EP_train:0:  49%|| 841/1732 [03:32<03:42,  4.00it/s]

{'epoch': 0, 'iter': 840, 'avg_loss': 8.925784933155981, 'avg_acc': 49.85601218787158, 'loss': 6.962301254272461}


EP_train:0:  49%|| 851/1732 [03:35<03:49,  3.84it/s]

{'epoch': 0, 'iter': 850, 'avg_loss': 8.901989307022543, 'avg_acc': 49.86963866039953, 'loss': 6.8448052406311035}


EP_train:0:  50%|| 861/1732 [03:37<03:26,  4.22it/s]

{'epoch': 0, 'iter': 860, 'avg_loss': 8.87779458086942, 'avg_acc': 49.88204123112659, 'loss': 6.753905773162842}


EP_train:0:  50%|| 871/1732 [03:40<03:31,  4.07it/s]

{'epoch': 0, 'iter': 870, 'avg_loss': 8.853487591245282, 'avg_acc': 49.9228616532721, 'loss': 6.838926792144775}


EP_train:0:  51%|| 881/1732 [03:42<03:31,  4.02it/s]

{'epoch': 0, 'iter': 880, 'avg_loss': 8.829504916400023, 'avg_acc': 49.91398269012486, 'loss': 6.6990556716918945}


EP_train:0:  51%|| 891/1732 [03:45<03:35,  3.90it/s]

{'epoch': 0, 'iter': 890, 'avg_loss': 8.80521056413918, 'avg_acc': 49.91582491582491, 'loss': 6.617224216461182}


EP_train:0:  52%|| 901/1732 [03:47<03:24,  4.07it/s]

{'epoch': 0, 'iter': 900, 'avg_loss': 8.781322792552817, 'avg_acc': 49.908955327413985, 'loss': 6.488630294799805}


EP_train:0:  53%|| 911/1732 [03:50<03:23,  4.03it/s]

{'epoch': 0, 'iter': 910, 'avg_loss': 8.757506525477254, 'avg_acc': 49.903951701427005, 'loss': 6.72423791885376}


EP_train:0:  53%|| 921/1732 [03:52<03:29,  3.87it/s]

{'epoch': 0, 'iter': 920, 'avg_loss': 8.73493476540465, 'avg_acc': 49.903298045602604, 'loss': 6.605984687805176}


EP_train:0:  54%|| 931/1732 [03:55<03:30,  3.81it/s]

{'epoch': 0, 'iter': 930, 'avg_loss': 8.711434553312309, 'avg_acc': 49.8984626745435, 'loss': 6.532863140106201}


EP_train:0:  54%|| 941/1732 [03:57<03:17,  4.00it/s]

{'epoch': 0, 'iter': 940, 'avg_loss': 8.68933296051593, 'avg_acc': 49.90452311370882, 'loss': 6.855358123779297}


EP_train:0:  55%|| 951/1732 [04:00<03:22,  3.86it/s]

{'epoch': 0, 'iter': 950, 'avg_loss': 8.666533954512058, 'avg_acc': 49.89238301787592, 'loss': 6.778086185455322}


EP_train:0:  55%|| 961/1732 [04:02<03:17,  3.90it/s]

{'epoch': 0, 'iter': 960, 'avg_loss': 8.644473187509114, 'avg_acc': 49.88293444328824, 'loss': 6.280248165130615}


EP_train:0:  56%|| 971/1732 [04:05<03:13,  3.93it/s]

{'epoch': 0, 'iter': 970, 'avg_loss': 8.622027213983506, 'avg_acc': 49.8744850669413, 'loss': 6.594397068023682}


EP_train:0:  57%|| 981/1732 [04:07<03:09,  3.96it/s]

{'epoch': 0, 'iter': 980, 'avg_loss': 8.599965600063312, 'avg_acc': 49.89726681957187, 'loss': 6.430147171020508}


EP_train:0:  57%|| 991/1732 [04:10<03:07,  3.96it/s]

{'epoch': 0, 'iter': 990, 'avg_loss': 8.578734472469172, 'avg_acc': 49.90382189707366, 'loss': 6.362616539001465}


EP_train:0:  58%|| 1001/1732 [04:12<02:58,  4.09it/s]

{'epoch': 0, 'iter': 1000, 'avg_loss': 8.556679011105777, 'avg_acc': 49.898538961038966, 'loss': 6.162772178649902}


EP_train:0:  58%|| 1011/1732 [04:15<02:59,  4.02it/s]

{'epoch': 0, 'iter': 1010, 'avg_loss': 8.535512895895394, 'avg_acc': 49.90417903066271, 'loss': 6.4303460121154785}


EP_train:0:  59%|| 1021/1732 [04:17<03:00,  3.93it/s]

{'epoch': 0, 'iter': 1020, 'avg_loss': 8.514905172977578, 'avg_acc': 49.9127693437806, 'loss': 6.444796085357666}


EP_train:0:  60%|| 1031/1732 [04:20<02:51,  4.08it/s]

{'epoch': 0, 'iter': 1030, 'avg_loss': 8.49437588008137, 'avg_acc': 49.90906886517944, 'loss': 6.514904975891113}


EP_train:0:  60%|| 1041/1732 [04:22<02:55,  3.93it/s]

{'epoch': 0, 'iter': 1040, 'avg_loss': 8.473327381581097, 'avg_acc': 49.897184197886645, 'loss': 6.299520969390869}


EP_train:0:  61%|| 1051/1732 [04:25<02:56,  3.86it/s]

{'epoch': 0, 'iter': 1050, 'avg_loss': 8.452557483477097, 'avg_acc': 49.86842887725975, 'loss': 6.027676582336426}


EP_train:0:  61%|| 1061/1732 [04:28<02:49,  3.96it/s]

{'epoch': 0, 'iter': 1060, 'avg_loss': 8.432519530710692, 'avg_acc': 49.86819627709708, 'loss': 6.453990936279297}


EP_train:0:  62%|| 1071/1732 [04:30<02:47,  3.95it/s]

{'epoch': 0, 'iter': 1070, 'avg_loss': 8.41201435172814, 'avg_acc': 49.870885854341736, 'loss': 6.323964595794678}


EP_train:0:  62%|| 1081/1732 [04:33<02:39,  4.08it/s]

{'epoch': 0, 'iter': 1080, 'avg_loss': 8.392255495037887, 'avg_acc': 49.87280296022202, 'loss': 6.335777282714844}


EP_train:0:  63%|| 1091/1732 [04:35<02:41,  3.97it/s]

{'epoch': 0, 'iter': 1090, 'avg_loss': 8.372463558487231, 'avg_acc': 49.85678276810266, 'loss': 6.190007209777832}


EP_train:0:  64%|| 1101/1732 [04:37<02:38,  3.98it/s]

{'epoch': 0, 'iter': 1100, 'avg_loss': 8.35339644299541, 'avg_acc': 49.86588896457766, 'loss': 6.345950126647949}


EP_train:0:  64%|| 1111/1732 [04:40<02:28,  4.18it/s]

{'epoch': 0, 'iter': 1110, 'avg_loss': 8.333659896970762, 'avg_acc': 49.87342484248425, 'loss': 6.248504638671875}


EP_train:0:  65%|| 1121/1732 [04:42<02:35,  3.93it/s]

{'epoch': 0, 'iter': 1120, 'avg_loss': 8.31485935537864, 'avg_acc': 49.86688782337199, 'loss': 6.090513229370117}


EP_train:0:  65%|| 1131/1732 [04:45<02:30,  3.99it/s]

{'epoch': 0, 'iter': 1130, 'avg_loss': 8.296641056354229, 'avg_acc': 49.862538682581786, 'loss': 6.178784370422363}


EP_train:0:  66%|| 1141/1732 [04:47<02:24,  4.10it/s]

{'epoch': 0, 'iter': 1140, 'avg_loss': 8.27831093414534, 'avg_acc': 49.882230499561786, 'loss': 6.206490993499756}


EP_train:0:  66%|| 1151/1732 [04:50<02:36,  3.70it/s]

{'epoch': 0, 'iter': 1150, 'avg_loss': 8.260737765881416, 'avg_acc': 49.88664748045178, 'loss': 6.163495063781738}


EP_train:0:  67%|| 1161/1732 [04:53<02:22,  4.01it/s]

{'epoch': 0, 'iter': 1160, 'avg_loss': 8.243638263705678, 'avg_acc': 49.88762381567614, 'loss': 6.261890888214111}


EP_train:0:  68%|| 1171/1732 [04:55<02:18,  4.06it/s]

{'epoch': 0, 'iter': 1170, 'avg_loss': 8.226288612228903, 'avg_acc': 49.89191929974381, 'loss': 6.272753715515137}


EP_train:0:  68%|| 1181/1732 [04:58<02:18,  3.97it/s]

{'epoch': 0, 'iter': 1180, 'avg_loss': 8.208732430557394, 'avg_acc': 49.89878810330229, 'loss': 6.294406414031982}


EP_train:0:  69%|| 1191/1732 [05:00<02:12,  4.09it/s]

{'epoch': 0, 'iter': 1190, 'avg_loss': 8.191702846315707, 'avg_acc': 49.8996379093199, 'loss': 6.2765092849731445}


EP_train:0:  69%|| 1201/1732 [05:03<02:18,  3.84it/s]

{'epoch': 0, 'iter': 1200, 'avg_loss': 8.174374616910377, 'avg_acc': 49.90242506244796, 'loss': 6.102795124053955}


EP_train:0:  70%|| 1211/1732 [05:05<02:07,  4.10it/s]

{'epoch': 0, 'iter': 1210, 'avg_loss': 8.158407594033454, 'avg_acc': 49.902585672997525, 'loss': 6.295603275299072}


EP_train:0:  70%|| 1221/1732 [05:08<02:15,  3.78it/s]

{'epoch': 0, 'iter': 1220, 'avg_loss': 8.14230120328486, 'avg_acc': 49.8745904995905, 'loss': 6.404106616973877}


EP_train:0:  71%|| 1231/1732 [05:10<02:09,  3.87it/s]

{'epoch': 0, 'iter': 1230, 'avg_loss': 8.12618698713549, 'avg_acc': 49.871166734362305, 'loss': 6.026593208312988}


EP_train:0:  72%|| 1241/1732 [05:13<02:03,  3.98it/s]

{'epoch': 0, 'iter': 1240, 'avg_loss': 8.110652450973424, 'avg_acc': 49.86087328767123, 'loss': 6.18991231918335}


EP_train:0:  72%|| 1251/1732 [05:15<01:58,  4.06it/s]

{'epoch': 0, 'iter': 1250, 'avg_loss': 8.094596545282695, 'avg_acc': 49.83887889688249, 'loss': 6.0659966468811035}


EP_train:0:  73%|| 1261/1732 [05:18<01:53,  4.15it/s]

{'epoch': 0, 'iter': 1260, 'avg_loss': 8.079068374482532, 'avg_acc': 49.84077616970658, 'loss': 6.295452117919922}


EP_train:0:  73%|| 1271/1732 [05:20<01:57,  3.92it/s]

{'epoch': 0, 'iter': 1270, 'avg_loss': 8.064514381504734, 'avg_acc': 49.843258261211645, 'loss': 6.193263053894043}


EP_train:0:  74%|| 1281/1732 [05:23<01:51,  4.03it/s]

{'epoch': 0, 'iter': 1280, 'avg_loss': 8.04872107598858, 'avg_acc': 49.823746096799375, 'loss': 6.113572120666504}


EP_train:0:  75%|| 1291/1732 [05:25<01:50,  4.00it/s]

{'epoch': 0, 'iter': 1290, 'avg_loss': 8.034180412543641, 'avg_acc': 49.80514136328428, 'loss': 6.258183479309082}


EP_train:0:  75%|| 1301/1732 [05:28<01:49,  3.92it/s]

{'epoch': 0, 'iter': 1300, 'avg_loss': 8.01981739700986, 'avg_acc': 49.82765661029977, 'loss': 6.2536773681640625}


EP_train:0:  76%|| 1311/1732 [05:30<01:46,  3.96it/s]

{'epoch': 0, 'iter': 1310, 'avg_loss': 8.005659028467198, 'avg_acc': 49.820032418001524, 'loss': 6.066373825073242}


EP_train:0:  76%|| 1321/1732 [05:33<01:40,  4.08it/s]

{'epoch': 0, 'iter': 1320, 'avg_loss': 7.9914517991024825, 'avg_acc': 49.81962055261166, 'loss': 6.036427021026611}


EP_train:0:  77%|| 1331/1732 [05:35<01:43,  3.88it/s]

{'epoch': 0, 'iter': 1330, 'avg_loss': 7.9771243743122655, 'avg_acc': 49.813932193839214, 'loss': 6.121427536010742}


EP_train:0:  77%|| 1341/1732 [05:38<01:36,  4.06it/s]

{'epoch': 0, 'iter': 1340, 'avg_loss': 7.962941571903442, 'avg_acc': 49.82289336316182, 'loss': 6.0087056159973145}


EP_train:0:  78%|| 1351/1732 [05:40<01:35,  3.98it/s]

{'epoch': 0, 'iter': 1350, 'avg_loss': 7.949036023247957, 'avg_acc': 49.815530162842336, 'loss': 5.939061164855957}


EP_train:0:  79%|| 1361/1732 [05:43<01:36,  3.86it/s]

{'epoch': 0, 'iter': 1360, 'avg_loss': 7.935601857022097, 'avg_acc': 49.809423218221895, 'loss': 6.140861988067627}


EP_train:0:  79%|| 1371/1732 [05:45<01:36,  3.73it/s]

{'epoch': 0, 'iter': 1370, 'avg_loss': 7.922796652144885, 'avg_acc': 49.81252279358132, 'loss': 6.149921417236328}


EP_train:0:  80%|| 1381/1732 [05:48<01:26,  4.04it/s]

{'epoch': 0, 'iter': 1380, 'avg_loss': 7.9098888917905015, 'avg_acc': 49.81105177407675, 'loss': 6.27725887298584}


EP_train:0:  80%|| 1391/1732 [05:50<01:23,  4.06it/s]

{'epoch': 0, 'iter': 1390, 'avg_loss': 7.897265965264963, 'avg_acc': 49.81858824586629, 'loss': 6.1098504066467285}


EP_train:0:  81%|| 1401/1732 [05:53<01:23,  3.97it/s]

{'epoch': 0, 'iter': 1400, 'avg_loss': 7.883966775386355, 'avg_acc': 49.80482690935047, 'loss': 5.9688873291015625}


EP_train:0:  81%|| 1411/1732 [05:55<01:19,  4.03it/s]

{'epoch': 0, 'iter': 1410, 'avg_loss': 7.871164088718947, 'avg_acc': 49.80510276399717, 'loss': 5.997461795806885}


EP_train:0:  82%|| 1421/1732 [05:58<01:16,  4.07it/s]

{'epoch': 0, 'iter': 1420, 'avg_loss': 7.858217616218819, 'avg_acc': 49.80702410274455, 'loss': 6.026021957397461}


EP_train:0:  83%|| 1431/1732 [06:00<01:14,  4.03it/s]

{'epoch': 0, 'iter': 1430, 'avg_loss': 7.8452927641398755, 'avg_acc': 49.814378057302584, 'loss': 5.877905368804932}


EP_train:0:  83%|| 1441/1732 [06:03<01:14,  3.89it/s]

{'epoch': 0, 'iter': 1440, 'avg_loss': 7.833533882350246, 'avg_acc': 49.80699167244968, 'loss': 6.162201881408691}


EP_train:0:  84%|| 1451/1732 [06:05<01:13,  3.84it/s]

{'epoch': 0, 'iter': 1450, 'avg_loss': 7.8211291518234365, 'avg_acc': 49.79863025499655, 'loss': 6.025408744812012}


EP_train:0:  84%|| 1461/1732 [06:08<01:08,  3.98it/s]

{'epoch': 0, 'iter': 1460, 'avg_loss': 7.809371039607931, 'avg_acc': 49.81925906913073, 'loss': 6.362613201141357}


EP_train:0:  85%|| 1471/1732 [06:11<01:06,  3.90it/s]

{'epoch': 0, 'iter': 1470, 'avg_loss': 7.797669028846849, 'avg_acc': 49.8199566621346, 'loss': 6.136240005493164}


EP_train:0:  86%|| 1481/1732 [06:13<01:01,  4.08it/s]

{'epoch': 0, 'iter': 1480, 'avg_loss': 7.785978023302386, 'avg_acc': 49.83910786630655, 'loss': 5.967618465423584}


EP_train:0:  86%|| 1491/1732 [06:15<01:01,  3.95it/s]

{'epoch': 0, 'iter': 1490, 'avg_loss': 7.774749073704004, 'avg_acc': 49.842806841046276, 'loss': 6.0305495262146}


EP_train:0:  87%|| 1501/1732 [06:18<00:55,  4.14it/s]

{'epoch': 0, 'iter': 1500, 'avg_loss': 7.763193294098185, 'avg_acc': 49.84749750166556, 'loss': 6.057958126068115}


EP_train:0:  87%|| 1511/1732 [06:20<00:55,  3.99it/s]

{'epoch': 0, 'iter': 1510, 'avg_loss': 7.752612918674512, 'avg_acc': 49.852126075446726, 'loss': 6.079247951507568}


EP_train:0:  88%|| 1521/1732 [06:23<00:52,  4.03it/s]

{'epoch': 0, 'iter': 1520, 'avg_loss': 7.741167152813592, 'avg_acc': 49.845393655489815, 'loss': 6.143087863922119}


EP_train:0:  88%|| 1531/1732 [06:25<00:50,  3.98it/s]

{'epoch': 0, 'iter': 1530, 'avg_loss': 7.73006383552028, 'avg_acc': 49.840280045721755, 'loss': 6.093453407287598}


EP_train:0:  89%|| 1541/1732 [06:28<00:46,  4.09it/s]

{'epoch': 0, 'iter': 1540, 'avg_loss': 7.719176897052663, 'avg_acc': 49.8519630110318, 'loss': 5.973392009735107}


EP_train:0:  90%|| 1551/1732 [06:30<00:44,  4.08it/s]

{'epoch': 0, 'iter': 1550, 'avg_loss': 7.708517464109577, 'avg_acc': 49.85795454545455, 'loss': 6.103730201721191}


EP_train:0:  90%|| 1561/1732 [06:33<00:43,  3.96it/s]

{'epoch': 0, 'iter': 1560, 'avg_loss': 7.697977494918559, 'avg_acc': 49.85786354900705, 'loss': 6.014070987701416}


EP_train:0:  91%|| 1571/1732 [06:35<00:40,  3.93it/s]

{'epoch': 0, 'iter': 1570, 'avg_loss': 7.687655278636269, 'avg_acc': 49.87617361553151, 'loss': 6.2637224197387695}


EP_train:0:  91%|| 1581/1732 [06:38<00:38,  3.90it/s]

{'epoch': 0, 'iter': 1580, 'avg_loss': 7.676997898651933, 'avg_acc': 49.87794512966477, 'loss': 5.9766621589660645}


EP_train:0:  92%|| 1591/1732 [06:41<00:36,  3.91it/s]

{'epoch': 0, 'iter': 1590, 'avg_loss': 7.66713201048088, 'avg_acc': 49.883622721558766, 'loss': 6.081023693084717}


EP_train:0:  92%|| 1601/1732 [06:43<00:33,  3.93it/s]

{'epoch': 0, 'iter': 1600, 'avg_loss': 7.657480761678721, 'avg_acc': 49.89069331667707, 'loss': 6.109538555145264}


EP_train:0:  93%|| 1611/1732 [06:46<00:30,  3.92it/s]

{'epoch': 0, 'iter': 1610, 'avg_loss': 7.647580933822452, 'avg_acc': 49.886522346368714, 'loss': 5.937328815460205}


EP_train:0:  94%|| 1621/1732 [06:48<00:27,  4.01it/s]

{'epoch': 0, 'iter': 1620, 'avg_loss': 7.637366769933024, 'avg_acc': 49.89686150524368, 'loss': 6.023506164550781}


EP_train:0:  94%|| 1631/1732 [06:51<00:25,  4.00it/s]

{'epoch': 0, 'iter': 1630, 'avg_loss': 7.627376148848355, 'avg_acc': 49.8869558553035, 'loss': 6.068386554718018}


EP_train:0:  95%|| 1641/1732 [06:53<00:23,  3.90it/s]

{'epoch': 0, 'iter': 1640, 'avg_loss': 7.6174840168717575, 'avg_acc': 49.89097730042657, 'loss': 6.051401138305664}


EP_train:0:  95%|| 1651/1732 [06:56<00:19,  4.06it/s]

{'epoch': 0, 'iter': 1650, 'avg_loss': 7.607809262879324, 'avg_acc': 49.89873561477892, 'loss': 6.1475510597229}


EP_train:0:  96%|| 1661/1732 [06:58<00:18,  3.89it/s]

{'epoch': 0, 'iter': 1660, 'avg_loss': 7.598576399592446, 'avg_acc': 49.895582480433475, 'loss': 6.24362850189209}


EP_train:0:  96%|| 1671/1732 [07:01<00:14,  4.09it/s]

{'epoch': 0, 'iter': 1670, 'avg_loss': 7.589632255741396, 'avg_acc': 49.90929832435667, 'loss': 5.87717866897583}


EP_train:0:  97%|| 1681/1732 [07:03<00:12,  4.11it/s]

{'epoch': 0, 'iter': 1680, 'avg_loss': 7.579833668658309, 'avg_acc': 49.90426085663296, 'loss': 5.818462371826172}


EP_train:0:  98%|| 1691/1732 [07:06<00:10,  3.97it/s]

{'epoch': 0, 'iter': 1690, 'avg_loss': 7.570617066129011, 'avg_acc': 49.91037108219988, 'loss': 6.183598518371582}


EP_train:0:  98%|| 1701/1732 [07:08<00:07,  3.88it/s]

{'epoch': 0, 'iter': 1700, 'avg_loss': 7.5618439446751475, 'avg_acc': 49.91457231040564, 'loss': 6.129424571990967}


EP_train:0:  99%|| 1711/1732 [07:11<00:05,  3.99it/s]

{'epoch': 0, 'iter': 1710, 'avg_loss': 7.553257980380122, 'avg_acc': 49.91826782583285, 'loss': 6.094128131866455}


EP_train:0:  99%|| 1721/1732 [07:13<00:02,  3.99it/s]

{'epoch': 0, 'iter': 1720, 'avg_loss': 7.5440268056604625, 'avg_acc': 49.937354735618825, 'loss': 5.982414722442627}


EP_train:0: 100%|| 1732/1732 [07:16<00:00,  3.97it/s]


{'epoch': 0, 'iter': 1730, 'avg_loss': 7.534978277348563, 'avg_acc': 49.93545999422299, 'loss': 6.0526580810546875}
EP0, train:             avg_loss=7.533935430694122,             total_acc=49.932766587250015


EP_train:1:   0%|| 1/1732 [00:00<07:10,  4.02it/s]

{'epoch': 1, 'iter': 0, 'avg_loss': 6.090795516967773, 'avg_acc': 53.90625, 'loss': 6.090795516967773}


EP_train:1:   1%|| 11/1732 [00:02<07:05,  4.04it/s]

{'epoch': 1, 'iter': 10, 'avg_loss': 6.0440991575067695, 'avg_acc': 48.29545454545455, 'loss': 6.13851261138916}


EP_train:1:   1%|| 21/1732 [00:05<07:04,  4.03it/s]

{'epoch': 1, 'iter': 20, 'avg_loss': 6.0548239435468405, 'avg_acc': 48.549107142857146, 'loss': 6.018393039703369}


EP_train:1:   2%|| 31/1732 [00:07<06:52,  4.12it/s]

{'epoch': 1, 'iter': 30, 'avg_loss': 6.069265257927679, 'avg_acc': 48.689516129032256, 'loss': 6.120303153991699}


EP_train:1:   2%|| 41/1732 [00:10<06:57,  4.05it/s]

{'epoch': 1, 'iter': 40, 'avg_loss': 6.062081953374351, 'avg_acc': 48.666158536585364, 'loss': 6.000753402709961}


EP_train:1:   3%|| 51/1732 [00:12<06:49,  4.10it/s]

{'epoch': 1, 'iter': 50, 'avg_loss': 6.050010157566445, 'avg_acc': 49.28002450980392, 'loss': 5.825657367706299}


EP_train:1:   4%|| 61/1732 [00:15<07:06,  3.92it/s]

{'epoch': 1, 'iter': 60, 'avg_loss': 6.052840553346227, 'avg_acc': 49.84631147540984, 'loss': 6.224813461303711}


EP_train:1:   4%|| 71/1732 [00:17<06:47,  4.08it/s]

{'epoch': 1, 'iter': 70, 'avg_loss': 6.048236215618295, 'avg_acc': 49.6368838028169, 'loss': 5.969658851623535}


EP_train:1:   5%|| 81/1732 [00:20<07:01,  3.92it/s]

{'epoch': 1, 'iter': 80, 'avg_loss': 6.045857600223871, 'avg_acc': 49.70100308641975, 'loss': 5.990528583526611}


EP_train:1:   5%|| 91/1732 [00:22<07:00,  3.90it/s]

{'epoch': 1, 'iter': 90, 'avg_loss': 6.042276927403042, 'avg_acc': 49.708104395604394, 'loss': 5.99607515335083}


EP_train:1:   6%|| 101/1732 [00:25<06:47,  4.00it/s]

{'epoch': 1, 'iter': 100, 'avg_loss': 6.040480396535137, 'avg_acc': 49.76794554455445, 'loss': 6.114190101623535}


EP_train:1:   6%|| 111/1732 [00:27<06:52,  3.93it/s]

{'epoch': 1, 'iter': 110, 'avg_loss': 6.038986291971293, 'avg_acc': 49.84515765765766, 'loss': 5.982227802276611}


EP_train:1:   7%|| 121/1732 [00:30<06:48,  3.95it/s]

{'epoch': 1, 'iter': 120, 'avg_loss': 6.038712371479381, 'avg_acc': 49.767561983471076, 'loss': 6.115499496459961}


EP_train:1:   8%|| 131/1732 [00:32<06:49,  3.91it/s]

{'epoch': 1, 'iter': 130, 'avg_loss': 6.035680144797754, 'avg_acc': 49.844942748091604, 'loss': 6.156992435455322}


EP_train:1:   8%|| 141/1732 [00:35<06:51,  3.87it/s]

{'epoch': 1, 'iter': 140, 'avg_loss': 6.037677064855048, 'avg_acc': 49.85593971631206, 'loss': 6.224249839782715}


EP_train:1:   9%|| 151/1732 [00:37<06:25,  4.10it/s]

{'epoch': 1, 'iter': 150, 'avg_loss': 6.040375160065707, 'avg_acc': 49.9948261589404, 'loss': 5.974771499633789}


EP_train:1:   9%|| 161/1732 [00:40<06:28,  4.04it/s]

{'epoch': 1, 'iter': 160, 'avg_loss': 6.040303979601179, 'avg_acc': 50.09219720496895, 'loss': 5.925529479980469}


EP_train:1:  10%|| 171/1732 [00:42<06:20,  4.11it/s]

{'epoch': 1, 'iter': 170, 'avg_loss': 6.041418532879032, 'avg_acc': 50.19645467836257, 'loss': 6.0146894454956055}


EP_train:1:  10%|| 181/1732 [00:45<06:43,  3.85it/s]

{'epoch': 1, 'iter': 180, 'avg_loss': 6.037519175703355, 'avg_acc': 50.0, 'loss': 5.906769752502441}


EP_train:1:  11%|| 191/1732 [00:47<06:33,  3.91it/s]

{'epoch': 1, 'iter': 190, 'avg_loss': 6.0384052785903375, 'avg_acc': 49.93046465968586, 'loss': 6.1203107833862305}


EP_train:1:  12%|| 201/1732 [00:50<06:24,  3.98it/s]

{'epoch': 1, 'iter': 200, 'avg_loss': 6.037369630823088, 'avg_acc': 50.019434079601986, 'loss': 5.916537284851074}


EP_train:1:  12%|| 211/1732 [00:52<06:20,  4.00it/s]

{'epoch': 1, 'iter': 210, 'avg_loss': 6.036508603118607, 'avg_acc': 49.96667654028436, 'loss': 6.0396599769592285}


EP_train:1:  13%|| 221/1732 [00:55<06:22,  3.95it/s]

{'epoch': 1, 'iter': 220, 'avg_loss': 6.033929097706376, 'avg_acc': 49.95050904977376, 'loss': 5.985090255737305}


EP_train:1:  13%|| 231/1732 [00:57<06:24,  3.90it/s]

{'epoch': 1, 'iter': 230, 'avg_loss': 6.036081448261872, 'avg_acc': 50.02367424242424, 'loss': 6.114626884460449}


EP_train:1:  14%|| 241/1732 [01:00<06:09,  4.04it/s]

{'epoch': 1, 'iter': 240, 'avg_loss': 6.033251790090221, 'avg_acc': 49.970824688796675, 'loss': 6.014585018157959}


EP_train:1:  14%|| 251/1732 [01:02<05:54,  4.18it/s]

{'epoch': 1, 'iter': 250, 'avg_loss': 6.034671992419725, 'avg_acc': 49.97509960159363, 'loss': 5.838320732116699}


EP_train:1:  15%|| 261/1732 [01:05<06:13,  3.94it/s]

{'epoch': 1, 'iter': 260, 'avg_loss': 6.033844370494857, 'avg_acc': 49.97605363984675, 'loss': 6.082959175109863}


EP_train:1:  16%|| 271/1732 [01:07<06:13,  3.92it/s]

{'epoch': 1, 'iter': 270, 'avg_loss': 6.033593265772746, 'avg_acc': 49.94810885608857, 'loss': 6.039376258850098}


EP_train:1:  16%|| 281/1732 [01:10<05:58,  4.05it/s]

{'epoch': 1, 'iter': 280, 'avg_loss': 6.034744722563177, 'avg_acc': 49.95273576512456, 'loss': 5.865879058837891}


EP_train:1:  17%|| 291/1732 [01:12<05:58,  4.02it/s]

{'epoch': 1, 'iter': 290, 'avg_loss': 6.036499731319466, 'avg_acc': 49.922143470790374, 'loss': 6.1734089851379395}


EP_train:1:  17%|| 301/1732 [01:15<06:01,  3.96it/s]

{'epoch': 1, 'iter': 300, 'avg_loss': 6.034835449484891, 'avg_acc': 49.870224252491695, 'loss': 5.932614326477051}


EP_train:1:  18%|| 311/1732 [01:17<05:51,  4.04it/s]

{'epoch': 1, 'iter': 310, 'avg_loss': 6.036176341139619, 'avg_acc': 49.85681270096463, 'loss': 6.000868320465088}


EP_train:1:  19%|| 321/1732 [01:20<05:54,  3.98it/s]

{'epoch': 1, 'iter': 320, 'avg_loss': 6.036977825877822, 'avg_acc': 49.91238317757009, 'loss': 5.953888416290283}


EP_train:1:  19%|| 331/1732 [01:22<05:56,  3.93it/s]

{'epoch': 1, 'iter': 330, 'avg_loss': 6.037438320609378, 'avg_acc': 50.00472054380665, 'loss': 5.886790752410889}


EP_train:1:  20%|| 341/1732 [01:25<05:45,  4.03it/s]

{'epoch': 1, 'iter': 340, 'avg_loss': 6.035425636425745, 'avg_acc': 49.95646994134898, 'loss': 5.871550559997559}


EP_train:1:  20%|| 351/1732 [01:27<05:34,  4.13it/s]

{'epoch': 1, 'iter': 350, 'avg_loss': 6.035702755648186, 'avg_acc': 49.933226495726494, 'loss': 6.173605442047119}


EP_train:1:  21%|| 361/1732 [01:30<05:44,  3.98it/s]

{'epoch': 1, 'iter': 360, 'avg_loss': 6.034999161876139, 'avg_acc': 49.941568559556785, 'loss': 6.1601409912109375}


EP_train:1:  21%|| 371/1732 [01:32<05:36,  4.04it/s]

{'epoch': 1, 'iter': 370, 'avg_loss': 6.033570235630251, 'avg_acc': 50.031586927223714, 'loss': 5.7998127937316895}


EP_train:1:  22%|| 381/1732 [01:35<05:40,  3.97it/s]

{'epoch': 1, 'iter': 380, 'avg_loss': 6.033677422781316, 'avg_acc': 49.97744422572178, 'loss': 6.0194501876831055}


EP_train:1:  23%|| 391/1732 [01:37<05:39,  3.95it/s]

{'epoch': 1, 'iter': 390, 'avg_loss': 6.0324642883847135, 'avg_acc': 49.95604219948849, 'loss': 5.978794574737549}


EP_train:1:  23%|| 401/1732 [01:40<05:33,  3.99it/s]

{'epoch': 1, 'iter': 400, 'avg_loss': 6.0335061068546745, 'avg_acc': 49.97856920199501, 'loss': 6.0435791015625}


EP_train:1:  24%|| 411/1732 [01:42<05:37,  3.92it/s]

{'epoch': 1, 'iter': 410, 'avg_loss': 6.035512071456353, 'avg_acc': 49.97909063260341, 'loss': 6.256180286407471}


EP_train:1:  24%|| 421/1732 [01:45<05:41,  3.84it/s]

{'epoch': 1, 'iter': 420, 'avg_loss': 6.036213336817726, 'avg_acc': 50.02783551068883, 'loss': 6.061382293701172}


EP_train:1:  25%|| 431/1732 [01:47<05:21,  4.04it/s]

{'epoch': 1, 'iter': 430, 'avg_loss': 6.036569969283455, 'avg_acc': 49.969185034802784, 'loss': 6.003904819488525}


EP_train:1:  25%|| 441/1732 [01:50<05:30,  3.91it/s]

{'epoch': 1, 'iter': 440, 'avg_loss': 6.03588580112068, 'avg_acc': 49.96811224489796, 'loss': 6.067135334014893}


EP_train:1:  26%|| 451/1732 [01:52<05:30,  3.88it/s]

{'epoch': 1, 'iter': 450, 'avg_loss': 6.035889805288378, 'avg_acc': 49.96189024390244, 'loss': 6.123373985290527}


EP_train:1:  27%|| 461/1732 [01:55<05:20,  3.96it/s]

{'epoch': 1, 'iter': 460, 'avg_loss': 6.035534264984462, 'avg_acc': 49.99322125813449, 'loss': 6.106454849243164}


EP_train:1:  27%|| 471/1732 [01:58<05:20,  3.93it/s]

{'epoch': 1, 'iter': 470, 'avg_loss': 6.036544502161111, 'avg_acc': 49.98673036093418, 'loss': 5.941934108734131}


EP_train:1:  28%|| 481/1732 [02:00<05:27,  3.82it/s]

{'epoch': 1, 'iter': 480, 'avg_loss': 6.034648420409204, 'avg_acc': 49.95614604989605, 'loss': 5.8506178855896}


EP_train:1:  28%|| 491/1732 [02:03<05:11,  3.98it/s]

{'epoch': 1, 'iter': 490, 'avg_loss': 6.034935995906046, 'avg_acc': 49.92998981670061, 'loss': 6.019662380218506}


EP_train:1:  29%|| 501/1732 [02:05<05:12,  3.94it/s]

{'epoch': 1, 'iter': 500, 'avg_loss': 6.033513064393978, 'avg_acc': 49.90955588822355, 'loss': 5.882952690124512}


EP_train:1:  30%|| 511/1732 [02:07<04:56,  4.11it/s]

{'epoch': 1, 'iter': 510, 'avg_loss': 6.032675563939163, 'avg_acc': 49.882277397260275, 'loss': 5.812644004821777}


EP_train:1:  30%|| 521/1732 [02:10<05:10,  3.90it/s]

{'epoch': 1, 'iter': 520, 'avg_loss': 6.033132821264285, 'avg_acc': 49.899532149712094, 'loss': 6.099174499511719}


EP_train:1:  31%|| 531/1732 [02:13<05:01,  3.98it/s]

{'epoch': 1, 'iter': 530, 'avg_loss': 6.03348179635804, 'avg_acc': 49.852871939736346, 'loss': 6.150146007537842}


EP_train:1:  31%|| 541/1732 [02:15<04:57,  4.01it/s]

{'epoch': 1, 'iter': 540, 'avg_loss': 6.033675841615293, 'avg_acc': 49.822377541589645, 'loss': 6.013599872589111}


EP_train:1:  32%|| 551/1732 [02:18<05:16,  3.73it/s]

{'epoch': 1, 'iter': 550, 'avg_loss': 6.033916713105356, 'avg_acc': 49.8114224137931, 'loss': 6.193902492523193}


EP_train:1:  32%|| 561/1732 [02:20<05:07,  3.80it/s]

{'epoch': 1, 'iter': 560, 'avg_loss': 6.034191616078749, 'avg_acc': 49.8147838680927, 'loss': 6.1384382247924805}


EP_train:1:  33%|| 571/1732 [02:23<05:08,  3.77it/s]

{'epoch': 1, 'iter': 570, 'avg_loss': 6.034772616551762, 'avg_acc': 49.83855078809107, 'loss': 6.033595561981201}


EP_train:1:  34%|| 581/1732 [02:25<04:54,  3.91it/s]

{'epoch': 1, 'iter': 580, 'avg_loss': 6.035691063498466, 'avg_acc': 49.8171256454389, 'loss': 6.229894161224365}


EP_train:1:  34%|| 588/1732 [02:27<04:47,  3.98it/s]


KeyboardInterrupt: 