In [None]:
import pandas as pd

def df_newlabel(df):
  df2 = pd.melt(df, id_vars = df.columns[~df.columns.str.contains('competency')])

  #역량명 숫자로 변경
  df2['variable'] = [int(var[-1]) for var in df2['variable']]
  df2.rename(columns = {'variable': 'label', 'value': 'score'}, inplace = True)

  return df2

def df_preprocessing(df, column_list : list):
  df1 = df.loc[:, column_list]

  #특수문자 및 불필요한 공백제거
  df1['student_assessment'] = df1['student_assessment'].str.replace("[^0-9ㄱ-ㅎㅏ-ㅣ가-힣 ]"," ")
  df1['student_assessment'] = df1['student_assessment'].str.replace(" +"," ")

  add_column = ['\t'.join(list(map(str, df1.loc[x, column_list].tolist()))) for x in range(len(df1))]
  score_list = df['score'] - 1

  data = [[add_column[i],score_list[i]] for i in range(len(df1))]


  return data

In [None]:
df = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/df_3points_subject_concepts.csv")

#사용할 칼럼명 리스트로 지정
column_list = ['label', 'program_category', 'mission_category', 'student_assessment']

#사용할 역량번호 지정
target_number = 1

df2 = df_newlabel(df)
dataset = df_preprocessing(df2, column_list)

# 사용할 dataset
dataset

In [None]:
!pip install mxnet
!pip install gluonnlp==0.8.0
!pip install tqdm pandas
!pip install sentencepiece
!pip install transformers
!pip install torch
!pip install 'git+https://github.com/SKTBrain/KoBERT.git#egg=kobert_tokenizer&subdirectory=kobert_hf'

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

In [None]:
from kobert_tokenizer import KoBERTTokenizer
from transformers import BertModel

from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

device = torch.device("cuda:0")

In [None]:
class BERTSentenceTransform:
    r"""BERT style data transformation.

    Parameters
    ----------
    tokenizer : BERTTokenizer.
        Tokenizer for the sentences.
    max_seq_length : int.
        Maximum sequence length of the sentences.
    pad : bool, default True
        Whether to pad the sentences to maximum length.
    pair : bool, default True
        Whether to transform sentences or sentence pairs.
    """

    def __init__(self, tokenizer, max_seq_length,vocab, pad=True, pair=True):
        self._tokenizer = tokenizer
        self._max_seq_length = max_seq_length
        self._pad = pad
        self._pair = pair
        self._vocab = vocab

    def __call__(self, line):
        """Perform transformation for sequence pairs or single sequences.

        The transformation is processed in the following steps:
        - tokenize the input sequences
        - insert [CLS], [SEP] as necessary
        - generate type ids to indicate whether a token belongs to the first
        sequence or the second sequence.
        - generate valid length

        For sequence pairs, the input is a tuple of 2 strings:
        text_a, text_b.

        Inputs:
            text_a: 'is this jacksonville ?'
            text_b: 'no it is not'
        Tokenization:
            text_a: 'is this jack ##son ##ville ?'
            text_b: 'no it is not .'
        Processed:
            tokens: '[CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]'
            type_ids: 0     0  0    0    0     0       0 0     1  1  1  1   1 1
            valid_length: 14

        For single sequences, the input is a tuple of single string:
        text_a.

        Inputs:
            text_a: 'the dog is hairy .'
        Tokenization:
            text_a: 'the dog is hairy .'
        Processed:
            text_a: '[CLS] the dog is hairy . [SEP]'
            type_ids: 0     0   0   0  0     0 0
            valid_length: 7

        Parameters
        ----------
        line: tuple of str
            Input strings. For sequence pairs, the input is a tuple of 2 strings:
            (text_a, text_b). For single sequences, the input is a tuple of single
            string: (text_a,).

        Returns
        -------
        np.array: input token ids in 'int32', shape (batch_size, seq_length)
        np.array: valid length in 'int32', shape (batch_size,)
        np.array: input token type ids in 'int32', shape (batch_size, seq_length)

        """

        # convert to unicode
        text_a = line[0]
        if self._pair:
            assert len(line) == 2
            text_b = line[1]

        tokens_a = self._tokenizer.tokenize(text_a)
        tokens_b = None

        if self._pair:
            tokens_b = self._tokenizer(text_b)

        if tokens_b:
            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            self._truncate_seq_pair(tokens_a, tokens_b,
                                    self._max_seq_length - 3)
        else:
            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > self._max_seq_length - 2:
                tokens_a = tokens_a[0:(self._max_seq_length - 2)]

        # The embedding vectors for `type=0` and `type=1` were learned during
        # pre-training and are added to the wordpiece embedding vector
        # (and position vector). This is not *strictly* necessary since
        # the [SEP] token unambiguously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.

        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        #vocab = self._tokenizer.vocab
        vocab = self._vocab
        tokens = []
        tokens.append(vocab.cls_token)
        tokens.extend(tokens_a)
        tokens.append(vocab.sep_token)
        segment_ids = [0] * len(tokens)

        if tokens_b:
            tokens.extend(tokens_b)
            tokens.append(vocab.sep_token)
            segment_ids.extend([1] * (len(tokens) - len(segment_ids)))

        input_ids = self._tokenizer.convert_tokens_to_ids(tokens)

        # The valid length of sentences. Only real  tokens are attended to.
        valid_length = len(input_ids)

        if self._pad:
            # Zero-pad up to the sequence length.
            padding_length = self._max_seq_length - valid_length
            # use padding tokens for the rest
            input_ids.extend([vocab[vocab.padding_token]] * padding_length)
            segment_ids.extend([0] * padding_length)

        return np.array(input_ids, dtype='int32'), np.array(valid_length, dtype='int32'),\
            np.array(segment_ids, dtype='int32')

# https://blog.naver.com/newyearchive/223097878715

In [None]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
        transform = BERTSentenceTransform(bert_tokenizer, max_seq_length=max_len,vocab=vocab, pad=pad, pair=pair)
        #transform = nlp.data.BERTSentenceTransform(tokenizer, max_seq_length=max_len, pad=pad, pair=pair)
        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))


In [None]:
from imblearn.over_sampling import SMOTE

class BERTOverDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
        transform = BERTSentenceTransform(bert_tokenizer, max_seq_length=max_len,vocab=vocab, pad=pad, pair=pair)
        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

        self.sentences1 = [self.sentences[i][0] for i in range(len(self.sentences))]

        smote = SMOTE(random_state = 2022)
        self.sen1_over, self.labels_over = smote.fit_resample(self.sentences1, self.labels)

        self.sen1_over = [np.array(sen) for sen in self.sen1_over]
        self.sen2_over = [sum(sen != 1) for sen in self.sen1_over]
        self.sen3_over = [np.zeros(max_len) for _ in self.sen1_over]

        self.sen_over = [(self.sen1_over[i], np.array(self.sen2_over[i]), self.sen3_over[i]) for i in range(len(self.sen1_over))]

    def __getitem__(self, i):
      return (self.sen_over[i] + (self.labels_over[i], ))

    def __len__(self):
      return (len(self.labels_over))

In [None]:
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
bertmodel = BertModel.from_pretrained('skt/kobert-base-v1', return_dict=False)
vocab = nlp.vocab.BERTVocab.from_sentencepiece(tokenizer.vocab_file, padding_token='[PAD]')

In [None]:
# Setting parameters
max_len = 256
batch_size = 40
warmup_ratio = 0.1
num_epochs = 50
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [None]:
from sklearn.model_selection import train_test_split

dataset_train, dataset_test = train_test_split(dataset, test_size=0.2, shuffle=True, random_state=34)

In [None]:
data_train = BERTOverDataset(dataset_train, 0, 1, tokenizer, vocab, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tokenizer, vocab, max_len, True, False)

In [None]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5, shuffle=True)

In [None]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=3,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate

        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)

    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)

        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [None]:
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)

In [None]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [None]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [None]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [None]:
train_history=[]
test_history=[]
train_loss_history=[]
test_loss_history=[]

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    train_history.append(train_acc / (batch_id+1))
    train_loss_history.append(loss.data.cpu().numpy())

    model.eval()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    test_history.append(test_acc / (batch_id+1))
    test_loss_history.append(loss.data.cpu().numpy())

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 1.2036504745483398 train acc 0.275
epoch 1 batch id 201 loss 0.9596861600875854 train acc 0.41293532338308453
epoch 1 batch id 401 loss 0.9353935122489929 train acc 0.4930798004987531
epoch 1 batch id 601 loss 0.755039393901825 train acc 0.5271214642262895
epoch 1 batch id 801 loss 0.7166288495063782 train acc 0.5612671660424466
epoch 1 batch id 1001 loss 0.6408514976501465 train acc 0.5860639360639364
epoch 1 train acc 0.6034582882774645


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 1 test acc 0.6941397849462366


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 0.5409797430038452 train acc 0.775
epoch 2 batch id 201 loss 0.6779986023902893 train acc 0.7094527363184081
epoch 2 batch id 401 loss 0.5835703611373901 train acc 0.7084164588528675
epoch 2 batch id 601 loss 0.532719075679779 train acc 0.7087354409317795
epoch 2 batch id 801 loss 0.5851322412490845 train acc 0.710923845193507
epoch 2 batch id 1001 loss 0.6342829465866089 train acc 0.7137612387612379
epoch 2 train acc 0.7160969602306854


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 2 test acc 0.7483333333333335


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 0.4878620207309723 train acc 0.725
epoch 3 batch id 201 loss 0.4472256600856781 train acc 0.7398009950248763
epoch 3 batch id 401 loss 0.5891748070716858 train acc 0.7440149625935165
epoch 3 batch id 601 loss 0.617524266242981 train acc 0.7468386023294511
epoch 3 batch id 801 loss 0.5051108002662659 train acc 0.7469725343320845
epoch 3 batch id 1001 loss 0.3790585994720459 train acc 0.7491258741258733
epoch 3 train acc 0.750598742440626


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 3 test acc 0.7739784946236556


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 0.511139988899231 train acc 0.75
epoch 4 batch id 201 loss 0.4526115357875824 train acc 0.7695273631840797
epoch 4 batch id 401 loss 0.4833505153656006 train acc 0.773628428927681
epoch 4 batch id 601 loss 0.6087675094604492 train acc 0.7678036605657241
epoch 4 batch id 801 loss 0.5608905553817749 train acc 0.7672908863920107
epoch 4 batch id 1001 loss 0.5688937306404114 train acc 0.7667582417582423
epoch 4 train acc 0.7672894389042422


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 4 test acc 0.7771505376344089


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 0.5206261873245239 train acc 0.775
epoch 5 batch id 201 loss 0.6392233967781067 train acc 0.7853233830845774
epoch 5 batch id 401 loss 0.45408838987350464 train acc 0.7870947630922694
epoch 5 batch id 601 loss 0.6224619150161743 train acc 0.7838186356073211
epoch 5 batch id 801 loss 0.41615191102027893 train acc 0.784987515605493
epoch 5 batch id 1001 loss 0.4623939096927643 train acc 0.7805194805194802
epoch 5 train acc 0.7806239737274215


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 5 test acc 0.7868817204301076


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 0.32663583755493164 train acc 0.95
epoch 6 batch id 201 loss 0.45207539200782776 train acc 0.8092039800995023
epoch 6 batch id 401 loss 0.35141631960868835 train acc 0.8084164588528677
epoch 6 batch id 601 loss 0.32077234983444214 train acc 0.8057404326123123
epoch 6 batch id 801 loss 0.38505715131759644 train acc 0.8046192259675415
epoch 6 batch id 1001 loss 0.477282851934433 train acc 0.8044455544455555
epoch 6 train acc 0.8040199847811296


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 6 test acc 0.7781720430107526


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 0.5208407640457153 train acc 0.775
epoch 7 batch id 201 loss 0.2501903772354126 train acc 0.846766169154229
epoch 7 batch id 401 loss 0.3872324824333191 train acc 0.8415211970074815
epoch 7 batch id 601 loss 0.45837411284446716 train acc 0.8389767054908488
epoch 7 batch id 801 loss 0.5234832167625427 train acc 0.836173533083645
epoch 7 batch id 1001 loss 0.2841983735561371 train acc 0.8330919080919077
epoch 7 train acc 0.8323841563538783


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 7 test acc 0.7725806451612903


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 0.17031648755073547 train acc 0.95
epoch 8 batch id 201 loss 0.35196664929389954 train acc 0.8631840796019902
epoch 8 batch id 401 loss 0.35347187519073486 train acc 0.8637780548628433
epoch 8 batch id 601 loss 0.39689499139785767 train acc 0.8648918469217975
epoch 8 batch id 801 loss 0.3592988848686218 train acc 0.862921348314607
epoch 8 batch id 1001 loss 0.2805764377117157 train acc 0.8623126873126874
epoch 8 train acc 0.8621590772557974


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 8 test acc 0.7796236559139786


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 0.1309303641319275 train acc 1.0
epoch 9 batch id 201 loss 0.16295352578163147 train acc 0.8919154228855725
epoch 9 batch id 401 loss 0.19008783996105194 train acc 0.8923316708229426
epoch 9 batch id 601 loss 0.4680912494659424 train acc 0.8932612312811975
epoch 9 batch id 801 loss 0.35133832693099976 train acc 0.8926966292134828
epoch 9 batch id 1001 loss 0.15099407732486725 train acc 0.8908591408591403
epoch 9 train acc 0.8905632984901265


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 9 test acc 0.759408602150538


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 0.4008867144584656 train acc 0.85
epoch 10 batch id 201 loss 0.20446555316448212 train acc 0.926990049751244
epoch 10 batch id 401 loss 0.07428698986768723 train acc 0.9241895261845383
epoch 10 batch id 601 loss 0.1238425150513649 train acc 0.9247088186356073
epoch 10 batch id 801 loss 0.15516944229602814 train acc 0.9224406991260915
epoch 10 batch id 1001 loss 0.3347626030445099 train acc 0.9204295704295694
epoch 10 train acc 0.9192638872201516


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 10 test acc 0.7813978494623657


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 11 batch id 1 loss 0.12113456428050995 train acc 0.95
epoch 11 batch id 201 loss 0.06698112934827805 train acc 0.9404228855721387
epoch 11 batch id 401 loss 0.1512073427438736 train acc 0.9380922693266832
epoch 11 batch id 601 loss 0.0846952572464943 train acc 0.9376871880199678
epoch 11 batch id 801 loss 0.1299075335264206 train acc 0.9351747815230969
epoch 11 batch id 1001 loss 0.14926239848136902 train acc 0.9346653346653365
epoch 11 train acc 0.933877007489288


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 11 test acc 0.7394623655913981


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 12 batch id 1 loss 0.14149686694145203 train acc 0.975
epoch 12 batch id 201 loss 0.039885155856609344 train acc 0.953606965174128
epoch 12 batch id 401 loss 0.19454821944236755 train acc 0.95286783042394
epoch 12 batch id 601 loss 0.168752983212471 train acc 0.9503327787021645
epoch 12 batch id 801 loss 0.08864608407020569 train acc 0.9490012484394537
epoch 12 batch id 1001 loss 0.24766261875629425 train acc 0.9477272727272757
epoch 12 train acc 0.9472475870078921


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 12 test acc 0.772311827956989


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 13 batch id 1 loss 0.0976996123790741 train acc 0.975
epoch 13 batch id 201 loss 0.09068536013364792 train acc 0.9610696517412924
epoch 13 batch id 401 loss 0.1145225390791893 train acc 0.9603491271820457
epoch 13 batch id 601 loss 0.251123309135437 train acc 0.9579866888519167
epoch 13 batch id 801 loss 0.05125683546066284 train acc 0.9568352059925151
epoch 13 batch id 1001 loss 0.04151182994246483 train acc 0.9554695304695376
epoch 13 train acc 0.9553886819656427


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 13 test acc 0.7546236559139783


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 14 batch id 1 loss 0.02200383134186268 train acc 1.0
epoch 14 batch id 201 loss 0.19382189214229584 train acc 0.9648009950248737
epoch 14 batch id 401 loss 0.02460658736526966 train acc 0.963715710723192
epoch 14 batch id 601 loss 0.07713816314935684 train acc 0.9623128119800362
epoch 14 batch id 801 loss 0.029821038246154785 train acc 0.961548064918857
epoch 14 batch id 1001 loss 0.19607213139533997 train acc 0.9599650349650412
epoch 14 train acc 0.9597560975609801


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 14 test acc 0.7771505376344089


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 15 batch id 1 loss 0.11177031695842743 train acc 0.95
epoch 15 batch id 201 loss 0.07705436646938324 train acc 0.9692786069651724
epoch 15 batch id 401 loss 0.02000303566455841 train acc 0.9668329177057364
epoch 15 batch id 601 loss 0.07072930037975311 train acc 0.9672628951747123
epoch 15 batch id 801 loss 0.2073136270046234 train acc 0.9669475655430767
epoch 15 batch id 1001 loss 0.2438034564256668 train acc 0.9658341658341726
epoch 15 train acc 0.9654962153069846


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 15 test acc 0.7675268817204303


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 16 batch id 1 loss 0.0732702910900116 train acc 0.975
epoch 16 batch id 201 loss 0.008828194811940193 train acc 0.9743781094527352
epoch 16 batch id 401 loss 0.07232610881328583 train acc 0.9726309226932678
epoch 16 batch id 601 loss 0.09124588221311569 train acc 0.970923460898506
epoch 16 batch id 801 loss 0.1112293154001236 train acc 0.9691947565543135
epoch 16 batch id 1001 loss 0.05422382429242134 train acc 0.9688311688311769
epoch 16 train acc 0.9678531378909904


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 16 test acc 0.7610215053763442


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 17 batch id 1 loss 0.03753627836704254 train acc 0.975
epoch 17 batch id 201 loss 0.05875714495778084 train acc 0.9766169154228841
epoch 17 batch id 401 loss 0.006132944021373987 train acc 0.9738154613466342
epoch 17 batch id 601 loss 0.12426288425922394 train acc 0.9747088186356104
epoch 17 batch id 801 loss 0.09371401369571686 train acc 0.9746254681648003
epoch 17 batch id 1001 loss 0.19432172179222107 train acc 0.9737262737262814
epoch 17 train acc 0.9731917577796502


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 17 test acc 0.7690860215053762


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 18 batch id 1 loss 0.015900786966085434 train acc 1.0
epoch 18 batch id 201 loss 0.003027424681931734 train acc 0.9808457711442772
epoch 18 batch id 401 loss 0.04822690784931183 train acc 0.9779301745635924
epoch 18 batch id 601 loss 0.021907588467001915 train acc 0.9779534109817019
epoch 18 batch id 801 loss 0.17949743568897247 train acc 0.9767166042447006
epoch 18 batch id 1001 loss 0.014859241433441639 train acc 0.9761988011988086
epoch 18 train acc 0.975567703952905


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 18 test acc 0.7680107526881718


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 19 batch id 1 loss 0.025744367390871048 train acc 1.0
epoch 19 batch id 201 loss 0.19271469116210938 train acc 0.9778606965174115
epoch 19 batch id 401 loss 0.020069705322384834 train acc 0.9785536159601009
epoch 19 batch id 601 loss 0.006051437463611364 train acc 0.9792429284525832
epoch 19 batch id 801 loss 0.06414724886417389 train acc 0.9795568039950127
epoch 19 batch id 1001 loss 0.22443966567516327 train acc 0.9788461538461611
epoch 19 train acc 0.9784272497897428


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 19 test acc 0.7584946236559139


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 20 batch id 1 loss 0.09310594201087952 train acc 0.95
epoch 20 batch id 201 loss 0.010314619168639183 train acc 0.9809701492537299
epoch 20 batch id 401 loss 0.002484747441485524 train acc 0.9821695760598511
epoch 20 batch id 601 loss 0.07194430381059647 train acc 0.9819467554076576
epoch 20 batch id 801 loss 0.07378576695919037 train acc 0.9817103620474464
epoch 20 batch id 1001 loss 0.21043142676353455 train acc 0.9813936063936136
epoch 20 train acc 0.9805508830950412


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 20 test acc 0.7715591397849458


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 21 batch id 1 loss 0.012612788006663322 train acc 1.0
epoch 21 batch id 201 loss 0.0025513870641589165 train acc 0.9856965174129342
epoch 21 batch id 401 loss 0.041329219937324524 train acc 0.984413965087283
epoch 21 batch id 601 loss 0.19196437299251556 train acc 0.9839434276206365
epoch 21 batch id 801 loss 0.013666734099388123 train acc 0.9830524344569355
epoch 21 batch id 1001 loss 0.2551007866859436 train acc 0.9824675324675403
epoch 21 train acc 0.982590412111022


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 21 test acc 0.772741935483871


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 22 batch id 1 loss 0.0014850596198812127 train acc 1.0
epoch 22 batch id 201 loss 0.007311646826565266 train acc 0.9883084577114415
epoch 22 batch id 401 loss 0.04858039319515228 train acc 0.9870947630922706
epoch 22 batch id 601 loss 0.016771910712122917 train acc 0.9869800332778739
epoch 22 batch id 801 loss 0.04946397989988327 train acc 0.9865480649188574
epoch 22 batch id 1001 loss 0.10990983247756958 train acc 0.9862137862137934
epoch 22 train acc 0.9855550883095072


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 22 test acc 0.766989247311828


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 23 batch id 1 loss 0.005126496776938438 train acc 1.0
epoch 23 batch id 201 loss 0.013599099591374397 train acc 0.988184079601989
epoch 23 batch id 401 loss 0.13753211498260498 train acc 0.9869700748129688
epoch 23 batch id 601 loss 0.16062749922275543 train acc 0.9865224625623994
epoch 23 batch id 801 loss 0.000981172313913703 train acc 0.9864232209737878
epoch 23 batch id 1001 loss 0.0186141524463892 train acc 0.9856643356643418
epoch 23 train acc 0.9855130361648468


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 23 test acc 0.7693548387096778


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 24 batch id 1 loss 0.005518420599400997 train acc 1.0
epoch 24 batch id 201 loss 0.0775839313864708 train acc 0.9888059701492526
epoch 24 batch id 401 loss 0.004842668771743774 train acc 0.9889650872817962
epoch 24 batch id 601 loss 0.09618563205003738 train acc 0.9889351081530813
epoch 24 batch id 801 loss 0.0011709454702213407 train acc 0.988171036204749
epoch 24 batch id 1001 loss 0.008606022223830223 train acc 0.9879620379620443
epoch 24 train acc 0.9880992430613987


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 24 test acc 0.7612365591397849


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 25 batch id 1 loss 0.004470780026167631 train acc 1.0
epoch 25 batch id 201 loss 0.0009230317664332688 train acc 0.9895522388059691
epoch 25 batch id 401 loss 0.11395076662302017 train acc 0.9889650872817969
epoch 25 batch id 601 loss 0.0022625592537224293 train acc 0.9887271214642299
epoch 25 batch id 801 loss 0.0018816431984305382 train acc 0.9887016229712909
epoch 25 batch id 1001 loss 0.00276264944113791 train acc 0.9883366633366697
epoch 25 train acc 0.9882063759061253


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 25 test acc 0.7718817204301072


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 26 batch id 1 loss 0.001950285048224032 train acc 1.0
epoch 26 batch id 201 loss 0.01277661882340908 train acc 0.9939054726368153
epoch 26 batch id 401 loss 0.0007568714208900928 train acc 0.9923940149625949
epoch 26 batch id 601 loss 0.0481402613222599 train acc 0.9915557404326159
epoch 26 batch id 801 loss 0.006924021057784557 train acc 0.9917602996254725
epoch 26 batch id 1001 loss 0.004771282896399498 train acc 0.9913086913086965
epoch 26 train acc 0.9908536585365869


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 26 test acc 0.7598924731182792


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 27 batch id 1 loss 0.002027668757364154 train acc 1.0
epoch 27 batch id 201 loss 0.0006448364583775401 train acc 0.9919154228855711
epoch 27 batch id 401 loss 0.0832677111029625 train acc 0.9921446384039905
epoch 27 batch id 601 loss 0.00165369245223701 train acc 0.9921381031613997
epoch 27 batch id 801 loss 0.1346733570098877 train acc 0.9922908863920131
epoch 27 batch id 1001 loss 0.005566856358200312 train acc 0.992332667332671
epoch 27 train acc 0.9924306139613136


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 27 test acc 0.7701075268817204


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 28 batch id 1 loss 0.0012072976678609848 train acc 1.0
epoch 28 batch id 201 loss 0.00040104746585711837 train acc 0.9911691542288549
epoch 28 batch id 401 loss 0.0004010201955679804 train acc 0.9917705735660858
epoch 28 batch id 601 loss 0.044019270688295364 train acc 0.9915557404326153
epoch 28 batch id 801 loss 0.01639929786324501 train acc 0.9914169787765335
epoch 28 batch id 1001 loss 0.012869101949036121 train acc 0.9913586413586462
epoch 28 train acc 0.9918839360807425


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 28 test acc 0.7745698924731185


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 29 batch id 1 loss 0.0003902489843312651 train acc 1.0
epoch 29 batch id 201 loss 0.00035533306072466075 train acc 0.9962686567164175
epoch 29 batch id 401 loss 0.0018754201009869576 train acc 0.9953241895261858
epoch 29 batch id 601 loss 0.0008952564676292241 train acc 0.9948419301164749
epoch 29 batch id 801 loss 0.17259788513183594 train acc 0.9948501872659206
epoch 29 batch id 1001 loss 0.0009158904431387782 train acc 0.994630369630373
epoch 29 train acc 0.9945332211942814


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 29 test acc 0.7775268817204299


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 30 batch id 1 loss 0.00046018854482099414 train acc 1.0
epoch 30 batch id 201 loss 0.000524514471180737 train acc 0.9937810945273623
epoch 30 batch id 401 loss 0.06485003978013992 train acc 0.9935162094763101
epoch 30 batch id 601 loss 0.0002817701024468988 train acc 0.9941763727121481
epoch 30 batch id 801 loss 0.0021368064917623997 train acc 0.9937578027465697
epoch 30 batch id 1001 loss 0.1301524043083191 train acc 0.9937812187812222
epoch 30 train acc 0.9937762825904128


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 30 test acc 0.7736021505376344


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 31 batch id 1 loss 0.022804081439971924 train acc 0.975
epoch 31 batch id 201 loss 0.0005275590810924768 train acc 0.9958955223880593
epoch 31 batch id 401 loss 0.08976946771144867 train acc 0.9948877805486293
epoch 31 batch id 601 loss 0.000279759697150439 train acc 0.9948003327787043
epoch 31 batch id 801 loss 0.0006433557136915624 train acc 0.9948813982521875
epoch 31 batch id 1001 loss 0.0024276121985167265 train acc 0.9948051948051979
epoch 31 train acc 0.9945962994112704


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 31 test acc 0.7733333333333338


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 32 batch id 1 loss 0.007338608615100384 train acc 1.0
epoch 32 batch id 201 loss 0.0003093223203904927 train acc 0.9962686567164174
epoch 32 batch id 401 loss 0.00014772245776839554 train acc 0.9962593516209484
epoch 32 batch id 601 loss 0.00027103902539238334 train acc 0.9957986688851933
epoch 32 batch id 801 loss 0.00019064932712353766 train acc 0.9956616729088663
epoch 32 batch id 1001 loss 0.0022347024641931057 train acc 0.9955294705294734
epoch 32 train acc 0.9954183187152076


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 32 test acc 0.7722580645161291


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 33 batch id 1 loss 0.004656774457544088 train acc 1.0
epoch 33 batch id 201 loss 0.00023477683134842664 train acc 0.995522388059701
epoch 33 batch id 401 loss 0.00013717389083467424 train acc 0.9958229426433921
epoch 33 batch id 601 loss 0.012105573900043964 train acc 0.9960898502495851
epoch 33 batch id 801 loss 0.0032874825410544872 train acc 0.996098626716606
epoch 33 batch id 1001 loss 0.0002802274248097092 train acc 0.9960789210789232
epoch 33 train acc 0.9961942809083266


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 33 test acc 0.7761827956989249


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 34 batch id 1 loss 0.00017907431174535304 train acc 1.0
epoch 34 batch id 201 loss 0.0001551502209622413 train acc 0.9957711442786062
epoch 34 batch id 401 loss 0.1940433233976364 train acc 0.9956982543640905
epoch 34 batch id 601 loss 0.0003089165547862649 train acc 0.9957570715474228
epoch 34 batch id 801 loss 0.00014437919890042394 train acc 0.9956616729088663
epoch 34 batch id 1001 loss 9.660990326665342e-05 train acc 0.9959040959040983
epoch 34 train acc 0.9956896551724141


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 34 test acc 0.7750537634408603


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 35 batch id 1 loss 0.001206702901981771 train acc 1.0
epoch 35 batch id 201 loss 0.0002384663966950029 train acc 0.997512437810945
epoch 35 batch id 401 loss 0.00015314322081394494 train acc 0.9968827930174571
epoch 35 batch id 601 loss 0.04460858926177025 train acc 0.9966306156406005
epoch 35 batch id 801 loss 0.0001451468124287203 train acc 0.9964731585518123
epoch 35 batch id 1001 loss 0.00503289420157671 train acc 0.9964535464535488
epoch 35 train acc 0.9964465937762836


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 35 test acc 0.7756451612903229


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 36 batch id 1 loss 0.18620240688323975 train acc 0.975
epoch 36 batch id 201 loss 0.0001274961541639641 train acc 0.9971393034825868
epoch 36 batch id 401 loss 9.64366554399021e-05 train acc 0.9968204488778063
epoch 36 batch id 601 loss 8.99368678801693e-05 train acc 0.9970881863560743
epoch 36 batch id 801 loss 0.00020448029681574553 train acc 0.9970349563046208
epoch 36 batch id 1001 loss 0.00013263363507576287 train acc 0.9968781218781239
epoch 36 train acc 0.9969512195121957


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 36 test acc 0.7780107526881721


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 37 batch id 1 loss 8.64663888933137e-05 train acc 1.0
epoch 37 batch id 201 loss 0.00012455807882361114 train acc 0.9981343283582088
epoch 37 batch id 401 loss 0.0002886579604819417 train acc 0.9973192019950133
epoch 37 batch id 601 loss 8.638575673103333e-05 train acc 0.9973377703826968
epoch 37 batch id 801 loss 0.0002710545377340168 train acc 0.9971285892634225
epoch 37 batch id 1001 loss 8.498736133333296e-05 train acc 0.997102897102899
epoch 37 train acc 0.9971614802354924


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 37 test acc 0.7731182795698922


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 38 batch id 1 loss 0.0018535822164267302 train acc 1.0
epoch 38 batch id 201 loss 0.0001248065527761355 train acc 0.9968905472636812
epoch 38 batch id 401 loss 7.417081360472366e-05 train acc 0.9972568578553621
epoch 38 batch id 601 loss 7.967841520439833e-05 train acc 0.9974209650582373
epoch 38 batch id 801 loss 7.431393896695226e-05 train acc 0.997284644194758
epoch 38 batch id 1001 loss 8.568206976633519e-05 train acc 0.9971028971028989
epoch 38 train acc 0.9972245584524815


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 38 test acc 0.7752150537634409


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 39 batch id 1 loss 0.00017027980356942862 train acc 1.0
epoch 39 batch id 201 loss 6.160798511700705e-05 train acc 0.9976368159203978
epoch 39 batch id 401 loss 0.008487144485116005 train acc 0.9973815461346643
epoch 39 batch id 601 loss 6.759454117855057e-05 train acc 0.9975457570715487
epoch 39 batch id 801 loss 0.00791633129119873 train acc 0.9974094881398269
epoch 39 batch id 1001 loss 0.0089622363448143 train acc 0.9974275724275742
epoch 39 train acc 0.9974348191757786


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 39 test acc 0.7817741935483874


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 40 batch id 1 loss 6.339223182294518e-05 train acc 1.0
epoch 40 batch id 201 loss 6.536691944347695e-05 train acc 0.998009950248756
epoch 40 batch id 401 loss 5.743841757066548e-05 train acc 0.9979426433915216
epoch 40 batch id 601 loss 0.00035268370993435383 train acc 0.9980033277870225
epoch 40 batch id 801 loss 0.00020731303084176034 train acc 0.9980649188514368
epoch 40 batch id 1001 loss 0.00018962760805152357 train acc 0.9980269730269743
epoch 40 train acc 0.9980866274179986


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 40 test acc 0.7712903225806451


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 41 batch id 1 loss 0.00017333266441710293 train acc 1.0
epoch 41 batch id 201 loss 5.579069329542108e-05 train acc 0.9981343283582087
epoch 41 batch id 401 loss 5.615159898297861e-05 train acc 0.99781795511222
epoch 41 batch id 601 loss 0.0002799284993670881 train acc 0.9974625623960081
epoch 41 batch id 801 loss 7.389960228465497e-05 train acc 0.9977215980024983
epoch 41 batch id 1001 loss 8.459585660602897e-05 train acc 0.9977022977022992
epoch 41 train acc 0.9977712363330535


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 41 test acc 0.7754301075268815


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 42 batch id 1 loss 5.933234206167981e-05 train acc 1.0
epoch 42 batch id 201 loss 4.748860374093056e-05 train acc 0.9977611940298505
epoch 42 batch id 401 loss 5.851979221915826e-05 train acc 0.9973815461346639
epoch 42 batch id 601 loss 7.297632691916078e-05 train acc 0.9973377703826966
epoch 42 batch id 801 loss 5.0432739953976125e-05 train acc 0.9977528089887651
epoch 42 batch id 1001 loss 4.995261042495258e-05 train acc 0.997827172827174
epoch 42 train acc 0.9978763666947017


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 42 test acc 0.7801075268817202


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 43 batch id 1 loss 5.118681656313129e-05 train acc 1.0
epoch 43 batch id 201 loss 8.357506158063188e-05 train acc 0.9986318407960196
epoch 43 batch id 401 loss 5.016758950660005e-05 train acc 0.9985037406483794
epoch 43 batch id 601 loss 0.0025973371230065823 train acc 0.9981697171381041
epoch 43 batch id 801 loss 5.5868040362838656e-05 train acc 0.9982521847690398
epoch 43 batch id 1001 loss 5.7613884564489126e-05 train acc 0.9983266733266744
epoch 43 train acc 0.9983809924306142


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 43 test acc 0.7800537634408605


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 44 batch id 1 loss 5.075766966911033e-05 train acc 1.0
epoch 44 batch id 201 loss 9.597877215128392e-05 train acc 0.9980099502487559
epoch 44 batch id 401 loss 5.062384298071265e-05 train acc 0.9981296758104743
epoch 44 batch id 601 loss 5.018790398025885e-05 train acc 0.9982945091514149
epoch 44 batch id 801 loss 4.61713680124376e-05 train acc 0.9984082397003753
epoch 44 batch id 1001 loss 4.6424625907093287e-05 train acc 0.9984265734265743
epoch 44 train acc 0.9982968881412949


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 44 test acc 0.7804301075268815


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 45 batch id 1 loss 6.698582001263276e-05 train acc 1.0
epoch 45 batch id 201 loss 4.568247095448896e-05 train acc 0.9985074626865671
epoch 45 batch id 401 loss 0.015364525839686394 train acc 0.998316708229427
epoch 45 batch id 601 loss 3.5309072700329125e-05 train acc 0.9981281198003338
epoch 45 batch id 801 loss 5.6733562814770266e-05 train acc 0.9981897627965054
epoch 45 batch id 1001 loss 0.02054182067513466 train acc 0.9980519480519494
epoch 45 train acc 0.9980235492010093


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 45 test acc 0.7849462365591398


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 46 batch id 1 loss 6.200872303452343e-05 train acc 1.0
epoch 46 batch id 201 loss 4.24045356339775e-05 train acc 0.9976368159203977
epoch 46 batch id 401 loss 4.617689046426676e-05 train acc 0.9977556109725686
epoch 46 batch id 601 loss 3.743976049008779e-05 train acc 0.9979201331114814
epoch 46 batch id 801 loss 4.0899722080212086e-05 train acc 0.9978776529338337
epoch 46 batch id 1001 loss 5.2210605645086616e-05 train acc 0.997902097902099
epoch 46 train acc 0.9979604709840202


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 46 test acc 0.7829569892473122


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 47 batch id 1 loss 3.293389090686105e-05 train acc 1.0
epoch 47 batch id 201 loss 0.018310489133000374 train acc 0.998507462686567
epoch 47 batch id 401 loss 6.006609328323975e-05 train acc 0.9984413965087282
epoch 47 batch id 601 loss 3.79195626010187e-05 train acc 0.9983361064891852
epoch 47 batch id 801 loss 0.018556304275989532 train acc 0.9982833957553067
epoch 47 batch id 1001 loss 4.7267971240216866e-05 train acc 0.9982767232767242
epoch 47 train acc 0.9982968881412954


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 47 test acc 0.7813440860215053


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 48 batch id 1 loss 4.6713386836927384e-05 train acc 1.0
epoch 48 batch id 201 loss 3.9305050449911505e-05 train acc 0.998507462686567
epoch 48 batch id 401 loss 3.557419404387474e-05 train acc 0.9986284289276809
epoch 48 batch id 601 loss 3.710306191351265e-05 train acc 0.9984608985024965
epoch 48 batch id 801 loss 3.889395884471014e-05 train acc 0.9982209737827726
epoch 48 batch id 1001 loss 4.0687922592042014e-05 train acc 0.9983516483516494
epoch 48 train acc 0.9983389402859547


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 48 test acc 0.7817204301075265


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 49 batch id 1 loss 3.9746268157614395e-05 train acc 1.0
epoch 49 batch id 201 loss 3.611657302826643e-05 train acc 0.9983830845771141
epoch 49 batch id 401 loss 0.02048661932349205 train acc 0.9986284289276809
epoch 49 batch id 601 loss 3.946619472117163e-05 train acc 0.9985856905158075
epoch 49 batch id 801 loss 3.2012998417485505e-05 train acc 0.9984706616729097
epoch 49 batch id 1001 loss 3.85364874091465e-05 train acc 0.9984765234765244
epoch 49 train acc 0.9984440706476029


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 49 test acc 0.7820430107526877


  0%|          | 0/1189 [00:00<?, ?it/s]

epoch 50 batch id 1 loss 3.32765412167646e-05 train acc 1.0
epoch 50 batch id 201 loss 0.019973253831267357 train acc 0.9982587064676615
epoch 50 batch id 401 loss 3.6748482671100646e-05 train acc 0.9981296758104743
epoch 50 batch id 601 loss 3.738012310350314e-05 train acc 0.9982113144758744
epoch 50 batch id 801 loss 3.968380769947544e-05 train acc 0.9982521847690398
epoch 50 batch id 1001 loss 4.136448114877567e-05 train acc 0.9983516483516494
epoch 50 train acc 0.9983599663582846


  0%|          | 0/155 [00:00<?, ?it/s]

epoch 50 test acc 0.782043010752688


In [None]:
train_history

[0.6034582882774645,
 0.7160969602306854,
 0.750598742440626,
 0.7672894389042422,
 0.7806239737274215,
 0.8040199847811296,
 0.8323841563538783,
 0.8621590772557974,
 0.8905632984901265,
 0.9192638872201516,
 0.933877007489288,
 0.9472475870078921,
 0.9553886819656427,
 0.9597560975609801,
 0.9654962153069846,
 0.9678531378909904,
 0.9731917577796502,
 0.975567703952905,
 0.9784272497897428,
 0.9805508830950412,
 0.982590412111022,
 0.9855550883095072,
 0.9855130361648468,
 0.9880992430613987,
 0.9882063759061253,
 0.9908536585365869,
 0.9924306139613136,
 0.9918839360807425,
 0.9945332211942814,
 0.9937762825904128,
 0.9945962994112704,
 0.9954183187152076,
 0.9961942809083266,
 0.9956896551724141,
 0.9964465937762836,
 0.9969512195121957,
 0.9971614802354924,
 0.9972245584524815,
 0.9974348191757786,
 0.9980866274179986,
 0.9977712363330535,
 0.9978763666947017,
 0.9983809924306142,
 0.9982968881412949,
 0.9980235492010093,
 0.9979604709840202,
 0.9982968881412954,
 0.99833894028595

In [None]:
test_history

[0.6941397849462366,
 0.7483333333333335,
 0.7739784946236556,
 0.7771505376344089,
 0.7868817204301076,
 0.7781720430107526,
 0.7725806451612903,
 0.7796236559139786,
 0.759408602150538,
 0.7813978494623657,
 0.7394623655913981,
 0.772311827956989,
 0.7546236559139783,
 0.7771505376344089,
 0.7675268817204303,
 0.7610215053763442,
 0.7690860215053762,
 0.7680107526881718,
 0.7584946236559139,
 0.7715591397849458,
 0.772741935483871,
 0.766989247311828,
 0.7693548387096778,
 0.7612365591397849,
 0.7718817204301072,
 0.7598924731182792,
 0.7701075268817204,
 0.7745698924731185,
 0.7775268817204299,
 0.7736021505376344,
 0.7733333333333338,
 0.7722580645161291,
 0.7761827956989249,
 0.7750537634408603,
 0.7756451612903229,
 0.7780107526881721,
 0.7731182795698922,
 0.7752150537634409,
 0.7817741935483874,
 0.7712903225806451,
 0.7754301075268815,
 0.7801075268817202,
 0.7800537634408605,
 0.7804301075268815,
 0.7849462365591398,
 0.7829569892473122,
 0.7813440860215053,
 0.78172043010752

In [None]:
train_loss_history

[array(0.49824136, dtype=float32),
 array(0.4849908, dtype=float32),
 array(0.38496912, dtype=float32),
 array(0.43650296, dtype=float32),
 array(0.42385706, dtype=float32),
 array(0.20700023, dtype=float32),
 array(0.23003152, dtype=float32),
 array(0.3109858, dtype=float32),
 array(0.4040987, dtype=float32),
 array(0.27382204, dtype=float32),
 array(0.17381284, dtype=float32),
 array(0.04715471, dtype=float32),
 array(0.2666063, dtype=float32),
 array(0.01195353, dtype=float32),
 array(0.00885608, dtype=float32),
 array(0.10372794, dtype=float32),
 array(0.02053358, dtype=float32),
 array(0.00463617, dtype=float32),
 array(0.00477684, dtype=float32),
 array(0.00363421, dtype=float32),
 array(0.00127381, dtype=float32),
 array(0.00591659, dtype=float32),
 array(0.00455574, dtype=float32),
 array(0.00082099, dtype=float32),
 array(0.13319547, dtype=float32),
 array(0.00306956, dtype=float32),
 array(0.00976817, dtype=float32),
 array(0.00063378, dtype=float32),
 array(0.00118564, dtype

In [None]:
test_loss_history

[array(0.49824136, dtype=float32),
 array(0.4849908, dtype=float32),
 array(0.38496912, dtype=float32),
 array(0.43650296, dtype=float32),
 array(0.42385706, dtype=float32),
 array(0.20700023, dtype=float32),
 array(0.23003152, dtype=float32),
 array(0.3109858, dtype=float32),
 array(0.4040987, dtype=float32),
 array(0.27382204, dtype=float32),
 array(0.17381284, dtype=float32),
 array(0.04715471, dtype=float32),
 array(0.2666063, dtype=float32),
 array(0.01195353, dtype=float32),
 array(0.00885608, dtype=float32),
 array(0.10372794, dtype=float32),
 array(0.02053358, dtype=float32),
 array(0.00463617, dtype=float32),
 array(0.00477684, dtype=float32),
 array(0.00363421, dtype=float32),
 array(0.00127381, dtype=float32),
 array(0.00591659, dtype=float32),
 array(0.00455574, dtype=float32),
 array(0.00082099, dtype=float32),
 array(0.13319547, dtype=float32),
 array(0.00306956, dtype=float32),
 array(0.00976817, dtype=float32),
 array(0.00063378, dtype=float32),
 array(0.00118564, dtype