In [1]:
# !pip install ipywidgets  # for vscode
# !pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
import pandas as pd
import json

#transformers
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup
from transformers import BertModel
from transformers import AutoTokenizer, AutoModel

from sklearn.utils.class_weight import compute_class_weight

In [3]:
tokenizer = AutoTokenizer.from_pretrained("monologg/kobert")
model = AutoModel.from_pretrained("monologg/kobert")

In [4]:
class sentencEmojiDataset(Dataset):
    def __init__(self, directory, tokenizer):
        
        data = pd.read_csv(directory, encoding='UTF-8')
        
        self.tokenizer = tokenizer
        self.sentences = list(data.iloc[:,0])
       
        emojis = list(data.iloc[:,1])
        emojis_unique = list(set(emojis))
        
        self.labels = [emojis_unique.index(i) for i in emojis]
 
        self.labels_dict = {'key': range(len(emojis_unique)), 'value': emojis_unique}
        
    def __getitem__(self, i): #collate 이전 미리 tokenize를 시켜주자
        tokenized = self.tokenizer(str(self.sentences[i]), return_tensors='pt')
        #아래 세 개는 tokenizer가 기본적으로 반환하는 정보. BERT의 input이기도 함
        input_ids = tokenized['input_ids']
        token_type_ids = tokenized['token_type_ids']
        attention_mask = tokenized['attention_mask']
        
        return {'input_ids': input_ids, 'token_type_ids': token_type_ids, 
                'attention_mask': attention_mask, 'label': self.labels[i]}
         
    def __len__(self): #data loader가 필요로 하여 필수적으로 있어야 하는 함수
        return len(self.sentences)

In [5]:
class collate_fn:
    def __init__(self, labels_dict):
        self.num_labels = len(labels_dict)
        
    def __call__(self, batch): #batch는 dataset.getitem의 return 값의 List. eg. [{}, {}. ...]
        #batch내 최대 문장 길이(토큰 개수)를 먼저 구해서 padding할 수 있도록 하기
        batchlen = [sample['input_ids'].size(1) for sample in batch] #tensor값을 반환하기 때문에 1번째 차원의 길이를 구함
        maxlen = max(batchlen)
        input_ids = []
        token_type_ids = []
        attention_mask = []
        #padding: [5, 6] [0, 0,  ...]을 concatenate 하는 방식으로 패딩
        for sample in batch:
            pad_len = maxlen - sample['input_ids'].size(1)
            pad = torch.zeros((1, pad_len), dtype=torch.int)            
            input_ids.append(torch.cat([sample['input_ids'], pad], dim=1))
            token_type_ids.append(torch.cat([sample['token_type_ids'], pad], dim=1))
            attention_mask.append(torch.cat([sample['attention_mask'], pad], dim=1))
        #batch 구성
        input_ids = torch.cat(input_ids, dim=0)
        token_type_ids = torch.cat(token_type_ids, dim=0)
        attention_mask = torch.cat(attention_mask, dim=0)
        
        #one-hot encoding
        #batch 내 라벨을 tensor로 변환
        tensor_label = torch.tensor([sample['label'] for sample in batch])
        
        return input_ids, token_type_ids, attention_mask, tensor_label

In [6]:
df = pd.read_csv('data/twitter_clean.csv', encoding="UTF-8")
print(len(df['y'].value_counts()))
df['y'].value_counts(sort = True).head(10)

30


N    235
😂     80
❤     53
😭     49
😍     32
💕     28
💜     26
🔥     25
👍     24
🥺     24
Name: y, dtype: int64

In [7]:
df['split'] = np.random.randn(df.shape[0], 1)
msk = np.random.rand(len(df)) <= 0.7

train = df[msk]
test = df[~msk]

train.to_csv('data/train.csv', index=False)
test.to_csv('data/test.csv', index=False)

In [8]:
train = sentencEmojiDataset('data/train.csv', tokenizer)
test = sentencEmojiDataset('data/test.csv', tokenizer)

train_collate_fn = collate_fn(train.labels_dict)
test_collate_fn = collate_fn(test.labels_dict)

train_collate_fn

<__main__.collate_fn at 0x2521223ddd8>

In [9]:
# Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 20  
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-3

In [10]:
train_dataloader = DataLoader(train, batch_size=batch_size, collate_fn=train_collate_fn, shuffle = True, drop_last = True)
test_dataloader = DataLoader(test, batch_size=batch_size, collate_fn=test_collate_fn, shuffle = False, drop_last = False)
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x2523e5368d0>

In [11]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        # do not train bert parameters
        for p in self.bert.parameters():
            p.requires_grad = False
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)

    def forward(self, input_ids, token_type_ids, attention_mask):
        #eval: drop out 중지, batch norm 고정과 같이 evaluation으로 모델 변경
        self.bert.eval()
        #gradient 계산을 중지
        with torch.no_grad():
            x = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
#         x = self.dropout(pooler)
        return self.classifier(x)

In [12]:
label = list(set(list(df.iloc[:,1])))

In [13]:
model = BERTClassifier(model,  dr_rate=0.5, num_classes = len(label))

In [14]:
#optimizer와 schedule 설정
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [15]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
# optimizer = AdamW(model.parameters(), lr=learning_rate)

In [16]:
np.unique(train.labels)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])

In [17]:
train.labels[:20]

[12, 3, 12, 12, 12, 12, 19, 5, 3, 12, 14, 12, 3, 9, 12, 3, 12, 16, 14, 12]

In [18]:
#Class Imbalance 문제 해결을 위한 weighted cross entropy 
class_weights = compute_class_weight(class_weight = 'balanced', classes = np.unique(train.labels), y = train.labels)
class_weights = torch.tensor(class_weights, dtype=torch.float)
print(class_weights) #([1.0000, 1.0000, 4.0000, 1.0000, 0.5714])

tensor([1.4125, 1.4125, 2.8250, 0.3767, 2.0545, 1.4125, 1.6143, 3.2286, 1.1895,
        1.7385, 1.5067, 1.5067, 0.1361, 2.0545, 0.6278, 1.3294, 1.3294, 1.1895,
        0.6848, 1.2556, 1.1300, 1.0762, 1.8833, 1.5067, 1.7385, 1.4125, 1.4125,
        1.8833, 1.5067, 2.0545])


In [19]:
train.labels_dict

{'key': range(0, 30),
 'value': ['✨',
  '💙',
  '🐹',
  '😂',
  '🤤',
  '🥰',
  '😊',
  '✌',
  '🙏',
  '😁',
  '💖',
  '❣',
  'N',
  '👉',
  '❤',
  '💕',
  '🔥',
  '👍',
  '😭',
  '🥺',
  '😍',
  '💜',
  '💯',
  '🤭',
  '💚',
  '💗',
  '😎',
  '😘',
  '😆',
  '😑']}

In [20]:
loss_fn = nn.CrossEntropyLoss(weight = class_weights, reduction = 'mean') 

In [21]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [22]:
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [23]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [24]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    loss_sum = 0
    model.train()
    for batch_id, (input_ids, token_type_ids, attention_mask, tensor_label) in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        out = model(input_ids, token_type_ids, attention_mask)
        loss = loss_fn(out, tensor_label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
#         scheduler.step()  # Update learning rate schedule
        batch_acc = calc_accuracy(out, tensor_label)
        train_acc += batch_acc
        loss_sum += loss.data.cpu().numpy()
        #f batch_id % log_interval == 0:
        print("epoch {} batch id {}/{} loss {} train acc {}".format(e+1, batch_id+1, len(train_dataloader), loss.data.cpu().numpy(), batch_acc))
    print("epoch {} train acc {} loss mean {}".format(e+1, train_acc / (batch_id+1), loss_sum / len(train_dataloader)))
    model.eval()
    with torch.no_grad():
        for batch_id, (input_ids, token_type_ids, attention_mask, tensor_label) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
            out = model(input_ids, token_type_ids, attention_mask)
            test_acc += calc_accuracy(out, tensor_label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

epoch 1 batch id 1/10 loss 3.4474871158599854 train acc 0.0
epoch 1 batch id 2/10 loss 3.432135581970215 train acc 0.046875
epoch 1 batch id 3/10 loss 3.4682517051696777 train acc 0.0
epoch 1 batch id 4/10 loss 3.402585983276367 train acc 0.046875
epoch 1 batch id 5/10 loss 3.46115779876709 train acc 0.03125
epoch 1 batch id 6/10 loss 3.402958393096924 train acc 0.09375
epoch 1 batch id 7/10 loss 3.4679598808288574 train acc 0.015625
epoch 1 batch id 8/10 loss 3.4406986236572266 train acc 0.015625
epoch 1 batch id 9/10 loss 3.43851375579834 train acc 0.046875
epoch 1 batch id 10/10 loss 3.4229142665863037 train acc 0.0
epoch 1 train acc 0.0296875 loss mean 3.4384663105010986


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.25s/it]


epoch 1 test acc 0.021875
epoch 2 batch id 1/10 loss 3.4318037033081055 train acc 0.015625
epoch 2 batch id 2/10 loss 3.410374164581299 train acc 0.015625
epoch 2 batch id 3/10 loss 3.4567105770111084 train acc 0.046875
epoch 2 batch id 4/10 loss 3.468426465988159 train acc 0.0625
epoch 2 batch id 5/10 loss 3.4304919242858887 train acc 0.03125
epoch 2 batch id 6/10 loss 3.4689078330993652 train acc 0.015625
epoch 2 batch id 7/10 loss 3.449918270111084 train acc 0.046875
epoch 2 batch id 8/10 loss 3.4582951068878174 train acc 0.03125
epoch 2 batch id 9/10 loss 3.437692880630493 train acc 0.03125
epoch 2 batch id 10/10 loss 3.406301736831665 train acc 0.03125
epoch 2 train acc 0.0328125 loss mean 3.4418922662734985


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.56s/it]


epoch 2 test acc 0.021875
epoch 3 batch id 1/10 loss 3.4071309566497803 train acc 0.015625
epoch 3 batch id 2/10 loss 3.356698513031006 train acc 0.015625
epoch 3 batch id 3/10 loss 3.414775848388672 train acc 0.0625
epoch 3 batch id 4/10 loss 3.523465633392334 train acc 0.03125
epoch 3 batch id 5/10 loss 3.525902032852173 train acc 0.0
epoch 3 batch id 6/10 loss 3.4190151691436768 train acc 0.078125
epoch 3 batch id 7/10 loss 3.4264016151428223 train acc 0.015625
epoch 3 batch id 8/10 loss 3.4083025455474854 train acc 0.03125
epoch 3 batch id 9/10 loss 3.4248745441436768 train acc 0.03125
epoch 3 batch id 10/10 loss 3.4372379779815674 train acc 0.03125
epoch 3 train acc 0.03125 loss mean 3.4343804836273195


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.26s/it]


epoch 3 test acc 0.021875
epoch 4 batch id 1/10 loss 3.4032931327819824 train acc 0.046875
epoch 4 batch id 2/10 loss 3.4954307079315186 train acc 0.015625
epoch 4 batch id 3/10 loss 3.37654447555542 train acc 0.078125
epoch 4 batch id 4/10 loss 3.5007572174072266 train acc 0.015625
epoch 4 batch id 5/10 loss 3.4747579097747803 train acc 0.015625
epoch 4 batch id 6/10 loss 3.4277358055114746 train acc 0.015625
epoch 4 batch id 7/10 loss 3.428326368331909 train acc 0.015625
epoch 4 batch id 8/10 loss 3.4422998428344727 train acc 0.015625
epoch 4 batch id 9/10 loss 3.4326870441436768 train acc 0.03125
epoch 4 batch id 10/10 loss 3.477846384048462 train acc 0.046875
epoch 4 train acc 0.0296875 loss mean 3.4459678888320924


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.25s/it]


epoch 4 test acc 0.021875
epoch 5 batch id 1/10 loss 3.4289050102233887 train acc 0.0
epoch 5 batch id 2/10 loss 3.460556745529175 train acc 0.03125
epoch 5 batch id 3/10 loss 3.4094104766845703 train acc 0.03125
epoch 5 batch id 4/10 loss 3.4069902896881104 train acc 0.015625
epoch 5 batch id 5/10 loss 3.4637951850891113 train acc 0.046875
epoch 5 batch id 6/10 loss 3.435253381729126 train acc 0.03125
epoch 5 batch id 7/10 loss 3.4508233070373535 train acc 0.03125
epoch 5 batch id 8/10 loss 3.4659371376037598 train acc 0.09375
epoch 5 batch id 9/10 loss 3.4560434818267822 train acc 0.03125
epoch 5 batch id 10/10 loss 3.4485247135162354 train acc 0.015625
epoch 5 train acc 0.0328125 loss mean 3.4426239728927612


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.26s/it]


epoch 5 test acc 0.021875
epoch 6 batch id 1/10 loss 3.3941690921783447 train acc 0.0625
epoch 6 batch id 2/10 loss 3.393348217010498 train acc 0.0625
epoch 6 batch id 3/10 loss 3.433828830718994 train acc 0.046875
epoch 6 batch id 4/10 loss 3.399130344390869 train acc 0.0
epoch 6 batch id 5/10 loss 3.4703218936920166 train acc 0.015625
epoch 6 batch id 6/10 loss 3.4520411491394043 train acc 0.0
epoch 6 batch id 7/10 loss 3.4359452724456787 train acc 0.015625
epoch 6 batch id 8/10 loss 3.474165678024292 train acc 0.0625
epoch 6 batch id 9/10 loss 3.43235182762146 train acc 0.0625
epoch 6 batch id 10/10 loss 3.501999855041504 train acc 0.0
epoch 6 train acc 0.0328125 loss mean 3.4387302160263062


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.42s/it]


epoch 6 test acc 0.021875
epoch 7 batch id 1/10 loss 3.4459872245788574 train acc 0.046875
epoch 7 batch id 2/10 loss 3.4314937591552734 train acc 0.015625
epoch 7 batch id 3/10 loss 3.449049472808838 train acc 0.03125
epoch 7 batch id 4/10 loss 3.4327011108398438 train acc 0.03125
epoch 7 batch id 5/10 loss 3.2939138412475586 train acc 0.046875
epoch 7 batch id 6/10 loss 3.4116458892822266 train acc 0.0
epoch 7 batch id 7/10 loss 3.5045902729034424 train acc 0.03125
epoch 7 batch id 8/10 loss 3.470829963684082 train acc 0.03125
epoch 7 batch id 9/10 loss 3.449956178665161 train acc 0.015625
epoch 7 batch id 10/10 loss 3.4839940071105957 train acc 0.0625
epoch 7 train acc 0.03125 loss mean 3.437416172027588


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.60s/it]


epoch 7 test acc 0.021875
epoch 8 batch id 1/10 loss 3.4305288791656494 train acc 0.03125
epoch 8 batch id 2/10 loss 3.413269281387329 train acc 0.078125
epoch 8 batch id 3/10 loss 3.4119839668273926 train acc 0.046875
epoch 8 batch id 4/10 loss 3.487715005874634 train acc 0.015625
epoch 8 batch id 5/10 loss 3.4536354541778564 train acc 0.0
epoch 8 batch id 6/10 loss 3.4127960205078125 train acc 0.046875
epoch 8 batch id 7/10 loss 3.446495771408081 train acc 0.03125
epoch 8 batch id 8/10 loss 3.4273464679718018 train acc 0.03125
epoch 8 batch id 9/10 loss 3.437107801437378 train acc 0.03125
epoch 8 batch id 10/10 loss 3.4712677001953125 train acc 0.0
epoch 8 train acc 0.03125 loss mean 3.4392146348953245


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.57s/it]


epoch 8 test acc 0.021875
epoch 9 batch id 1/10 loss 3.404717206954956 train acc 0.0
epoch 9 batch id 2/10 loss 3.483940601348877 train acc 0.03125
epoch 9 batch id 3/10 loss 3.4061059951782227 train acc 0.09375
epoch 9 batch id 4/10 loss 3.4476444721221924 train acc 0.015625
epoch 9 batch id 5/10 loss 3.4086928367614746 train acc 0.03125
epoch 9 batch id 6/10 loss 3.422592878341675 train acc 0.03125
epoch 9 batch id 7/10 loss 3.4889309406280518 train acc 0.015625
epoch 9 batch id 8/10 loss 3.397414207458496 train acc 0.03125
epoch 9 batch id 9/10 loss 3.4219415187835693 train acc 0.078125
epoch 9 batch id 10/10 loss 3.489124298095703 train acc 0.0
epoch 9 train acc 0.0328125 loss mean 3.437110495567322


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.31s/it]


epoch 9 test acc 0.021875
epoch 10 batch id 1/10 loss 3.358760118484497 train acc 0.015625
epoch 10 batch id 2/10 loss 3.4491844177246094 train acc 0.03125
epoch 10 batch id 3/10 loss 3.482325553894043 train acc 0.015625
epoch 10 batch id 4/10 loss 3.4741363525390625 train acc 0.0
epoch 10 batch id 5/10 loss 3.428377151489258 train acc 0.015625
epoch 10 batch id 6/10 loss 3.4620094299316406 train acc 0.03125
epoch 10 batch id 7/10 loss 3.43325138092041 train acc 0.046875
epoch 10 batch id 8/10 loss 3.406557559967041 train acc 0.109375
epoch 10 batch id 9/10 loss 3.4568490982055664 train acc 0.015625
epoch 10 batch id 10/10 loss 3.457197904586792 train acc 0.046875
epoch 10 train acc 0.0328125 loss mean 3.440864896774292


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.29s/it]


epoch 10 test acc 0.021875
epoch 11 batch id 1/10 loss 3.3790838718414307 train acc 0.03125
epoch 11 batch id 2/10 loss 3.4343838691711426 train acc 0.015625
epoch 11 batch id 3/10 loss 3.4959278106689453 train acc 0.0
epoch 11 batch id 4/10 loss 3.4279093742370605 train acc 0.046875
epoch 11 batch id 5/10 loss 3.4928133487701416 train acc 0.0625
epoch 11 batch id 6/10 loss 3.4173834323883057 train acc 0.03125
epoch 11 batch id 7/10 loss 3.433011054992676 train acc 0.046875
epoch 11 batch id 8/10 loss 3.473184585571289 train acc 0.03125
epoch 11 batch id 9/10 loss 3.4321813583374023 train acc 0.015625
epoch 11 batch id 10/10 loss 3.4340620040893555 train acc 0.03125
epoch 11 train acc 0.03125 loss mean 3.441994071006775


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.45s/it]


epoch 11 test acc 0.021875
epoch 12 batch id 1/10 loss 3.3819220066070557 train acc 0.03125
epoch 12 batch id 2/10 loss 3.467395067214966 train acc 0.015625
epoch 12 batch id 3/10 loss 3.4469618797302246 train acc 0.03125
epoch 12 batch id 4/10 loss 3.4211552143096924 train acc 0.0
epoch 12 batch id 5/10 loss 3.5153722763061523 train acc 0.03125
epoch 12 batch id 6/10 loss 3.453728675842285 train acc 0.03125
epoch 12 batch id 7/10 loss 3.3999712467193604 train acc 0.015625
epoch 12 batch id 8/10 loss 3.399916887283325 train acc 0.09375
epoch 12 batch id 9/10 loss 3.479884147644043 train acc 0.046875
epoch 12 batch id 10/10 loss 3.4142298698425293 train acc 0.015625
epoch 12 train acc 0.03125 loss mean 3.4380537271499634


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.42s/it]


epoch 12 test acc 0.021875
epoch 13 batch id 1/10 loss 3.4167866706848145 train acc 0.03125
epoch 13 batch id 2/10 loss 3.422137498855591 train acc 0.0625
epoch 13 batch id 3/10 loss 3.407938241958618 train acc 0.015625
epoch 13 batch id 4/10 loss 3.448625326156616 train acc 0.015625
epoch 13 batch id 5/10 loss 3.488831043243408 train acc 0.015625
epoch 13 batch id 6/10 loss 3.461484432220459 train acc 0.015625
epoch 13 batch id 7/10 loss 3.4149794578552246 train acc 0.0
epoch 13 batch id 8/10 loss 3.440863847732544 train acc 0.046875
epoch 13 batch id 9/10 loss 3.4340004920959473 train acc 0.0625
epoch 13 batch id 10/10 loss 3.477861166000366 train acc 0.03125
epoch 13 train acc 0.0296875 loss mean 3.441350817680359


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.35s/it]


epoch 13 test acc 0.021875
epoch 14 batch id 1/10 loss 3.4293272495269775 train acc 0.03125
epoch 14 batch id 2/10 loss 3.3776559829711914 train acc 0.03125
epoch 14 batch id 3/10 loss 3.4770469665527344 train acc 0.03125
epoch 14 batch id 4/10 loss 3.4068803787231445 train acc 0.078125
epoch 14 batch id 5/10 loss 3.4359946250915527 train acc 0.03125
epoch 14 batch id 6/10 loss 3.420745372772217 train acc 0.03125
epoch 14 batch id 7/10 loss 3.505894184112549 train acc 0.0
epoch 14 batch id 8/10 loss 3.4479596614837646 train acc 0.0
epoch 14 batch id 9/10 loss 3.4102373123168945 train acc 0.015625
epoch 14 batch id 10/10 loss 3.474088191986084 train acc 0.0625
epoch 14 train acc 0.03125 loss mean 3.438582992553711


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.31s/it]


epoch 14 test acc 0.021875
epoch 15 batch id 1/10 loss 3.457831382751465 train acc 0.0
epoch 15 batch id 2/10 loss 3.4266247749328613 train acc 0.0625
epoch 15 batch id 3/10 loss 3.4370646476745605 train acc 0.078125
epoch 15 batch id 4/10 loss 3.4649622440338135 train acc 0.0
epoch 15 batch id 5/10 loss 3.482205867767334 train acc 0.015625
epoch 15 batch id 6/10 loss 3.4365971088409424 train acc 0.046875
epoch 15 batch id 7/10 loss 3.462501049041748 train acc 0.03125
epoch 15 batch id 8/10 loss 3.420283555984497 train acc 0.046875
epoch 15 batch id 9/10 loss 3.4007177352905273 train acc 0.0
epoch 15 batch id 10/10 loss 3.4364709854125977 train acc 0.015625
epoch 15 train acc 0.0296875 loss mean 3.442525935173035


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:13<00:00,  2.62s/it]


epoch 15 test acc 0.021875
epoch 16 batch id 1/10 loss 3.4131081104278564 train acc 0.046875
epoch 16 batch id 2/10 loss 3.4064598083496094 train acc 0.046875
epoch 16 batch id 3/10 loss 3.3983700275421143 train acc 0.0625
epoch 16 batch id 4/10 loss 3.4764719009399414 train acc 0.015625
epoch 16 batch id 5/10 loss 3.436828136444092 train acc 0.03125
epoch 16 batch id 6/10 loss 3.4477643966674805 train acc 0.0
epoch 16 batch id 7/10 loss 3.440430164337158 train acc 0.03125
epoch 16 batch id 8/10 loss 3.423748254776001 train acc 0.078125
epoch 16 batch id 9/10 loss 3.4627881050109863 train acc 0.015625
epoch 16 batch id 10/10 loss 3.459860324859619 train acc 0.0
epoch 16 train acc 0.0328125 loss mean 3.436582922935486


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:13<00:00,  2.66s/it]


epoch 16 test acc 0.021875
epoch 17 batch id 1/10 loss 3.406405448913574 train acc 0.03125
epoch 17 batch id 2/10 loss 3.440988779067993 train acc 0.03125
epoch 17 batch id 3/10 loss 3.4944353103637695 train acc 0.015625
epoch 17 batch id 4/10 loss 3.390575647354126 train acc 0.078125
epoch 17 batch id 5/10 loss 3.43420672416687 train acc 0.03125
epoch 17 batch id 6/10 loss 3.445481061935425 train acc 0.015625
epoch 17 batch id 7/10 loss 3.4377050399780273 train acc 0.03125
epoch 17 batch id 8/10 loss 3.398327350616455 train acc 0.03125
epoch 17 batch id 9/10 loss 3.4273297786712646 train acc 0.0
epoch 17 batch id 10/10 loss 3.492462396621704 train acc 0.015625
epoch 17 train acc 0.028125 loss mean 3.436791753768921


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.59s/it]


epoch 17 test acc 0.021875
epoch 18 batch id 1/10 loss 3.409731149673462 train acc 0.046875
epoch 18 batch id 2/10 loss 3.4809210300445557 train acc 0.015625
epoch 18 batch id 3/10 loss 3.431920051574707 train acc 0.03125
epoch 18 batch id 4/10 loss 3.4941418170928955 train acc 0.03125
epoch 18 batch id 5/10 loss 3.377577543258667 train acc 0.046875
epoch 18 batch id 6/10 loss 3.44870662689209 train acc 0.015625
epoch 18 batch id 7/10 loss 3.4155759811401367 train acc 0.03125
epoch 18 batch id 8/10 loss 3.464240074157715 train acc 0.046875
epoch 18 batch id 9/10 loss 3.479848861694336 train acc 0.015625
epoch 18 batch id 10/10 loss 3.4314427375793457 train acc 0.015625
epoch 18 train acc 0.0296875 loss mean 3.443410587310791


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.54s/it]


epoch 18 test acc 0.021875
epoch 19 batch id 1/10 loss 3.426352024078369 train acc 0.015625
epoch 19 batch id 2/10 loss 3.4677891731262207 train acc 0.015625
epoch 19 batch id 3/10 loss 3.4131627082824707 train acc 0.0625
epoch 19 batch id 4/10 loss 3.4699504375457764 train acc 0.046875
epoch 19 batch id 5/10 loss 3.4094960689544678 train acc 0.03125
epoch 19 batch id 6/10 loss 3.40065598487854 train acc 0.015625
epoch 19 batch id 7/10 loss 3.4682583808898926 train acc 0.03125
epoch 19 batch id 8/10 loss 3.431318998336792 train acc 0.046875
epoch 19 batch id 9/10 loss 3.4944841861724854 train acc 0.0
epoch 19 batch id 10/10 loss 3.356519937515259 train acc 0.0625
epoch 19 train acc 0.0328125 loss mean 3.4337987899780273


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.37s/it]


epoch 19 test acc 0.021875
epoch 20 batch id 1/10 loss 3.4685728549957275 train acc 0.03125
epoch 20 batch id 2/10 loss 3.4762003421783447 train acc 0.0
epoch 20 batch id 3/10 loss 3.3904988765716553 train acc 0.015625
epoch 20 batch id 4/10 loss 3.442117929458618 train acc 0.046875
epoch 20 batch id 5/10 loss 3.4253952503204346 train acc 0.015625
epoch 20 batch id 6/10 loss 3.4351093769073486 train acc 0.0625
epoch 20 batch id 7/10 loss 3.4595260620117188 train acc 0.0625
epoch 20 batch id 8/10 loss 3.475236654281616 train acc 0.0
epoch 20 batch id 9/10 loss 3.4562385082244873 train acc 0.0
epoch 20 batch id 10/10 loss 3.406052589416504 train acc 0.03125
epoch 20 train acc 0.0265625 loss mean 3.4434948444366453


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.32s/it]

epoch 20 test acc 0.021875



