# <center> Efficient detection of longitudinal bacteria fission using transfer learning in Deep Neural Networks
# <center> Supplemental Material

## Training code for longitudinal division classification

#### Requirements:
 * GPU
 * Nvidia driver
 * cuda version > 10
 * torch
 * torchvision
 * sklearn
 * numpy
 * pandas

#### Folders
 * data: contains all image samples in the folders train, validation and test. 
 * model: where the trained model will be saved
 
#### train_functions_sgd.py code contains all functions that are called in the main function. Do not remove or delete this file.

In [None]:
from __future__ import print_function, division
import torch
import torchvision
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
import os
import torch.nn as nn
import train_functions_sgd as trnfn
import time
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import sys

### Function for making the training

In [None]:
def runtraining():
    print('processing data')
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    data_dir = 'data/'

    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x])
                      for x in ['train', 'val', 'test']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=8,
                                                  shuffle=True, num_workers=8)
                   for x in ['train', 'val', 'test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
    class_names = image_datasets['train'].classes

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f'Device: {device} / Dataset_size: {dataset_sizes}')

    # prepare for training
    model_ft = models.resnet18(pretrained=True)
    num_ftrs = model_ft.fc.in_features

    model_ft.fc = nn.Linear(num_ftrs, 2)

    model_ft = model_ft.to(device)

    criterion = nn.CrossEntropyLoss()

    # Observe that all parameters are being optimized
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

    model_ft = trnfn.train_model(model_ft, criterion, optimizer_ft, 
                                 exp_lr_scheduler, dataloaders, 
                                 dataset_sizes, device, num_epochs=25)

    # Save model
    torch.save(model_ft.state_dict(), "model/trained_net_sgd.pth")
  
    y_true, y_pred = trnfn.test_model(model_ft, criterion, device, dataloaders, dataset_sizes)
    y_pred = y_pred + 1
    y_pred = y_pred % 2

    print("accuracy score: ",accuracy_score(y_true, y_pred),"\n")
    print("confusion matrix:")
    print(confusion_matrix(y_true, y_pred),"\n")
    print("classification report:")
    print(classification_report(y_true, y_pred))


### Running the training 

In [None]:
runtraining()