## Imports

In [None]:
import os
import time

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision import transforms

import matplotlib.pyplot as plt
from PIL import Image

from tqdm import tqdm

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

## Model Settings

In [None]:
RANDOM_SEED = 1
LEARNING_RATE = 0.0001
BATCH_SIZE = 32
NUM_EPOCHS = 20

RESOLUTION = 200

NUM_FEATURES = RESOLUTION*RESOLUTION
NUM_CLASSES = 66

DEVICE = "cuda:0"
GRAYSCALE = False

In [None]:
transform = transforms.Compose([
    transforms.Resize(RESOLUTION),
    transforms.CenterCrop(RESOLUTION),
    transforms.ToTensor()
])

path = 'CV_data/Extracted_dataset_input_combined'
dataset = datasets.ImageFolder(root=path, transform=transform)


# Get class to index mapping
class_to_idx = dataset.class_to_idx

# Define index to class mapping
idx_to_class = {v: k for k, v in class_to_idx.items()}

train_size = int(0.6 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size

# Split the dataset into training set and test set
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size])


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True)

val_loader = DataLoader(dataset=val_dataset, 
                        batch_size=BATCH_SIZE, 
                        shuffle=False)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=BATCH_SIZE, 
                         shuffle=True)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    print('Labels:', labels)
    break


In [None]:
device = torch.device(DEVICE)
torch.manual_seed(0)

for epoch in range(2):

    for batch_idx, (x, y) in enumerate(train_loader):
        
        print('Epoch:', epoch+1, end='')
        print(' | Batch index:', batch_idx, end='')
        print(' | Batch size:', y.size()[0])
        
        x = x.to(device)
        y = y.to(device)
        break

# Model

In [None]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out




class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes, grayscale):
        self.inplanes = 64
        if grayscale:
            in_dim = 1
        else:
            in_dim = 3
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n)**.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        
        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        probas = F.softmax(logits, dim=1)
        return logits, probas



def resnet50(num_classes):
    """Constructs a ResNet-50 model."""
    model = ResNet(block=Bottleneck, 
                   layers=[3, 4, 6, 3],
                   num_classes=NUM_CLASSES,
                   grayscale=GRAYSCALE)
    return model


In [None]:
torch.manual_seed(RANDOM_SEED)

model = resnet50(NUM_CLASSES)
model.to(DEVICE)
 
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  

In [None]:
def compute_accuracy(model, data_loader, device, mode="train"):
    correct_pred, num_examples = 0, 0
    for batch_idx, (features, targets) in tqdm(enumerate(data_loader)):
            
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)

        logits, probas = model(features)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()
        
        if not batch_idx % 50:
            accuracy = correct_pred.float() / num_examples * 100
            with open(f'CV_data/models/{model_name}-{LEARNING_RATE}-{mode}-accuracy.csv', 'a') as fp:
                fp.write(f"{epoch+1}, {batch_idx}, {accuracy}\n")
            correct_pred, num_examples = 0, 0



In [None]:
def compute_loss(model, data_loader, class_weights, device, mode="val"):
    losses = []
    dataset_len = len(data_loader)
    for batch_idx, (features, targets) in tqdm(enumerate(data_loader)):
        
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)
        class_weights = class_weights.to(DEVICE)
            
        logits, _ = model(features)
        loss = F.cross_entropy(logits, targets)
        losses.append(loss)
        
        if not batch_idx % 50:
            mean_loss = sum(losses) / len(losses)
            with open(f'CV_data/models/{model_name}-{LEARNING_RATE}-{mode}-loss.csv', 'a') as fp:
                fp.write(f"{epoch+1}, {batch_idx}, {mean_loss}\n")
            losses.clear()

In [None]:
model_name = "00-model"

## Training

In [11]:
start_time = time.time()

for epoch in range(NUM_EPOCHS):
    
    model.train()
    losses = []
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)
        class_weights = class_weights.to(DEVICE)
            
        ### FORWARD AND BACK PROP
        logits, probas = model(features)
        loss = F.cross_entropy(logits, targets)
        losses.append(loss)

        optimizer.zero_grad()
        
        loss.backward()
        
        ### UPDATE MODEL PARAMETERS
        optimizer.step()
        
        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' 
                   %(epoch+1, NUM_EPOCHS, batch_idx,
                     len(train_loader), loss))
    
            mean_loss = sum(losses) / len(losses)
            with open(f'CV_data/models/{model_name}-{LEARNING_RATE}-train-loss.csv', 'a') as fp:
                fp.write(f"{epoch+1}, {batch_idx}, {mean_loss}\n")
            losses.clear()


    model.eval()
    with torch.set_grad_enabled(False): # save memory during inference
        compute_accuracy(model, train_loader, device=DEVICE, mode="train")
        compute_loss(model, val_loader, class_weights, device=DEVICE, mode="val")
        compute_accuracy(model, val_loader, device=DEVICE, mode="val")
    
    file_path = f'CV_data/models/{model_name}-{LEARNING_RATE}-epoch{epoch+1}.pth'
    torch.save(model.state_dict(), file_path)
        
    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    
    
    
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

Epoch: 001/020 | Batch 0000/1284 | Cost: 4.2649
Epoch: 001/020 | Batch 0050/1284 | Cost: 3.8272
Epoch: 001/020 | Batch 0100/1284 | Cost: 3.5944
Epoch: 001/020 | Batch 0150/1284 | Cost: 3.7192
Epoch: 001/020 | Batch 0200/1284 | Cost: 3.3314




Epoch: 001/020 | Batch 0250/1284 | Cost: 3.0997
Epoch: 001/020 | Batch 0300/1284 | Cost: 3.1404
Epoch: 001/020 | Batch 0350/1284 | Cost: 2.5295
Epoch: 001/020 | Batch 0400/1284 | Cost: 2.7810
Epoch: 001/020 | Batch 0450/1284 | Cost: 2.8899
Epoch: 001/020 | Batch 0500/1284 | Cost: 2.6300
Epoch: 001/020 | Batch 0550/1284 | Cost: 2.3803
Epoch: 001/020 | Batch 0600/1284 | Cost: 2.2133
Epoch: 001/020 | Batch 0650/1284 | Cost: 2.6430
Epoch: 001/020 | Batch 0700/1284 | Cost: 1.9110
Epoch: 001/020 | Batch 0750/1284 | Cost: 1.8039
Epoch: 001/020 | Batch 0800/1284 | Cost: 2.4884
Epoch: 001/020 | Batch 0850/1284 | Cost: 1.8090
Epoch: 001/020 | Batch 0900/1284 | Cost: 1.6273
Epoch: 001/020 | Batch 0950/1284 | Cost: 1.8678
Epoch: 001/020 | Batch 1000/1284 | Cost: 1.6614
Epoch: 001/020 | Batch 1050/1284 | Cost: 1.4066
Epoch: 001/020 | Batch 1100/1284 | Cost: 1.3027
Epoch: 001/020 | Batch 1150/1284 | Cost: 1.4242
Epoch: 001/020 | Batch 1200/1284 | Cost: 1.4810
Epoch: 001/020 | Batch 1250/1284 | Cost:

1284it [02:12,  9.72it/s]
428it [00:55,  7.74it/s]
428it [00:40, 10.50it/s]


Time elapsed: 8.59 min
Epoch: 002/020 | Batch 0000/1284 | Cost: 1.7066
Epoch: 002/020 | Batch 0050/1284 | Cost: 1.1887
Epoch: 002/020 | Batch 0100/1284 | Cost: 1.5444
Epoch: 002/020 | Batch 0150/1284 | Cost: 1.0590
Epoch: 002/020 | Batch 0200/1284 | Cost: 1.4377
Epoch: 002/020 | Batch 0250/1284 | Cost: 1.1466
Epoch: 002/020 | Batch 0300/1284 | Cost: 0.7916
Epoch: 002/020 | Batch 0350/1284 | Cost: 1.1159
Epoch: 002/020 | Batch 0400/1284 | Cost: 1.1256
Epoch: 002/020 | Batch 0450/1284 | Cost: 0.9754
Epoch: 002/020 | Batch 0500/1284 | Cost: 1.2763
Epoch: 002/020 | Batch 0550/1284 | Cost: 0.6040
Epoch: 002/020 | Batch 0600/1284 | Cost: 0.7611
Epoch: 002/020 | Batch 0650/1284 | Cost: 1.0281
Epoch: 002/020 | Batch 0700/1284 | Cost: 0.5661
Epoch: 002/020 | Batch 0750/1284 | Cost: 1.0329
Epoch: 002/020 | Batch 0800/1284 | Cost: 0.8230
Epoch: 002/020 | Batch 0850/1284 | Cost: 0.5291
Epoch: 002/020 | Batch 0900/1284 | Cost: 0.4634
Epoch: 002/020 | Batch 0950/1284 | Cost: 0.5387
Epoch: 002/020 | 

1284it [02:08,  9.99it/s]
428it [00:40, 10.54it/s]
428it [00:40, 10.57it/s]


Time elapsed: 16.20 min
Epoch: 003/020 | Batch 0000/1284 | Cost: 0.7621
Epoch: 003/020 | Batch 0050/1284 | Cost: 1.0084
Epoch: 003/020 | Batch 0100/1284 | Cost: 0.5572
Epoch: 003/020 | Batch 0150/1284 | Cost: 0.4871
Epoch: 003/020 | Batch 0200/1284 | Cost: 0.3948
Epoch: 003/020 | Batch 0250/1284 | Cost: 0.5541
Epoch: 003/020 | Batch 0300/1284 | Cost: 0.5037
Epoch: 003/020 | Batch 0350/1284 | Cost: 0.3841
Epoch: 003/020 | Batch 0400/1284 | Cost: 0.2560
Epoch: 003/020 | Batch 0450/1284 | Cost: 0.4980
Epoch: 003/020 | Batch 0500/1284 | Cost: 0.4673
Epoch: 003/020 | Batch 0550/1284 | Cost: 0.3828
Epoch: 003/020 | Batch 0600/1284 | Cost: 0.5367
Epoch: 003/020 | Batch 0650/1284 | Cost: 0.1563
Epoch: 003/020 | Batch 0700/1284 | Cost: 0.4658
Epoch: 003/020 | Batch 0750/1284 | Cost: 0.8065
Epoch: 003/020 | Batch 0800/1284 | Cost: 0.3542
Epoch: 003/020 | Batch 0850/1284 | Cost: 0.5669
Epoch: 003/020 | Batch 0900/1284 | Cost: 0.2251
Epoch: 003/020 | Batch 0950/1284 | Cost: 0.4289
Epoch: 003/020 |

1284it [02:08,  9.98it/s]
428it [00:40, 10.62it/s]
428it [00:40, 10.62it/s]


Time elapsed: 23.77 min
Epoch: 004/020 | Batch 0000/1284 | Cost: 0.3249
Epoch: 004/020 | Batch 0050/1284 | Cost: 0.3785
Epoch: 004/020 | Batch 0100/1284 | Cost: 0.3090
Epoch: 004/020 | Batch 0150/1284 | Cost: 0.3581
Epoch: 004/020 | Batch 0200/1284 | Cost: 0.6255
Epoch: 004/020 | Batch 0250/1284 | Cost: 0.3037
Epoch: 004/020 | Batch 0300/1284 | Cost: 0.2664
Epoch: 004/020 | Batch 0350/1284 | Cost: 0.1478
Epoch: 004/020 | Batch 0400/1284 | Cost: 0.2993
Epoch: 004/020 | Batch 0450/1284 | Cost: 0.1546
Epoch: 004/020 | Batch 0500/1284 | Cost: 0.2511
Epoch: 004/020 | Batch 0550/1284 | Cost: 0.1720
Epoch: 004/020 | Batch 0600/1284 | Cost: 0.3923
Epoch: 004/020 | Batch 0650/1284 | Cost: 0.2034
Epoch: 004/020 | Batch 0700/1284 | Cost: 0.3683
Epoch: 004/020 | Batch 0750/1284 | Cost: 0.2094
Epoch: 004/020 | Batch 0800/1284 | Cost: 0.2378
Epoch: 004/020 | Batch 0850/1284 | Cost: 0.8335
Epoch: 004/020 | Batch 0900/1284 | Cost: 0.1392
Epoch: 004/020 | Batch 0950/1284 | Cost: 0.2377
Epoch: 004/020 |

1284it [02:07, 10.10it/s]
428it [00:40, 10.57it/s]
428it [00:40, 10.62it/s]


Time elapsed: 31.33 min
Epoch: 005/020 | Batch 0000/1284 | Cost: 0.1862
Epoch: 005/020 | Batch 0050/1284 | Cost: 0.2240
Epoch: 005/020 | Batch 0100/1284 | Cost: 0.1818
Epoch: 005/020 | Batch 0150/1284 | Cost: 0.5085
Epoch: 005/020 | Batch 0200/1284 | Cost: 0.3171
Epoch: 005/020 | Batch 0250/1284 | Cost: 0.1354
Epoch: 005/020 | Batch 0300/1284 | Cost: 0.2144
Epoch: 005/020 | Batch 0350/1284 | Cost: 0.2266
Epoch: 005/020 | Batch 0400/1284 | Cost: 0.6857
Epoch: 005/020 | Batch 0450/1284 | Cost: 0.2047
Epoch: 005/020 | Batch 0500/1284 | Cost: 0.0990
Epoch: 005/020 | Batch 0550/1284 | Cost: 0.2109
Epoch: 005/020 | Batch 0600/1284 | Cost: 0.2220
Epoch: 005/020 | Batch 0650/1284 | Cost: 0.0619
Epoch: 005/020 | Batch 0700/1284 | Cost: 0.1242
Epoch: 005/020 | Batch 0750/1284 | Cost: 0.2689
Epoch: 005/020 | Batch 0800/1284 | Cost: 0.0469
Epoch: 005/020 | Batch 0850/1284 | Cost: 0.0505
Epoch: 005/020 | Batch 0900/1284 | Cost: 0.1020
Epoch: 005/020 | Batch 0950/1284 | Cost: 0.1343
Epoch: 005/020 |

1284it [02:10,  9.81it/s]
428it [00:41, 10.37it/s]
428it [00:41, 10.36it/s]


Time elapsed: 39.06 min
Epoch: 006/020 | Batch 0000/1284 | Cost: 0.0849
Epoch: 006/020 | Batch 0050/1284 | Cost: 0.2723
Epoch: 006/020 | Batch 0100/1284 | Cost: 0.2554
Epoch: 006/020 | Batch 0150/1284 | Cost: 0.0976
Epoch: 006/020 | Batch 0200/1284 | Cost: 0.1160
Epoch: 006/020 | Batch 0250/1284 | Cost: 0.2171
Epoch: 006/020 | Batch 0300/1284 | Cost: 0.0869
Epoch: 006/020 | Batch 0350/1284 | Cost: 0.1036
Epoch: 006/020 | Batch 0400/1284 | Cost: 0.0573
Epoch: 006/020 | Batch 0450/1284 | Cost: 0.1514
Epoch: 006/020 | Batch 0500/1284 | Cost: 0.0293
Epoch: 006/020 | Batch 0550/1284 | Cost: 0.2593
Epoch: 006/020 | Batch 0600/1284 | Cost: 0.1931
Epoch: 006/020 | Batch 0650/1284 | Cost: 0.2018
Epoch: 006/020 | Batch 0700/1284 | Cost: 0.3463
Epoch: 006/020 | Batch 0750/1284 | Cost: 0.1285
Epoch: 006/020 | Batch 0800/1284 | Cost: 0.2038
Epoch: 006/020 | Batch 0850/1284 | Cost: 0.1111
Epoch: 006/020 | Batch 0900/1284 | Cost: 0.1785
Epoch: 006/020 | Batch 0950/1284 | Cost: 0.0879
Epoch: 006/020 |

1284it [01:57, 10.92it/s]
428it [00:36, 11.77it/s]
428it [00:36, 11.78it/s]


Time elapsed: 46.24 min
Epoch: 007/020 | Batch 0000/1284 | Cost: 0.0702
Epoch: 007/020 | Batch 0050/1284 | Cost: 0.3181
Epoch: 007/020 | Batch 0100/1284 | Cost: 0.1095
Epoch: 007/020 | Batch 0150/1284 | Cost: 0.3129
Epoch: 007/020 | Batch 0200/1284 | Cost: 0.0330
Epoch: 007/020 | Batch 0250/1284 | Cost: 0.0921
Epoch: 007/020 | Batch 0300/1284 | Cost: 0.0274
Epoch: 007/020 | Batch 0350/1284 | Cost: 0.0504
Epoch: 007/020 | Batch 0400/1284 | Cost: 0.0618
Epoch: 007/020 | Batch 0450/1284 | Cost: 0.2028
Epoch: 007/020 | Batch 0500/1284 | Cost: 0.1517
Epoch: 007/020 | Batch 0550/1284 | Cost: 0.1102
Epoch: 007/020 | Batch 0600/1284 | Cost: 0.4056
Epoch: 007/020 | Batch 0650/1284 | Cost: 0.1662
Epoch: 007/020 | Batch 0700/1284 | Cost: 0.0809
Epoch: 007/020 | Batch 0750/1284 | Cost: 0.1949
Epoch: 007/020 | Batch 0800/1284 | Cost: 0.0127
Epoch: 007/020 | Batch 0850/1284 | Cost: 0.0199
Epoch: 007/020 | Batch 0900/1284 | Cost: 0.0127
Epoch: 007/020 | Batch 0950/1284 | Cost: 0.1456
Epoch: 007/020 |

1284it [01:55, 11.08it/s]
428it [00:39, 10.95it/s]
428it [00:39, 10.93it/s]


Time elapsed: 53.37 min
Epoch: 008/020 | Batch 0000/1284 | Cost: 0.1306
Epoch: 008/020 | Batch 0050/1284 | Cost: 0.2516
Epoch: 008/020 | Batch 0100/1284 | Cost: 0.0365
Epoch: 008/020 | Batch 0150/1284 | Cost: 0.0162
Epoch: 008/020 | Batch 0200/1284 | Cost: 0.2405
Epoch: 008/020 | Batch 0250/1284 | Cost: 0.1442
Epoch: 008/020 | Batch 0300/1284 | Cost: 0.0915
Epoch: 008/020 | Batch 0350/1284 | Cost: 0.1776
Epoch: 008/020 | Batch 0400/1284 | Cost: 0.0171
Epoch: 008/020 | Batch 0450/1284 | Cost: 0.1440
Epoch: 008/020 | Batch 0500/1284 | Cost: 0.2032
Epoch: 008/020 | Batch 0550/1284 | Cost: 0.0694
Epoch: 008/020 | Batch 0600/1284 | Cost: 0.1180
Epoch: 008/020 | Batch 0650/1284 | Cost: 0.0135
Epoch: 008/020 | Batch 0700/1284 | Cost: 0.1228
Epoch: 008/020 | Batch 0750/1284 | Cost: 0.0130
Epoch: 008/020 | Batch 0800/1284 | Cost: 0.0053
Epoch: 008/020 | Batch 0850/1284 | Cost: 0.3276
Epoch: 008/020 | Batch 0900/1284 | Cost: 0.1128
Epoch: 008/020 | Batch 0950/1284 | Cost: 0.2027
Epoch: 008/020 |

1284it [01:55, 11.12it/s]
428it [00:36, 11.74it/s]
428it [00:36, 11.78it/s]


Time elapsed: 60.39 min
Epoch: 009/020 | Batch 0000/1284 | Cost: 0.0526
Epoch: 009/020 | Batch 0050/1284 | Cost: 0.1675
Epoch: 009/020 | Batch 0100/1284 | Cost: 0.1328
Epoch: 009/020 | Batch 0150/1284 | Cost: 0.1561
Epoch: 009/020 | Batch 0200/1284 | Cost: 0.1184
Epoch: 009/020 | Batch 0250/1284 | Cost: 0.1359
Epoch: 009/020 | Batch 0300/1284 | Cost: 0.0143
Epoch: 009/020 | Batch 0350/1284 | Cost: 0.0222
Epoch: 009/020 | Batch 0400/1284 | Cost: 0.0491
Epoch: 009/020 | Batch 0450/1284 | Cost: 0.0232
Epoch: 009/020 | Batch 0500/1284 | Cost: 0.0706
Epoch: 009/020 | Batch 0550/1284 | Cost: 0.1514
Epoch: 009/020 | Batch 0600/1284 | Cost: 0.1781
Epoch: 009/020 | Batch 0650/1284 | Cost: 0.0169
Epoch: 009/020 | Batch 0700/1284 | Cost: 0.0658
Epoch: 009/020 | Batch 0750/1284 | Cost: 0.0789
Epoch: 009/020 | Batch 0800/1284 | Cost: 0.0355
Epoch: 009/020 | Batch 0850/1284 | Cost: 0.1059
Epoch: 009/020 | Batch 0900/1284 | Cost: 0.2115
Epoch: 009/020 | Batch 0950/1284 | Cost: 0.1128
Epoch: 009/020 |

1284it [01:56, 11.05it/s]
428it [00:36, 11.76it/s]
428it [00:36, 11.77it/s]


Time elapsed: 67.44 min
Epoch: 010/020 | Batch 0000/1284 | Cost: 0.0655
Epoch: 010/020 | Batch 0050/1284 | Cost: 0.0308
Epoch: 010/020 | Batch 0100/1284 | Cost: 0.0674
Epoch: 010/020 | Batch 0150/1284 | Cost: 0.0869
Epoch: 010/020 | Batch 0200/1284 | Cost: 0.1852
Epoch: 010/020 | Batch 0250/1284 | Cost: 0.0578
Epoch: 010/020 | Batch 0300/1284 | Cost: 0.0904
Epoch: 010/020 | Batch 0350/1284 | Cost: 0.0178
Epoch: 010/020 | Batch 0400/1284 | Cost: 0.0300
Epoch: 010/020 | Batch 0450/1284 | Cost: 0.1328
Epoch: 010/020 | Batch 0500/1284 | Cost: 0.0339
Epoch: 010/020 | Batch 0550/1284 | Cost: 0.0614
Epoch: 010/020 | Batch 0600/1284 | Cost: 0.0375
Epoch: 010/020 | Batch 0650/1284 | Cost: 0.0056
Epoch: 010/020 | Batch 0700/1284 | Cost: 0.1356
Epoch: 010/020 | Batch 0750/1284 | Cost: 0.0768
Epoch: 010/020 | Batch 0800/1284 | Cost: 0.0971
Epoch: 010/020 | Batch 0850/1284 | Cost: 0.0850
Epoch: 010/020 | Batch 0900/1284 | Cost: 0.1391
Epoch: 010/020 | Batch 0950/1284 | Cost: 0.0113
Epoch: 010/020 |

1284it [01:55, 11.09it/s]
428it [00:36, 11.67it/s]
428it [00:36, 11.77it/s]


Time elapsed: 74.48 min
Epoch: 011/020 | Batch 0000/1284 | Cost: 0.0767
Epoch: 011/020 | Batch 0050/1284 | Cost: 0.0969
Epoch: 011/020 | Batch 0100/1284 | Cost: 0.0838
Epoch: 011/020 | Batch 0150/1284 | Cost: 0.0180
Epoch: 011/020 | Batch 0200/1284 | Cost: 0.0363
Epoch: 011/020 | Batch 0250/1284 | Cost: 0.0427
Epoch: 011/020 | Batch 0300/1284 | Cost: 0.1163
Epoch: 011/020 | Batch 0350/1284 | Cost: 0.0315
Epoch: 011/020 | Batch 0400/1284 | Cost: 0.0038
Epoch: 011/020 | Batch 0450/1284 | Cost: 0.0067
Epoch: 011/020 | Batch 0500/1284 | Cost: 0.0181
Epoch: 011/020 | Batch 0550/1284 | Cost: 0.0042
Epoch: 011/020 | Batch 0600/1284 | Cost: 0.2439
Epoch: 011/020 | Batch 0650/1284 | Cost: 0.1200
Epoch: 011/020 | Batch 0700/1284 | Cost: 0.0029
Epoch: 011/020 | Batch 0750/1284 | Cost: 0.0364
Epoch: 011/020 | Batch 0800/1284 | Cost: 0.4341
Epoch: 011/020 | Batch 0850/1284 | Cost: 0.1179
Epoch: 011/020 | Batch 0900/1284 | Cost: 0.0270
Epoch: 011/020 | Batch 0950/1284 | Cost: 0.0027
Epoch: 011/020 |

1284it [01:56, 11.06it/s]
428it [00:36, 11.85it/s]
428it [00:36, 11.82it/s]


Time elapsed: 81.52 min
Epoch: 012/020 | Batch 0000/1284 | Cost: 0.0549
Epoch: 012/020 | Batch 0050/1284 | Cost: 0.0740
Epoch: 012/020 | Batch 0100/1284 | Cost: 0.1096
Epoch: 012/020 | Batch 0150/1284 | Cost: 0.2158
Epoch: 012/020 | Batch 0200/1284 | Cost: 0.0264
Epoch: 012/020 | Batch 0250/1284 | Cost: 0.0895
Epoch: 012/020 | Batch 0300/1284 | Cost: 0.0248
Epoch: 012/020 | Batch 0350/1284 | Cost: 0.0113
Epoch: 012/020 | Batch 0400/1284 | Cost: 0.0056
Epoch: 012/020 | Batch 0450/1284 | Cost: 0.1463
Epoch: 012/020 | Batch 0500/1284 | Cost: 0.0107
Epoch: 012/020 | Batch 0550/1284 | Cost: 0.0321
Epoch: 012/020 | Batch 0600/1284 | Cost: 0.1168
Epoch: 012/020 | Batch 0650/1284 | Cost: 0.0371
Epoch: 012/020 | Batch 0700/1284 | Cost: 0.0031
Epoch: 012/020 | Batch 0750/1284 | Cost: 0.1276
Epoch: 012/020 | Batch 0800/1284 | Cost: 0.0808
Epoch: 012/020 | Batch 0850/1284 | Cost: 0.1123
Epoch: 012/020 | Batch 0900/1284 | Cost: 0.0400
Epoch: 012/020 | Batch 0950/1284 | Cost: 0.0627
Epoch: 012/020 |

1284it [01:55, 11.10it/s]
428it [00:36, 11.83it/s]
428it [00:36, 11.77it/s]


Time elapsed: 88.54 min
Epoch: 013/020 | Batch 0000/1284 | Cost: 0.2305
Epoch: 013/020 | Batch 0050/1284 | Cost: 0.0128
Epoch: 013/020 | Batch 0100/1284 | Cost: 0.1252
Epoch: 013/020 | Batch 0150/1284 | Cost: 0.0709
Epoch: 013/020 | Batch 0200/1284 | Cost: 0.0239
Epoch: 013/020 | Batch 0250/1284 | Cost: 0.0047
Epoch: 013/020 | Batch 0300/1284 | Cost: 0.0220
Epoch: 013/020 | Batch 0350/1284 | Cost: 0.1010
Epoch: 013/020 | Batch 0400/1284 | Cost: 0.1359
Epoch: 013/020 | Batch 0450/1284 | Cost: 0.1502
Epoch: 013/020 | Batch 0500/1284 | Cost: 0.0531
Epoch: 013/020 | Batch 0550/1284 | Cost: 0.1427
Epoch: 013/020 | Batch 0600/1284 | Cost: 0.0310
Epoch: 013/020 | Batch 0650/1284 | Cost: 0.0402
Epoch: 013/020 | Batch 0700/1284 | Cost: 0.0690
Epoch: 013/020 | Batch 0750/1284 | Cost: 0.0040
Epoch: 013/020 | Batch 0800/1284 | Cost: 0.0167
Epoch: 013/020 | Batch 0850/1284 | Cost: 0.0644
Epoch: 013/020 | Batch 0900/1284 | Cost: 0.0072
Epoch: 013/020 | Batch 0950/1284 | Cost: 0.0124
Epoch: 013/020 |

1284it [01:56, 11.03it/s]
428it [00:41, 10.35it/s]
428it [00:36, 11.73it/s]


Time elapsed: 95.72 min
Epoch: 014/020 | Batch 0000/1284 | Cost: 0.0048
Epoch: 014/020 | Batch 0050/1284 | Cost: 0.0053
Epoch: 014/020 | Batch 0100/1284 | Cost: 0.0475
Epoch: 014/020 | Batch 0150/1284 | Cost: 0.0250
Epoch: 014/020 | Batch 0200/1284 | Cost: 0.0225
Epoch: 014/020 | Batch 0250/1284 | Cost: 0.1615
Epoch: 014/020 | Batch 0300/1284 | Cost: 0.0526
Epoch: 014/020 | Batch 0350/1284 | Cost: 0.0040
Epoch: 014/020 | Batch 0400/1284 | Cost: 0.2202
Epoch: 014/020 | Batch 0450/1284 | Cost: 0.0095
Epoch: 014/020 | Batch 0500/1284 | Cost: 0.0168
Epoch: 014/020 | Batch 0550/1284 | Cost: 0.0087
Epoch: 014/020 | Batch 0600/1284 | Cost: 0.1243
Epoch: 014/020 | Batch 0650/1284 | Cost: 0.1112
Epoch: 014/020 | Batch 0700/1284 | Cost: 0.0838
Epoch: 014/020 | Batch 0750/1284 | Cost: 0.1928
Epoch: 014/020 | Batch 0800/1284 | Cost: 0.0667
Epoch: 014/020 | Batch 0850/1284 | Cost: 0.0717
Epoch: 014/020 | Batch 0900/1284 | Cost: 0.0549
Epoch: 014/020 | Batch 0950/1284 | Cost: 0.0793
Epoch: 014/020 |

1284it [02:13,  9.60it/s]
428it [00:46,  9.22it/s]
428it [00:45,  9.37it/s]


Time elapsed: 103.44 min
Epoch: 015/020 | Batch 0000/1284 | Cost: 0.0668
Epoch: 015/020 | Batch 0050/1284 | Cost: 0.0470
Epoch: 015/020 | Batch 0100/1284 | Cost: 0.1366
Epoch: 015/020 | Batch 0150/1284 | Cost: 0.0569
Epoch: 015/020 | Batch 0200/1284 | Cost: 0.0466
Epoch: 015/020 | Batch 0250/1284 | Cost: 0.1084
Epoch: 015/020 | Batch 0300/1284 | Cost: 0.0515
Epoch: 015/020 | Batch 0350/1284 | Cost: 0.0021
Epoch: 015/020 | Batch 0400/1284 | Cost: 0.1268
Epoch: 015/020 | Batch 0450/1284 | Cost: 0.0843
Epoch: 015/020 | Batch 0500/1284 | Cost: 0.0657
Epoch: 015/020 | Batch 0550/1284 | Cost: 0.0110
Epoch: 015/020 | Batch 0600/1284 | Cost: 0.0030
Epoch: 015/020 | Batch 0650/1284 | Cost: 0.0081
Epoch: 015/020 | Batch 0700/1284 | Cost: 0.0163
Epoch: 015/020 | Batch 0750/1284 | Cost: 0.0479
Epoch: 015/020 | Batch 0800/1284 | Cost: 0.0372
Epoch: 015/020 | Batch 0850/1284 | Cost: 0.1377
Epoch: 015/020 | Batch 0900/1284 | Cost: 0.1623
Epoch: 015/020 | Batch 0950/1284 | Cost: 0.0755
Epoch: 015/020 

1284it [03:04,  6.94it/s]
428it [00:38, 11.06it/s]
428it [00:37, 11.28it/s]


Time elapsed: 111.92 min
Epoch: 016/020 | Batch 0000/1284 | Cost: 0.0941
Epoch: 016/020 | Batch 0050/1284 | Cost: 0.0690
Epoch: 016/020 | Batch 0100/1284 | Cost: 0.0025
Epoch: 016/020 | Batch 0150/1284 | Cost: 0.2844
Epoch: 016/020 | Batch 0200/1284 | Cost: 0.0181
Epoch: 016/020 | Batch 0250/1284 | Cost: 0.0026
Epoch: 016/020 | Batch 0300/1284 | Cost: 0.0058
Epoch: 016/020 | Batch 0350/1284 | Cost: 0.0887
Epoch: 016/020 | Batch 0400/1284 | Cost: 0.0242
Epoch: 016/020 | Batch 0450/1284 | Cost: 0.0023
Epoch: 016/020 | Batch 0500/1284 | Cost: 0.0050
Epoch: 016/020 | Batch 0550/1284 | Cost: 0.1647
Epoch: 016/020 | Batch 0600/1284 | Cost: 0.0117
Epoch: 016/020 | Batch 0650/1284 | Cost: 0.1609
Epoch: 016/020 | Batch 0700/1284 | Cost: 0.0100
Epoch: 016/020 | Batch 0750/1284 | Cost: 0.0523
Epoch: 016/020 | Batch 0800/1284 | Cost: 0.0073
Epoch: 016/020 | Batch 0850/1284 | Cost: 0.0112
Epoch: 016/020 | Batch 0900/1284 | Cost: 0.3612
Epoch: 016/020 | Batch 0950/1284 | Cost: 0.0180
Epoch: 016/020 

1284it [01:56, 11.06it/s]
428it [00:36, 11.75it/s]
428it [00:36, 11.72it/s]


Time elapsed: 119.12 min
Epoch: 017/020 | Batch 0000/1284 | Cost: 0.0031
Epoch: 017/020 | Batch 0050/1284 | Cost: 0.0397
Epoch: 017/020 | Batch 0100/1284 | Cost: 0.0052
Epoch: 017/020 | Batch 0150/1284 | Cost: 0.0182
Epoch: 017/020 | Batch 0200/1284 | Cost: 0.1997
Epoch: 017/020 | Batch 0250/1284 | Cost: 0.0061
Epoch: 017/020 | Batch 0300/1284 | Cost: 0.0022
Epoch: 017/020 | Batch 0350/1284 | Cost: 0.0455
Epoch: 017/020 | Batch 0400/1284 | Cost: 0.0215
Epoch: 017/020 | Batch 0450/1284 | Cost: 0.0498
Epoch: 017/020 | Batch 0500/1284 | Cost: 0.0389
Epoch: 017/020 | Batch 0550/1284 | Cost: 0.0031
Epoch: 017/020 | Batch 0600/1284 | Cost: 0.2566
Epoch: 017/020 | Batch 0650/1284 | Cost: 0.0073
Epoch: 017/020 | Batch 0700/1284 | Cost: 0.0229
Epoch: 017/020 | Batch 0750/1284 | Cost: 0.0612
Epoch: 017/020 | Batch 0800/1284 | Cost: 0.0310
Epoch: 017/020 | Batch 0850/1284 | Cost: 0.0859
Epoch: 017/020 | Batch 0900/1284 | Cost: 0.1291
Epoch: 017/020 | Batch 0950/1284 | Cost: 0.0401
Epoch: 017/020 

1284it [01:57, 10.92it/s]
428it [00:37, 11.54it/s]
428it [00:36, 11.61it/s]


Time elapsed: 126.22 min
Epoch: 018/020 | Batch 0000/1284 | Cost: 0.0189
Epoch: 018/020 | Batch 0050/1284 | Cost: 0.0020
Epoch: 018/020 | Batch 0100/1284 | Cost: 0.0289
Epoch: 018/020 | Batch 0150/1284 | Cost: 0.0047
Epoch: 018/020 | Batch 0200/1284 | Cost: 0.0345
Epoch: 018/020 | Batch 0250/1284 | Cost: 0.0991
Epoch: 018/020 | Batch 0300/1284 | Cost: 0.0402
Epoch: 018/020 | Batch 0350/1284 | Cost: 0.0028
Epoch: 018/020 | Batch 0400/1284 | Cost: 0.0017
Epoch: 018/020 | Batch 0450/1284 | Cost: 0.0127
Epoch: 018/020 | Batch 0500/1284 | Cost: 0.0112
Epoch: 018/020 | Batch 0550/1284 | Cost: 0.2700
Epoch: 018/020 | Batch 0600/1284 | Cost: 0.0457
Epoch: 018/020 | Batch 0650/1284 | Cost: 0.0052
Epoch: 018/020 | Batch 0700/1284 | Cost: 0.0111
Epoch: 018/020 | Batch 0750/1284 | Cost: 0.0092
Epoch: 018/020 | Batch 0800/1284 | Cost: 0.0016
Epoch: 018/020 | Batch 0850/1284 | Cost: 0.0393
Epoch: 018/020 | Batch 0900/1284 | Cost: 0.0019
Epoch: 018/020 | Batch 0950/1284 | Cost: 0.0426
Epoch: 018/020 

1284it [01:56, 10.99it/s]
428it [00:36, 11.70it/s]
428it [00:36, 11.62it/s]


Time elapsed: 133.31 min
Epoch: 019/020 | Batch 0000/1284 | Cost: 0.0181
Epoch: 019/020 | Batch 0050/1284 | Cost: 0.0015
Epoch: 019/020 | Batch 0100/1284 | Cost: 0.0021
Epoch: 019/020 | Batch 0150/1284 | Cost: 0.0456
Epoch: 019/020 | Batch 0200/1284 | Cost: 0.0706
Epoch: 019/020 | Batch 0250/1284 | Cost: 0.0017
Epoch: 019/020 | Batch 0300/1284 | Cost: 0.0756
Epoch: 019/020 | Batch 0350/1284 | Cost: 0.1165
Epoch: 019/020 | Batch 0400/1284 | Cost: 0.0068
Epoch: 019/020 | Batch 0450/1284 | Cost: 0.0769
Epoch: 019/020 | Batch 0500/1284 | Cost: 0.0184
Epoch: 019/020 | Batch 0550/1284 | Cost: 0.1105
Epoch: 019/020 | Batch 0600/1284 | Cost: 0.0662
Epoch: 019/020 | Batch 0650/1284 | Cost: 0.0307
Epoch: 019/020 | Batch 0700/1284 | Cost: 0.0350
Epoch: 019/020 | Batch 0750/1284 | Cost: 0.0009
Epoch: 019/020 | Batch 0800/1284 | Cost: 0.0054
Epoch: 019/020 | Batch 0850/1284 | Cost: 0.0028
Epoch: 019/020 | Batch 0900/1284 | Cost: 0.0037
Epoch: 019/020 | Batch 0950/1284 | Cost: 0.0382
Epoch: 019/020 

1284it [01:56, 11.04it/s]
428it [00:36, 11.65it/s]
428it [00:36, 11.72it/s]


Time elapsed: 140.37 min
Epoch: 020/020 | Batch 0000/1284 | Cost: 0.0065
Epoch: 020/020 | Batch 0050/1284 | Cost: 0.0028
Epoch: 020/020 | Batch 0100/1284 | Cost: 0.0031
Epoch: 020/020 | Batch 0150/1284 | Cost: 0.0479
Epoch: 020/020 | Batch 0200/1284 | Cost: 0.0417
Epoch: 020/020 | Batch 0250/1284 | Cost: 0.0278
Epoch: 020/020 | Batch 0300/1284 | Cost: 0.0433
Epoch: 020/020 | Batch 0350/1284 | Cost: 0.0093
Epoch: 020/020 | Batch 0400/1284 | Cost: 0.1345
Epoch: 020/020 | Batch 0450/1284 | Cost: 0.1095
Epoch: 020/020 | Batch 0500/1284 | Cost: 0.1126
Epoch: 020/020 | Batch 0550/1284 | Cost: 0.0044
Epoch: 020/020 | Batch 0600/1284 | Cost: 0.2095
Epoch: 020/020 | Batch 0650/1284 | Cost: 0.0294
Epoch: 020/020 | Batch 0700/1284 | Cost: 0.0087
Epoch: 020/020 | Batch 0750/1284 | Cost: 0.1749
Epoch: 020/020 | Batch 0800/1284 | Cost: 0.3778
Epoch: 020/020 | Batch 0850/1284 | Cost: 0.0345
Epoch: 020/020 | Batch 0900/1284 | Cost: 0.0151
Epoch: 020/020 | Batch 0950/1284 | Cost: 0.0137
Epoch: 020/020 

1284it [01:55, 11.07it/s]
428it [00:36, 11.77it/s]
428it [00:36, 11.79it/s]


Time elapsed: 147.40 min
Total Training Time: 147.40 min


## Evaluation

In [13]:
DEVICE = "cuda:0"
device = torch.device(DEVICE)

torch.manual_seed(RANDOM_SEED)

model_weights = torch.load(f"CV_data/models/{model_name}-{LEARNING_RATE}-epoch20.pth", map_location=torch.device('cpu'))

model = resnet50(NUM_CLASSES) 

# Loading the weights to the model
model.load_state_dict(model_weights)

model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

# Confusion matrix

!pip install seaborn

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

y_pred = []
y_true = []

model.eval()
with torch.set_grad_enabled(False):
    for inputs, labels in test_loader:
            inputs = inputs.to(DEVICE)
            _, classes = model(inputs) # Feed Network

            output = (torch.max(torch.exp(classes), 1)[1]).data.cpu().numpy()
            y_pred.extend(output) # Save Prediction

            labels = labels.data.cpu().numpy()
            y_true.extend(labels) # Save Truth

In [None]:
# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
classes = dataset.class_to_idx.keys()
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
plt.figure(figsize=(30, 25))
sn.heatmap(df_cm, annot=True)
plt.savefig(f'CV_data/conf_matrix/{model_name}-output.pdf')

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

y_pred = []
y_true = []

model.eval()
with torch.set_grad_enabled(False):
    for inputs, labels in test_loader:
            inputs = inputs.to(DEVICE)
            _, classes = model(inputs) # Feed Network

            output = (torch.max(torch.exp(classes), 1)[1]).data.cpu().numpy()
            y_pred.extend(output) # Save Prediction

            labels = labels.data.cpu().numpy()
            y_true.extend(labels) # Save Truth



In [None]:
# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
classes = dataset.class_to_idx.keys()
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
plt.figure(figsize=(30, 25))
sn.heatmap(df_cm, annot=True)
plt.savefig(f'CV_data/conf_matrix/{model_name}-output.pdf')

# Test examples

In [None]:
for batch_idx, (features, targets) in enumerate(test_loader):

    features = features
    targets = targets
    break

NUM_IMG = 4
fig, axs = plt.subplots(1, NUM_IMG)
imgs = []
classes = []

for idx in range(NUM_IMG):
    nhwc_img = np.transpose(features[idx], axes=(1, 2, 0))
    imgs.append(nhwc_img)

    model.eval()
    _ , probas = model(features.to(device)[idx, None])
    probas = probas.cpu().detach().numpy()

    classes.append(
        {"class": idx_to_class[np.argmax(probas[0])],
         "proba": np.max(probas[0])})

for i, ax in enumerate(axs.flat):
    if i < len(imgs):
        ax.imshow(imgs[i])
        ax.set_title(f'Class: {classes[i]["class"]}, \nProbability:  {classes[i]["proba"]:.4f}', fontsize=8)
    ax.set_xticks([])
    ax.set_yticks([])


# Own photos

In [None]:
import os
path = "CV_data/own_photos/"
classes = os.listdir(path)
classes = [fldr for fldr in classes if not fldr.startswith('.')]

print(classes)

transform = transforms.Compose([
    transforms.Resize(RESOLUTION),
    transforms.CenterCrop(RESOLUTION),
    transforms.ToTensor()
])

custom_test_dataset = datasets.ImageFolder(root=path, transform=transform)

custom_test_loader = DataLoader(dataset=custom_test_dataset, 
                         batch_size=BATCH_SIZE, 
                         shuffle=True)


In [None]:
for batch_idx, (features, targets) in enumerate(custom_test_loader):

    features = features
    targets = targets
    break

NUM_IMG = 4
fig, axs = plt.subplots(1, NUM_IMG)
imgs = []
classes = []

for idx in range(NUM_IMG):
    nhwc_img = np.transpose(features[idx], axes=(1, 2, 0))
    imgs.append(nhwc_img)

    model.eval()
    _ , probas = model(features.to(device)[idx, None])
    probas = probas.cpu().detach().numpy()

    classes.append(
        {"class": idx_to_class[np.argmax(probas[0])],
         "proba": np.max(probas[0])})

for i, ax in enumerate(axs.flat):
    if i < len(imgs):
        ax.imshow(imgs[i])
        ax.set_title(f'Class: {classes[i]["class"]}, \nProbability:  {classes[i]["proba"]:.4f}', fontsize=8)
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
from lib.nutrifacts import retrieve_nutrition_facts
import json

from IPython.core.display import HTML
HTML("&micro;")

print(json.dumps(retrieve_nutrition_facts(classes[2]["class"]), indent=4))