In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchsummary import summary

Load model and dataset

In [2]:
from src.dataset import get_data

train_loader, val_loader, test_loader = get_data()

[32m2025-06-04 10:47:30.764[0m | [1mINFO    [0m | [36msrc.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: C:\Git\fmnist-classification[0m


In [3]:
model = models.efficientnet_b0(weights="IMAGENET1K_V1")
summary(model, (3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
├─Sequential: 1-1                             [-1, 1280, 7, 7]          --
|    └─Conv2dNormActivation: 2-1              [-1, 32, 112, 112]        --
|    |    └─Conv2d: 3-1                       [-1, 32, 112, 112]        864
|    |    └─BatchNorm2d: 3-2                  [-1, 32, 112, 112]        64
|    |    └─SiLU: 3-3                         [-1, 32, 112, 112]        --
|    └─Sequential: 2-2                        [-1, 16, 112, 112]        --
|    |    └─MBConv: 3-4                       [-1, 16, 112, 112]        1,448
|    └─Sequential: 2-3                        [-1, 24, 56, 56]          --
|    |    └─MBConv: 3-5                       [-1, 24, 56, 56]          6,004
|    |    └─MBConv: 3-6                       [-1, 24, 56, 56]          10,710
|    └─Sequential: 2-4                        [-1, 40, 28, 28]          --
|    |    └─MBConv: 3-7                       [-1, 40, 28, 28]          15,350
|    

Layer (type:depth-idx)                        Output Shape              Param #
├─Sequential: 1-1                             [-1, 1280, 7, 7]          --
|    └─Conv2dNormActivation: 2-1              [-1, 32, 112, 112]        --
|    |    └─Conv2d: 3-1                       [-1, 32, 112, 112]        864
|    |    └─BatchNorm2d: 3-2                  [-1, 32, 112, 112]        64
|    |    └─SiLU: 3-3                         [-1, 32, 112, 112]        --
|    └─Sequential: 2-2                        [-1, 16, 112, 112]        --
|    |    └─MBConv: 3-4                       [-1, 16, 112, 112]        1,448
|    └─Sequential: 2-3                        [-1, 24, 56, 56]          --
|    |    └─MBConv: 3-5                       [-1, 24, 56, 56]          6,004
|    |    └─MBConv: 3-6                       [-1, 24, 56, 56]          10,710
|    └─Sequential: 2-4                        [-1, 40, 28, 28]          --
|    |    └─MBConv: 3-7                       [-1, 40, 28, 28]          15,350
|    

In [4]:
num_features = model.classifier[1].in_features
num_classes = 10

model.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(num_features, num_classes)
    )

Stage 1: Freeze base model parameters and train only the classifier head

In [5]:
for param in model.parameters():
    param.requires_grad = False
    
for param in model.classifier.parameters():
    param.requires_grad = True

In [6]:
from src.modeling.train import train_model

model, history1 = train_model(model, train_loader, val_loader, num_epochs=5)

Epoch 1/5: 100%|██████████| 750/750 [00:56<00:00, 13.30it/s]


Epoch [1/5]:
  Train Loss: 0.6627, Train Acc: 79.02%
  Val Loss: 0.4350, Val Acc: 84.67%


Epoch 2/5: 100%|██████████| 750/750 [00:57<00:00, 12.96it/s]


Epoch [2/5]:
  Train Loss: 0.4788, Train Acc: 83.29%
  Val Loss: 0.4024, Val Acc: 85.91%


Epoch 3/5: 100%|██████████| 750/750 [00:56<00:00, 13.31it/s]


Epoch [3/5]:
  Train Loss: 0.4490, Train Acc: 84.10%
  Val Loss: 0.3870, Val Acc: 86.12%


Epoch 4/5: 100%|██████████| 750/750 [00:58<00:00, 12.90it/s]


Epoch [4/5]:
  Train Loss: 0.4407, Train Acc: 84.22%
  Val Loss: 0.3870, Val Acc: 86.08%


Epoch 5/5: 100%|██████████| 750/750 [00:56<00:00, 13.29it/s]


Epoch [5/5]:
  Train Loss: 0.4240, Train Acc: 84.88%
  Val Loss: 0.3703, Val Acc: 86.72%


Stage 2: Unfreeze and fine-tune the entire model 

In [7]:
for param in model.parameters():
    param.requires_grad = True

In [8]:
model, history2 = train_model(model, train_loader, val_loader, 
                                 num_epochs=5, lr=0.0001)

Epoch 1/5: 100%|██████████| 750/750 [01:56<00:00,  6.46it/s]


Epoch [1/5]:
  Train Loss: 0.2823, Train Acc: 89.97%
  Val Loss: 0.2005, Val Acc: 92.81%


Epoch 2/5: 100%|██████████| 750/750 [01:57<00:00,  6.38it/s]


Epoch [2/5]:
  Train Loss: 0.1716, Train Acc: 93.79%
  Val Loss: 0.1896, Val Acc: 93.36%


Epoch 3/5: 100%|██████████| 750/750 [01:57<00:00,  6.39it/s]


Epoch [3/5]:
  Train Loss: 0.1269, Train Acc: 95.40%
  Val Loss: 0.1795, Val Acc: 93.83%


Epoch 4/5: 100%|██████████| 750/750 [01:56<00:00,  6.42it/s]


Epoch [4/5]:
  Train Loss: 0.0969, Train Acc: 96.55%
  Val Loss: 0.1803, Val Acc: 93.83%


Epoch 5/5: 100%|██████████| 750/750 [01:55<00:00,  6.50it/s]


Epoch [5/5]:
  Train Loss: 0.0689, Train Acc: 97.44%
  Val Loss: 0.1838, Val Acc: 94.27%
