# Code for the paper: 

## Efficient and Mobile Deep Learning Architectures for Fast Identification of BacterialStrains in Resource-Constrained Devices

### Architecture: SqueezeNet 1.0
### Data: Original + Cross-Validation

#### Loading libraries

In [1]:
# Imports here
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb
import torch.nn.functional as F
import pandas as pd
import time
import numpy as np

from PIL import Image
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
from sklearn.model_selection import KFold
from torch import nn
from torch import optim
from torch.utils.data import SubsetRandomSampler
from torchvision import datasets, transforms, models
from pytorch_model_summary import summary

# Archs not in Pytorch
from efficientnet_pytorch import EfficientNet

# External functions
from scripts.utils import *

In [2]:
print(torch.__version__)

import sys
print(sys.version)

1.5.1
3.8.5 (default, Jul 28 2020, 12:59:40) 
[GCC 9.3.0]


## Data paths and hyperparameters

#### Hyperparameters and dataset details. 

In [3]:
# Dataset details
dataset_version = 'original' # original or augmented
img_shape = (224,224)
img_size = str(img_shape[0])+"x"+str(img_shape[1])

# Root directory of dataset
data_dir = '/home/yibbtstll/venvs/pytorch_gpu/CySDeepBacterial/Dataset/DIBaS/'

train_batch_size = 32
val_test_batch_size = 32
feature_extract = False
pretrained = True
h_epochs = 15
kfolds = 10

## Data preparation and loading

#### Defining transforms and creating dataloaders

In [4]:
# Define transforms for input data
training_transforms = transforms.Compose([transforms.Resize((224,224), Image.LANCZOS),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.485, 0.456, 0.406], 
                                                               [0.229, 0.224, 0.225])])

# TODO: Load the datasets with ImageFolder
total_set = datasets.ImageFolder(data_dir, transform=training_transforms)

# Defining folds
splits = KFold(n_splits = kfolds, shuffle = True, random_state = 42)

#### Visualizing the target classes in dataset

In [5]:
train_labels = {value : key for (key, value) in total_set.class_to_idx.items()}
    
print(len(train_labels)) 
print(train_labels)

32
{0: 'Acinetobacter.baumanii', 1: 'Actinomyces.israeli', 2: 'Bacteroides.fragilis', 3: 'Bifidobacterium.spp', 4: 'Clostridium.perfringens', 5: 'Enterococcus.faecalis', 6: 'Enterococcus.faecium', 7: 'Escherichia.coli', 8: 'Fusobacterium', 9: 'Lactobacillus.casei', 10: 'Lactobacillus.crispatus', 11: 'Lactobacillus.delbrueckii', 12: 'Lactobacillus.gasseri', 13: 'Lactobacillus.jehnsenii', 14: 'Lactobacillus.johnsonii', 15: 'Lactobacillus.paracasei', 16: 'Lactobacillus.plantarum', 17: 'Lactobacillus.reuteri', 18: 'Lactobacillus.rhamnosus', 19: 'Lactobacillus.salivarius', 20: 'Listeria.monocytogenes', 21: 'Micrococcus.spp', 22: 'Neisseria.gonorrhoeae', 23: 'Porfyromonas.gingivalis', 24: 'Propionibacterium.acnes', 25: 'Proteus', 26: 'Pseudomonas.aeruginosa', 27: 'Staphylococcus.aureus', 28: 'Staphylococcus.epidermidis', 29: 'Staphylococcus.saprophiticus', 30: 'Streptococcus.agalactiae', 31: 'Veionella'}


## Model definition and inicialization

#### Freezing pre-trained parameters, finetunning the classifier to output 32 classes.

In [6]:
# Freeze pretrained model parameters to avoid backpropogating through them
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        print("Setting grad to false.")
        for param in model.parameters():
            param.requires_grad = False
    
    return model

def get_device():
    # Model and criterion to GPU
    if torch.cuda.is_available():
        return 'cuda'
    else:
        return 'cpu'

In [7]:
models.squeezenet1_1(pretrained=False)

SqueezeNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (3): Fire(
      (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (4): Fire(
      (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (5): MaxPool2d

In [8]:
def load_model():
    # Transfer Learning
    model = models.squeezenet1_1(pretrained=pretrained)
    
    # Mode
    model = set_parameter_requires_grad(model, feature_extract)
    
    # Fine tuning
    # Build custom classifier
    # Build custom classifier
    model.classifier[1] = nn.Conv2d(512, 32, 
                                kernel_size=(1,1), stride=(1,1))
    return model

def create_optimizer(model):
    # Parameters to update
    params_to_update = model.parameters()

    if feature_extract:
        params_to_update = []
        for param in model.parameters():
            if param.requires_grad == True:
                params_to_update.append(param)

    else:
        n_params = 0
        for param in model.parameters():
            if param.requires_grad == True:
                n_params += 1


    # Loss function and gradient descent

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(params_to_update, 
                          lr=0.001, 
                          weight_decay=0.000004)
    
    return criterion.to(get_device()), model.to(get_device()), optimizer

## Training, validation and test functions

In [9]:
# Variables to store fold scores
train_acc = []
test_top1_acc = []
test_top5_acc = []
test_precision = []
test_recall = []
test_f1 = []
times = []

for fold, (train_idx, valid_idx) in enumerate(splits.split(total_set)):
    
    start_time = time.time()
    
    print('Fold : {}'.format(fold))
    
    # Train and val samplers
    train_sampler = SubsetRandomSampler(train_idx)
    print("Samples in training:", len(train_sampler))
    valid_sampler = SubsetRandomSampler(valid_idx)
    print("Samples in test:", len(valid_sampler))
    
    # Train and val loaders
    train_loader = torch.utils.data.DataLoader(
                      total_set, batch_size=train_batch_size, sampler=train_sampler)
    valid_loader = torch.utils.data.DataLoader(
                      total_set, batch_size=1, sampler=valid_sampler)
    
    device = get_device()
    
    criterion, model, optimizer = create_optimizer(load_model())
    
    # Training
    for epoch in range(h_epochs):
        
        model.train()
        running_loss = 0.0
        running_corrects = 0
        trunning_corrects = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
            running_loss += loss.item() * inputs.size(0)
            running_corrects += (preds == labels).sum()
            trunning_corrects += preds.size(0)
            

        epoch_loss = running_loss / trunning_corrects
        epoch_acc = (running_corrects.double()*100) / trunning_corrects
        train_acc.append(epoch_acc.item())
        
        print('\t\t Training: Epoch({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc))
        
        # Validation
        
        model.eval()  
        
        vrunning_loss = 0.0
        vrunning_corrects = 0
        num_samples = 0
        
        for data, labels in valid_loader:
            
            data = data.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            
            with torch.no_grad():
                outputs = model(data)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                
            vrunning_loss += loss.item() * data.size(0)
            vrunning_corrects += (preds == labels).sum()
            num_samples += preds.size(0)
            
        vepoch_loss = vrunning_loss/num_samples
        vepoch_acc = (vrunning_corrects.double() * 100)/num_samples
        
        print('\t\t Validation({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, vepoch_loss, vepoch_acc))
    
    # Calculating and appending scores to this fold
    model.class_to_idx = total_set.class_to_idx
    scores = get_scores(model, valid_loader)
    
    test_top1_acc.append(scores[0])
    test_top5_acc.append(scores[1])
    test_precision.append(scores[2])
    test_recall.append(scores[3])
    test_f1.append(scores[4])
    
    time_fold = time.time() - start_time
    times.append(time_fold)
    print("Total time per fold: %s seconds." %(time_fold))

Fold : 0
Samples in training: 602
Samples in test: 67
		 Training: Epoch(0) - Loss: 3.4969, Acc: 2.8239
		 Validation(0) - Loss: 3.4696, Acc: 4.4776
		 Training: Epoch(1) - Loss: 3.4255, Acc: 4.1528
		 Validation(1) - Loss: 3.4953, Acc: 0.0000
		 Training: Epoch(2) - Loss: 3.3677, Acc: 5.6478
		 Validation(2) - Loss: 3.3989, Acc: 2.9851
		 Training: Epoch(3) - Loss: 3.2932, Acc: 7.8073
		 Validation(3) - Loss: 3.4346, Acc: 1.4925
		 Training: Epoch(4) - Loss: 3.3071, Acc: 5.6478
		 Validation(4) - Loss: 3.3650, Acc: 5.9701
		 Training: Epoch(5) - Loss: 3.1392, Acc: 6.4784
		 Validation(5) - Loss: 3.2957, Acc: 5.9701
		 Training: Epoch(6) - Loss: 3.0170, Acc: 10.6312
		 Validation(6) - Loss: 3.0803, Acc: 7.4627
		 Training: Epoch(7) - Loss: 2.7584, Acc: 13.9535
		 Validation(7) - Loss: 2.7815, Acc: 14.9254
		 Training: Epoch(8) - Loss: 2.5780, Acc: 17.6080
		 Validation(8) - Loss: 2.6250, Acc: 26.8657
		 Training: Epoch(9) - Loss: 2.3974, Acc: 21.7608
		 Validation(9) - Loss: 2.6656, Ac

		 Validation(3) - Loss: 2.7977, Acc: 20.8955
		 Training: Epoch(4) - Loss: 2.3804, Acc: 25.4153
		 Validation(4) - Loss: 1.9747, Acc: 29.8507
		 Training: Epoch(5) - Loss: 1.8685, Acc: 37.3754
		 Validation(5) - Loss: 1.4690, Acc: 46.2687
		 Training: Epoch(6) - Loss: 1.4979, Acc: 46.8439
		 Validation(6) - Loss: 1.5777, Acc: 44.7761
		 Training: Epoch(7) - Loss: 1.1540, Acc: 57.1429
		 Validation(7) - Loss: 1.2125, Acc: 55.2239
		 Training: Epoch(8) - Loss: 1.0370, Acc: 59.4684
		 Validation(8) - Loss: 1.2635, Acc: 55.2239
		 Training: Epoch(9) - Loss: 0.8388, Acc: 68.9369
		 Validation(9) - Loss: 1.4789, Acc: 61.1940
		 Training: Epoch(10) - Loss: 1.1793, Acc: 58.9701
		 Validation(10) - Loss: 1.3192, Acc: 50.7463
		 Training: Epoch(11) - Loss: 0.8255, Acc: 71.2625
		 Validation(11) - Loss: 0.9842, Acc: 62.6866
		 Training: Epoch(12) - Loss: 0.7307, Acc: 74.0864
		 Validation(12) - Loss: 1.0678, Acc: 70.1493
		 Training: Epoch(13) - Loss: 0.5803, Acc: 76.7442
		 Validation(13) - Los

In [10]:
print("Train accuracy average: ", np.mean(train_acc) / 100)
print("Top-1 test accuracy average: ", np.mean(test_top1_acc))
print("Top-5 test accuracy average: ", np.mean(test_top5_acc))
print("Weighted Precision test accuracy average: ", np.mean(test_precision))
print("Weighted Recall test accuracy average: ", np.mean(test_recall))
print("Weighted F1 test accuracy average: ", np.mean(test_f1))
print("Average time per fold (seconds):", np.mean(times))

Train accuracy average:  0.32520645205129756
Top-1 test accuracy average:  0.5810719131614653
Top-5 test accuracy average:  0.9686114880144732
Weighted Precision test accuracy average:  0.5675010768667483
Weighted Recall test accuracy average:  0.5810719131614653
Weighted F1 test accuracy average:  0.5403571650722261
Average time per fold (seconds): 320.6574599027634
