In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torch.optim as optim
from tqdm import tqdm

In [3]:
USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cuda


In [4]:
class VehicleDataset(Dataset):
    
    def __init__(self, path):
        data = np.load(path)
        self.images = data["images"]
        self.labels = data["labels"]
        print("Images shape:", self.images.shape)
        print("Labels shape:", self.labels.shape)
        self.__combinelabels__()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
        label = torch.tensor(label, dtype=torch.long)

        #mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        #std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        #image = (image - mean) / std
        
        return image, label
    def __combinelabels__(self):
        for idx, label in enumerate(self.labels):
            if label == 1:
                self.labels[idx] = 1
            if label >= 2 and label <= 7:
                self.labels[idx] = 2
            if label >= 8 and label <= 11:
                self.labels[idx] = 3
            if label >= 12 and label <= 25:
                self.labels[idx] = 4
            if label >= 26 and label <= 38:
               self.labels[idx] = 5
            if label >= 39 and label <= 44:
                self.labels[idx] = 6
            if label >= 45 and label <= 50:
                self.labels[idx] = 7
            if label >= 51 and label <= 53:
                self.labels[idx] = 8
            if label >= 54 and label <= 75:
                self.labels[idx] = 9
            if label >= 76 and label <= 81:
                self.labels[idx] = 10
            if label == 82:
                self.labels[idx] = 11
            if label >= 83 and label <= 97:
                self.labels[idx] = 12
            if label == 98:
                self.labels[idx] = 13
            if label >= 99 and label <= 100:
                self.labels[idx] = 14
            if label >= 101 and label <= 104:
                self.labels[idx] = 15
            if label == 105:
                self.labels[idx] = 16
            if label >= 106 and label <= 117:
                self.labels[idx] = 17
            if label >= 118 and label <= 122:
                self.labels[idx] = 18
            if label == 123:
                self.labels[idx] = 19
            if label >= 124 and label <= 125:
                self.labels[idx] = 20
            if label >= 126 and label <= 129:
                self.labels[idx] = 21
            if label >= 130 and label <= 140:
                self.labels[idx] = 22
            if label >= 141 and label <= 142:
                self.labels[idx] = 23
            if label == 143:
                self.labels[idx] = 24
            if label == 144:
                self.labels[idx] = 25
            if label >= 145 and label <= 149:
                self.labels[idx] = 26
            if label >= 150 and label <= 153:
                self.labels[idx] = 27
            if label >= 154 and label <= 155:
                self.labels[idx] = 28
            if label == 156:
                self.labels[idx] = 29
            if label == 157:
                self.labels[idx] = 30
            if label == 158:
                self.labels[idx] = 31
            if label == 159:
                self.labels[idx] = 32
            if label == 160:
                self.labels[idx] = 33
            if label >= 161 and label <= 166:
                self.labels[idx] = 34
            if label == 167:
                self.labels[idx] = 35
            if label >= 168 and label <= 171:
                self.labels[idx] = 36
            if label == 172:
                self.labels[idx] = 37
            if label == 173:
                self.labels[idx] = 38
            if label == 174:
                self.labels[idx] = 39
            if label >= 175 and label <= 177:
                self.labels[idx] = 40
            if label == 178:
                self.labels[idx] = 41
            if label >= 179 and label <= 180:
                self.labels[idx] = 42
            if label >= 181 and label <= 184:
                self.labels[idx] = 43
            if label == 185:
                self.labels[idx] = 44
            if label >= 186 and label <= 189:
                self.labels[idx] = 45
            if label >= 190 and label <= 192:
                self.labels[idx] = 46
            if label >= 193 and label <= 195:
                self.labels[idx] = 47
            if label == 196:
                self.labels[idx] = 48

In [5]:
dataset = VehicleDataset('../dataset/stanford_cars_dataset.npz')

batch_size = 32
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = int(len(dataset) - train_size - val_size)

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Images shape: (8144, 112, 112, 3)
Labels shape: (8144, 1)


In [25]:
def train_model(model):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    # optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    num_epochs = 30
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
    
        for images, labels in tqdm(train_loader):
            images = images.to(device)
            labels = labels[:,0].to(device) - 1

            # Forward
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            running_loss += loss.item() * images.size(0)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = correct / total
        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Accuracy: {epoch_acc:.4f}")
    
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels[:,0].to(device) - 1
            
                outputs = model(images)
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        val_acc = correct / total
        print(f'Validation Acc: {val_acc:.4f}')

In [26]:
def test_model(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels[:,0].to(device) - 1 
            
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    accuracy = correct / total
    print(f'Test Accuracy: {accuracy:.4f}')

    # Concatenate all predictions and labels if needed for further analysis
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    
    return accuracy, all_preds, all_labels

In [27]:
class DenseLayer(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super(DenseLayer, self).__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        out = self.conv(self.relu(self.bn(x)))
        return torch.cat([x, out], 1) 

class Block(nn.Module):
    def __init__(self, num_layers, in_channels, growth_rate=32):
        super(Block, self).__init__()
        layers = []
        for i in range(num_layers):
            layers.append(DenseLayer(in_channels + (i * growth_rate), growth_rate))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class Transition(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Transition, self).__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.avgpool = nn.AvgPool2d(2)

    def forward(self, x):
        x = self.conv(self.relu(self.bn(x)))
        return self.avgpool(x)

class DenseNet(nn.Module):
    def __init__(self, num_classes=96, growth_rate=32):
        super(DenseNet, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        num_features = 64
        config_list = [6, 12, 24, 16] 

        self.dense_blocks = nn.ModuleList()
        self.transitions = nn.ModuleList()

        for i, num_layers in enumerate(config_list):
            block = Block(num_layers, num_features, growth_rate)
            self.dense_blocks.append(block)
            num_features += num_layers * growth_rate

            if i != len(config_list) - 1:
                transition = Transition(num_features, num_features // 2)
                self.transitions.append(transition)
                num_features = num_features // 2

        self.bn = nn.BatchNorm2d(num_features)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(num_features, 48)

    def forward(self, x):
        x = self.features(x)
        for i in range(len(self.dense_blocks)):
            x = self.dense_blocks[i](x)
            if i < len(self.transitions):
                x = self.transitions[i](x)
        x = self.avgpool(self.relu(self.bn(x)))
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [28]:
model = DenseNet().to(device)
train_model(model)

100%|██████████| 179/179 [00:20<00:00,  8.53it/s]


Epoch 1/30 - Loss: 3.4009 - Accuracy: 0.1153
Validation Acc: 0.1290


100%|██████████| 179/179 [00:20<00:00,  8.67it/s]


Epoch 2/30 - Loss: 3.1597 - Accuracy: 0.1589
Validation Acc: 0.1511


100%|██████████| 179/179 [00:20<00:00,  8.61it/s]


Epoch 3/30 - Loss: 2.9919 - Accuracy: 0.1842
Validation Acc: 0.1720


100%|██████████| 179/179 [00:20<00:00,  8.76it/s]


Epoch 4/30 - Loss: 2.7893 - Accuracy: 0.2382
Validation Acc: 0.1536


100%|██████████| 179/179 [00:20<00:00,  8.94it/s]


Epoch 5/30 - Loss: 2.5498 - Accuracy: 0.2946
Validation Acc: 0.2045


100%|██████████| 179/179 [00:21<00:00,  8.42it/s]


Epoch 6/30 - Loss: 2.2859 - Accuracy: 0.3688
Validation Acc: 0.1701


100%|██████████| 179/179 [00:20<00:00,  8.79it/s]


Epoch 7/30 - Loss: 1.9971 - Accuracy: 0.4416
Validation Acc: 0.2346


100%|██████████| 179/179 [00:20<00:00,  8.62it/s]


Epoch 8/30 - Loss: 1.6719 - Accuracy: 0.5358
Validation Acc: 0.2242


100%|██████████| 179/179 [00:20<00:00,  8.75it/s]


Epoch 9/30 - Loss: 1.3269 - Accuracy: 0.6396
Validation Acc: 0.2297


100%|██████████| 179/179 [00:20<00:00,  8.84it/s]


Epoch 10/30 - Loss: 0.9420 - Accuracy: 0.7532
Validation Acc: 0.2359


100%|██████████| 179/179 [00:20<00:00,  8.89it/s]


Epoch 11/30 - Loss: 0.4927 - Accuracy: 0.9147
Validation Acc: 0.3305


100%|██████████| 179/179 [00:22<00:00,  8.12it/s]


Epoch 12/30 - Loss: 0.3249 - Accuracy: 0.9593
Validation Acc: 0.3391


100%|██████████| 179/179 [00:22<00:00,  7.98it/s]


Epoch 13/30 - Loss: 0.2662 - Accuracy: 0.9728
Validation Acc: 0.3385


100%|██████████| 179/179 [00:20<00:00,  8.78it/s]


Epoch 14/30 - Loss: 0.2232 - Accuracy: 0.9833
Validation Acc: 0.3421


100%|██████████| 179/179 [00:20<00:00,  8.85it/s]


Epoch 15/30 - Loss: 0.1902 - Accuracy: 0.9896
Validation Acc: 0.3458


100%|██████████| 179/179 [00:20<00:00,  8.67it/s]


Epoch 16/30 - Loss: 0.1650 - Accuracy: 0.9933
Validation Acc: 0.3440


100%|██████████| 179/179 [00:21<00:00,  8.38it/s]


Epoch 17/30 - Loss: 0.1486 - Accuracy: 0.9940
Validation Acc: 0.3458


100%|██████████| 179/179 [00:20<00:00,  8.58it/s]


Epoch 18/30 - Loss: 0.1234 - Accuracy: 0.9972
Validation Acc: 0.3458


100%|██████████| 179/179 [00:20<00:00,  8.83it/s]


Epoch 19/30 - Loss: 0.1087 - Accuracy: 0.9984
Validation Acc: 0.3458


100%|██████████| 179/179 [00:20<00:00,  8.67it/s]


Epoch 20/30 - Loss: 0.0944 - Accuracy: 0.9993
Validation Acc: 0.3532


100%|██████████| 179/179 [00:19<00:00,  9.10it/s]


Epoch 21/30 - Loss: 0.0837 - Accuracy: 0.9993
Validation Acc: 0.3489


100%|██████████| 179/179 [00:19<00:00,  9.07it/s]


Epoch 22/30 - Loss: 0.0797 - Accuracy: 0.9993
Validation Acc: 0.3428


100%|██████████| 179/179 [00:20<00:00,  8.91it/s]


Epoch 23/30 - Loss: 0.0766 - Accuracy: 0.9998
Validation Acc: 0.3477


100%|██████████| 179/179 [00:20<00:00,  8.70it/s]


Epoch 24/30 - Loss: 0.0773 - Accuracy: 0.9998
Validation Acc: 0.3526


100%|██████████| 179/179 [00:20<00:00,  8.89it/s]


Epoch 25/30 - Loss: 0.0746 - Accuracy: 0.9998
Validation Acc: 0.3538


100%|██████████| 179/179 [00:19<00:00,  9.30it/s]


Epoch 26/30 - Loss: 0.0723 - Accuracy: 0.9998
Validation Acc: 0.3507


100%|██████████| 179/179 [00:20<00:00,  8.82it/s]


Epoch 27/30 - Loss: 0.0745 - Accuracy: 0.9996
Validation Acc: 0.3464


100%|██████████| 179/179 [00:19<00:00,  9.23it/s]


Epoch 28/30 - Loss: 0.0733 - Accuracy: 0.9998
Validation Acc: 0.3526


100%|██████████| 179/179 [00:20<00:00,  8.71it/s]


Epoch 29/30 - Loss: 0.0731 - Accuracy: 0.9993
Validation Acc: 0.3464


100%|██████████| 179/179 [00:20<00:00,  8.89it/s]


Epoch 30/30 - Loss: 0.0694 - Accuracy: 0.9995
Validation Acc: 0.3434


In [30]:
torch.save(model.state_dict(), "DenseNet121")

In [29]:
test_model(model, test_loader, device)

Test Accuracy: 0.3150


(0.31495098039215685,
 tensor([22,  8, 25, 27, 44, 21, 28, 11,  2, 13, 41,  4, 42,  8,  3, 39,  8,  4,
          4,  8,  1,  4,  6, 11, 17, 25,  3,  0, 11, 11, 11, 17, 22, 45, 16, 14,
         11,  3,  4, 11,  3, 16, 35,  8,  6, 16,  4, 16, 33,  8,  4, 17, 21, 16,
          4,  3, 33,  4, 11,  4, 40,  6,  8, 15,  4, 44, 16,  8, 35,  8,  3, 39,
         20, 45,  8,  8, 11,  8, 16,  8,  4,  8, 39, 44, 21,  3, 11, 33, 11, 16,
         37, 11, 33,  7, 21,  0,  8, 16,  4,  3,  4, 11,  3,  8,  8,  8, 18, 16,
          3, 16, 37, 34,  3,  5, 20, 17,  8, 36,  8,  3,  4,  5,  8,  8, 33, 16,
         16,  8,  6, 11,  9, 16,  7,  3,  4,  9,  6, 32,  8, 46,  6,  6,  4, 10,
          3, 18,  8, 33, 39,  4,  8, 26, 14, 16,  5,  4, 22, 35,  5, 11, 38, 11,
         42,  3,  4,  1, 16,  3, 20,  9, 21,  8, 21,  5, 24,  8, 35,  4,  7, 13,
          8,  8,  4,  3,  8, 21,  8,  3, 11,  8, 44, 16,  8, 16,  6, 11,  1, 19,
         11, 29,  3,  4,  4, 25,  6, 11,  4, 42, 20, 35,  4,  8, 16,  8, 44,  3,
      