In [None]:
#!pip install -q wandb

In [None]:
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader

from torchvision import datasets
import torchvision.transforms.v2 as T

from torchinfo import summary
from torchmetrics.classification import Accuracy

import wandb

In [None]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmgjeon[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [None]:
transform = T.Compose([
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
    T.ToPureTensor(),
])

In [None]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transform,
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transform,
)

In [None]:
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=32),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=32),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.LeakyReLU(),

            nn.Flatten(),

            nn.LazyLinear(10),
        )

    def forward(self, x):
        return self.model(x)

In [None]:
num_epoch = 5
learning_rate = 1e-3
batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size, drop_last=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

loss_fn = nn.CrossEntropyLoss()
metric_fn = Accuracy(task="multiclass", num_classes=10).to(device)
model = ImageClassifier().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

config = {
    "num_epoch": num_epoch,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "loss_function": loss_fn.__class__.__name__,
    "metric_function": metric_fn.__class__.__name__,
    "optimizer": optimizer.__class__.__name__,
}
config

{'num_epoch': 5,
 'learning_rate': 0.001,
 'batch_size': 64,
 'loss_function': 'CrossEntropyLoss',
 'metric_function': 'MulticlassAccuracy',
 'optimizer': 'SGD'}

In [None]:
wandb.init(
    project="dl-khu",
    name="fashion-mnist",
    config=config,
)

In [None]:
global_step = -1
for epoch in range(num_epoch):
    # Train
    model.train()
    for batch_idx, batch in enumerate(train_dataloader):
        X, y = batch
        X, y = X.to(device), y.to(device)

        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        accuracy = metric_fn(y_pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        global_step += 1

        if batch_idx % 100 == 0:
            train_loss = loss.item()
            train_accuracy = accuracy.item()
            wandb.log(
                {'loss': train_loss, 'accuracy': train_accuracy},
                step=global_step,
            )
            print(f"loss: {train_loss:.4f} accuracy: {train_accuracy:.4f} [{batch_idx} / {len(train_dataloader)}]")
    
    # Evaluate
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    with torch.inference_mode():
        for X, y in test_dataloader:
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            eval_loss += loss_fn(y_pred, y).item()
            eval_accuracy += metric_fn(y_pred, y).item()

    eval_loss /= len(test_dataloader)
    eval_accuracy /= len(test_dataloader)
    wandb.log(
        {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy},
        step=global_step,
    )
    print(f"Eval metrics: \nAccuracy: {eval_accuracy:.4f}, Avg loss: {eval_loss:.4f} \n")

loss: 2.4196 accuracy: 0.1094 [0 / 937]
loss: 1.2822 accuracy: 0.6719 [100 / 937]
loss: 0.6943 accuracy: 0.8906 [200 / 937]
loss: 0.8523 accuracy: 0.7656 [300 / 937]
loss: 0.6692 accuracy: 0.7812 [400 / 937]
loss: 0.6267 accuracy: 0.7812 [500 / 937]
loss: 0.6172 accuracy: 0.7812 [600 / 937]
loss: 0.6195 accuracy: 0.7812 [700 / 937]
loss: 0.6373 accuracy: 0.7656 [800 / 937]
loss: 0.4813 accuracy: 0.8438 [900 / 937]
Eval metrics: 
Accuracy: 0.8149, Avg loss: 0.5368 

loss: 0.4431 accuracy: 0.8438 [0 / 937]
loss: 0.5552 accuracy: 0.7969 [100 / 937]
loss: 0.2950 accuracy: 0.8906 [200 / 937]
loss: 0.5558 accuracy: 0.8125 [300 / 937]
loss: 0.5211 accuracy: 0.8125 [400 / 937]
loss: 0.4395 accuracy: 0.8594 [500 / 937]
loss: 0.4480 accuracy: 0.7969 [600 / 937]
loss: 0.5670 accuracy: 0.7969 [700 / 937]
loss: 0.5626 accuracy: 0.8125 [800 / 937]
loss: 0.3991 accuracy: 0.8906 [900 / 937]
Eval metrics: 
Accuracy: 0.8384, Avg loss: 0.4573 

loss: 0.3452 accuracy: 0.8750 [0 / 937]
loss: 0.4600 accurac

In [None]:
wandb.finish()