In [1]:
# unzip
# unzip <src of zip> -d <where to unip>

# check num files
# ls <dir> -1 | wc -l

In [2]:
import os
import math
import shutil
import random
from pathlib import Path


data_path = Path("../data/dataset")
excl_path = ["train", "test"]

# create test and train dirs
for p in excl_path:
    pth = data_path / p
    if not pth.is_dir():
        os.makedirs(pth)

for path in data_path.iterdir():
    if path.is_dir() and path.stem not in excl_path:
        all_images = list(path.glob("*.jpg"))
        test_images = random.sample(population=all_images, k=math.floor(len(all_images)*0.2)) # 20% test
        test_subdir = data_path / "test" / path.stem
        train_subdir = data_path / "train" / path.stem

        # move test images to each respective class (/test/class_name)
        if not test_subdir.is_dir():
            # create subdir for each class
            os.makedirs(test_subdir)

            for img_path_str in test_images:
                img_path = Path(img_path_str)
                shutil.move(img_path_str.absolute(), test_subdir.absolute())
                
        # move the remaining images to train directory
        if not train_subdir.is_dir():
            shutil.move(path, train_subdir)


In [3]:
test_path = data_path / "test"
for d in test_path.iterdir():
    print(f"test_{d.stem} => {len(list(d.glob('*.jpg')))}")

test_dew => 139
test_fogsmog => 170
test_frost => 95
test_glaze => 127
test_hail => 118
test_lightning => 75
test_rain => 105
test_rainbow => 46
test_rime => 232
test_sandstorm => 138
test_snow => 124


In [4]:
train_path = data_path / "train"
for d in train_path.iterdir():
    print(f"train_{d.stem} => {len(list(d.glob('*.jpg')))}")

train_dew => 559
train_fogsmog => 681
train_frost => 380
train_glaze => 512
train_hail => 473
train_lightning => 302
train_rain => 421
train_rainbow => 186
train_rime => 928
train_sandstorm => 554
train_snow => 497


In [5]:
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_WORKERS = 8

T = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


train_data = ImageFolder(root=train_path, transform=T)
test_data = ImageFolder(root=test_path, transform=T)

train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              shuffle=True)
test_dataloader = DataLoader(dataset=test_data,
                             batch_size=BATCH_SIZE,
                             num_workers=NUM_WORKERS,
                             shuffle=False)

sum([len(train_data), len(test_data)])


6862

In [6]:
class_names = train_data.classes
class_names

['dew',
 'fogsmog',
 'frost',
 'glaze',
 'hail',
 'lightning',
 'rain',
 'rainbow',
 'rime',
 'sandstorm',
 'snow']

In [7]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

In [8]:
import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

model_v0 = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
for param in model_v0.features.parameters():
    param.requires_grad = False

model_v0.classifier = nn.Sequential(
    nn.Dropout(p=0.2, inplace=True),
    nn.Linear(in_features=1280,
              out_features=len(class_names),
              bias=True)
).to(device)

In [9]:
image, label = next(iter(train_data))
image.shape, label

(torch.Size([3, 224, 224]), 0)

In [10]:
from torchinfo import summary

summary(model=model_v0,
        input_size=(1, 3, 224, 224),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20)

Layer (type:depth-idx)                                  Input Shape          Output Shape         Param #              Trainable
EfficientNet                                            [1, 3, 224, 224]     [1, 11]              --                   Partial
├─Sequential: 1-1                                       [1, 3, 224, 224]     [1, 1280, 7, 7]      --                   False
│    └─Conv2dNormActivation: 2-1                        [1, 3, 224, 224]     [1, 32, 112, 112]    --                   False
│    │    └─Conv2d: 3-1                                 [1, 3, 224, 224]     [1, 32, 112, 112]    (864)                False
│    │    └─BatchNorm2d: 3-2                            [1, 32, 112, 112]    [1, 32, 112, 112]    (64)                 False
│    │    └─SiLU: 3-3                                   [1, 32, 112, 112]    [1, 32, 112, 112]    --                   --
│    └─Sequential: 2-2                                  [1, 32, 112, 112]    [1, 16, 112, 112]    --                   Fal

In [11]:
from torch.optim import Adam

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(params=model_v0.parameters(), lr=0.001)

In [12]:
model_v0.train()

for  epoch in range(5):
    train_loss = 0
    correct = 0
    total = 0
    
    for batch, (input, label) in enumerate(train_dataloader):
        input, label = input.to(device), label.to(device)
        optimizer.zero_grad()

        pred = model_v0(input)

        loss = loss_fn(pred, label)
        loss.backward()

        optimizer.step()
    
        # 
        _, prediction = pred.max(dim=1)
        total += label.size(0)
        correct += prediction.eq(label).sum().item()
        train_loss += loss.item() 

    print(f"Epoch: {epoch} ~ Acc: {(correct/total)*100:.4f} | Loss: {(train_loss / len(train_dataloader)):.4f}")

Epoch: 0 ~ Acc: 70.9995 | Loss: 1.0918
Epoch: 1 ~ Acc: 82.6324 | Loss: 0.6002
Epoch: 2 ~ Acc: 84.7988 | Loss: 0.5051
Epoch: 3 ~ Acc: 85.9457 | Loss: 0.4624
Epoch: 4 ~ Acc: 86.0732 | Loss: 0.4349


In [13]:
model_v0.eval()
correct = 0
total = 0

with torch.inference_mode():
    for inputs, labels in test_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        test_pred = model_v0(inputs)
        _, prediction = test_pred.max(dim=1)

        total += labels.size(0)
        correct += prediction.eq(labels).sum().item()
    
    print(f"Test Acc: {(correct/total)*100:.4f}")

Test Acc: 86.5595
