In [62]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [63]:
from torchvision.datasets import ImageFolder

from classification.network.resnet import SeResNet
from classification.utils.loaders import VehicleDataLoader, read_image
from classification.utils.transforms import VehicleTransform

train_dataset = ImageFolder("../data/vehicles/train", loader=read_image)
valid_dataset = ImageFolder("../data/vehicles/valid", loader=read_image)

num_classes = len(train_dataset.classes)
transform = VehicleTransform(size=(224, 224))

train_loader = VehicleDataLoader(
    train_dataset,
    train_transform=transform.train_transform,
    eval_transform=transform.eval_transform,
    batch_size=64,
    shuffle=True,
)
valid_loader = VehicleDataLoader(
    valid_dataset,
    eval_transform=transform.eval_transform,
    batch_size=64,
    shuffle=False,
)

train_loader.train()
model = SeResNet(num_classes=6)

x, y = next(iter(train_loader))
pred = model(x)

In [65]:
train_dataset.transform

Compose(
    ToTensor()
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
    RandomHorizontalFlip(p=0.5)
    RandomRotation(degrees=[-15.0, 15.0], interpolation=nearest, expand=False, fill=0)
    ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=None)
    RandomGrayscale(p=0.2)
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
)

In [None]:
for x, y in train_loader:
    print(x.shape, y.shape)

In [28]:
pred

tensor([[ 0.0467,  0.0160,  0.0514,  0.0083,  0.0402,  0.0492],
        [ 0.0467,  0.0160,  0.0514,  0.0083,  0.0402,  0.0492],
        [-0.9165,  0.9254, -0.0039, -0.5991,  0.5991, -0.7124],
        [-1.4088, -0.8216,  1.7251,  0.6593,  0.6106, -0.0317],
        [ 0.5141, -0.3288,  1.9802, -0.3713,  0.6967, -1.9077],
        [ 0.3654,  0.0222, -0.0081, -0.0517,  0.2052,  0.1017],
        [ 0.3654,  0.0222, -0.0081, -0.0517,  0.2052,  0.1017],
        [ 0.3654,  0.0222, -0.0081, -0.0517,  0.2052,  0.1017]])

In [33]:
import torch

torch.softmax(pred, dim=-1)

tensor([[0.1686, 0.1635, 0.1694, 0.1622, 0.1675, 0.1690],
        [0.1686, 0.1635, 0.1694, 0.1622, 0.1675, 0.1690],
        [0.0590, 0.3721, 0.1469, 0.0810, 0.2685, 0.0724],
        [0.0221, 0.0398, 0.5084, 0.1751, 0.1668, 0.0877],
        [0.1340, 0.0577, 0.5804, 0.0553, 0.1608, 0.0119],
        [0.2138, 0.1517, 0.1472, 0.1409, 0.1822, 0.1642],
        [0.2138, 0.1517, 0.1472, 0.1409, 0.1822, 0.1642],
        [0.2138, 0.1517, 0.1472, 0.1409, 0.1822, 0.1642]])

In [30]:
from torcheval.metrics import MulticlassAccuracy

metric = MulticlassAccuracy(average=None, num_classes=6)
metric.update(pred, y)

<torcheval.metrics.classification.accuracy.MulticlassAccuracy at 0x1f696978650>

In [25]:
metric.compute()

tensor([0.1429, 0.3333, 0.6667, 0.0000, 0.0000, 0.0000])

In [1]:
import torch.nn as nn

from classification.network import layers

In [None]:
img_size = (224, 244)
imgs = torch.randn(10, 3, *img_size)

img_width, img_height



feed_forward = nn.Sequential(
    #
    nn.Conv2d(3, 32, kernel_size=5, stride=2, padding=2),
    nn.BatchNorm2d(num_features=32),
    nn.Mish(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    #
    layers.SEResidualBlock(32, 64, kernel_size=3, stride=1, squeeze_active=True),
    layers.SEResidualBlock(64, 64, kernel_size=3, stride=1, squeeze_active=True),
    layers.MaxDepthPool2d(pool_size=2),
    nn.MaxPool2d(kernel_size=2, stride=2),
    #
    layers.SEResidualBlock(32, 96, kernel_size=5, stride=1, squeeze_active=True),
    layers.SEResidualBlock(96, 96, kernel_size=5, stride=1, squeeze_active=True),
    layers.MaxDepthPool2d(pool_size=2),
    nn.MaxPool2d(kernel_size=2, stride=2),
    #
    layers.SEResidualBlock(48, 128, kernel_size=3, stride=1, squeeze_active=True),
    layers.SEResidualBlock(128, 128, kernel_size=3, stride=1, squeeze_active=True),
    layers.MaxDepthPool2d(pool_size=4),
    nn.MaxPool2d(kernel_size=2, stride=2),
    #
    nn.Flatten(),
    #
    nn.Linear(32 * 7 * 7, 256, bias=False),
    nn.BatchNorm1d(num_features=256),
    nn.Mish(),
    nn.Dropout1d(0.4),
    #
    nn.Linear(256, 256, bias=False),
    nn.BatchNorm1d(num_features=256),
    nn.Mish(),
    nn.Dropout1d(0.4),
    #
    nn.Linear(256, num_classes),
)