In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
import pandas as pd
from sklearn.model_selection import train_test_split
from kobert_tokenizer import KoBERTTokenizer
from transformers import BertModel
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup


# GPU 가 있는 경우 지정
device = torch.device("cuda:1")

# GPU 가 없는 경우 지정
# device = torch.device("cpu")

In [2]:
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
bertmodel = BertModel.from_pretrained('skt/kobert-base-v1', return_dict=False)
vocab = nlp.vocab.BERTVocab.from_sentencepiece(tokenizer.vocab_file, padding_token='[PAD]')
tok = tokenizer.tokenize

# Setting parameters
max_len = 64
batch_size = 256
warmup_ratio = 0.1
num_epochs = 120
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'XLNetTokenizer'. 
The class this function is called from is 'KoBERTTokenizer'.


In [3]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len,
                 pad, pair):
   
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len,vocab=vocab, pad=pad, pair=pair)
        
        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))
         
    def __len__(self):
        return (len(self.labels))


In [4]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=58,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate

        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [5]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [6]:
def predict(sentence):
    dataset = [[sentence, '0']]
    test = BERTDataset(dataset, 0, 1, tok, vocab, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(test, batch_size=batch_size, num_workers=2)
    model.eval()
    answer = 0
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        for logits in out:
            logits = logits.detach().cpu().numpy()
            answer = np.argmax(logits)
    return answer

In [7]:
train_set = pd.read_csv('./custom_chatbotdataset(Training).csv')
validation_set = pd.read_csv('./custom_chatbotdataset(Validation).csv')

In [8]:
train_set.head(20)

Unnamed: 0,label,Q,A
0,9,일은 왜 해도 해도 끝이 없을까? 화가 난다.,많이 힘드시겠어요. 주위에 의논할 상대가 있나요?
1,9,이번 달에 또 급여가 깎였어! 물가는 오르는데 월급만 자꾸 깎이니까 너무 화가 나.,급여가 줄어 속상하시겠어요. 월급이 줄어든 것을 어떻게 보완하실 건가요?
2,9,회사에 신입이 들어왔는데 말투가 거슬려. 그런 애를 매일 봐야 한다고 생각하니까 스...,회사 동료 때문에 스트레스를 많이 받는 것 같아요. 문제 해결을 위해 어떤 노력을 ...
3,9,직장에서 막내라는 이유로 나에게만 온갖 심부름을 시켜. 일도 많은 데 정말 분하고 ...,관련 없는 심부름을 모두 하게 되어서 노여우시군요. 어떤 것이 상황을 나아질 수 있...
4,9,얼마 전 입사한 신입사원이 나를 무시하는 것 같아서 너무 화가 나.,무시하는 것 같은 태도에 화가 나셨군요. 상대방의 어떤 행동이 그런 감정을 유발하는...
5,9,직장에 다니고 있지만 시간만 버리는 거 같아. 진지하게 진로에 대한 고민이 생겨.,진로에 대해서 고민하고 계시는군요. 어떤 점이 고민인가요?
6,9,성인인데도 진로를 아직도 못 정했다고 부모님이 노여워하셔. 나도 섭섭해.,부모님의 노여움에 섭섭하시군요. 이런 상황을 어떻게 해결하면 좋을까요?
7,11,퇴사한 지 얼마 안 됐지만 천천히 직장을 구해보려고.,천천히라도 직장을 구해 보려고 하시는군요. 특별한 이유가 있으신가요?
8,2,졸업반이라서 취업을 생각해야 하는데 지금 너무 느긋해서 이래도 되나 싶어.,취업에 대해 걱정이 되는군요.
9,11,요즘 직장생활이 너무 편하고 좋은 것 같아!,직장생활이 편하고 좋으시다니 좋아 보여요. 다니고 계신 회사만의 장점이 있나요?


In [9]:
train_set_data = [[i, str(j)] for i, j in zip(train_set['Q'], train_set['label'])]
validation_set_data = [[i, str(j)] for i, j in zip(validation_set['Q'], validation_set['label'])]

train_set_data = BERTDataset(train_set_data, 0, 1, tok, vocab, max_len, True, False)
validation_set_data = BERTDataset(validation_set_data, 0, 1, tok, vocab, max_len, True, False)
train_dataloader = torch.utils.data.DataLoader(train_set_data, batch_size=batch_size, num_workers=2)
validation_dataloader = torch.utils.data.DataLoader(validation_set_data, batch_size=batch_size, num_workers=2)

In [10]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss().to(device)
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    best_test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(validation_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    
    if(best_test_acc < test_acc):
        best_test_acc = test_acc
        torch.save({
            'model' : model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, './save_model/QtEmodel{}.pth'.format(e+1))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 4.083265781402588 train acc 0.03515625
epoch 1 batch id 201 loss 4.088296890258789 train acc 0.016616138059701493
epoch 1 batch id 401 loss 4.034827709197998 train acc 0.02041770573566085
epoch 1 train acc 0.02482301753888946


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(validation_dataloader)):


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 1 test acc 0.03332232981220657


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 4.078164577484131 train acc 0.01953125
epoch 2 batch id 201 loss 3.9560341835021973 train acc 0.03838230721393035
epoch 2 batch id 401 loss 3.8421578407287598 train acc 0.04573527119700748
epoch 2 train acc 0.04714864273204904


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 2 test acc 0.05503594483568075


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 4.029691219329834 train acc 0.0546875
epoch 3 batch id 201 loss 3.5870862007141113 train acc 0.10725668532338309
epoch 3 batch id 401 loss 3.5008864402770996 train acc 0.10767027743142145
epoch 3 train acc 0.09729769753785927


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 3 test acc 0.0991050469483568


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 3.8497931957244873 train acc 0.1171875
epoch 4 batch id 201 loss 3.142490863800049 train acc 0.2005985696517413
epoch 4 batch id 401 loss 3.228684663772583 train acc 0.18173316708229426
epoch 4 train acc 0.1558194151128052


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 4 test acc 0.12144219483568076


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 3.6559906005859375 train acc 0.19140625
epoch 5 batch id 201 loss 2.802623748779297 train acc 0.2572294776119403
epoch 5 batch id 401 loss 2.8942394256591797 train acc 0.22642612219451372
epoch 5 train acc 0.19183442039250026


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 5 test acc 0.12727406103286384


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 3.505790948867798 train acc 0.203125
epoch 6 batch id 201 loss 2.5150833129882812 train acc 0.2882851368159204
epoch 6 batch id 401 loss 2.692615270614624 train acc 0.25370168329177056
epoch 6 train acc 0.21496769406098692


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 6 test acc 0.1457599765258216


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 3.3104820251464844 train acc 0.2421875
epoch 7 batch id 201 loss 2.3041832447052 train acc 0.31689210199004975
epoch 7 batch id 401 loss 2.5111958980560303 train acc 0.27855166770573564
epoch 7 train acc 0.23673637581127024


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 7 test acc 0.15061986502347416


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 3.193145990371704 train acc 0.26953125
epoch 8 batch id 201 loss 2.1619184017181396 train acc 0.33704524253731344
epoch 8 batch id 401 loss 2.32182240486145 train acc 0.29795628117206985
epoch 8 train acc 0.2546197331822396


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 8 test acc 0.15584653755868544


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 3.0263261795043945 train acc 0.27734375
epoch 9 batch id 201 loss 2.0265865325927734 train acc 0.3583644278606965
epoch 9 batch id 401 loss 2.2375166416168213 train acc 0.3189779457605985
epoch 9 train acc 0.27485199147522404


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 9 test acc 0.1600462147887324


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 2.852454900741577 train acc 0.328125
epoch 10 batch id 201 loss 1.936285138130188 train acc 0.3801500310945274
epoch 10 batch id 401 loss 2.1005916595458984 train acc 0.3419381234413965
epoch 10 train acc 0.2979482428917276


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 10 test acc 0.16591475938967135


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 11 batch id 1 loss 2.62420916557312 train acc 0.33984375
epoch 11 batch id 201 loss 1.8490827083587646 train acc 0.40671641791044777
epoch 11 batch id 401 loss 1.9451539516448975 train acc 0.36919420199501246
epoch 11 train acc 0.32497569408674154


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 11 test acc 0.17403902582159625


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 12 batch id 1 loss 2.570908308029175 train acc 0.359375
epoch 12 batch id 201 loss 1.6821469068527222 train acc 0.43118392412935325
epoch 12 batch id 401 loss 1.807902455329895 train acc 0.3977750935162095
epoch 12 train acc 0.3535407759864016


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 12 test acc 0.1830619131455399


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 13 batch id 1 loss 2.384719133377075 train acc 0.39453125
epoch 13 batch id 201 loss 1.5118136405944824 train acc 0.46379430970149255
epoch 13 batch id 401 loss 1.6456633806228638 train acc 0.4305836970074813
epoch 13 train acc 0.3872946069846503


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 13 test acc 0.18181484741784038


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 14 batch id 1 loss 2.1899478435516357 train acc 0.41796875
epoch 14 batch id 201 loss 1.3514219522476196 train acc 0.49945584577114427
epoch 14 batch id 401 loss 1.5249251127243042 train acc 0.46507754052369077
epoch 14 train acc 0.4237144425414649


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 14 test acc 0.19342356220657275


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 15 batch id 1 loss 2.020589590072632 train acc 0.48828125
epoch 15 batch id 201 loss 1.2558103799819946 train acc 0.5338930348258707
epoch 15 batch id 401 loss 1.346158504486084 train acc 0.5004675810473815
epoch 15 train acc 0.4591817599412795


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 15 test acc 0.20495892018779344


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 16 batch id 1 loss 1.7936149835586548 train acc 0.52734375
epoch 16 batch id 201 loss 1.0933151245117188 train acc 0.5672613495024875
epoch 16 batch id 401 loss 1.1942733526229858 train acc 0.5354192643391521
epoch 16 train acc 0.4963939521736891


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 16 test acc 0.215687353286385


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 17 batch id 1 loss 1.5362250804901123 train acc 0.60546875
epoch 17 batch id 201 loss 1.0054728984832764 train acc 0.5966068097014925
epoch 17 batch id 401 loss 1.112313151359558 train acc 0.5644579956359103
epoch 17 train acc 0.527483790692284


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 17 test acc 0.2147887323943662


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 18 batch id 1 loss 1.5013151168823242 train acc 0.578125
epoch 18 batch id 201 loss 0.8019782900810242 train acc 0.6226290422885572
epoch 18 batch id 401 loss 1.0053176879882812 train acc 0.5929122506234414
epoch 18 train acc 0.5583064251313485


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 18 test acc 0.21376173708920188


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 19 batch id 1 loss 1.415585994720459 train acc 0.609375
epoch 19 batch id 201 loss 0.7984010577201843 train acc 0.6437538868159204
epoch 19 batch id 401 loss 0.966517984867096 train acc 0.62022677680798
epoch 19 train acc 0.5869564167611002


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 19 test acc 0.23200924295774647


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 20 batch id 1 loss 1.3015717267990112 train acc 0.671875
epoch 20 batch id 201 loss 0.677749752998352 train acc 0.6742653917910447
epoch 20 batch id 401 loss 0.7446978688240051 train acc 0.6483693110972568
epoch 20 train acc 0.6140269264448336


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 20 test acc 0.23290786384976525


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 21 batch id 1 loss 1.2348873615264893 train acc 0.66015625
epoch 21 batch id 201 loss 0.5327683687210083 train acc 0.6977223258706468
epoch 21 batch id 401 loss 0.7694499492645264 train acc 0.672439993765586
epoch 21 train acc 0.640625


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 21 test acc 0.24097711267605634


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 22 batch id 1 loss 1.0265262126922607 train acc 0.69140625
epoch 22 batch id 201 loss 0.5309775471687317 train acc 0.7156405472636815
epoch 22 batch id 401 loss 0.7322962880134583 train acc 0.6911529769326683
epoch 22 train acc 0.6638331230040178


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 22 test acc 0.2500733568075117


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 23 batch id 1 loss 0.8543027639389038 train acc 0.79296875
epoch 23 batch id 201 loss 0.4552684426307678 train acc 0.7374844527363185
epoch 23 batch id 401 loss 0.5304955244064331 train acc 0.7138988466334164
epoch 23 train acc 0.6873732390285362


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 23 test acc 0.255300029342723


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 24 batch id 1 loss 0.7428084015846252 train acc 0.78515625
epoch 24 batch id 201 loss 0.4665936827659607 train acc 0.7603389303482587
epoch 24 batch id 401 loss 0.5660426616668701 train acc 0.7396839931421446
epoch 24 train acc 0.7143511930823118


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 24 test acc 0.2728323063380282


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 25 batch id 1 loss 0.7238967418670654 train acc 0.7734375
epoch 25 batch id 201 loss 0.37682750821113586 train acc 0.7778684701492538
epoch 25 batch id 401 loss 0.48077908158302307 train acc 0.7579683603491272
epoch 25 train acc 0.7353943191768827


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 25 test acc 0.27956279342723


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 26 batch id 1 loss 0.5657145977020264 train acc 0.80859375
epoch 26 batch id 201 loss 0.3488236963748932 train acc 0.8008784203980099
epoch 26 batch id 401 loss 0.48957058787345886 train acc 0.7803148379052369
epoch 26 train acc 0.7570463003502627


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 26 test acc 0.28532130281690143


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 27 batch id 1 loss 0.6194716691970825 train acc 0.82421875
epoch 27 batch id 201 loss 0.2921890318393707 train acc 0.8138992537313433
epoch 27 batch id 401 loss 0.4006965458393097 train acc 0.7967093983790524
epoch 27 train acc 0.7760302648861647


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 27 test acc 0.27550982981220656


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 28 batch id 1 loss 0.5898421406745911 train acc 0.8125
epoch 28 batch id 201 loss 0.2718110680580139 train acc 0.8256957400497512
epoch 28 batch id 401 loss 0.39935553073883057 train acc 0.8127045667082294
epoch 28 train acc 0.7904544561914083


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 28 test acc 0.27644512910798125


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 29 batch id 1 loss 0.4001552164554596 train acc 0.87890625
epoch 29 batch id 201 loss 0.30797407031059265 train acc 0.8417677238805971
epoch 29 batch id 401 loss 0.3737815320491791 train acc 0.827365180798005
epoch 29 train acc 0.8094762478108581


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 29 test acc 0.29256528755868544


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 30 batch id 1 loss 0.3702283799648285 train acc 0.89453125
epoch 30 batch id 201 loss 0.23751981556415558 train acc 0.8543221393034826
epoch 30 batch id 401 loss 0.33891111612319946 train acc 0.8409250311720698
epoch 30 train acc 0.8236030538528897


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 30 test acc 0.30166153169014087


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 31 batch id 1 loss 0.37613698840141296 train acc 0.86328125
epoch 31 batch id 201 loss 0.22029809653759003 train acc 0.8641363495024875
epoch 31 batch id 401 loss 0.30978962779045105 train acc 0.8540660068578554
epoch 31 train acc 0.8377161777583187


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 31 test acc 0.3104460093896714


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 32 batch id 1 loss 0.3978423476219177 train acc 0.890625
epoch 32 batch id 201 loss 0.24388854205608368 train acc 0.8744364116915423
epoch 32 batch id 401 loss 0.4149446189403534 train acc 0.8657165679551122
epoch 32 train acc 0.8496196366024519


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 32 test acc 0.3131968896713615


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 33 batch id 1 loss 0.3337060809135437 train acc 0.88671875
epoch 33 batch id 201 loss 0.15507900714874268 train acc 0.8814715485074627
epoch 33 batch id 401 loss 0.22977766394615173 train acc 0.8740161315461347
epoch 33 train acc 0.8585745950087565


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 33 test acc 0.322824970657277


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 34 batch id 1 loss 0.30138543248176575 train acc 0.89453125
epoch 34 batch id 201 loss 0.11492151767015457 train acc 0.8935012437810945
epoch 34 batch id 401 loss 0.24438033998012543 train acc 0.8848873908977556
epoch 34 train acc 0.8711279553415061


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 34 test acc 0.3374963321596244


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 35 batch id 1 loss 0.2740660309791565 train acc 0.921875
epoch 35 batch id 201 loss 0.18109743297100067 train acc 0.8935595460199005
epoch 35 batch id 401 loss 0.1964835375547409 train acc 0.8876246882793017
epoch 35 train acc 0.8751537228288864


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 35 test acc 0.3403755868544601


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 36 batch id 1 loss 0.19911588728427887 train acc 0.9296875
epoch 36 batch id 201 loss 0.15091833472251892 train acc 0.9032571517412935
epoch 36 batch id 401 loss 0.18197666108608246 train acc 0.8969568266832918
epoch 36 train acc 0.8842559654115587


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 36 test acc 0.3455289025821596


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 37 batch id 1 loss 0.29263433814048767 train acc 0.91015625
epoch 37 batch id 201 loss 0.15403716266155243 train acc 0.9065220771144279
epoch 37 batch id 401 loss 0.15862496197223663 train acc 0.9012722100997507
epoch 37 train acc 0.8910592066292367


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 37 test acc 0.34499706572769956


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 38 batch id 1 loss 0.1933164745569229 train acc 0.94921875
epoch 38 batch id 201 loss 0.2165534347295761 train acc 0.9142179726368159
epoch 38 batch id 401 loss 0.24840936064720154 train acc 0.9087632481296758
epoch 38 train acc 0.8986358909807356


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 38 test acc 0.36241930751173707


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 39 batch id 1 loss 0.17648673057556152 train acc 0.93359375
epoch 39 batch id 201 loss 0.12907296419143677 train acc 0.9172108208955224
epoch 39 batch id 401 loss 0.08711203187704086 train acc 0.9129130299251871
epoch 39 train acc 0.9048338988616462


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 39 test acc 0.36735255281690143


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 40 batch id 1 loss 0.10711000859737396 train acc 0.96875
epoch 40 batch id 201 loss 0.09336443990468979 train acc 0.9234491604477612
epoch 40 batch id 401 loss 0.18404942750930786 train acc 0.917881078553616
epoch 40 train acc 0.9103888463222417


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 40 test acc 0.3738446302816901


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 41 batch id 1 loss 0.21050387620925903 train acc 0.9453125
epoch 41 batch id 201 loss 0.16358034312725067 train acc 0.9288907027363185
epoch 41 batch id 401 loss 0.21090616285800934 train acc 0.9221964619700748
epoch 41 train acc 0.9154786011383538


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 41 test acc 0.36630721830985913


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 42 batch id 1 loss 0.13739819824695587 train acc 0.953125
epoch 42 batch id 201 loss 0.13704127073287964 train acc 0.9294542910447762
epoch 42 batch id 401 loss 0.15252196788787842 train acc 0.9256546134663342
epoch 42 train acc 0.9190701619964974


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 42 test acc 0.36837954812206575


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 43 batch id 1 loss 0.11808006465435028 train acc 0.96875
epoch 43 batch id 201 loss 0.11844206601381302 train acc 0.9358286691542289
epoch 43 batch id 401 loss 0.12420815229415894 train acc 0.9298141365336658
epoch 43 train acc 0.922812226357268


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 43 test acc 0.36870965375586856


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 44 batch id 1 loss 0.1602238267660141 train acc 0.9453125
epoch 44 batch id 201 loss 0.10741034895181656 train acc 0.9385494402985075
epoch 44 batch id 401 loss 0.22939543426036835 train acc 0.9328144482543641
epoch 44 train acc 0.9275599277583187


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 44 test acc 0.38191387910798125


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 45 batch id 1 loss 0.12739011645317078 train acc 0.953125
epoch 45 batch id 201 loss 0.0785069614648819 train acc 0.9423779539800995
epoch 45 batch id 401 loss 0.15845851600170135 train acc 0.9366719918952618
epoch 45 train acc 0.9329233253064798


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 45 test acc 0.3862602699530517


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 46 batch id 1 loss 0.08957088738679886 train acc 0.97265625
epoch 46 batch id 201 loss 0.10914883762598038 train acc 0.9450404228855721
epoch 46 batch id 401 loss 0.20106439292430878 train acc 0.939496960723192
epoch 46 train acc 0.9351808778458844


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 46 test acc 0.3882775821596244


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 47 batch id 1 loss 0.16324204206466675 train acc 0.94921875
epoch 47 batch id 201 loss 0.05922670662403107 train acc 0.9464202425373134
epoch 47 batch id 401 loss 0.11799628287553787 train acc 0.941708229426434
epoch 47 train acc 0.9376710267075307


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 47 test acc 0.3868654636150235


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 48 batch id 1 loss 0.09505659341812134 train acc 0.97265625
epoch 48 batch id 201 loss 0.06971172988414764 train acc 0.9496657338308457
epoch 48 batch id 401 loss 0.05384187400341034 train acc 0.9456631857855362
epoch 48 train acc 0.941768826619965


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 48 test acc 0.37773254107981225


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 49 batch id 1 loss 0.15324154496192932 train acc 0.953125
epoch 49 batch id 201 loss 0.05634866654872894 train acc 0.9517063121890548
epoch 49 batch id 401 loss 0.11851541697978973 train acc 0.9470172225685786
epoch 49 train acc 0.943650120402802


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 49 test acc 0.3922388497652582


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 50 batch id 1 loss 0.11813797056674957 train acc 0.9765625
epoch 50 batch id 201 loss 0.13471773266792297 train acc 0.9539606654228856
epoch 50 batch id 401 loss 0.09298986196517944 train acc 0.949608400872818
epoch 50 train acc 0.9466054619089317


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 50 test acc 0.3951914612676056


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 51 batch id 1 loss 0.06220278516411781 train acc 0.97265625
epoch 51 batch id 201 loss 0.06649383902549744 train acc 0.9556902985074627
epoch 51 batch id 401 loss 0.10105253010988235 train acc 0.9501052057356608
epoch 51 train acc 0.9476179400175131


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 51 test acc 0.39311913145539906


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 52 batch id 1 loss 0.12714733183383942 train acc 0.95703125
epoch 52 batch id 201 loss 0.07626638561487198 train acc 0.9583138992537313
epoch 52 batch id 401 loss 0.0661502480506897 train acc 0.9537581826683291
epoch 52 train acc 0.9510589973730298


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 52 test acc 0.4063233568075117


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 53 batch id 1 loss 0.028043143451213837 train acc 0.99609375
epoch 53 batch id 201 loss 0.039955493062734604 train acc 0.958877487562189
epoch 53 batch id 401 loss 0.09627261757850647 train acc 0.9548297225685786
epoch 53 train acc 0.9524135288966725


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 53 test acc 0.408487382629108


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 54 batch id 1 loss 0.11871993541717529 train acc 0.96875
epoch 54 batch id 201 loss 0.07240745425224304 train acc 0.9614233519900498
epoch 54 batch id 401 loss 0.09415408968925476 train acc 0.9562422069825436
epoch 54 train acc 0.9537064908056042


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 54 test acc 0.4076987969483568


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 55 batch id 1 loss 0.06849879771471024 train acc 0.9765625
epoch 55 batch id 201 loss 0.06565308570861816 train acc 0.9614622201492538
epoch 55 batch id 401 loss 0.05803997442126274 train acc 0.9578397755610972
epoch 55 train acc 0.9555946256567426


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 55 test acc 0.4134206279342723


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 56 batch id 1 loss 0.06138638034462929 train acc 0.97265625
epoch 56 batch id 201 loss 0.06342308223247528 train acc 0.964804881840796
epoch 56 batch id 401 loss 0.07977832853794098 train acc 0.9597880299251871
epoch 56 train acc 0.9573322570052539


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 56 test acc 0.41125660211267606


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 57 batch id 1 loss 0.08100521564483643 train acc 0.96875
epoch 57 batch id 201 loss 0.06054133176803589 train acc 0.9650769589552238
epoch 57 batch id 401 loss 0.11729680001735687 train acc 0.9611810317955112
epoch 57 train acc 0.9587757224168126


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 57 test acc 0.41859228286384975


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 58 batch id 1 loss 0.09522171318531036 train acc 0.96875
epoch 58 batch id 201 loss 0.0745605081319809 train acc 0.9664373445273632
epoch 58 batch id 401 loss 0.05822492018342018 train acc 0.9623889495012469
epoch 58 train acc 0.9604449430823118


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 58 test acc 0.4135306631455399


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 59 batch id 1 loss 0.07691887021064758 train acc 0.96875
epoch 59 batch id 201 loss 0.041947055608034134 train acc 0.9683224502487562
epoch 59 batch id 401 loss 0.12087118625640869 train acc 0.9636358322942643
epoch 59 train acc 0.9619910245183888


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 59 test acc 0.4268632629107981


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 60 batch id 1 loss 0.03526700660586357 train acc 0.98828125
epoch 60 batch id 201 loss 0.039717208594083786 train acc 0.9699354788557214
epoch 60 batch id 401 loss 0.043531592935323715 train acc 0.964454099127182
epoch 60 train acc 0.9634618542031523


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 60 test acc 0.4221867664319249


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 61 batch id 1 loss 0.03065330907702446 train acc 0.984375
epoch 61 batch id 201 loss 0.04426528513431549 train acc 0.9707322761194029
epoch 61 batch id 401 loss 0.05700656771659851 train acc 0.9668017456359103
epoch 61 train acc 0.9657946584938704


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 61 test acc 0.429137323943662


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 62 batch id 1 loss 0.03020462952554226 train acc 0.98828125
epoch 62 batch id 201 loss 0.027803895995020866 train acc 0.9740555037313433
epoch 62 batch id 401 loss 0.09027642756700516 train acc 0.9687402587281796
epoch 62 train acc 0.9673475809982487


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 62 test acc 0.42297535211267606


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 63 batch id 1 loss 0.08149438351392746 train acc 0.9765625
epoch 63 batch id 201 loss 0.04352555423974991 train acc 0.9741526741293532
epoch 63 batch id 401 loss 0.09244642406702042 train acc 0.969051979426434
epoch 63 train acc 0.9682848073555166


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 63 test acc 0.4271200117370892


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 64 batch id 1 loss 0.050388772040605545 train acc 0.9921875
epoch 64 batch id 201 loss 0.052458424121141434 train acc 0.9753575870646766
epoch 64 batch id 401 loss 0.06737179309129715 train acc 0.9709807512468828
epoch 64 train acc 0.9703371278458844


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 64 test acc 0.4366930751173709


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 65 batch id 1 loss 0.04165555536746979 train acc 0.98828125
epoch 65 batch id 201 loss 0.027341479435563087 train acc 0.9756685323383084
epoch 65 batch id 401 loss 0.0352577306330204 train acc 0.9716821228179551
epoch 65 train acc 0.9710622810858144


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 65 test acc 0.440012470657277


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 66 batch id 1 loss 0.02433394081890583 train acc 0.9921875
epoch 66 batch id 201 loss 0.032025884836912155 train acc 0.9777091106965174
epoch 66 batch id 401 loss 0.020268283784389496 train acc 0.9730166770573566
epoch 66 train acc 0.9722252626970228


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 66 test acc 0.43484081572769956


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 67 batch id 1 loss 0.02385527640581131 train acc 0.98828125
epoch 67 batch id 201 loss 0.016237309202551842 train acc 0.9793221393034826
epoch 67 batch id 401 loss 0.06965070962905884 train acc 0.9748870012468828
epoch 67 train acc 0.9742502189141856


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 67 test acc 0.4372982687793427


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 68 batch id 1 loss 0.036192767322063446 train acc 0.9921875
epoch 68 batch id 201 loss 0.071394182741642 train acc 0.9788362873134329
epoch 68 batch id 401 loss 0.09877200424671173 train acc 0.9747895885286783
epoch 68 train acc 0.9744280866900175


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 68 test acc 0.4436803110328638


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 69 batch id 1 loss 0.06002851575613022 train acc 0.98046875
epoch 69 batch id 201 loss 0.030539769679307938 train acc 0.9803327114427861
epoch 69 batch id 401 loss 0.031748153269290924 train acc 0.9763189682044888
epoch 69 train acc 0.9761588769702276


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 69 test acc 0.44843016431924887


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 70 batch id 1 loss 0.04750017076730728 train acc 0.984375
epoch 70 batch id 201 loss 0.02391498163342476 train acc 0.9815570584577115
epoch 70 batch id 401 loss 0.03895651921629906 train acc 0.9771664588528678
epoch 70 train acc 0.9769798051663747


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 70 test acc 0.4485218603286385


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 71 batch id 1 loss 0.05840767174959183 train acc 0.98828125
epoch 71 batch id 201 loss 0.011827567592263222 train acc 0.982548196517413
epoch 71 batch id 401 loss 0.02418936975300312 train acc 0.9783938591022444
epoch 71 train acc 0.9786490258318739


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 71 test acc 0.44775161384976525


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 72 batch id 1 loss 0.022333340719342232 train acc 0.9921875
epoch 72 batch id 201 loss 0.03693222999572754 train acc 0.9829951803482587
epoch 72 batch id 401 loss 0.029642274603247643 train acc 0.9784425654613467
epoch 72 train acc 0.9783138134851138


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 72 test acc 0.45286825117370894


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 73 batch id 1 loss 0.025337662547826767 train acc 0.9921875
epoch 73 batch id 201 loss 0.03162912279367447 train acc 0.9839668843283582
epoch 73 batch id 401 loss 0.019901171326637268 train acc 0.9798453086034913
epoch 73 train acc 0.979346814798599


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 73 test acc 0.4515845070422535


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 74 batch id 1 loss 0.011460144072771072 train acc 0.99609375
epoch 74 batch id 201 loss 0.010768814943730831 train acc 0.9854633084577115
epoch 74 batch id 401 loss 0.021874068304896355 train acc 0.9808194357855362
epoch 74 train acc 0.9806293138971877


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 74 test acc 0.4553257042253521


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 75 batch id 1 loss 0.013295423239469528 train acc 0.9921875
epoch 75 batch id 201 loss 0.035362306982278824 train acc 0.9856965174129353
epoch 75 batch id 401 loss 0.0313744843006134 train acc 0.981228569201995
epoch 75 train acc 0.9806671409807356


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 75 test acc 0.4579482100938967


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 76 batch id 1 loss 0.029969822615385056 train acc 0.9921875
epoch 76 batch id 201 loss 0.01636248081922531 train acc 0.9855410447761194
epoch 76 batch id 401 loss 0.0337664894759655 train acc 0.981705891521197
epoch 76 train acc 0.9814128174255692


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 76 test acc 0.4558208626760563


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 77 batch id 1 loss 0.0050796205177903175 train acc 1.0
epoch 77 batch id 201 loss 0.014868375845253468 train acc 0.9855410447761194
epoch 77 batch id 401 loss 0.03906727954745293 train acc 0.9822221789276808
epoch 77 train acc 0.9821242885288967


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 77 test acc 0.46108421361502344


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 78 batch id 1 loss 0.03405764698982239 train acc 0.9921875
epoch 78 batch id 201 loss 0.036805201321840286 train acc 0.9874455845771144
epoch 78 batch id 401 loss 0.016933556646108627 train acc 0.98285536159601
epoch 78 train acc 0.9822063813485113


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 78 test acc 0.46566901408450706


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 79 batch id 1 loss 0.033105406910181046 train acc 0.9921875
epoch 79 batch id 201 loss 0.020324978977441788 train acc 0.9881452114427861
epoch 79 batch id 401 loss 0.03439510986208916 train acc 0.9841606920199502
epoch 79 train acc 0.9835130253940455


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 79 test acc 0.46651261737089206


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 80 batch id 1 loss 0.030781211331486702 train acc 0.98828125
epoch 80 batch id 201 loss 0.00859143491834402 train acc 0.9881452114427861
epoch 80 batch id 401 loss 0.05182848870754242 train acc 0.983420355361596
epoch 80 train acc 0.9834719789842382


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 80 test acc 0.4677963615023474


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 81 batch id 1 loss 0.03835470974445343 train acc 0.98046875
epoch 81 batch id 201 loss 0.014789344742894173 train acc 0.9879120024875622
epoch 81 batch id 401 loss 0.011390969157218933 train acc 0.9840145729426434
epoch 81 train acc 0.9840329465849387


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 81 test acc 0.4675396126760563


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 82 batch id 1 loss 0.027459483593702316 train acc 0.9921875
epoch 82 batch id 201 loss 0.03364093601703644 train acc 0.9896999378109452
epoch 82 batch id 401 loss 0.04489676654338837 train acc 0.9854952462593516
epoch 82 train acc 0.9851069943082311


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 82 test acc 0.47109741784037557


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 83 batch id 1 loss 0.01139016728848219 train acc 0.99609375
epoch 83 batch id 201 loss 0.016587423160672188 train acc 0.9898359763681592
epoch 83 batch id 401 loss 0.06729397922754288 train acc 0.9856413653366584
epoch 83 train acc 0.9858116243432574


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 83 test acc 0.4732797828638498


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 84 batch id 1 loss 0.02651204913854599 train acc 0.99609375
epoch 84 batch id 201 loss 0.04673678055405617 train acc 0.9906327736318408
epoch 84 batch id 401 loss 0.036204058676958084 train acc 0.9867031639650873
epoch 84 train acc 0.9866393936077058


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 84 test acc 0.47183098591549294


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 85 batch id 1 loss 0.017416104674339294 train acc 0.99609375
epoch 85 batch id 201 loss 0.009548633359372616 train acc 0.9910214552238806
epoch 85 batch id 401 loss 0.04315720871090889 train acc 0.987044108478803
epoch 85 train acc 0.9864820490367776


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 85 test acc 0.4756455399061033


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 86 batch id 1 loss 0.027098210528492928 train acc 0.98828125
epoch 86 batch id 201 loss 0.054471705108881 train acc 0.9913518345771144
epoch 86 batch id 401 loss 0.0072231292724609375 train acc 0.9872778990024937
epoch 86 train acc 0.9870635398423818


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 86 test acc 0.47729606807511743


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 87 batch id 1 loss 0.011103828437626362 train acc 0.99609375
epoch 87 batch id 201 loss 0.013006070628762245 train acc 0.9917405161691543
epoch 87 batch id 401 loss 0.05858512222766876 train acc 0.9880182356608479
epoch 87 train acc 0.9880554947460596


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 87 test acc 0.47698430164319244


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 88 batch id 1 loss 0.03917921334505081 train acc 0.9921875
epoch 88 batch id 201 loss 0.015091443434357643 train acc 0.992129197761194
epoch 88 batch id 401 loss 0.027788346633315086 train acc 0.9882520261845387
epoch 88 train acc 0.9879870840630472


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 88 test acc 0.4789832746478873


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 89 batch id 1 loss 0.027191750705242157 train acc 0.99609375
epoch 89 batch id 201 loss 0.024708181619644165 train acc 0.992168065920398
epoch 89 batch id 401 loss 0.034733764827251434 train acc 0.9884273690773068
epoch 89 train acc 0.9882880910683012


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 89 test acc 0.47845143779342725


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 90 batch id 1 loss 0.003519469639286399 train acc 1.0
epoch 90 batch id 201 loss 0.006113994400948286 train acc 0.9926927860696517
epoch 90 batch id 401 loss 0.02333684265613556 train acc 0.9889046913965087
epoch 90 train acc 0.9891090192644484


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 90 test acc 0.48397153755868544


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 91 batch id 1 loss 0.01332548912614584 train acc 0.9921875
epoch 91 batch id 201 loss 0.016553092747926712 train acc 0.9931203358208955
epoch 91 batch id 401 loss 0.03950962796807289 train acc 0.989235894638404
epoch 91 train acc 0.9890269264448336


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 91 test acc 0.4829628814553991


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 92 batch id 1 loss 0.025156468152999878 train acc 0.9921875
epoch 92 batch id 201 loss 0.002667473629117012 train acc 0.9936644900497512
epoch 92 batch id 401 loss 0.04217345267534256 train acc 0.9892846009975063
epoch 92 train acc 0.9892047942206655


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 92 test acc 0.4835130575117371


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 93 batch id 1 loss 0.009091775864362717 train acc 0.99609375
epoch 93 batch id 201 loss 0.003493569092825055 train acc 0.9937422263681592
epoch 93 batch id 401 loss 0.00793592818081379 train acc 0.98980088840399
epoch 93 train acc 0.9897999671628721


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 93 test acc 0.4873276115023474


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 94 batch id 1 loss 0.035508736968040466 train acc 0.98828125
epoch 94 batch id 201 loss 0.056384939700365067 train acc 0.9939365671641791
epoch 94 batch id 401 loss 0.022524822503328323 train acc 0.9900249376558603
epoch 94 train acc 0.9900599277583187


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 94 test acc 0.48576877934272306


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 95 batch id 1 loss 0.01610533893108368 train acc 0.9921875
epoch 95 batch id 201 loss 0.015560867264866829 train acc 0.994130907960199
epoch 95 batch id 401 loss 0.0234413743019104 train acc 0.9904145885286783
epoch 95 train acc 0.9903677758318739


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 95 test acc 0.48571376173708924


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 96 batch id 1 loss 0.009488388895988464 train acc 1.0
epoch 96 batch id 201 loss 0.011491245590150356 train acc 0.994072605721393
epoch 96 batch id 401 loss 0.017293116077780724 train acc 0.9902976932668329
epoch 96 train acc 0.9905866900175131


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 96 test acc 0.48633729460093894


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 97 batch id 1 loss 0.011571499519050121 train acc 0.99609375
epoch 97 batch id 201 loss 0.017280692234635353 train acc 0.9946750621890548
epoch 97 batch id 401 loss 0.0038100741803646088 train acc 0.9909893235660848
epoch 97 train acc 0.9910039951838879


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 97 test acc 0.4874926643192488


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 98 batch id 1 loss 0.017861830070614815 train acc 0.99609375
epoch 98 batch id 201 loss 0.04846971109509468 train acc 0.9951026119402985
epoch 98 batch id 401 loss 0.010461388155817986 train acc 0.991232855361596
epoch 98 train acc 0.9913597307355516


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 98 test acc 0.48870305164319244


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 99 batch id 1 loss 0.03862973675131798 train acc 0.984375
epoch 99 batch id 201 loss 0.049384962767362595 train acc 0.9948111007462687
epoch 99 batch id 401 loss 0.016195004805922508 train acc 0.9909406172069826
epoch 99 train acc 0.9911818629597198


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 99 test acc 0.48928990610328643


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 100 batch id 1 loss 0.0042891851626336575 train acc 1.0
epoch 100 batch id 201 loss 0.013006637804210186 train acc 0.9950637437810945
epoch 100 batch id 401 loss 0.011097881942987442 train acc 0.9913692331670823
epoch 100 train acc 0.9914555056917689


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 100 test acc 0.49062866784037557


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 101 batch id 1 loss 0.018272971734404564 train acc 0.9921875
epoch 101 batch id 201 loss 0.003364717587828636 train acc 0.9952969527363185
epoch 101 batch id 401 loss 0.03015202097594738 train acc 0.9917881078553616
epoch 101 train acc 0.9920643607705779


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 101 test acc 0.4909404342723005


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 102 batch id 1 loss 0.03205287456512451 train acc 0.99609375
epoch 102 batch id 201 loss 0.00298166717402637 train acc 0.996054881840796
epoch 102 batch id 401 loss 0.004470945335924625 train acc 0.9922556889027432
epoch 102 train acc 0.9923722088441331


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 102 test acc 0.4917290199530516


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 103 batch id 1 loss 0.032876886427402496 train acc 0.9921875
epoch 103 batch id 201 loss 0.0171615332365036 train acc 0.9956078980099502
epoch 103 batch id 401 loss 0.028477365151047707 train acc 0.9918757793017456
epoch 103 train acc 0.9921122482486865


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 103 test acc 0.4934345657276995


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 104 batch id 1 loss 0.005764907691627741 train acc 0.99609375
epoch 104 batch id 201 loss 0.007229337934404612 train acc 0.996268656716418
epoch 104 batch id 401 loss 0.002915823133662343 train acc 0.9924407730673317
epoch 104 train acc 0.992639010507881


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 104 test acc 0.4933245305164319


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 105 batch id 1 loss 0.0069125257432460785 train acc 0.99609375
epoch 105 batch id 201 loss 0.03612801805138588 train acc 0.9959382773631841
epoch 105 batch id 401 loss 0.007510222494602203 train acc 0.9926940461346634
epoch 105 train acc 0.9929536996497373


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 105 test acc 0.49409477699530513


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 106 batch id 1 loss 0.011751909740269184 train acc 0.99609375
epoch 106 batch id 201 loss 0.0077129630371928215 train acc 0.9961714863184079
epoch 106 batch id 401 loss 0.0013227941235527396 train acc 0.9924310317955112
epoch 106 train acc 0.9927553086690017


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 106 test acc 0.4951951291079812


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 107 batch id 1 loss 0.009542961604893208 train acc 0.99609375
epoch 107 batch id 201 loss 0.016625162214040756 train acc 0.9957828047263682
epoch 107 batch id 401 loss 0.0029496815986931324 train acc 0.9928499064837906
epoch 107 train acc 0.9931520906304728


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 107 test acc 0.4957453051643192


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 108 batch id 1 loss 0.0064863539300858974 train acc 0.99609375
epoch 108 batch id 201 loss 0.0452042780816555 train acc 0.9966767723880597
epoch 108 batch id 401 loss 0.0039053841028362513 train acc 0.9928693890274314
epoch 108 train acc 0.9933025941330998


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 108 test acc 0.49684565727699526


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 109 batch id 1 loss 0.01657538115978241 train acc 0.99609375
epoch 109 batch id 201 loss 0.011562311090528965 train acc 0.9968711131840796
epoch 109 batch id 401 loss 0.003393056569620967 train acc 0.9931518859102244
epoch 109 train acc 0.993398369089317


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 109 test acc 0.49772593896713613


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 110 batch id 1 loss 0.013647891581058502 train acc 0.99609375
epoch 110 batch id 201 loss 0.0030763400718569756 train acc 0.9965796019900498
epoch 110 batch id 401 loss 0.0017527726013213396 train acc 0.9930836970074813
epoch 110 train acc 0.9933915280210157


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 110 test acc 0.4972857981220657


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 111 batch id 1 loss 0.015901723876595497 train acc 0.99609375
epoch 111 batch id 201 loss 0.02960837073624134 train acc 0.9966962064676617
epoch 111 batch id 401 loss 0.0043074446730315685 train acc 0.9930447319201995
epoch 111 train acc 0.9935657740548058


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 111 test acc 0.49827611502347413


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 112 batch id 1 loss 0.015002750791609287 train acc 0.99609375
epoch 112 batch id 201 loss 0.0021541593596339226 train acc 0.9969294154228856
epoch 112 batch id 401 loss 0.005178007762879133 train acc 0.993424641521197
epoch 112 train acc 0.9937677867775832


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 112 test acc 0.4979460093896713


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 113 batch id 1 loss 0.007607935927808285 train acc 0.99609375
epoch 113 batch id 201 loss 0.006596619263291359 train acc 0.9971237562189055
epoch 113 batch id 401 loss 0.026949351653456688 train acc 0.9932492986284289
epoch 113 train acc 0.993733581436077


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 113 test acc 0.4982210974178403


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 114 batch id 1 loss 0.004181691445410252 train acc 1.0
epoch 114 batch id 201 loss 0.024462701752781868 train acc 0.9971820584577115
epoch 114 batch id 401 loss 0.004300333559513092 train acc 0.9935025716957606
epoch 114 train acc 0.9939388134851138


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 114 test acc 0.49838615023474175


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 115 batch id 1 loss 0.006664813030511141 train acc 0.99609375
epoch 115 batch id 201 loss 0.0013954469468444586 train acc 0.9971431902985075
epoch 115 batch id 401 loss 0.010385905392467976 train acc 0.9938240336658354
epoch 115 train acc 0.9941098401926445


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 115 test acc 0.4988262910798122


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 116 batch id 1 loss 0.03601163625717163 train acc 0.99609375
epoch 116 batch id 201 loss 0.02860182337462902 train acc 0.9972986629353234
epoch 116 batch id 401 loss 0.00484533840790391 train acc 0.9936584320448878
epoch 116 train acc 0.9940756348511384


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 116 test acc 0.49844116784037557


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 117 batch id 1 loss 0.004938093945384026 train acc 1.0
epoch 117 batch id 201 loss 0.00224431324750185 train acc 0.9972014925373134
epoch 117 batch id 401 loss 0.0010604605777189136 train acc 0.993794809850374
epoch 117 train acc 0.9940277473730298


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 117 test acc 0.49838615023474175


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 118 batch id 1 loss 0.012295648455619812 train acc 0.99609375
epoch 118 batch id 201 loss 0.005576770752668381 train acc 0.9973375310945274
epoch 118 batch id 401 loss 0.024220656603574753 train acc 0.9936292082294265
epoch 118 train acc 0.9939319724168126


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 118 test acc 0.49871625586854457


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 119 batch id 1 loss 0.005406916607171297 train acc 1.0
epoch 119 batch id 201 loss 0.0020609956700354815 train acc 0.9975707400497512
epoch 119 batch id 401 loss 0.004263220354914665 train acc 0.99370713840399
epoch 119 train acc 0.994096158056042


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 119 test acc 0.49893632629107976


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 120 batch id 1 loss 0.0058753290213644505 train acc 1.0
epoch 120 batch id 201 loss 0.020781101658940315 train acc 0.9971626243781094
epoch 120 batch id 401 loss 0.005928573198616505 train acc 0.9936681733167082
epoch 120 train acc 0.9939935420315237


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 120 test acc 0.49893632629107976


In [11]:
def predict(predict_sentence):
    data = [predict_sentence, '0']
    dataset_another = [data]
    another_test = BERTDataset(dataset_another, 0, 1, tok, vocab, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(another_test, batch_size = batch_size, num_workers = 4)
    
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        
        valid_length = valid_length
        label = label.long().to(device)
        
        out = model(token_ids, valid_length, segment_ids)
        test_eval = []
        for i in out:
            logits = i
            logits = logits.detach().cpu().numpy()
            test_eval.append([np.argmax(logits)])
            
        print(">> 입력하신 내용은 " , test_eval[0] , "입니다.")

In [12]:
predict('나 지금 매우 신난다')

>> 입력하신 내용은  [29] 입니다.
