In [5]:
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import pandas as pd
import numpy as np
import os
from tqdm import tqdm

def encode_text(sent, tokenizer, max_len):
    '''
    Encode text(arg:sent) by tokenizer. Padding and truncation is true.
    '''
    encoded = tokenizer.encode_plus(
        sent,
        add_special_tokens=True,
        max_length=max_len,
        padding='max_length',
        return_attention_mask=True,
        truncation=True
    )
    return encoded['input_ids'], encoded['attention_mask']

def bert_encode(data, tokenizer, max_len):
    '''
    Encode data(arg:data) with tokenizer.
    '''
    input_ids = []
    attention_masks = []

    for sent in tqdm(data, desc="Encoding"):
        ids, masks = encode_text(sent, tokenizer, max_len)
        input_ids.append(ids)
        attention_masks.append(masks)

    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)

    return input_ids, attention_masks

def train(model, train_dataloader, optimizer, device):
    '''
    Train model.
    '''
    model.train()
    total_loss = 0

    for batch in tqdm(train_dataloader, desc="Training"):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        model.zero_grad()
        outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()

        optimizer.step()

    avg_loss = total_loss / len(train_dataloader)

    return avg_loss

def extract_dimension_labels(encoded_labels, label_encoder):
    '''
    Evaluate accuracy by type.
    '''
    decoded_labels = label_encoder.inverse_transform(encoded_labels)
    # decoded_labels의 각 요소를 문자열로 변환
    decoded_labels = [str(label) for label in decoded_labels]
    print("Decoded labels:", decoded_labels)  # 디버깅용 출력 코드 추가
    
    # 길이가 4인 문자열만 유효한 레이블로 선택
    valid_labels = [label for label in decoded_labels if len(label) == 4]
    
    dimensions = {
        'EI': [1 if label[0] == 'E' else 0 for label in valid_labels],
        'NS': [1 if label[1] == 'N' else 0 for label in valid_labels],
        'FT': [1 if label[2] == 'F' else 0 for label in valid_labels],
        'JP': [1 if label[3] == 'J' else 0 for label in valid_labels]
    }
    return dimensions


def evaluate(model, validation_dataloader, label_encoder, device):
    '''
    Evaluate model. This func output is accuracy, precision, recall, f1 score, accuracy by type.
    '''
    model.eval()
    predictions, true_labels = [], []

    for batch in tqdm(validation_dataloader, desc="Evaluating"):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        with torch.no_grad():
            outputs = model(b_input_ids, attention_mask=b_input_mask)
        
        logits = outputs.logits
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        
        predictions.extend(np.argmax(logits, axis=1).flatten())
        true_labels.extend(label_ids.flatten())

    overall_accuracy = accuracy_score(true_labels, predictions)
    overall_precision, overall_recall, overall_f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted')

    true_dimensions = extract_dimension_labels(true_labels, label_encoder)
    pred_dimensions = extract_dimension_labels(predictions, label_encoder)
    dimension_accuracies = {}

    for dimension, true_dim_labels in true_dimensions.items():
        dimension_accuracies[dimension] = accuracy_score(true_dim_labels, pred_dimensions[dimension])

    return overall_accuracy, overall_precision, overall_recall, overall_f1, dimension_accuracies

if __name__ == '__main__':
    data_upsampled = pd.read_csv('C:/Users/user/Desktop/dataset/JNU_OOP_project/MBTI_pred/filtered_dataset.csv')
    #data_upsampled = data_upsampled.sample(frac=1/10, random_state=2018)
    # 컬럼 이름 확인
    print("Columns in the dataset:", data_upsampled.columns)

    # 'encoded_labels' 컬럼이 있는지 확인
    if 'encoded_labels' not in data_upsampled.columns:
        raise KeyError("'encoded_labels' column is not found in the dataset.")
    
    # DistilBERT 토크나이저 및 모델 불러오기
    tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
    label_encoder = LabelEncoder()
    data_upsampled['encoded_labels'] = label_encoder.fit_transform(data_upsampled['encoded_labels'])
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=len(label_encoder.classes_))

    # 데이터 인코딩
    input_ids, attention_masks = bert_encode(data_upsampled['filtered_posts'], tokenizer, max_len=64)
    labels = torch.tensor(data_upsampled['encoded_labels'].values)

    # 데이터셋 분할
    train_inputs, temp_inputs, train_labels, temp_labels = train_test_split(input_ids, labels, random_state=2018, test_size=0.2)
    validation_inputs, test_inputs, validation_labels, test_labels = train_test_split(temp_inputs, temp_labels, random_state=2018, test_size=0.5)
    train_masks, temp_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.2)
    validation_masks, test_masks, _, _ = train_test_split(temp_masks, temp_labels, random_state=2018, test_size=0.5)

    # 데이터셋 및 데이터로더 정의
    train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
    validation_dataset = TensorDataset(validation_inputs, validation_masks, validation_labels)
    test_dataset = TensorDataset(test_inputs, test_masks, test_labels)

    train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=32)
    validation_dataloader = DataLoader(validation_dataset, sampler=SequentialSampler(validation_dataset), batch_size=32)
    test_dataloader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=32)

    # 모델을 GPU 또는 CPU로 이동
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Using device: {device}")

    # 옵티마이저 정의
    optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)

    # Define training parameters
    epochs = 10
    best_accuracy = 0.0
    no_improve_epochs = 0
    early_stopping_patience = 2
    model_save_path = './model'

    # 모델 학습
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}/{epochs}')
        print('-' * 10)

        train_loss = train(model, train_dataloader, optimizer, device)
        print(f'Training loss: {train_loss}')

        accuracy, precision, recall, f1, dimension_accuracies = evaluate(model, validation_dataloader, label_encoder, device)
        print(f'Accuracy: {accuracy}')
        print(f'Precision: {precision}')
        print(f'Recall: {recall}')
        print(f'F1 Score: {f1}')

        for dim, acc in dimension_accuracies.items():
            print(f'{dim} Accuracy: {acc}')

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            no_improve_epochs = 0
            torch.save(model.state_dict(), os.path.join(model_save_path, 'C:/Users/user/Desktop/dataset/JNU_OOP_project/MBTI_pred/bestmodel.pth'))
            print("Improved validation accuracy. Model saved.")
        else:
            no_improve_epochs += 1
            print("No improvement in validation accuracy.")
            if no_improve_epochs >= early_stopping_patience:
                print("Stopping early due to no improvement.")
                break

    # 테스트 데이터셋으로 모델 평가
    best_model_path = os.path.join(model_save_path, 'C:/Users/user/Desktop/dataset/JNU_OOP_project/MBTI_pred/bestmodel.pth')
    model.load_state_dict(torch.load(best_model_path))
    model.to(device)

    test_accuracy, test_precision, test_recall, test_f1, dimension_accuracies = evaluate(model, test_dataloader, label_encoder, device)

    print(f'Test Accuracy: {test_accuracy}')
    print(f'Test Precision: {test_precision}')
    print(f'Test Recall: {test_recall}')
    print(f'Test F1 Score: {test_f1}')

    print("MBTI Dimension Accuracies:")
    for dimension, accuracy in dimension_accuracies.items():
        print(f"{dimension} Accuracy: {accuracy}")


Columns in the dataset: Index(['posts', 'type', 'filtered_posts', 'encoded_labels'], dtype='object')


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Encoding: 100%|███████████████████████████████████████████████████████████████| 199536/199536 [03:21<00:00, 991.54it/s]


Using device: cuda
Epoch 1/10
----------


Training: 100%|████████████████████████████████████████████████████████████████████| 4989/4989 [06:38<00:00, 12.52it/s]


Training loss: 1.2258854164768254


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.84it/s]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


Decoded labels: ['3', '5', '9', '11', '12', '15', '10', '9', '1', '5', '11', '8', '6', '4', '7', '10', '0', '6', '8', '11', '2', '2', '9', '14', '10', '10', '12', '2', '10', '4', '9', '7', '2', '0', '3', '9', '8', '2', '13', '3', '11', '3', '10', '9', '1', '2', '8', '10', '5', '7', '13', '9', '11', '9', '15', '8', '11', '4', '11', '13', '5', '12', '4', '2', '12', '6', '0', '6', '8', '1', '11', '5', '11', '9', '9', '7', '4', '6', '10', '3', '4', '11', '2', '8', '3', '9', '7', '11', '2', '3', '1', '5', '2', '12', '0', '1', '4', '12', '3', '2', '2', '0', '3', '6', '13', '11', '6', '10', '7', '9', '10', '5', '13', '11', '6', '10', '11', '6', '2', '14', '10', '0', '15', '11', '5', '0', '1', '5', '3', '2', '10', '11', '14', '3', '13', '14', '10', '12', '2', '8', '12', '0', '14', '3', '12', '10', '6', '2', '14', '0', '12', '14', '0', '6', '6', '9', '8', '2', '6', '14', '10', '13', '12', '3', '11', '7', '0', '13', '6', '0', '8', '5', '6', '4', '15', '13', '1', '5', '7', '10', '1', '3', '9', '1

Training: 100%|████████████████████████████████████████████████████████████████████| 4989/4989 [06:40<00:00, 12.45it/s]


Training loss: 0.5319704702446948


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.73it/s]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


Decoded labels: ['3', '5', '9', '11', '12', '15', '10', '9', '1', '5', '11', '8', '6', '4', '7', '10', '0', '6', '8', '11', '2', '2', '9', '14', '10', '10', '12', '2', '10', '4', '9', '7', '2', '0', '3', '9', '8', '2', '13', '3', '11', '3', '10', '9', '1', '2', '8', '10', '5', '7', '13', '9', '11', '9', '15', '8', '11', '4', '11', '13', '5', '12', '4', '2', '12', '6', '0', '6', '8', '1', '11', '5', '11', '9', '9', '7', '4', '6', '10', '3', '4', '11', '2', '8', '3', '9', '7', '11', '2', '3', '1', '5', '2', '12', '0', '1', '4', '12', '3', '2', '2', '0', '3', '6', '13', '11', '6', '10', '7', '9', '10', '5', '13', '11', '6', '10', '11', '6', '2', '14', '10', '0', '15', '11', '5', '0', '1', '5', '3', '2', '10', '11', '14', '3', '13', '14', '10', '12', '2', '8', '12', '0', '14', '3', '12', '10', '6', '2', '14', '0', '12', '14', '0', '6', '6', '9', '8', '2', '6', '14', '10', '13', '12', '3', '11', '7', '0', '13', '6', '0', '8', '5', '6', '4', '15', '13', '1', '5', '7', '10', '1', '3', '9', '1

Training: 100%|████████████████████████████████████████████████████████████████████| 4989/4989 [06:43<00:00, 12.36it/s]


Training loss: 0.33483469905882424


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.53it/s]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


Decoded labels: ['3', '5', '9', '11', '12', '15', '10', '9', '1', '5', '11', '8', '6', '4', '7', '10', '0', '6', '8', '11', '2', '2', '9', '14', '10', '10', '12', '2', '10', '4', '9', '7', '2', '0', '3', '9', '8', '2', '13', '3', '11', '3', '10', '9', '1', '2', '8', '10', '5', '7', '13', '9', '11', '9', '15', '8', '11', '4', '11', '13', '5', '12', '4', '2', '12', '6', '0', '6', '8', '1', '11', '5', '11', '9', '9', '7', '4', '6', '10', '3', '4', '11', '2', '8', '3', '9', '7', '11', '2', '3', '1', '5', '2', '12', '0', '1', '4', '12', '3', '2', '2', '0', '3', '6', '13', '11', '6', '10', '7', '9', '10', '5', '13', '11', '6', '10', '11', '6', '2', '14', '10', '0', '15', '11', '5', '0', '1', '5', '3', '2', '10', '11', '14', '3', '13', '14', '10', '12', '2', '8', '12', '0', '14', '3', '12', '10', '6', '2', '14', '0', '12', '14', '0', '6', '6', '9', '8', '2', '6', '14', '10', '13', '12', '3', '11', '7', '0', '13', '6', '0', '8', '5', '6', '4', '15', '13', '1', '5', '7', '10', '1', '3', '9', '1

Training: 100%|████████████████████████████████████████████████████████████████████| 4989/4989 [06:44<00:00, 12.32it/s]


Training loss: 0.21480729297459825


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.42it/s]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


Decoded labels: ['3', '5', '9', '11', '12', '15', '10', '9', '1', '5', '11', '8', '6', '4', '7', '10', '0', '6', '8', '11', '2', '2', '9', '14', '10', '10', '12', '2', '10', '4', '9', '7', '2', '0', '3', '9', '8', '2', '13', '3', '11', '3', '10', '9', '1', '2', '8', '10', '5', '7', '13', '9', '11', '9', '15', '8', '11', '4', '11', '13', '5', '12', '4', '2', '12', '6', '0', '6', '8', '1', '11', '5', '11', '9', '9', '7', '4', '6', '10', '3', '4', '11', '2', '8', '3', '9', '7', '11', '2', '3', '1', '5', '2', '12', '0', '1', '4', '12', '3', '2', '2', '0', '3', '6', '13', '11', '6', '10', '7', '9', '10', '5', '13', '11', '6', '10', '11', '6', '2', '14', '10', '0', '15', '11', '5', '0', '1', '5', '3', '2', '10', '11', '14', '3', '13', '14', '10', '12', '2', '8', '12', '0', '14', '3', '12', '10', '6', '2', '14', '0', '12', '14', '0', '6', '6', '9', '8', '2', '6', '14', '10', '13', '12', '3', '11', '7', '0', '13', '6', '0', '8', '5', '6', '4', '15', '13', '1', '5', '7', '10', '1', '3', '9', '1

Training: 100%|████████████████████████████████████████████████████████████████████| 4989/4989 [06:45<00:00, 12.31it/s]


Training loss: 0.14358946068895995


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.43it/s]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


Decoded labels: ['3', '5', '9', '11', '12', '15', '10', '9', '1', '5', '11', '8', '6', '4', '7', '10', '0', '6', '8', '11', '2', '2', '9', '14', '10', '10', '12', '2', '10', '4', '9', '7', '2', '0', '3', '9', '8', '2', '13', '3', '11', '3', '10', '9', '1', '2', '8', '10', '5', '7', '13', '9', '11', '9', '15', '8', '11', '4', '11', '13', '5', '12', '4', '2', '12', '6', '0', '6', '8', '1', '11', '5', '11', '9', '9', '7', '4', '6', '10', '3', '4', '11', '2', '8', '3', '9', '7', '11', '2', '3', '1', '5', '2', '12', '0', '1', '4', '12', '3', '2', '2', '0', '3', '6', '13', '11', '6', '10', '7', '9', '10', '5', '13', '11', '6', '10', '11', '6', '2', '14', '10', '0', '15', '11', '5', '0', '1', '5', '3', '2', '10', '11', '14', '3', '13', '14', '10', '12', '2', '8', '12', '0', '14', '3', '12', '10', '6', '2', '14', '0', '12', '14', '0', '6', '6', '9', '8', '2', '6', '14', '10', '13', '12', '3', '11', '7', '0', '13', '6', '0', '8', '5', '6', '4', '15', '13', '1', '5', '7', '10', '1', '3', '9', '1

Training: 100%|████████████████████████████████████████████████████████████████████| 4989/4989 [06:45<00:00, 12.32it/s]


Training loss: 0.10123847061277685


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.49it/s]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


Decoded labels: ['3', '5', '9', '11', '12', '15', '10', '9', '1', '5', '11', '8', '6', '4', '7', '10', '0', '6', '8', '11', '2', '2', '9', '14', '10', '10', '12', '2', '10', '4', '9', '7', '2', '0', '3', '9', '8', '2', '13', '3', '11', '3', '10', '9', '1', '2', '8', '10', '5', '7', '13', '9', '11', '9', '15', '8', '11', '4', '11', '13', '5', '12', '4', '2', '12', '6', '0', '6', '8', '1', '11', '5', '11', '9', '9', '7', '4', '6', '10', '3', '4', '11', '2', '8', '3', '9', '7', '11', '2', '3', '1', '5', '2', '12', '0', '1', '4', '12', '3', '2', '2', '0', '3', '6', '13', '11', '6', '10', '7', '9', '10', '5', '13', '11', '6', '10', '11', '6', '2', '14', '10', '0', '15', '11', '5', '0', '1', '5', '3', '2', '10', '11', '14', '3', '13', '14', '10', '12', '2', '8', '12', '0', '14', '3', '12', '10', '6', '2', '14', '0', '12', '14', '0', '6', '6', '9', '8', '2', '6', '14', '10', '13', '12', '3', '11', '7', '0', '13', '6', '0', '8', '5', '6', '4', '15', '13', '1', '5', '7', '10', '1', '3', '9', '1

Training: 100%|████████████████████████████████████████████████████████████████████| 4989/4989 [06:45<00:00, 12.32it/s]


Training loss: 0.07744127368245246


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.46it/s]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


Decoded labels: ['3', '5', '9', '11', '12', '15', '10', '9', '1', '5', '11', '8', '6', '4', '7', '10', '0', '6', '8', '11', '2', '2', '9', '14', '10', '10', '12', '2', '10', '4', '9', '7', '2', '0', '3', '9', '8', '2', '13', '3', '11', '3', '10', '9', '1', '2', '8', '10', '5', '7', '13', '9', '11', '9', '15', '8', '11', '4', '11', '13', '5', '12', '4', '2', '12', '6', '0', '6', '8', '1', '11', '5', '11', '9', '9', '7', '4', '6', '10', '3', '4', '11', '2', '8', '3', '9', '7', '11', '2', '3', '1', '5', '2', '12', '0', '1', '4', '12', '3', '2', '2', '0', '3', '6', '13', '11', '6', '10', '7', '9', '10', '5', '13', '11', '6', '10', '11', '6', '2', '14', '10', '0', '15', '11', '5', '0', '1', '5', '3', '2', '10', '11', '14', '3', '13', '14', '10', '12', '2', '8', '12', '0', '14', '3', '12', '10', '6', '2', '14', '0', '12', '14', '0', '6', '6', '9', '8', '2', '6', '14', '10', '13', '12', '3', '11', '7', '0', '13', '6', '0', '8', '5', '6', '4', '15', '13', '1', '5', '7', '10', '1', '3', '9', '1

Training: 100%|████████████████████████████████████████████████████████████████████| 4989/4989 [06:45<00:00, 12.31it/s]


Training loss: 0.06048878685539845


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.49it/s]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


Decoded labels: ['3', '5', '9', '11', '12', '15', '10', '9', '1', '5', '11', '8', '6', '4', '7', '10', '0', '6', '8', '11', '2', '2', '9', '14', '10', '10', '12', '2', '10', '4', '9', '7', '2', '0', '3', '9', '8', '2', '13', '3', '11', '3', '10', '9', '1', '2', '8', '10', '5', '7', '13', '9', '11', '9', '15', '8', '11', '4', '11', '13', '5', '12', '4', '2', '12', '6', '0', '6', '8', '1', '11', '5', '11', '9', '9', '7', '4', '6', '10', '3', '4', '11', '2', '8', '3', '9', '7', '11', '2', '3', '1', '5', '2', '12', '0', '1', '4', '12', '3', '2', '2', '0', '3', '6', '13', '11', '6', '10', '7', '9', '10', '5', '13', '11', '6', '10', '11', '6', '2', '14', '10', '0', '15', '11', '5', '0', '1', '5', '3', '2', '10', '11', '14', '3', '13', '14', '10', '12', '2', '8', '12', '0', '14', '3', '12', '10', '6', '2', '14', '0', '12', '14', '0', '6', '6', '9', '8', '2', '6', '14', '10', '13', '12', '3', '11', '7', '0', '13', '6', '0', '8', '5', '6', '4', '15', '13', '1', '5', '7', '10', '1', '3', '9', '1

Training: 100%|████████████████████████████████████████████████████████████████████| 4989/4989 [06:45<00:00, 12.32it/s]


Training loss: 0.05039992369370133


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.49it/s]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


Decoded labels: ['3', '5', '9', '11', '12', '15', '10', '9', '1', '5', '11', '8', '6', '4', '7', '10', '0', '6', '8', '11', '2', '2', '9', '14', '10', '10', '12', '2', '10', '4', '9', '7', '2', '0', '3', '9', '8', '2', '13', '3', '11', '3', '10', '9', '1', '2', '8', '10', '5', '7', '13', '9', '11', '9', '15', '8', '11', '4', '11', '13', '5', '12', '4', '2', '12', '6', '0', '6', '8', '1', '11', '5', '11', '9', '9', '7', '4', '6', '10', '3', '4', '11', '2', '8', '3', '9', '7', '11', '2', '3', '1', '5', '2', '12', '0', '1', '4', '12', '3', '2', '2', '0', '3', '6', '13', '11', '6', '10', '7', '9', '10', '5', '13', '11', '6', '10', '11', '6', '2', '14', '10', '0', '15', '11', '5', '0', '1', '5', '3', '2', '10', '11', '14', '3', '13', '14', '10', '12', '2', '8', '12', '0', '14', '3', '12', '10', '6', '2', '14', '0', '12', '14', '0', '6', '6', '9', '8', '2', '6', '14', '10', '13', '12', '3', '11', '7', '0', '13', '6', '0', '8', '5', '6', '4', '15', '13', '1', '5', '7', '10', '1', '3', '9', '1

Training: 100%|████████████████████████████████████████████████████████████████████| 4989/4989 [06:44<00:00, 12.33it/s]


Training loss: 0.04330715250918272


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.42it/s]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


Decoded labels: ['3', '5', '9', '11', '12', '15', '10', '9', '1', '5', '11', '8', '6', '4', '7', '10', '0', '6', '8', '11', '2', '2', '9', '14', '10', '10', '12', '2', '10', '4', '9', '7', '2', '0', '3', '9', '8', '2', '13', '3', '11', '3', '10', '9', '1', '2', '8', '10', '5', '7', '13', '9', '11', '9', '15', '8', '11', '4', '11', '13', '5', '12', '4', '2', '12', '6', '0', '6', '8', '1', '11', '5', '11', '9', '9', '7', '4', '6', '10', '3', '4', '11', '2', '8', '3', '9', '7', '11', '2', '3', '1', '5', '2', '12', '0', '1', '4', '12', '3', '2', '2', '0', '3', '6', '13', '11', '6', '10', '7', '9', '10', '5', '13', '11', '6', '10', '11', '6', '2', '14', '10', '0', '15', '11', '5', '0', '1', '5', '3', '2', '10', '11', '14', '3', '13', '14', '10', '12', '2', '8', '12', '0', '14', '3', '12', '10', '6', '2', '14', '0', '12', '14', '0', '6', '6', '9', '8', '2', '6', '14', '10', '13', '12', '3', '11', '7', '0', '13', '6', '0', '8', '5', '6', '4', '15', '13', '1', '5', '7', '10', '1', '3', '9', '1

Evaluating: 100%|████████████████████████████████████████████████████████████████████| 624/624 [00:14<00:00, 42.41it/s]


Decoded labels: ['11', '13', '15', '7', '12', '13', '10', '13', '15', '3', '6', '1', '12', '6', '8', '9', '7', '2', '12', '10', '6', '15', '10', '11', '2', '2', '12', '8', '2', '3', '3', '4', '11', '9', '1', '15', '3', '3', '10', '7', '13', '3', '3', '15', '14', '8', '4', '5', '13', '15', '0', '13', '13', '0', '3', '6', '4', '11', '5', '5', '13', '4', '6', '6', '15', '15', '12', '15', '10', '9', '4', '7', '13', '9', '15', '1', '0', '0', '6', '12', '6', '1', '11', '12', '7', '1', '14', '9', '15', '11', '2', '10', '12', '3', '12', '4', '4', '10', '5', '4', '14', '3', '6', '12', '8', '15', '2', '14', '9', '0', '14', '9', '3', '13', '14', '3', '7', '3', '13', '15', '2', '10', '3', '10', '2', '5', '11', '3', '6', '7', '14', '0', '8', '1', '15', '10', '10', '10', '12', '6', '4', '5', '5', '6', '0', '3', '15', '0', '6', '1', '0', '9', '8', '12', '6', '14', '8', '2', '10', '6', '2', '4', '12', '9', '8', '0', '14', '3', '5', '14', '12', '2', '11', '10', '8', '5', '2', '10', '3', '4', '12', '7',

  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
