In [1]:
# !pip install ipywidgets  # for vscode
# !pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
import pandas as pd
import json

#transformers
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup
from transformers import BertModel
from transformers import AutoTokenizer, AutoModel

from sklearn.utils.class_weight import compute_class_weight

In [3]:
tokenizer = AutoTokenizer.from_pretrained("monologg/kobert")
model = AutoModel.from_pretrained("monologg/kobert")

In [4]:
class sentencEmojiDataset(Dataset):
    def __init__(self, directory, tokenizer):
        
        data = pd.read_csv(directory, encoding='UTF-8')
        
        self.tokenizer = tokenizer
        self.sentences = list(data.iloc[:,0])
       
        emojis = list(data.iloc[:,1])
        emojis_unique = list(set(emojis))
        
        self.labels = [emojis_unique.index(i) for i in emojis]
 
        self.labels_dict = {'key': range(len(emojis_unique)), 'value': emojis_unique}
        
    def __getitem__(self, i): #collate 이전 미리 tokenize를 시켜주자
        tokenized = self.tokenizer(str(self.sentences[i]), return_tensors='pt')
        #아래 세 개는 tokenizer가 기본적으로 반환하는 정보. BERT의 input이기도 함
        input_ids = tokenized['input_ids']
        token_type_ids = tokenized['token_type_ids']
        attention_mask = tokenized['attention_mask']
        
        return {'input_ids': input_ids, 'token_type_ids': token_type_ids, 
                'attention_mask': attention_mask, 'label': self.labels[i]}
         
    def __len__(self): #data loader가 필요로 하여 필수적으로 있어야 하는 함수
        return len(self.sentences)

In [5]:
class collate_fn:
    def __init__(self, labels_dict):
        self.num_labels = len(labels_dict)
        
    def __call__(self, batch): #batch는 dataset.getitem의 return 값의 List. eg. [{}, {}. ...]
        #batch내 최대 문장 길이(토큰 개수)를 먼저 구해서 padding할 수 있도록 하기
        batchlen = [sample['input_ids'].size(1) for sample in batch] #tensor값을 반환하기 때문에 1번째 차원의 길이를 구함
        maxlen = max(batchlen)
        input_ids = []
        token_type_ids = []
        attention_mask = []
        #padding: [5, 6] [0, 0,  ...]을 concatenate 하는 방식으로 패딩
        for sample in batch:
            pad_len = maxlen - sample['input_ids'].size(1)
            pad = torch.zeros((1, pad_len), dtype=torch.int)            
            input_ids.append(torch.cat([sample['input_ids'], pad], dim=1))
            token_type_ids.append(torch.cat([sample['token_type_ids'], pad], dim=1))
            attention_mask.append(torch.cat([sample['attention_mask'], pad], dim=1))
        #batch 구성
        input_ids = torch.cat(input_ids, dim=0)
        token_type_ids = torch.cat(token_type_ids, dim=0)
        attention_mask = torch.cat(attention_mask, dim=0)
        
        #one-hot encoding
        #batch 내 라벨을 tensor로 변환
        tensor_label = torch.tensor([sample['label'] for sample in batch])
        
        return input_ids, token_type_ids, attention_mask, tensor_label

In [6]:
df = pd.read_csv('data/twitter_clean.csv', encoding="UTF-8")
print(len(df['y'].value_counts()))
df['y'].value_counts(sort = True).head(10)

30


N    1690
😂      76
😭      33
❤      32
😍      21
🔥      21
💜      19
✨      18
💗      16
📸      15
Name: y, dtype: int64

In [7]:
df['split'] = np.random.randn(df.shape[0], 1)
msk = np.random.rand(len(df)) <= 0.7

train = df[msk]
test = df[~msk]

train.to_csv('data/train.csv', index=False)
test.to_csv('data/test.csv', index=False)

In [8]:
train = sentencEmojiDataset('data/train.csv', tokenizer)
test = sentencEmojiDataset('data/test.csv', tokenizer)

train_collate_fn = collate_fn(train.labels_dict)
test_collate_fn = collate_fn(test.labels_dict)

train_collate_fn

<__main__.collate_fn at 0x1f71120d5f8>

In [9]:
# Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 20  
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-3

In [10]:
train_dataloader = DataLoader(train, batch_size=batch_size, collate_fn=train_collate_fn, shuffle = True, drop_last = True)
test_dataloader = DataLoader(test, batch_size=batch_size, collate_fn=test_collate_fn, shuffle = False, drop_last = False)
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x1f7112d22b0>

In [11]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        # do not train bert parameters
        for p in self.bert.parameters():
            p.requires_grad = False
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)

    def forward(self, input_ids, token_type_ids, attention_mask):
        #eval: drop out 중지, batch norm 고정과 같이 evaluation으로 모델 변경
        self.bert.eval()
        #gradient 계산을 중지
        with torch.no_grad():
            x = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
#         x = self.dropout(pooler)
        return self.classifier(x)

In [12]:
label = list(set(list(df.iloc[:,1])))

In [13]:
model = BERTClassifier(model,  dr_rate=0.5, num_classes = len(label))

In [14]:
#optimizer와 schedule 설정
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [15]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
# optimizer = AdamW(model.parameters(), lr=learning_rate)

In [16]:
np.unique(train.labels)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])

In [17]:
train.labels[:20]

[11, 5, 11, 11, 23, 8, 23, 11, 20, 23, 11, 23, 4, 11, 20, 11, 11, 17, 5, 20]

In [18]:
#Class Imbalance 문제 해결을 위한 weighted cross entropy 
class_weights = compute_class_weight(class_weight = 'balanced', classes = np.unique(train.labels), y = train.labels)
class_weights = torch.tensor(class_weights, dtype=torch.float)
print(class_weights) #([1.0000, 1.0000, 4.0000, 1.0000, 0.5714])

tensor([ 2.0427,  6.3833,  5.1067,  5.1067,  3.9282,  2.6877,  4.6424,  4.6424,
         4.6424,  8.5111, 10.2133,  0.0434,  8.5111,  5.6741,  5.1067,  7.2952,
         5.6741,  7.2952,  5.6741,  5.6741,  2.1278,  5.6741,  6.3833,  0.9457,
         5.6741,  5.1067,  3.1917,  6.3833,  3.6476,  5.6741])


In [19]:
train.labels_dict

{'key': range(0, 30),
 'value': ['😭',
  '🙄',
  '💕',
  '🤤',
  '🔥',
  '😍',
  '😘',
  '💙',
  '🥺',
  '😑',
  '🎉',
  'N',
  '🏻',
  '📸',
  '😩',
  '🤣',
  '💗',
  '👍',
  '✌',
  '💯',
  '❤',
  '❣',
  '😁',
  '😂',
  '😆',
  '😊',
  '💜',
  '💖',
  '✨',
  '🙏']}

In [20]:
loss_fn = nn.CrossEntropyLoss(weight = class_weights, reduction = 'mean') 

In [21]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [22]:
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [23]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [24]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    loss_sum = 0
    model.train()
    for batch_id, (input_ids, token_type_ids, attention_mask, tensor_label) in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        out = model(input_ids, token_type_ids, attention_mask)
        loss = loss_fn(out, tensor_label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
#         scheduler.step()  # Update learning rate schedule
        batch_acc = calc_accuracy(out, tensor_label)
        train_acc += batch_acc
        loss_sum += loss.data.cpu().numpy()
        #f batch_id % log_interval == 0:
        print("epoch {} batch id {}/{} loss {} train acc {}".format(e+1, batch_id+1, len(train_dataloader), loss.data.cpu().numpy(), batch_acc))
    print("epoch {} train acc {} loss mean {}".format(e+1, train_acc / (batch_id+1), loss_sum / len(train_dataloader)))
    model.eval()
    with torch.no_grad():
        for batch_id, (input_ids, token_type_ids, attention_mask, tensor_label) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
            out = model(input_ids, token_type_ids, attention_mask)
            test_acc += calc_accuracy(out, tensor_label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

epoch 1 batch id 1/23 loss 3.4549572467803955 train acc 0.0
epoch 1 batch id 2/23 loss 3.4629199504852295 train acc 0.0
epoch 1 batch id 3/23 loss 3.3897719383239746 train acc 0.03125
epoch 1 batch id 4/23 loss 3.4274215698242188 train acc 0.0
epoch 1 batch id 5/23 loss 3.281050205230713 train acc 0.0
epoch 1 batch id 6/23 loss 3.3971595764160156 train acc 0.015625
epoch 1 batch id 7/23 loss 3.3611748218536377 train acc 0.015625
epoch 1 batch id 8/23 loss 3.5217983722686768 train acc 0.0
epoch 1 batch id 9/23 loss 3.4654290676116943 train acc 0.015625
epoch 1 batch id 10/23 loss 3.4421963691711426 train acc 0.0
epoch 1 batch id 11/23 loss 3.4965920448303223 train acc 0.0
epoch 1 batch id 12/23 loss 3.4427907466888428 train acc 0.0
epoch 1 batch id 13/23 loss 3.4203364849090576 train acc 0.015625
epoch 1 batch id 14/23 loss 3.410829782485962 train acc 0.03125
epoch 1 batch id 15/23 loss 3.4675986766815186 train acc 0.015625
epoch 1 batch id 16/23 loss 3.5574512481689453 train acc 0.0
ep

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:33<00:00,  3.06s/it]


epoch 1 test acc 0.004261363636363636
epoch 2 batch id 1/23 loss 3.5113120079040527 train acc 0.0
epoch 2 batch id 2/23 loss 3.3815982341766357 train acc 0.0
epoch 2 batch id 3/23 loss 3.4776947498321533 train acc 0.0
epoch 2 batch id 4/23 loss 3.3555214405059814 train acc 0.015625
epoch 2 batch id 5/23 loss 3.3867571353912354 train acc 0.015625
epoch 2 batch id 6/23 loss 3.466486930847168 train acc 0.015625
epoch 2 batch id 7/23 loss 3.5295748710632324 train acc 0.0
epoch 2 batch id 8/23 loss 3.4656667709350586 train acc 0.015625
epoch 2 batch id 9/23 loss 3.621302604675293 train acc 0.0
epoch 2 batch id 10/23 loss 3.429151773452759 train acc 0.0
epoch 2 batch id 11/23 loss 3.4153623580932617 train acc 0.015625
epoch 2 batch id 12/23 loss 3.4252240657806396 train acc 0.015625
epoch 2 batch id 13/23 loss 3.443650722503662 train acc 0.0
epoch 2 batch id 14/23 loss 3.34031343460083 train acc 0.0
epoch 2 batch id 15/23 loss 3.3701040744781494 train acc 0.03125
epoch 2 batch id 16/23 loss 

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:33<00:00,  3.03s/it]


epoch 2 test acc 0.004261363636363636
epoch 3 batch id 1/23 loss 3.5000338554382324 train acc 0.015625
epoch 3 batch id 2/23 loss 3.488783597946167 train acc 0.0
epoch 3 batch id 3/23 loss 3.5340259075164795 train acc 0.0
epoch 3 batch id 4/23 loss 3.380797863006592 train acc 0.015625
epoch 3 batch id 5/23 loss 3.4941303730010986 train acc 0.0
epoch 3 batch id 6/23 loss 3.434579372406006 train acc 0.0
epoch 3 batch id 7/23 loss 3.2965009212493896 train acc 0.015625
epoch 3 batch id 8/23 loss 3.4618799686431885 train acc 0.0
epoch 3 batch id 9/23 loss 3.3332927227020264 train acc 0.03125
epoch 3 batch id 10/23 loss 3.5063023567199707 train acc 0.015625
epoch 3 batch id 11/23 loss 3.3889665603637695 train acc 0.015625
epoch 3 batch id 12/23 loss 3.369523525238037 train acc 0.015625
epoch 3 batch id 13/23 loss 3.449453353881836 train acc 0.0
epoch 3 batch id 14/23 loss 3.473227024078369 train acc 0.015625
epoch 3 batch id 15/23 loss 3.365506172180176 train acc 0.015625
epoch 3 batch id 16

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:34<00:00,  3.13s/it]


epoch 3 test acc 0.004261363636363636
epoch 4 batch id 1/23 loss 3.4663219451904297 train acc 0.015625
epoch 4 batch id 2/23 loss 3.434060573577881 train acc 0.0
epoch 4 batch id 3/23 loss 3.3811848163604736 train acc 0.0
epoch 4 batch id 4/23 loss 3.375319242477417 train acc 0.015625
epoch 4 batch id 5/23 loss 3.318955898284912 train acc 0.03125
epoch 4 batch id 6/23 loss 3.4309093952178955 train acc 0.0
epoch 4 batch id 7/23 loss 3.2549614906311035 train acc 0.03125
epoch 4 batch id 8/23 loss 3.501028060913086 train acc 0.0
epoch 4 batch id 9/23 loss 3.624546766281128 train acc 0.0
epoch 4 batch id 10/23 loss 3.462674617767334 train acc 0.0
epoch 4 batch id 11/23 loss 3.408342123031616 train acc 0.0
epoch 4 batch id 12/23 loss 3.6264843940734863 train acc 0.0
epoch 4 batch id 13/23 loss 3.4862983226776123 train acc 0.0
epoch 4 batch id 14/23 loss 3.349769115447998 train acc 0.03125
epoch 4 batch id 15/23 loss 3.556368827819824 train acc 0.0
epoch 4 batch id 16/23 loss 3.3728137016296

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:37<00:00,  3.41s/it]


epoch 4 test acc 0.004261363636363636
epoch 5 batch id 1/23 loss 3.5253114700317383 train acc 0.0
epoch 5 batch id 2/23 loss 3.392733097076416 train acc 0.0
epoch 5 batch id 3/23 loss 3.4926090240478516 train acc 0.0
epoch 5 batch id 4/23 loss 3.52701997756958 train acc 0.015625
epoch 5 batch id 5/23 loss 3.433424711227417 train acc 0.0
epoch 5 batch id 6/23 loss 3.4745945930480957 train acc 0.015625
epoch 5 batch id 7/23 loss 3.363084077835083 train acc 0.015625
epoch 5 batch id 8/23 loss 3.4568803310394287 train acc 0.015625
epoch 5 batch id 9/23 loss 3.3389763832092285 train acc 0.0
epoch 5 batch id 10/23 loss 3.492065668106079 train acc 0.0
epoch 5 batch id 11/23 loss 3.3737099170684814 train acc 0.0
epoch 5 batch id 12/23 loss 3.533153533935547 train acc 0.0
epoch 5 batch id 13/23 loss 3.372884750366211 train acc 0.0
epoch 5 batch id 14/23 loss 3.486675262451172 train acc 0.015625
epoch 5 batch id 15/23 loss 3.3350839614868164 train acc 0.046875
epoch 5 batch id 16/23 loss 3.51122

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:35<00:00,  3.20s/it]


epoch 5 test acc 0.004261363636363636
epoch 6 batch id 1/23 loss 3.4605236053466797 train acc 0.03125
epoch 6 batch id 2/23 loss 3.5289387702941895 train acc 0.0
epoch 6 batch id 3/23 loss 3.4965386390686035 train acc 0.0
epoch 6 batch id 4/23 loss 3.388655662536621 train acc 0.0
epoch 6 batch id 5/23 loss 3.477313756942749 train acc 0.015625
epoch 6 batch id 6/23 loss 3.5159316062927246 train acc 0.0
epoch 6 batch id 7/23 loss 3.51442289352417 train acc 0.015625
epoch 6 batch id 8/23 loss 3.370028257369995 train acc 0.015625
epoch 6 batch id 9/23 loss 3.486786365509033 train acc 0.0
epoch 6 batch id 10/23 loss 3.4422662258148193 train acc 0.015625
epoch 6 batch id 11/23 loss 3.450613498687744 train acc 0.0
epoch 6 batch id 12/23 loss 3.423661470413208 train acc 0.0
epoch 6 batch id 13/23 loss 3.5033674240112305 train acc 0.0
epoch 6 batch id 14/23 loss 3.4997315406799316 train acc 0.0
epoch 6 batch id 15/23 loss 3.385765552520752 train acc 0.0
epoch 6 batch id 16/23 loss 3.42014431953

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:38<00:00,  3.49s/it]


epoch 6 test acc 0.004261363636363636
epoch 7 batch id 1/23 loss 3.5069327354431152 train acc 0.0
epoch 7 batch id 2/23 loss 3.47760272026062 train acc 0.015625
epoch 7 batch id 3/23 loss 3.2869694232940674 train acc 0.015625
epoch 7 batch id 4/23 loss 3.4166648387908936 train acc 0.0
epoch 7 batch id 5/23 loss 3.4913463592529297 train acc 0.0
epoch 7 batch id 6/23 loss 3.475839376449585 train acc 0.0
epoch 7 batch id 7/23 loss 3.5293960571289062 train acc 0.0
epoch 7 batch id 8/23 loss 3.5301096439361572 train acc 0.0
epoch 7 batch id 9/23 loss 3.362197160720825 train acc 0.03125
epoch 7 batch id 10/23 loss 3.374936103820801 train acc 0.015625
epoch 7 batch id 11/23 loss 3.377686023712158 train acc 0.015625
epoch 7 batch id 12/23 loss 3.447532892227173 train acc 0.0
epoch 7 batch id 13/23 loss 3.460367202758789 train acc 0.015625
epoch 7 batch id 14/23 loss 3.3135344982147217 train acc 0.0
epoch 7 batch id 15/23 loss 3.4966225624084473 train acc 0.015625
epoch 7 batch id 16/23 loss 3.

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:32<00:00,  2.98s/it]


epoch 7 test acc 0.004261363636363636
epoch 8 batch id 1/23 loss 3.394768714904785 train acc 0.015625
epoch 8 batch id 2/23 loss 3.5246450901031494 train acc 0.0
epoch 8 batch id 3/23 loss 3.4434659481048584 train acc 0.015625
epoch 8 batch id 4/23 loss 3.459472179412842 train acc 0.0
epoch 8 batch id 5/23 loss 3.3868846893310547 train acc 0.015625
epoch 8 batch id 6/23 loss 3.527589797973633 train acc 0.0
epoch 8 batch id 7/23 loss 3.512089252471924 train acc 0.0
epoch 8 batch id 8/23 loss 3.3622944355010986 train acc 0.0
epoch 8 batch id 9/23 loss 3.502382278442383 train acc 0.0
epoch 8 batch id 10/23 loss 3.4843075275421143 train acc 0.0
epoch 8 batch id 11/23 loss 3.400731086730957 train acc 0.015625
epoch 8 batch id 12/23 loss 3.3820183277130127 train acc 0.015625
epoch 8 batch id 13/23 loss 3.269564628601074 train acc 0.03125
epoch 8 batch id 14/23 loss 3.3954689502716064 train acc 0.0
epoch 8 batch id 15/23 loss 3.3508260250091553 train acc 0.0
epoch 8 batch id 16/23 loss 3.4976

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:33<00:00,  3.04s/it]


epoch 8 test acc 0.004261363636363636
epoch 9 batch id 1/23 loss 3.4672365188598633 train acc 0.0
epoch 9 batch id 2/23 loss 3.464799642562866 train acc 0.0
epoch 9 batch id 3/23 loss 3.554474353790283 train acc 0.0
epoch 9 batch id 4/23 loss 3.4775314331054688 train acc 0.0
epoch 9 batch id 5/23 loss 3.466677188873291 train acc 0.0
epoch 9 batch id 6/23 loss 3.2717065811157227 train acc 0.015625
epoch 9 batch id 7/23 loss 3.452512264251709 train acc 0.015625
epoch 9 batch id 8/23 loss 3.403635263442993 train acc 0.03125
epoch 9 batch id 9/23 loss 3.5392072200775146 train acc 0.0
epoch 9 batch id 10/23 loss 3.4982492923736572 train acc 0.0
epoch 9 batch id 11/23 loss 3.414479970932007 train acc 0.03125
epoch 9 batch id 12/23 loss 3.537642002105713 train acc 0.0
epoch 9 batch id 13/23 loss 3.497823715209961 train acc 0.015625
epoch 9 batch id 14/23 loss 3.4723386764526367 train acc 0.0
epoch 9 batch id 15/23 loss 3.3272593021392822 train acc 0.015625
epoch 9 batch id 16/23 loss 3.369804

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:33<00:00,  3.07s/it]


epoch 9 test acc 0.004261363636363636
epoch 10 batch id 1/23 loss 3.3917696475982666 train acc 0.015625
epoch 10 batch id 2/23 loss 3.437828302383423 train acc 0.015625
epoch 10 batch id 3/23 loss 3.477713108062744 train acc 0.0
epoch 10 batch id 4/23 loss 3.527576446533203 train acc 0.0
epoch 10 batch id 5/23 loss 3.505939245223999 train acc 0.0
epoch 10 batch id 6/23 loss 3.4437339305877686 train acc 0.03125
epoch 10 batch id 7/23 loss 3.560321569442749 train acc 0.0
epoch 10 batch id 8/23 loss 3.374302625656128 train acc 0.0
epoch 10 batch id 9/23 loss 3.4538912773132324 train acc 0.0
epoch 10 batch id 10/23 loss 3.367922306060791 train acc 0.015625
epoch 10 batch id 11/23 loss 3.3908307552337646 train acc 0.0
epoch 10 batch id 12/23 loss 3.3325626850128174 train acc 0.03125
epoch 10 batch id 13/23 loss 3.493860960006714 train acc 0.0
epoch 10 batch id 14/23 loss 3.408151388168335 train acc 0.0
epoch 10 batch id 15/23 loss 3.419408082962036 train acc 0.0
epoch 10 batch id 16/23 loss

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:32<00:00,  2.98s/it]


epoch 10 test acc 0.004261363636363636
epoch 11 batch id 1/23 loss 3.4039785861968994 train acc 0.0
epoch 11 batch id 2/23 loss 3.389486789703369 train acc 0.015625
epoch 11 batch id 3/23 loss 3.414552927017212 train acc 0.015625
epoch 11 batch id 4/23 loss 3.3949358463287354 train acc 0.0
epoch 11 batch id 5/23 loss 3.3499996662139893 train acc 0.015625
epoch 11 batch id 6/23 loss 3.3319458961486816 train acc 0.015625
epoch 11 batch id 7/23 loss 3.436568260192871 train acc 0.0
epoch 11 batch id 8/23 loss 3.361441135406494 train acc 0.0
epoch 11 batch id 9/23 loss 3.4828271865844727 train acc 0.015625
epoch 11 batch id 10/23 loss 3.469836711883545 train acc 0.0
epoch 11 batch id 11/23 loss 3.3270647525787354 train acc 0.015625
epoch 11 batch id 12/23 loss 3.4850196838378906 train acc 0.0
epoch 11 batch id 13/23 loss 3.451691150665283 train acc 0.015625
epoch 11 batch id 14/23 loss 3.4622883796691895 train acc 0.0
epoch 11 batch id 15/23 loss 3.5049760341644287 train acc 0.0
epoch 11 ba

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:33<00:00,  3.05s/it]


epoch 11 test acc 0.004261363636363636
epoch 12 batch id 1/23 loss 3.5552868843078613 train acc 0.0
epoch 12 batch id 2/23 loss 3.3624908924102783 train acc 0.015625
epoch 12 batch id 3/23 loss 3.2545957565307617 train acc 0.03125
epoch 12 batch id 4/23 loss 3.4618515968322754 train acc 0.015625
epoch 12 batch id 5/23 loss 3.377579689025879 train acc 0.0
epoch 12 batch id 6/23 loss 3.459718942642212 train acc 0.0
epoch 12 batch id 7/23 loss 3.5084497928619385 train acc 0.0
epoch 12 batch id 8/23 loss 3.4307398796081543 train acc 0.015625
epoch 12 batch id 9/23 loss 3.278613805770874 train acc 0.015625
epoch 12 batch id 10/23 loss 3.365391969680786 train acc 0.0
epoch 12 batch id 11/23 loss 3.4884591102600098 train acc 0.0
epoch 12 batch id 12/23 loss 3.3507776260375977 train acc 0.046875
epoch 12 batch id 13/23 loss 3.4205446243286133 train acc 0.0
epoch 12 batch id 14/23 loss 3.410128355026245 train acc 0.0
epoch 12 batch id 15/23 loss 3.4505090713500977 train acc 0.015625
epoch 12 ba

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:33<00:00,  3.06s/it]


epoch 12 test acc 0.004261363636363636
epoch 13 batch id 1/23 loss 3.488541841506958 train acc 0.0
epoch 13 batch id 2/23 loss 3.404971122741699 train acc 0.015625
epoch 13 batch id 3/23 loss 3.392439842224121 train acc 0.015625
epoch 13 batch id 4/23 loss 3.3900821208953857 train acc 0.015625
epoch 13 batch id 5/23 loss 3.3301470279693604 train acc 0.03125
epoch 13 batch id 6/23 loss 3.4485836029052734 train acc 0.015625
epoch 13 batch id 7/23 loss 3.4020133018493652 train acc 0.0
epoch 13 batch id 8/23 loss 3.4505491256713867 train acc 0.0
epoch 13 batch id 9/23 loss 3.504425525665283 train acc 0.015625
epoch 13 batch id 10/23 loss 3.3816325664520264 train acc 0.0
epoch 13 batch id 11/23 loss 3.4806630611419678 train acc 0.0
epoch 13 batch id 12/23 loss 3.4741270542144775 train acc 0.0
epoch 13 batch id 13/23 loss 3.52160906791687 train acc 0.0
epoch 13 batch id 14/23 loss 3.364490270614624 train acc 0.0
epoch 13 batch id 15/23 loss 3.417506217956543 train acc 0.015625
epoch 13 batch

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:36<00:00,  3.29s/it]


epoch 13 test acc 0.004261363636363636
epoch 14 batch id 1/23 loss 3.522768020629883 train acc 0.0
epoch 14 batch id 2/23 loss 3.576038360595703 train acc 0.0
epoch 14 batch id 3/23 loss 3.4250335693359375 train acc 0.0
epoch 14 batch id 4/23 loss 3.4230642318725586 train acc 0.015625
epoch 14 batch id 5/23 loss 3.382864236831665 train acc 0.0
epoch 14 batch id 6/23 loss 3.385202407836914 train acc 0.03125
epoch 14 batch id 7/23 loss 3.527888298034668 train acc 0.0
epoch 14 batch id 8/23 loss 3.3856496810913086 train acc 0.03125
epoch 14 batch id 9/23 loss 3.3890833854675293 train acc 0.0
epoch 14 batch id 10/23 loss 3.5570530891418457 train acc 0.0
epoch 14 batch id 11/23 loss 3.3731589317321777 train acc 0.0
epoch 14 batch id 12/23 loss 3.4137656688690186 train acc 0.015625
epoch 14 batch id 13/23 loss 3.3981640338897705 train acc 0.015625
epoch 14 batch id 14/23 loss 3.427125930786133 train acc 0.0
epoch 14 batch id 15/23 loss 3.441467523574829 train acc 0.0
epoch 14 batch id 16/23 

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:34<00:00,  3.09s/it]


epoch 14 test acc 0.004261363636363636
epoch 15 batch id 1/23 loss 3.4741806983947754 train acc 0.0
epoch 15 batch id 2/23 loss 3.382072925567627 train acc 0.0
epoch 15 batch id 3/23 loss 3.488633871078491 train acc 0.0
epoch 15 batch id 4/23 loss 3.4834721088409424 train acc 0.0
epoch 15 batch id 5/23 loss 3.4054224491119385 train acc 0.015625
epoch 15 batch id 6/23 loss 3.3311514854431152 train acc 0.0
epoch 15 batch id 7/23 loss 3.2816548347473145 train acc 0.015625
epoch 15 batch id 8/23 loss 3.463719606399536 train acc 0.0
epoch 15 batch id 9/23 loss 3.464935779571533 train acc 0.0
epoch 15 batch id 10/23 loss 3.407379388809204 train acc 0.0
epoch 15 batch id 11/23 loss 3.405430316925049 train acc 0.015625
epoch 15 batch id 12/23 loss 3.4853873252868652 train acc 0.0
epoch 15 batch id 13/23 loss 3.2772274017333984 train acc 0.03125
epoch 15 batch id 14/23 loss 3.4725019931793213 train acc 0.015625
epoch 15 batch id 15/23 loss 3.4384348392486572 train acc 0.0
epoch 15 batch id 16/2

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:34<00:00,  3.17s/it]


epoch 15 test acc 0.004261363636363636
epoch 16 batch id 1/23 loss 3.389923095703125 train acc 0.03125
epoch 16 batch id 2/23 loss 3.463848352432251 train acc 0.0
epoch 16 batch id 3/23 loss 3.550621509552002 train acc 0.0
epoch 16 batch id 4/23 loss 3.358124256134033 train acc 0.0
epoch 16 batch id 5/23 loss 3.4676437377929688 train acc 0.015625
epoch 16 batch id 6/23 loss 3.436434030532837 train acc 0.0
epoch 16 batch id 7/23 loss 3.4344887733459473 train acc 0.0
epoch 16 batch id 8/23 loss 3.5297868251800537 train acc 0.0
epoch 16 batch id 9/23 loss 3.4298031330108643 train acc 0.0
epoch 16 batch id 10/23 loss 3.405683994293213 train acc 0.015625
epoch 16 batch id 11/23 loss 3.484005928039551 train acc 0.0
epoch 16 batch id 12/23 loss 3.369858503341675 train acc 0.015625
epoch 16 batch id 13/23 loss 3.329671621322632 train acc 0.046875
epoch 16 batch id 14/23 loss 3.4010705947875977 train acc 0.0
epoch 16 batch id 15/23 loss 3.4960975646972656 train acc 0.0


KeyboardInterrupt: 