# Image classification

In [1]:
import torch
from torch import Generator

# Set the random seed for reproducibility
random_state = 42

In [2]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


## Data

In [3]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Load the main training dataset
base_dir = './data/train'
base_ds = ImageFolder(base_dir, transform=transform)

total_samples = len(base_ds)
total_classes = len(base_ds.classes)

print(f"Number of classes: {total_classes}")
print(f"Number of samples: {total_samples}")

Number of classes: 50
Number of samples: 88011


In [4]:
sample_img, sample_label = base_ds[0]

print(f"Sample image shape: {sample_img.size}")
print(f"Sample label: {sample_label}")

Sample image shape: <built-in method size of Tensor object at 0x167d02d00>
Sample label: 0


In [5]:
from torch.utils.data import random_split, DataLoader

val_ratio = 0.2
total_size = len(base_ds)
val_size = int(total_size * val_ratio)
train_size = total_size - val_size

# train validation split
train_ds, val_ds = random_split(base_ds, [train_size, val_size], generator=Generator().manual_seed(random_state))

print(f"Training size: {len(train_ds)}")
print(f"Validation size: {len(val_ds)}")

batch_size = 256
num_workers = 0

# setup data loaders
train_dl = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, generator=Generator().manual_seed(random_state))
val_dl = DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers, shuffle=False)

Training size: 70409
Validation size: 17602


## Model

In [None]:
from models import BasicNet

model = BasicNet(num_classes=total_classes)

## Training

In [None]:
from training import Trainer

trainer = Trainer(model, train_dl, val_dl, device=device, lr=1e-3)
trainer.train(num_epochs=10)

Epoch 1/10 (training): 100%|██████████| 276/276 [00:41<00:00,  6.64it/s]
Epoch 1/10 (validation): 100%|██████████| 69/69 [00:13<00:00,  5.15it/s]


Epoch 1/10: Train Loss: 3.0815, Val Loss: 2.7171, Avg Class Accuracy: 0.2611


Epoch 2/10 (training): 100%|██████████| 276/276 [00:36<00:00,  7.60it/s]
Epoch 2/10 (validation): 100%|██████████| 69/69 [00:13<00:00,  5.08it/s]


Epoch 2/10: Train Loss: 2.6197, Val Loss: 2.3970, Avg Class Accuracy: 0.3446


Epoch 3/10 (training): 100%|██████████| 276/276 [00:41<00:00,  6.68it/s]
Epoch 3/10 (validation): 100%|██████████| 69/69 [00:14<00:00,  4.76it/s]


Epoch 3/10: Train Loss: 2.4312, Val Loss: 2.3280, Avg Class Accuracy: 0.3588


Epoch 4/10 (training): 100%|██████████| 276/276 [00:40<00:00,  6.74it/s]
Epoch 4/10 (validation): 100%|██████████| 69/69 [00:14<00:00,  4.69it/s]


Epoch 4/10: Train Loss: 2.3118, Val Loss: 2.1310, Avg Class Accuracy: 0.4070


Epoch 5/10 (training): 100%|██████████| 276/276 [00:41<00:00,  6.70it/s]
Epoch 5/10 (validation): 100%|██████████| 69/69 [00:14<00:00,  4.64it/s]


Epoch 5/10: Train Loss: 2.1967, Val Loss: 1.9932, Avg Class Accuracy: 0.4404


Epoch 6/10 (training): 100%|██████████| 276/276 [00:41<00:00,  6.69it/s]
Epoch 6/10 (validation): 100%|██████████| 69/69 [00:15<00:00,  4.52it/s]


Epoch 6/10: Train Loss: 2.1149, Val Loss: 2.0325, Avg Class Accuracy: 0.4430


Epoch 7/10 (training): 100%|██████████| 276/276 [00:42<00:00,  6.49it/s]
Epoch 7/10 (validation): 100%|██████████| 69/69 [00:14<00:00,  4.64it/s]


Epoch 7/10: Train Loss: 2.0548, Val Loss: 1.8883, Avg Class Accuracy: 0.4769


Epoch 8/10 (training): 100%|██████████| 276/276 [00:41<00:00,  6.64it/s]
Epoch 8/10 (validation): 100%|██████████| 69/69 [00:15<00:00,  4.55it/s]


Epoch 8/10: Train Loss: 1.9718, Val Loss: 1.8377, Avg Class Accuracy: 0.4911


Epoch 9/10 (training): 100%|██████████| 276/276 [00:41<00:00,  6.69it/s]
Epoch 9/10 (validation): 100%|██████████| 69/69 [00:15<00:00,  4.56it/s]


Epoch 9/10: Train Loss: 1.9128, Val Loss: 1.9217, Avg Class Accuracy: 0.4704


Epoch 10/10 (training): 100%|██████████| 276/276 [00:41<00:00,  6.66it/s]
Epoch 10/10 (validation): 100%|██████████| 69/69 [00:14<00:00,  4.60it/s]

Epoch 10/10: Train Loss: 1.8612, Val Loss: 1.9441, Avg Class Accuracy: 0.4673



