In [None]:
import torch
import numpy as np
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchsampler import ImbalancedDatasetSampler
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [None]:
DATA_DIR = "./data/mikans"
TORCH_MODEL_PATH = "./models/mikan_classifier.pth"
ONNX_MODEL_PATH = "./models/mikan_classifier.onnx"

IMAGE_SIZE = 96

In [None]:
dataset = ImageFolder(
    DATA_DIR,
    transform=transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
    ])
)

loader = DataLoader(
    dataset,
    sampler=ImbalancedDatasetSampler(dataset),
    batch_size=1
)

mean = np.array([0.0, 0.0, 0.0])
std = np.array([0.0, 0.0, 0.0])

for image, label in loader:
    im = np.squeeze(image.numpy())
    mean += np.mean(im, axis=(1, 2))
    std += np.std(im, axis=(1, 2))

mean /= len(loader) 
std /= len(loader)

print("mean:", mean) # [0.63526792 0.57570206 0.48665065]
print("std:", std)   # [0.22508891 0.20648694 0.26393888]

In [None]:
normalize = transforms.Normalize(
    (0.63526792, 0.57570206, 0.48665065),
    (0.22508891, 0.20648694, 0.26393888)
)
unnormalize = transforms.Compose([
    transforms.Normalize(
        (0.0, 0.0, 0.0),
        (1/0.22508891, 1/0.20648694, 1/0.26393888)
    ),
    transforms.Normalize(
        (-0.63526792, -0.57570206, -0.48665065),
        (1.0, 1.0, 1.0)
    )
])

In [None]:
dataset = ImageFolder(
    DATA_DIR,
    transform=transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        normalize
    ])
)
idx_to_label = list(dataset.class_to_idx.keys())
print("labels:", idx_to_label)

augmentated_dataset = ImageFolder(
    DATA_DIR,
    transform=transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.9, 1.1), ratio=(1.0, 1,0)),
        transforms.RandomAffine(degrees=(0.0, 0.0), translate=(0.05, 0.05), fill=(122, 111, 95)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        normalize
    ])
)

train_val_indices, test_indices = train_test_split(
    list(range(len(dataset.targets))),
    test_size=0.2,
    stratify=dataset.targets
)
train_indices, val_indices = train_test_split(
    train_val_indices,
    test_size=0.25,
    stratify=np.array(dataset.targets)[train_val_indices],
)
train_dataset = Subset(augmentated_dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

def create_loader(subset):
    return DataLoader(
        subset,
        batch_size=10,
        sampler=ImbalancedDatasetSampler(
            subset,
            labels=[subset.dataset.targets[i] for i in subset.indices]
        )
    )

train_loader = create_loader(train_dataset)
val_loader = create_loader(val_dataset)
test_loader = create_loader(test_dataset)

In [None]:
images, targets = next(iter(train_loader))
image, target = next(zip(images, targets))
plt.title(idx_to_label[target])
plt.imshow(np.transpose(unnormalize(image).numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 5,),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 16, 5),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(16, 32, 5),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(32, 32, 5),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc1 = nn.Linear(10368, len(idx_to_label))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

In [None]:
from torchsummary import summary
summary(Net(), (3, IMAGE_SIZE, IMAGE_SIZE))

In [None]:
from typing import Any
import lightning as L
from lightning.pytorch.utilities.types import STEP_OUTPUT

class PieceClassifier(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = Net()
        self.loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, inputs):
        return self.model(inputs)
    
    def training_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs)
        loss = self.loss_fn(output, target)
        metrics = {"loss": loss}
        self.log_dict(metrics, prog_bar=True, logger=True, on_epoch=True, on_step=False)
        return loss
    
    def validation_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs)
        loss = self.loss_fn(output, target)
        pred = torch.argmax(output, dim=1)
        acc = torch.sum(pred == target) * 1.0 / len(target)
        metrics = {"val_loss": loss, "val_acc": acc}
        self.log_dict(metrics, prog_bar=True, logger=True, on_epoch=True, on_step=False)
        return metrics
    
    def test_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs)
        loss = self.loss_fn(output, target)
        pred = torch.argmax(output, dim=1)
        acc = torch.sum(pred == target) * 1.0 / len(target)
        metrics = {"test_loss": loss, "test_acc": acc}
        self.log_dict(metrics, prog_bar=True, logger=True, on_epoch=True, on_step=False)
        return metrics

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), lr=0.1)

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint

model = PieceClassifier()
trainer = L.Trainer(
    limit_train_batches=10,
    max_epochs=100,
    callbacks=[
        ModelCheckpoint(
            monitor="val_acc",
            dirpath="./models/checkpoints",
            filename="mikan-{epoch:02d}-{val_acc:.02f}",
            save_top_k=3,
            mode="max",
        )
    ]
)
trainer.fit(model, train_loader, val_loader)

In [None]:
best_model = PieceClassifier.load_from_checkpoint("./models/checkpoints/mikan-epoch=97-val_acc=1.00.ckpt")
best_model.eval()
trainer.test(best_model, test_loader)


In [None]:
torch.save(best_model.model.state_dict(), TORCH_MODEL_PATH)

In [None]:
torch_model = Net()
torch_model.load_state_dict(torch.load(TORCH_MODEL_PATH))
torch_input = next(iter(loader))[0]
torch.onnx.export(torch_model, torch_input, ONNX_MODEL_PATH)