In [None]:
import os
from timeit import default_timer as timer
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
# from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
from torchvision.transforms import v2

import models as M
import trainer as T

from torch.utils.tensorboard import SummaryWriter

from plot_lib import set_default


In [None]:
set_default()

In [None]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
# device = "cpu"
# print(f"Using {device} device")

In [None]:
DATADIR = "/Users/mghifary/Work/Code/AI/data"
MODEL_DIR = "models"
# MODEL_SUFFIX = "convnet-exp2"
MODEL_SUFFIX = "mlp-bn-exp2"
# MODEL_SUFFIX = "mlp3-bn-randaug-exp1"
# MODEL_SUFFIX = "convnet-randaug-exp1"
# MODEL_SUFFIX = "convnet-exp3"
BATCH_SIZE = 256
EPOCHS = 50

In [None]:
# set tensorboard "log_dir" to "logs"
writer = SummaryWriter(f"logs/fashion-mnist_{MODEL_SUFFIX}")

In [None]:
train_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # åv2.RandAugment(),
    ]
)

inference_transform = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

# Download training data from open datasets.
train_data = datasets.FashionMNIST(
    root=DATADIR,
    train=True,
    download=True,
    transform=train_transform,
)

test_data = datasets.FashionMNIST(
    root=DATADIR,
    train=False,
    download=True,
    transform=inference_transform,
)

In [None]:
# Create data loaders
train_dataloader = DataLoader(
    train_data, 
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_dataloader = DataLoader(
    test_data, 
    batch_size=BATCH_SIZE,
    shuffle=False,
)

for X, y in train_dataloader:
    [_, c, dx1, dx2] = X.shape
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape}, {y.dtype}")
    break

num_classes = len(torch.unique(train_data.train_labels))

In [None]:
# helper function to show an image
# (used in the `plot_classes_preds` function below)
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

# get some random training images
dataiter = iter(train_dataloader)
images, labels = next(dataiter)

# create grid of images
img_grid = torchvision.utils.make_grid(images)

        
# show images
matplotlib_imshow(img_grid, one_channel=True)

In [None]:
# write to tensorboard
writer.add_image("fashion_mnist_images", img_grid)

In [None]:
# Initialize model
model = M.NeuralNetwork(c, dx1, dx2, num_classes, with_bn=True)
# model = M.ResNet(1, 18, M.ResidualBlock, num_classes=num_classes)
# model = M.TinyResnet(c, M.ResidualBlock, num_classes=num_classes)
# model = M.TinyResnetV2(c, M.ResidualBlock, num_classes=num_classes)
# model = M.ConvNet(c, dx1, dx2, num_classes=num_classes)
# model = M.MLP(c, dx1, dx2, 512, num_classes)
model = model.to(device)
print(model)

In [None]:
# Optimizer and loss function
loss_fn = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [None]:
# Inspect model using Tensorboard
images = images.to(device)
writer.add_graph(model, images)

In [13]:
checkpoint_dir = os.path.join(MODEL_DIR, "fashion-mnist")

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

checkpoint_path = os.path.join(checkpoint_dir, f"{MODEL_SUFFIX}.pth")

T.fit(
    model, 
    train_dataloader, 
    test_dataloader, 
    loss_fn, 
    optimizer, 
    n_epochs=EPOCHS, 
    checkpoint_path=checkpoint_path, 
    writer=writer,
    device=device
)
 
print("Done!")

100%|██████████| 235/235 [00:03<00:00, 64.35batch/s, loss=0.354]


Training performance:


KeyboardInterrupt: 

In [None]:
from torchviz import make_dot

In [None]:
pred_labels = model(images)

In [None]:
make_dot(pred_labels.mean(), params=dict(model.named_parameters()))