## MLPスクラッチ実装

In [4]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split

In [5]:
# 1. データロード
dataset = datasets.load_digits()
images = dataset['images']
target = dataset['target']

# 学習データと検証データに分割
X_train, X_val, y_train, y_val = train_test_split(images, target, test_size=0.2, random_state=42)

# 2. 前処理
# 2-1. ラベルのonehotencoding
y_train = F.one_hot(torch.tensor(y_train), num_classes=10)
X_train = torch.tensor(X_train, dtype=torch.float32).reshape(-1, 64)
y_val = F.one_hot(torch.tensor(y_val), num_classes=10)
X_val = torch.tensor(X_val, dtype=torch.float32).reshape(-1, 64)

# 2-2. 画像の標準化
X_train = (X_train - X_train.mean()) / X_train.std()
X_val = (X_val - X_train.mean()) / X_train.std()

### 順伝播

In [6]:
nh = 30
m, n = X_train.shape
class_num = 10

W1 = torch.randn((nh, n), requires_grad=True)   # 出力 x 入力
b1 = torch.randn((1, nh), requires_grad=True)   # 1 x nh

W2 = torch.randn((class_num, nh), requires_grad=True)   # 出力 x 入力
b2 = torch.randn((1, class_num), requires_grad=True)   # 1 x nh

In [7]:
def linear(X, W, b):
    return X @ W.T + b

def relu(Z):
    return Z.clamp_min(0.)

def softmax(x):
    e_x = torch.exp(x - torch.max(x, dim=1, keepdim=True)[0])
    return e_x / (torch.sum(e_x, dim=1, keepdim=True) + 1e-10)

def model(X):
    Z1 = linear(X, W1, b1)
    A1 = relu(Z1)
    Z2 = linear(A1, W2, b2)
    A2 = softmax(Z2)
    return A2

In [8]:
y_train_pred = model(X_train)
y_train_pred

tensor([[3.0837e-30, 6.0695e-30, 3.3349e-23,  ..., 2.1700e-27, 2.1386e-15,
         9.6923e-14],
        [2.9962e-39, 5.8498e-33, 1.4970e-09,  ..., 2.8671e-02, 1.3182e-28,
         2.0864e-13],
        [1.7762e-21, 2.0422e-22, 1.5922e-13,  ..., 1.5471e-13, 2.2273e-05,
         3.3669e-06],
        ...,
        [4.8975e-42, 9.7892e-27, 3.1146e-03,  ..., 2.1735e-24, 1.4130e-36,
         4.0039e-31],
        [4.9934e-31, 5.4085e-22, 1.5211e-07,  ..., 2.5253e-05, 1.7913e-11,
         6.9210e-16],
        [1.4396e-11, 4.6489e-20, 1.2519e-10,  ..., 4.0360e-16, 2.5905e-15,
         2.6523e-21]], grad_fn=<DivBackward0>)