In [1]:
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader

## **Dataset**

In [2]:
trans = transforms.Compose([transforms.Resize((32, 32)),  # upscale
                            transforms.ToTensor()])

data_train = torchvision.datasets.FashionMNIST(
    root='./data', train=True, transform=trans, download=False 
)
data_val = torchvision.datasets.FashionMNIST(
    root='./data', train=False, transform=trans, download=False
)

In [3]:
batch_size = 256
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(data_val, batch_size=batch_size, shuffle=False)

In [4]:
X, y = next(iter(train_loader))

In [5]:
print(X.shape)
print(y.shape)

torch.Size([256, 1, 32, 32])
torch.Size([256])


## **Softmax Function**

In [6]:
X = torch.tensor([[1.0, 2, 3,], [4, 5, 6]])

In [7]:
X.shape  # (2, 3)

torch.Size([2, 3])

In [8]:
X.sum(axis=0, keepdims=True)  # shape: (1, 3)

tensor([[5., 7., 9.]])

In [9]:
X.sum(axis=1, keepdims=True)  # shape: (2, 1)

tensor([[ 6.],
        [15.]])

In [10]:
def softmax(X):  # X.shape = (n, d)
    X_exp = torch.exp(X)  # elementwise
    partition = X_exp.sum(1, keepdims=True)  # shape: (n, 1)
    return X_exp / partition  # shape: (n, 1)

In [11]:
softmax(X)  # each row sums up to 1, as is required for a probability

tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652]])

## **Model**

In [12]:
# O = XW + b
# d = num_inputs
# q = num_outputs
# Y: (n, q)
# X: (n, d)
# W: (d, q)
# b: (q)

class SoftmaxRegressionScratch(nn.Module):
    def __init__(self, num_inputs, num_outputs, sigma=0.01):
        super().__init__()
        self.W = torch.normal(0, sigma, size=(num_inputs, num_outputs), requires_grad=True)
        self.b = torch.zeros(num_outputs, requires_grad=True)

    def forward(self, X):  # (B, c, h, w), d = c*h*w
        X = X.reshape((-1, self.W.shape[0]))  # (-1, d) = (B, d)
        return softmax(torch.matmul(X, self.W) + self.b)

## **Loss**

In [13]:
def cross_entropy(y_hat, y):
    # y_hat: (B, q)
    # y: (B)
    # sum -y_i*log(y_hat_i)
    return -torch.log(y_hat[list(range(y_hat.shape[0])), y]).mean()  # 정의는 sum()인데 batch_size로 나눠주려고 mean() 씀

In [14]:
def accuracy(y_hat, y):
    # y_hat: (B, q)
    # y: (B)
    preds = y_hat.argmax(axis=1).type(y.dtype)  # (B)
    compare = (preds == y).type(torch.float32)  # (B)
    return compare.mean()

## **Training**

In [15]:
lr = 0.1
model = SoftmaxRegressionScratch(num_inputs=1*32*32, num_outputs=10)

In [16]:
optimizer = torch.optim.SGD(params=[model.W, model.b], lr=lr)

In [17]:
%%time
max_epochs = 10

for i in range(max_epochs):
    train_loss = 0
    num_train_batches = 0
    
    for X, y in train_loader:
        optimizer.zero_grad()
        y_hat = model(X)
        loss = cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        num_train_batches += 1

    val_loss = 0
    val_acc = 0
    num_val_batches = 0
    with torch.no_grad():
        for X, y in val_loader:
            y_hat = model(X)
            loss = cross_entropy(y_hat, y)
            val_loss += loss.item()
            num_val_batches += 1
            val_acc += accuracy(y_hat, y)

    print(f'epoch={i:02d} | train_loss={train_loss/num_train_batches:.4f} | val_loss={val_loss/num_val_batches:.4f} | val_acc={val_acc/num_val_batches:.4f}')

epoch=00 | train_loss=0.7793 | val_loss=0.6283 | val_acc=0.7865
epoch=01 | train_loss=0.5724 | val_loss=0.5629 | val_acc=0.8080
epoch=02 | train_loss=0.5282 | val_loss=0.6079 | val_acc=0.7795
epoch=03 | train_loss=0.5075 | val_loss=0.5775 | val_acc=0.7907
epoch=04 | train_loss=0.4890 | val_loss=0.5138 | val_acc=0.8242
epoch=05 | train_loss=0.4807 | val_loss=0.5384 | val_acc=0.8138
epoch=06 | train_loss=0.4701 | val_loss=0.4974 | val_acc=0.8261
epoch=07 | train_loss=0.4628 | val_loss=0.4929 | val_acc=0.8245
epoch=08 | train_loss=0.4568 | val_loss=0.4860 | val_acc=0.8345
epoch=09 | train_loss=0.4535 | val_loss=0.4859 | val_acc=0.8318
CPU times: total: 8min 47s
Wall time: 1min 29s
