In [1]:
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import Dataset, Subset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import torch.optim as optim
import torch.nn as nn
import torch

In [2]:
import torch_directml
dml = torch_directml.device()
print(dml)

privateuseone:0


In [3]:
# Load the dataset with data augmentation and normalizing
class Cifar10(Dataset):
    def __init__(self):
        super().__init__()
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.RandomAutocontrast(),
            transforms.ToTensor(),
            transforms.Resize((128, 128)), # B2 for resizing image
            transforms.Normalize(mean=[0.4854, 0.4567, 0.4062], std=[0.2291, 0.2249, 0.2253])
        ])
        self.data = ImageFolder("clean_train", transform=transform_train)

    def __len__(self):
        return self.data.__len__()

    def __getitem__(self, idx):
        features, label = self.data[idx]
        return features, label

    def split_data(self):
        indices = np.arange(len(self.data))
        targets = np.array(self.data.targets)
        train_indices, val_indices = train_test_split(indices,
                                                       test_size=0.2,
                                                       random_state=42,
                                                       stratify=targets)
        final_train_set = Subset(self.data, train_indices)
        final_val_set = Subset(self.data, val_indices)
        return final_train_set, final_val_set
        
train_data = Cifar10()

In [4]:
# Split the dataset
train_set, val_set = train_data.split_data()

In [5]:
# Build the CNN
class CnnCifar(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), # B4. Padding in each CONV layer
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.01),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), # B4. Padding in each CONV layer
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.01),
            nn.MaxPool2d(kernel_size=2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.2)
        )
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.classifier(x)
        return x


In [6]:
# Feed training data into the model
batch_size = 128
train_load = DataLoader(train_set, shuffle=True,  batch_size=batch_size,
                        num_workers=4, pin_memory=True, persistent_workers=True)
val_load   = DataLoader(val_set,   shuffle=False, batch_size=batch_size,
                        num_workers=4, pin_memory=True, persistent_workers=True)
cnn = CnnCifar().to(dml)
criterion = nn.CrossEntropyLoss().to(dml)
optimizer = optim.Adam(cnn.parameters(), lr=0.001)

for epoch in range(30):
    run_amount = 0
    for images, labels in train_load:
        images = images.to(memory_format=torch.channels_last).to(dml, non_blocking=True)
        labels = labels.to(dml, non_blocking=True)
        
        optimizer.zero_grad()
        outputs = cnn(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        run_amount += loss.item()
    avg_loss = run_amount / len(train_load)
    print(f"Epoch [{epoch+1}/30], Loss: {avg_loss:.4}")

print("Saving baseline model so no need to retrain anymore.")
path = "./cnn_cifar_base.pth"
torch.save(cnn.state_dict(), path)

  torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)


Epoch [1/30], Loss: 1.969
Epoch [2/30], Loss: 1.82
Epoch [3/30], Loss: 1.763
Epoch [4/30], Loss: 1.729
Epoch [5/30], Loss: 1.705
Epoch [6/30], Loss: 1.674
Epoch [7/30], Loss: 1.65
Epoch [8/30], Loss: 1.626
Epoch [9/30], Loss: 1.617
Epoch [10/30], Loss: 1.588
Epoch [11/30], Loss: 1.583
Epoch [12/30], Loss: 1.571
Epoch [13/30], Loss: 1.555
Epoch [14/30], Loss: 1.539
Epoch [15/30], Loss: 1.537
Epoch [16/30], Loss: 1.522
Epoch [17/30], Loss: 1.514
Epoch [18/30], Loss: 1.501
Epoch [19/30], Loss: 1.501
Epoch [20/30], Loss: 1.495
Epoch [21/30], Loss: 1.484
Epoch [22/30], Loss: 1.474
Epoch [23/30], Loss: 1.465
Epoch [24/30], Loss: 1.458
Epoch [25/30], Loss: 1.456
Epoch [26/30], Loss: 1.449
Epoch [27/30], Loss: 1.445
Epoch [28/30], Loss: 1.44
Epoch [29/30], Loss: 1.438
Epoch [30/30], Loss: 1.431
Saving baseline model so no need to retrain anymore.


In [7]:
# Evaluate
from torchmetrics import Recall, Precision

cnn.eval()
metric_recall = Recall(
    task = "multiclass", num_classes=10, average=None
).to(dml)

metric_precision = Precision(
    task = "multiclass", num_classes=10, average=None
).to(dml)

with torch.no_grad():
    for images, labels in val_load:
        images = images.to(memory_format=torch.channels_last).to(dml, non_blocking=True)
        labels = labels.to(dml, non_blocking=True)

        outputs = cnn(images)
        _, preds = torch.max(outputs, 1)

        metric_recall(preds, labels)
        metric_precision(preds, labels)

recall = metric_recall.compute()
precision = metric_precision.compute()

recall_per_class = {
    k: recall[v].item()
    for k, v
    in train_data.data.class_to_idx.items()
}

precision_per_class = {
    k: precision[v].item()
    for k, v
    in train_data.data.class_to_idx.items()
}

print(recall_per_class)
print(precision_per_class)

{'airplane': 0.534000039100647, 'automobile': 0.6160000562667847, 'bird': 0.47700002789497375, 'cat': 0.37700000405311584, 'deer': 0.2750000059604645, 'dog': 0.38200002908706665, 'frog': 0.6990000605583191, 'horse': 0.5520000457763672, 'ship': 0.6520000100135803, 'truck': 0.6340000033378601}
{'airplane': 0.5585774183273315, 'automobile': 0.5767790079116821, 'bird': 0.3638443946838379, 'cat': 0.37252965569496155, 'deer': 0.5478087663650513, 'dog': 0.5268965363502502, 'frog': 0.5659918785095215, 'horse': 0.5696594715118408, 'ship': 0.5863309502601624, 'truck': 0.5711711645126343}


In [8]:
metric_recall_weighted = Recall(
    task = "multiclass", num_classes=10, average='weighted'
).to(dml)

metric_precision_weighted = Precision(
    task = "multiclass", num_classes=10, average='weighted'
).to(dml)

with torch.no_grad():
    for images, labels in val_load:
        images = images.to(memory_format=torch.channels_last).to(dml, non_blocking=True)
        labels = labels.to(dml, non_blocking=True)

        outputs = cnn(images)
        _, preds = torch.max(outputs, 1)

        metric_recall_weighted(preds, labels)
        metric_precision_weighted(preds, labels)

recall_weighted = metric_recall_weighted.compute()
precision_weighted = metric_precision_weighted.compute()
print(recall_weighted)
print(precision_weighted)

tensor(0.5183, device='privateuseone:0')
tensor(0.5234, device='privateuseone:0')


In [9]:
metric_recall_macro = Recall(
    task = "multiclass", num_classes=10, average='macro'
).to(dml)

metric_precision_macro = Precision(
    task = "multiclass", num_classes=10, average='macro'
).to(dml)

with torch.no_grad():
    for images, labels in val_load:
        images = images.to(memory_format=torch.channels_last).to(dml, non_blocking=True)
        labels = labels.to(dml, non_blocking=True)

        outputs = cnn(images)
        _, preds = torch.max(outputs, 1)

        metric_recall_macro(preds, labels)
        metric_precision_macro(preds, labels)

recall_macro = metric_recall_macro.compute()
precision_macro = metric_precision_macro.compute()
print(recall_macro)
print(precision_macro)

tensor(0.5235, device='privateuseone:0')
tensor(0.5298, device='privateuseone:0')


In [11]:
val_correct = 0
val_total = 0
with torch.no_grad():
    for images, labels in val_load:
        images = images.to(memory_format=torch.channels_last).to(dml, non_blocking=True)
        labels = labels.to(dml, non_blocking=True)

        outputs = cnn(images)
        _, preds = torch.max(outputs, 1)

        val_correct += (preds == labels).sum().item()
        val_total += labels.size(0)
val_accuracy = val_correct / val_total

train_correct = 0
train_total = 0
with torch.no_grad():
     for images, labels in train_load:
        images = images.to(memory_format=torch.channels_last).to(dml, non_blocking=True)
        labels = labels.to(dml, non_blocking=True)

        outputs = cnn(images)
        _, preds = torch.max(outputs, 1)

        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)
train_accuracy = train_correct / train_total

print(f'Train Accuracy: {train_accuracy:.4f}')
print(f'Validation Accuracy: {val_accuracy:.4f}')

Train Accuracy: 0.5270
Validation Accuracy: 0.5212
