In [1]:
# !pip install ipywidgets  # for vscode
# !pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
import pandas as pd
import json

#transformers
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup
from transformers import BertModel
from transformers import AutoTokenizer, AutoModel
from transformers import ElectraModel, ElectraTokenizer

from sklearn.utils.class_weight import compute_class_weight

In [3]:
#KoBERT
# tokenizer = AutoTokenizer.from_pretrained("monologg/kobert")
# model = AutoModel.from_pretrained("monologg/kobert")

# KoELECTRA-Base
tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-base-discriminator")
model = ElectraModel.from_pretrained("monologg/koelectra-base-discriminator")

Some weights of the model checkpoint at monologg/koelectra-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
class sentencEmojiDataset(Dataset):
    def __init__(self, directory, tokenizer):
        
        data = pd.read_csv(directory, encoding='UTF-8')
        
        self.tokenizer = tokenizer
        self.sentences = list(data.iloc[:,0])
       
        emojis = list(data.iloc[:,1])
        emojis_unique = list(set(emojis))
        
        self.labels = [emojis_unique.index(i) for i in emojis]
 
        self.labels_dict = {'key': range(len(emojis_unique)), 'value': emojis_unique}
        
    def __getitem__(self, i): #collate 이전 미리 tokenize를 시켜주자
        tokenized = self.tokenizer(str(self.sentences[i]), return_tensors='pt')
        
        #아래 세 개는 tokenizer가 기본적으로 반환하는 정보. BERT의 input이기도 함
        input_ids = tokenized['input_ids']
        token_type_ids = tokenized['token_type_ids']
        attention_mask = tokenized['attention_mask']
        
#         print(str(self.sentences[i]) +' : ')
#         print(tokenized)
        
        return {'input_ids': input_ids, 'token_type_ids': token_type_ids, 
                'attention_mask': attention_mask, 'label': self.labels[i]}
         
    def __len__(self): #data loader가 필요로 하여 필수적으로 있어야 하는 함수
        return len(self.sentences)

In [5]:
class collate_fn:
    def __init__(self, labels_dict):
        self.num_labels = len(labels_dict)
        
    def __call__(self, batch): #batch는 dataset.getitem의 return 값의 List. eg. [{}, {}. ...]
        #batch내 최대 문장 길이(토큰 개수)를 먼저 구해서 padding할 수 있도록 하기
        batchlen = [sample['input_ids'].size(1) for sample in batch] #tensor값을 반환하기 때문에 1번째 차원의 길이를 구함
        maxlen = max(batchlen)
        input_ids = []
        token_type_ids = []
        attention_mask = []
        #padding: [5, 6] [0, 0,  ...]을 concatenate 하는 방식으로 패딩
        for sample in batch:
            pad_len = maxlen - sample['input_ids'].size(1)
            pad = torch.zeros((1, pad_len), dtype=torch.int)            
            input_ids.append(torch.cat([sample['input_ids'], pad], dim=1))
            token_type_ids.append(torch.cat([sample['token_type_ids'], pad], dim=1))
            attention_mask.append(torch.cat([sample['attention_mask'], pad], dim=1))
        #batch 구성
        input_ids = torch.cat(input_ids, dim=0)
        token_type_ids = torch.cat(token_type_ids, dim=0)
        attention_mask = torch.cat(attention_mask, dim=0)
        
        #one-hot encoding
        #batch 내 라벨을 tensor로 변환
        tensor_label = torch.tensor([sample['label'] for sample in batch])
        
        return input_ids, token_type_ids, attention_mask, tensor_label

In [6]:
df = pd.read_csv('data/twitter_clean.csv', encoding="UTF-8")
print(len(df['y'].value_counts()))
df['y'].value_counts(sort = True).head(10)

5


{love}                141
{kind-smile}          119
{laughing-out}         96
{open-mouth-smile}     84
{good-job}             81
Name: y, dtype: int64

In [7]:
df['split'] = np.random.randn(df.shape[0], 1)
msk = np.random.rand(len(df)) <= 0.7

train = df[msk]
test = df[~msk]

train.to_csv('data/train.csv', index=False)
test.to_csv('data/test.csv', index=False)

In [8]:
train = sentencEmojiDataset('data/train.csv', tokenizer)
test = sentencEmojiDataset('data/test.csv', tokenizer)

train_collate_fn = collate_fn(train.labels_dict)
test_collate_fn = collate_fn(test.labels_dict)

train_collate_fn

<__main__.collate_fn at 0x215614496d8>

In [9]:
# Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 20  
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-3

In [10]:
train_dataloader = DataLoader(train, batch_size=batch_size, collate_fn=train_collate_fn, shuffle = True, drop_last = True)
test_dataloader = DataLoader(test, batch_size=batch_size, collate_fn=test_collate_fn, shuffle = False, drop_last = False)
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x2156269f320>

In [11]:
# class BERTClassifier(nn.Module):
#     def __init__(self,
#                  bert,
#                  hidden_size = 768,
#                  num_classes=2,
#                  dr_rate=None,
#                  params=None):
#         super(BERTClassifier, self).__init__()
#         self.bert = bert
#         # do not train bert parameters
#         for p in self.bert.parameters():
#             p.requires_grad = False
#         self.dr_rate = dr_rate
                 
#         self.classifier = nn.Linear(hidden_size , num_classes)
        
#         if dr_rate:
#             self.dropout = nn.Dropout(p=dr_rate)

#     def forward(self, input_ids, token_type_ids, attention_mask):
#         #eval: drop out 중지, batch norm 고정과 같이 evaluation으로 모델 변경
#         self.bert.eval()
#         #gradient 계산을 중지
#         with torch.no_grad():
#             x = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
# #         x = self.dropout(pooler)
#         return self.classifier(x)

In [12]:
class ELECTRAClassifier(nn.Module):
    def __init__(self,
                 electra,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        
        super(ELECTRAClassifier, self).__init__()
        
        self.electra = electra
        
        # do not train electra parameters
        for p in self.electra.parameters():
            p.requires_grad = False
        
        self.dr_rate = dr_rate
        
        
        #self.classifier = nn.Linear(hidden_size , num_classes)
        
#         # 방법 1 -> forward에서 처리해줘야 함. 
#         self.classifier1 = nn.Linear(hidden_size, 100) # y = Wx
#         self.classifier2 = nn.Linear(100, num_classes) # z = Uy
#         #layer 추가 시 activation function을 주지 않으면 의미가 없음. 
#         self.relu = nn.ReLU()
        
        #방법 2 -> forward에서 별도 처리 필요 X
        self.classifier = nn.Sequential(nn.Linear(hidden_size, 100), nn.ReLU(), nn.Linear(100, num_classes))
        
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)

    def forward(self, input_ids, token_type_ids, attention_mask):
        #eval: drop out 중지, batch norm 고정과 같이 evaluation으로 모델 변경
        self.electra.eval()
        
        #gradient 계산을 중지
        with torch.no_grad():
            #ElectraModel은 pooled_output을 리턴하지 않는 것을 제외하고 BertModel과 유사합니다.
#            x = self.electra(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output

            x = self.electra(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] 
            #.last_hidden_state[:, 0, :]: [batch , CLS 위치, depth]
            
            #Sentence Embedding으로 무엇을 넣을까? CLS, average, ... (Bert의 경우에는 Sentence BERT라는 게 제안되었다고 함)
                   
        x = self.dropout(x)
      
        return self.classifier(x)

In [13]:
label = list(set(list(df.iloc[:,1])))

In [14]:
# model = BERTClassifier(model,  dr_rate=0.5, num_classes = len(label))
model = ELECTRAClassifier(model,  dr_rate=0.2, num_classes = len(label))

In [15]:
#optimizer와 schedule 설정
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [16]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
# optimizer = AdamW(model.parameters(), lr=learning_rate)

In [17]:
np.unique(train.labels)

array([0, 1, 2, 3, 4])

In [18]:
train.labels[:20]

[4, 2, 0, 2, 1, 3, 4, 4, 4, 2, 0, 0, 4, 0, 0, 4, 4, 4, 4, 4]

In [19]:
#Class Imbalance 문제 해결을 위한 weighted cross entropy 
class_weights = compute_class_weight(class_weight = 'balanced', classes = np.unique(train.labels), y = train.labels)
class_weights = torch.tensor(class_weights, dtype=torch.float)
print(class_weights) #([1.0000, 1.0000, 4.0000, 1.0000, 0.5714])

tensor([0.9167, 1.1159, 1.1667, 1.5714, 0.6581])


In [20]:
train.labels_dict

{'key': range(0, 5),
 'value': ['{kind-smile}',
  '{open-mouth-smile}',
  '{laughing-out}',
  '{good-job}',
  '{love}']}

In [21]:
loss_fn = nn.CrossEntropyLoss(weight = class_weights, reduction = 'mean') 

In [22]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [23]:
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [24]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [25]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    loss_sum = 0
    model.train()
    for batch_id, (input_ids, token_type_ids, attention_mask, tensor_label) in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        out = model(input_ids, token_type_ids, attention_mask)
        loss = loss_fn(out, tensor_label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
#         scheduler.step()  # Update learning rate schedule
        batch_acc = calc_accuracy(out, tensor_label)
        train_acc += batch_acc
        loss_sum += loss.data.cpu().numpy()
        #f batch_id % log_interval == 0:
        print("epoch {} batch id {}/{} loss {} train acc {}".format(e+1, batch_id+1, len(train_dataloader), loss.data.cpu().numpy(), batch_acc))
    print("epoch {} train acc {} loss mean {}".format(e+1, train_acc / (batch_id+1), loss_sum / len(train_dataloader)))
    model.eval()
    with torch.no_grad():
        for batch_id, (input_ids, token_type_ids, attention_mask, tensor_label) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
            out = model(input_ids, token_type_ids, attention_mask)
            test_acc += calc_accuracy(out, tensor_label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

epoch 1 batch id 1/6 loss 1.613451600074768 train acc 0.21875
epoch 1 batch id 2/6 loss 1.6458004713058472 train acc 0.125
epoch 1 batch id 3/6 loss 1.614746332168579 train acc 0.15625
epoch 1 batch id 4/6 loss 1.604219913482666 train acc 0.171875
epoch 1 batch id 5/6 loss 1.6407359838485718 train acc 0.15625
epoch 1 batch id 6/6 loss 1.6272358894348145 train acc 0.09375
epoch 1 train acc 0.15364583333333334 loss mean 1.6243650317192078


100%|███████████████████████████████████████████████████| 3/3 [00:16<00:00,  5.62s/it]


epoch 1 test acc 0.2604166666666667
epoch 2 batch id 1/6 loss 1.630483627319336 train acc 0.140625
epoch 2 batch id 2/6 loss 1.6245185136795044 train acc 0.140625
epoch 2 batch id 3/6 loss 1.6198183298110962 train acc 0.171875
epoch 2 batch id 4/6 loss 1.6111185550689697 train acc 0.234375
epoch 2 batch id 5/6 loss 1.609501838684082 train acc 0.203125
epoch 2 batch id 6/6 loss 1.6386713981628418 train acc 0.09375
epoch 2 train acc 0.1640625 loss mean 1.6223520437876384


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.69s/it]


epoch 2 test acc 0.2604166666666667
epoch 3 batch id 1/6 loss 1.61758553981781 train acc 0.140625
epoch 3 batch id 2/6 loss 1.6227327585220337 train acc 0.15625
epoch 3 batch id 3/6 loss 1.615502119064331 train acc 0.203125
epoch 3 batch id 4/6 loss 1.610782265663147 train acc 0.25
epoch 3 batch id 5/6 loss 1.636742353439331 train acc 0.171875
epoch 3 batch id 6/6 loss 1.6294193267822266 train acc 0.09375
epoch 3 train acc 0.16927083333333334 loss mean 1.62212739388148


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.75s/it]


epoch 3 test acc 0.2604166666666667
epoch 4 batch id 1/6 loss 1.616374135017395 train acc 0.171875
epoch 4 batch id 2/6 loss 1.6274409294128418 train acc 0.140625
epoch 4 batch id 3/6 loss 1.626684308052063 train acc 0.125
epoch 4 batch id 4/6 loss 1.6109684705734253 train acc 0.15625
epoch 4 batch id 5/6 loss 1.612525463104248 train acc 0.140625
epoch 4 batch id 6/6 loss 1.6460633277893066 train acc 0.125
epoch 4 train acc 0.14322916666666666 loss mean 1.62334277232488


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.73s/it]


epoch 4 test acc 0.2604166666666667
epoch 5 batch id 1/6 loss 1.629732608795166 train acc 0.15625
epoch 5 batch id 2/6 loss 1.6583685874938965 train acc 0.09375
epoch 5 batch id 3/6 loss 1.6367729902267456 train acc 0.15625
epoch 5 batch id 4/6 loss 1.5809884071350098 train acc 0.1875
epoch 5 batch id 5/6 loss 1.616137981414795 train acc 0.109375
epoch 5 batch id 6/6 loss 1.6407182216644287 train acc 0.171875
epoch 5 train acc 0.14583333333333334 loss mean 1.6271197994550068


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.70s/it]


epoch 5 test acc 0.2604166666666667
epoch 6 batch id 1/6 loss 1.6215732097625732 train acc 0.171875
epoch 6 batch id 2/6 loss 1.6046267747879028 train acc 0.171875
epoch 6 batch id 3/6 loss 1.6236131191253662 train acc 0.203125
epoch 6 batch id 4/6 loss 1.6322646141052246 train acc 0.15625
epoch 6 batch id 5/6 loss 1.6168535947799683 train acc 0.09375
epoch 6 batch id 6/6 loss 1.5985954999923706 train acc 0.1875
epoch 6 train acc 0.1640625 loss mean 1.6162544687589009


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.76s/it]


epoch 6 test acc 0.2604166666666667
epoch 7 batch id 1/6 loss 1.632094383239746 train acc 0.140625
epoch 7 batch id 2/6 loss 1.636955738067627 train acc 0.171875
epoch 7 batch id 3/6 loss 1.647245168685913 train acc 0.1875
epoch 7 batch id 4/6 loss 1.6175442934036255 train acc 0.203125
epoch 7 batch id 5/6 loss 1.61517333984375 train acc 0.171875
epoch 7 batch id 6/6 loss 1.581209659576416 train acc 0.203125
epoch 7 train acc 0.1796875 loss mean 1.6217037638028462


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.72s/it]


epoch 7 test acc 0.2604166666666667
epoch 8 batch id 1/6 loss 1.6502846479415894 train acc 0.078125
epoch 8 batch id 2/6 loss 1.618670105934143 train acc 0.125
epoch 8 batch id 3/6 loss 1.5987085103988647 train acc 0.1875
epoch 8 batch id 4/6 loss 1.600403904914856 train acc 0.140625
epoch 8 batch id 5/6 loss 1.6563019752502441 train acc 0.15625
epoch 8 batch id 6/6 loss 1.619942307472229 train acc 0.1875
epoch 8 train acc 0.14583333333333334 loss mean 1.6240519086519878


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.77s/it]


epoch 8 test acc 0.2604166666666667
epoch 9 batch id 1/6 loss 1.6265913248062134 train acc 0.125
epoch 9 batch id 2/6 loss 1.5972068309783936 train acc 0.234375
epoch 9 batch id 3/6 loss 1.6089309453964233 train acc 0.171875
epoch 9 batch id 4/6 loss 1.627313256263733 train acc 0.15625
epoch 9 batch id 5/6 loss 1.6358076333999634 train acc 0.140625
epoch 9 batch id 6/6 loss 1.6435877084732056 train acc 0.140625
epoch 9 train acc 0.16145833333333334 loss mean 1.6232396165529888


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.72s/it]


epoch 9 test acc 0.2604166666666667
epoch 10 batch id 1/6 loss 1.6345865726470947 train acc 0.140625
epoch 10 batch id 2/6 loss 1.5981851816177368 train acc 0.21875
epoch 10 batch id 3/6 loss 1.602709412574768 train acc 0.15625
epoch 10 batch id 4/6 loss 1.6229135990142822 train acc 0.171875
epoch 10 batch id 5/6 loss 1.6607722043991089 train acc 0.09375
epoch 10 batch id 6/6 loss 1.606272578239441 train acc 0.15625
epoch 10 train acc 0.15625 loss mean 1.6209065914154053


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.74s/it]


epoch 10 test acc 0.2604166666666667
epoch 11 batch id 1/6 loss 1.6454896926879883 train acc 0.140625
epoch 11 batch id 2/6 loss 1.637091040611267 train acc 0.125
epoch 11 batch id 3/6 loss 1.6321349143981934 train acc 0.109375
epoch 11 batch id 4/6 loss 1.6131805181503296 train acc 0.1875
epoch 11 batch id 5/6 loss 1.6089885234832764 train acc 0.265625
epoch 11 batch id 6/6 loss 1.618019461631775 train acc 0.15625
epoch 11 train acc 0.1640625 loss mean 1.625817358493805


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.77s/it]


epoch 11 test acc 0.2604166666666667
epoch 12 batch id 1/6 loss 1.6259502172470093 train acc 0.1875
epoch 12 batch id 2/6 loss 1.650724172592163 train acc 0.125
epoch 12 batch id 3/6 loss 1.6021134853363037 train acc 0.15625
epoch 12 batch id 4/6 loss 1.6224579811096191 train acc 0.1875
epoch 12 batch id 5/6 loss 1.6356079578399658 train acc 0.125
epoch 12 batch id 6/6 loss 1.5983006954193115 train acc 0.1875
epoch 12 train acc 0.16145833333333334 loss mean 1.6225257515907288


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.73s/it]


epoch 12 test acc 0.2604166666666667
epoch 13 batch id 1/6 loss 1.6393537521362305 train acc 0.140625
epoch 13 batch id 2/6 loss 1.5875513553619385 train acc 0.25
epoch 13 batch id 3/6 loss 1.6028742790222168 train acc 0.140625
epoch 13 batch id 4/6 loss 1.6598411798477173 train acc 0.078125
epoch 13 batch id 5/6 loss 1.6079858541488647 train acc 0.203125
epoch 13 batch id 6/6 loss 1.630764365196228 train acc 0.125
epoch 13 train acc 0.15625 loss mean 1.6213951309521992


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.71s/it]


epoch 13 test acc 0.2604166666666667
epoch 14 batch id 1/6 loss 1.6203601360321045 train acc 0.234375
epoch 14 batch id 2/6 loss 1.6065282821655273 train acc 0.171875
epoch 14 batch id 3/6 loss 1.6055279970169067 train acc 0.140625
epoch 14 batch id 4/6 loss 1.6063368320465088 train acc 0.109375
epoch 14 batch id 5/6 loss 1.6230372190475464 train acc 0.1875
epoch 14 batch id 6/6 loss 1.6529351472854614 train acc 0.125
epoch 14 train acc 0.16145833333333334 loss mean 1.6191209355990093


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.71s/it]


epoch 14 test acc 0.2604166666666667
epoch 15 batch id 1/6 loss 1.6443899869918823 train acc 0.125
epoch 15 batch id 2/6 loss 1.6480600833892822 train acc 0.109375
epoch 15 batch id 3/6 loss 1.5882917642593384 train acc 0.28125
epoch 15 batch id 4/6 loss 1.5993629693984985 train acc 0.1875
epoch 15 batch id 5/6 loss 1.6341460943222046 train acc 0.1875
epoch 15 batch id 6/6 loss 1.6171791553497314 train acc 0.15625
epoch 15 train acc 0.17447916666666666 loss mean 1.621905008951823


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.72s/it]


epoch 15 test acc 0.2604166666666667
epoch 16 batch id 1/6 loss 1.616754174232483 train acc 0.21875
epoch 16 batch id 2/6 loss 1.6311415433883667 train acc 0.125
epoch 16 batch id 3/6 loss 1.6653913259506226 train acc 0.0625
epoch 16 batch id 4/6 loss 1.6196776628494263 train acc 0.15625
epoch 16 batch id 5/6 loss 1.6096181869506836 train acc 0.234375
epoch 16 batch id 6/6 loss 1.6056464910507202 train acc 0.15625
epoch 16 train acc 0.15885416666666666 loss mean 1.624704897403717


100%|███████████████████████████████████████████████████| 3/3 [00:18<00:00,  6.03s/it]


epoch 16 test acc 0.2604166666666667
epoch 17 batch id 1/6 loss 1.6343644857406616 train acc 0.171875
epoch 17 batch id 2/6 loss 1.5977003574371338 train acc 0.203125
epoch 17 batch id 3/6 loss 1.6558666229248047 train acc 0.09375
epoch 17 batch id 4/6 loss 1.6220749616622925 train acc 0.15625
epoch 17 batch id 5/6 loss 1.6291532516479492 train acc 0.171875
epoch 17 batch id 6/6 loss 1.6026897430419922 train acc 0.1875
epoch 17 train acc 0.1640625 loss mean 1.623641570409139


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.81s/it]


epoch 17 test acc 0.2604166666666667
epoch 18 batch id 1/6 loss 1.600422739982605 train acc 0.1875
epoch 18 batch id 2/6 loss 1.6128394603729248 train acc 0.15625
epoch 18 batch id 3/6 loss 1.606170415878296 train acc 0.1875
epoch 18 batch id 4/6 loss 1.6453073024749756 train acc 0.140625
epoch 18 batch id 5/6 loss 1.6557445526123047 train acc 0.109375
epoch 18 batch id 6/6 loss 1.6419743299484253 train acc 0.109375
epoch 18 train acc 0.1484375 loss mean 1.6270764668782551


100%|███████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.75s/it]


epoch 18 test acc 0.2604166666666667
epoch 19 batch id 1/6 loss 1.6038633584976196 train acc 0.203125
epoch 19 batch id 2/6 loss 1.5889335870742798 train acc 0.15625
epoch 19 batch id 3/6 loss 1.6289522647857666 train acc 0.171875
epoch 19 batch id 4/6 loss 1.615651249885559 train acc 0.21875
epoch 19 batch id 5/6 loss 1.6239826679229736 train acc 0.15625
epoch 19 batch id 6/6 loss 1.6488990783691406 train acc 0.15625
epoch 19 train acc 0.17708333333333334 loss mean 1.61838036775589


100%|███████████████████████████████████████████████████| 3/3 [00:18<00:00,  6.22s/it]


epoch 19 test acc 0.2604166666666667
epoch 20 batch id 1/6 loss 1.6307857036590576 train acc 0.15625
epoch 20 batch id 2/6 loss 1.6268943548202515 train acc 0.15625
epoch 20 batch id 3/6 loss 1.622048020362854 train acc 0.078125
epoch 20 batch id 4/6 loss 1.6271264553070068 train acc 0.15625
epoch 20 batch id 5/6 loss 1.6194063425064087 train acc 0.140625
epoch 20 batch id 6/6 loss 1.6164839267730713 train acc 0.171875
epoch 20 train acc 0.14322916666666666 loss mean 1.6237908005714417


100%|███████████████████████████████████████████████████| 3/3 [00:16<00:00,  5.46s/it]

epoch 20 test acc 0.2604166666666667



