In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
from huggingface_hub import login
from dotenv import dotenv_values

login(token=dotenv_values('.env')['HUGGING_FACE_TOKEN'])

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/hakatashi/.cache/huggingface/token
Login successful


In [3]:
from datasets import load_dataset

dataset = load_dataset("hakatashi/hakatashi-pixiv-bookmark-deepdanbooru-private", cache_dir='/mnt/f/.cache')

Found cached dataset parquet (/mnt/f/.cache/hakatashi___parquet/hakatashi--hakatashi-pixiv-bookmark-deepdanbooru-private-dc6bd44c53eea7d4/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['key', 'tag_probs', 'class'],
        num_rows: 179121
    })
    validation: Dataset({
        features: ['key', 'tag_probs', 'class'],
        num_rows: 59708
    })
    test: Dataset({
        features: ['key', 'tag_probs', 'class'],
        num_rows: 59707
    })
})

In [5]:
device = torch.device('cuda')
torch_dataset = dataset.with_format(type='torch', device=device)

In [6]:
import torch.nn.functional as F

In [7]:
x = torch_dataset['train']['tag_probs']
y = F.one_hot(torch_dataset['train']['class']).float()

In [8]:
%%time

import torch.nn as nn
import torch.optim as optim

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.middle_layer = nn.Linear(6000, 128, device=device)
        self.out_layer = nn.Linear(128, 3, device=device)

    def forward(self, x):
        x = F.relu(self.middle_layer(x))
        x = self.out_layer(x)
        return x

network = Network()
optimizer = optim.SGD(network.parameters(), lr=0.01)
criterion = nn.MSELoss()

for i in range(3000):
    optimizer.zero_grad()
    output = network(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()
    
    if i % 100 == 0:
        print(f'[Epoch {i}] Loss: {loss.item():.3f}')

[Epoch 0] Loss: 0.359
[Epoch 100] Loss: 0.165
[Epoch 200] Loss: 0.138
[Epoch 300] Loss: 0.128
[Epoch 400] Loss: 0.120
[Epoch 500] Loss: 0.113
[Epoch 600] Loss: 0.109
[Epoch 700] Loss: 0.105
[Epoch 800] Loss: 0.103
[Epoch 900] Loss: 0.101
[Epoch 1000] Loss: 0.099
[Epoch 1100] Loss: 0.098
[Epoch 1200] Loss: 0.097
[Epoch 1300] Loss: 0.097
[Epoch 1400] Loss: 0.096
[Epoch 1500] Loss: 0.095
[Epoch 1600] Loss: 0.094
[Epoch 1700] Loss: 0.094
[Epoch 1800] Loss: 0.093
[Epoch 1900] Loss: 0.093
[Epoch 2000] Loss: 0.092
[Epoch 2100] Loss: 0.092
[Epoch 2200] Loss: 0.091
[Epoch 2300] Loss: 0.091
[Epoch 2400] Loss: 0.090
[Epoch 2500] Loss: 0.090
[Epoch 2600] Loss: 0.090
[Epoch 2700] Loss: 0.089
[Epoch 2800] Loss: 0.089
[Epoch 2900] Loss: 0.089
CPU times: user 3min 59s, sys: 1min 57s, total: 5min 57s
Wall time: 5min 56s


In [9]:
torch.save(network.state_dict(), 'torch-multiclass-onehot-shallow-network')

In [10]:
x_test = torch_dataset['test']['tag_probs']
y_test = torch_dataset['test']['class']

In [11]:
y_test_predict = network(x_test)

In [12]:
_, y_test_predict_class = torch.max(y_test_predict.data, 1)

In [13]:
from torcheval.metrics.functional import multiclass_accuracy, multiclass_confusion_matrix, multiclass_precision, multiclass_f1_score, multiclass_recall

print('confusion_matrix:')
print(multiclass_confusion_matrix(y_test_predict_class, y_test, num_classes=3))
print(f'accuracy_score: {multiclass_accuracy(y_test_predict_class, y_test)}')
print(f'precision_score: {multiclass_precision(y_test_predict_class, y_test, average="macro", num_classes=3)}')
print(f'recall_score: {multiclass_recall(y_test_predict_class, y_test, average="macro", num_classes=3)}')
print(f'f1_score: {multiclass_f1_score(y_test_predict_class, y_test, average="macro", num_classes=3)}')

confusion_matrix:
tensor([[38134,  2488,    76],
        [ 4352, 10143,    24],
        [ 2726,  1001,   763]], device='cuda:0')
accuracy_score: 0.8213441967964172
precision_score: 0.8238773345947266
recall_score: 0.6018447875976562
f1_score: 0.6311513781547546
