In [13]:
# load data
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import os
import cnn

batch_size = 8

transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

train_dataset = cnn.CustomImageDataset("annotations_file_training.csv", "All_training", 
                                       transform=transforms)

validation_dataset = cnn.CustomImageDataset("annotations_file_validation.csv", "All_validation", 
                                            transform=transforms)

test_dataset = cnn.CustomImageDataset("annotations_file_evaluation.csv", "All_evaluation", 
                                      transform=transforms)
                                            
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

dataloaders = {'train': train_loader, 'val': validation_loader}
dataset_sizes = {'train': len(train_dataset), 'val': len(validation_dataset)}
                                          
classes = ('bread', 'dairy' 'products', 'dessert', 'eggs', 'fried food', 'meat', 'noodles-pasta',
           'rice', 'seafood', 'soup', 'vegetables-fruits')
print('Done')

Done


In [None]:
# train our model
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
import cnn
import importlib
importlib.reload(cnn)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model = cnn.ConvNet().to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

model = cnn.train_model(model, criterion, optimizer, exp_lr_scheduler, 
                        dataloaders=dataloaders, dataset_sizes=dataset_sizes, num_epochs=15)

print('Finished Training')
PATH = './CustomNet/CustomNet.pth'
torch.save(model, PATH)

In [None]:
# Test the model
try:
    model
except NameError:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.load('CustomNet/CustomNet.pth').to(device)
    print('model loaded')
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    n_class_correct = [0 for i in range(11)]
    n_class_samples = [0 for i in range(11)]
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images.float())
        # max returns (value ,index)
        _, predicted = torch.max(outputs, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

        for i in range(labels.size(0)):
            label = labels[i]
            pred = predicted[i]
            if (label == pred):
                n_class_correct[label] += 1
            n_class_samples[label] += 1

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the network: {acc} %')

    for i in range(11):
        acc = 100.0 * n_class_correct[i] / n_class_samples[i]
        print(f'Accuracy of {classes[i]}: {acc} %')

In [None]:
# Display the heatmaps
import cnn
from PIL import Image
import cv2
import importlib
import numpy as np
importlib.reload(cnn)

transform = torchvision.transforms.Resize((512, 512))
for i, (images, labels) in enumerate(test_loader):
    
    images = images.to(device)
    output = model(images.float(), feature_conv=True)
    images = transform(images)
    print(model.fc1.weight.shape)
    print(model.fc2.weight.shape)
    print(model.fc3.weight.shape)
    cam = cnn.returnCAM(output, model.fc1.weight, model.fc2.weight, model.fc3.weight, list(range(0, 11)))
    print(len(cam))
 
    #files = os.listdir('./All_evaluation')
  
    images = torchvision.utils.make_grid(images.cpu())
    images = np.transpose(images, (1, 2, 0))
    fig, axs = plt.subplots(2, 6, figsize=(20,6))
    
    for i in range(12):
        axs[int(i>=6), i%6].imshow(images)
        axs[int(i>=6), i%6].axis('off')
        if i < 11:
            heatmap = cam[i]
            heatmap = heatmap.cpu()
            heatmap = heatmap.detach().numpy()
            heatmap = np.squeeze(heatmap)
            axs[int(i>=6), i%6].imshow(heatmap, cmap='jet', alpha=0.5)
            axs[int(i>=6), i%6].set_title(classes[i])
            
        
    plt.show()