<a href="https://colab.research.google.com/github/gsrflo/deep-apple-learning/blob/dev/deep_apple_learning_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Import Packages


In [None]:
#If pytorch lightning not installed:
!pip install pytorch-lightning

import os
import csv
import time
import numpy as np
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import seaborn as sn
import pandas as pd

# Module for Google Drive
from google.colab import drive

# Module for Importing Images
from PIL import Image 

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader


import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import collections


print(torch.__version__)

### Import Drive Content

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


### Definition of Dataset




In [None]:
class AppleDataset(Dataset):
    def __init__(self, data_path, img_type, set_type):
        #choose between rgb or ir data: train, validate or test
        self.img_type = img_type # 'rgb' or 'ir'
        self.set_type = set_type # 'train', 'validate', 'test', 'evaluation'

        #set paths to training / validation / test folders:
        self.data_path = data_path
        self.target_path = os.path.join(data_path, self.img_type, self.set_type)

        #get labels and filenames from labels.txt:
        labels_path = os.path.join(data_path, 'labels_' + self.img_type + '.txt')
        labels_txt = open(labels_path, 'r')
        labels_txt = labels_txt.readlines()
        label = list()
        samplename = list()
        for i in range(len(labels_txt)):
            txtrow = labels_txt[i].rstrip().split(' ')
            if(len(txtrow) > 1):    #because of empty row between apples in txt
              samplename.append(txtrow[0] + '.jpg')
              label.append(txtrow[1])

        #transforming label from string to int:
        label =  [int(i) for i in label]

        #combine path, label and name of sample:
        img_list = sorted(os.listdir(self.target_path))
        self.img_path_label = list()
        for fp in img_list:
            img_label = label[samplename.index(fp)]
            full_fp = os.path.join(self.target_path, fp)
            self.img_path_label.append((full_fp, img_label, fp))

        #data augmentation:
        self.tensor_transform = torchvision.transforms.ToTensor()
        
        self.random_flip = torchvision.transforms.RandomHorizontalFlip(p=0.5)
        self.random_affine = torchvision.transforms.RandomAffine(20)
            
        self.transform = torchvision.transforms.Compose([self.tensor_transform, self.random_flip,self.random_affine])
        

    def __len__(self):
        return len(self.img_path_label)

    def __getitem__(self, idx):
        #get filepath of sample image, label, and name:
        (fp, label, name) = self.img_path_label[idx]

        #sample image:
        img = Image.open(fp)
        
        #data augmentation only for train set:
        if self.set_type == 'train':
          sample = self.transform(img)
        else:
          sample = self.tensor_transform(img) 

        #normalization:
        mean = torch.mean(sample, axis=(1,2), keepdims=True)
        std = torch.std(sample, axis=(1,2), keepdims=True)
        sample = (sample - mean) / (std + 1e-11)

        return sample, label, name, self.tensor_transform(img)      #name and img is only used for plotting

### Set DataSet and DataLoader

In [None]:
#path of dataset:
data_path = '/content/drive/MyDrive/deep-apple-learning/dataset/' 

#path of model (for manual storing and checkpoints):
model_dir = '/content/drive/MyDrive/deep-apple-learning/checkpoint_models/' 

batch_size = 64

#infrared:
train_dataset_ir = AppleDataset(data_path, 'ir', 'train')
train_dataloader_ir = DataLoader(train_dataset_ir, batch_size=batch_size, shuffle=True)
val_dataset_ir = AppleDataset(data_path, 'ir', 'validate')
val_dataloader_ir = DataLoader(val_dataset_ir, batch_size=52, shuffle=False)      #number of val images = 52
test_dataset_ir = AppleDataset(data_path, 'ir', 'test')
test_dataloader_ir = DataLoader(test_dataset_ir, batch_size=55, shuffle=False)    #number of test images = 55

#rgb:
train_dataset_rgb = AppleDataset(data_path, 'rgb', 'train')
train_dataloader_rgb = DataLoader(train_dataset_rgb, batch_size=batch_size, shuffle=True)
val_dataset_rgb = AppleDataset(data_path, 'rgb', 'validate')
val_dataloader_rgb = DataLoader(val_dataset_rgb, batch_size=64, shuffle=False)      #number of val images = 64
test_dataset_rgb = AppleDataset(data_path, 'rgb', 'test')
test_dataloader_rgb = DataLoader(test_dataset_rgb, batch_size=55, shuffle=False)    #number of test images = 55

### Take and show a sample

In [None]:
#infrared sample:
sample_train_ir, label_train_ir, name_train_ir, img_train_ir = next(iter(train_dataloader_ir))

#rgb samples:
sample_train_rgb, label_train_rgb, name_train_rgb, img_train_rgb = next(iter(train_dataloader_rgb))



In [None]:
#choose IR sample:
sample = img_train_ir
label = label_train_ir
name = name_train_ir

plot_idx = np.random.randint(0, batch_size)
img = sample[plot_idx]
plt.imshow(img.repeat(3,1,1).permute(1,2,0))
#plt.imshow(img[:][:][0], cmap='gray')
plt.title('Filename: ' + name[plot_idx] + ', Label: ' + str(label[plot_idx].detach().numpy()))

In [None]:
#choose RGB sample:
sample = img_train_rgb
label = label_train_rgb
name = name_train_rgb

plot_idx = np.random.randint(0, batch_size)
img = sample[plot_idx]

plt.imshow(img.permute(1,2,0))
plt.title('Filename: ' + name[plot_idx] + ', Label: ' + str(label[plot_idx].detach().numpy()))

### Choose your device (GPU or CPU)

In [None]:
#device = 'cpu'
device = 'cuda'
print('Current Device : {}'.format(device))

Current Device : cuda


### Define models

In [None]:
class IR_Model(pl.LightningModule):
    def __init__(self):
        super(IR_Model, self).__init__()
        
        self.cnn = nn.Sequential(collections.OrderedDict([
                  ('conv1', nn.Conv2d(in_channels=1, out_channels = 16, kernel_size = 3, stride = 1, padding = 1)),
                  ('bn1', nn.BatchNorm2d(16)),
                  ('elu1', nn.ELU()),
                  ('maxpool1', nn.MaxPool2d(kernel_size = 2, stride = 2)),
                  ('conv2', nn.Conv2d(in_channels=16, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)),
                  ('bn2', nn.BatchNorm2d(64)),
                  ('elu2', nn.ELU()),
                  ('maxpool2', nn.MaxPool2d(kernel_size = 2, stride = 2)),
                  ('conv3', nn.Conv2d(in_channels=64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)),
                  ('bn3', nn.BatchNorm2d(128)),
                  ('elu3', nn.ELU()),
                  ('maxpool3', nn.MaxPool2d(kernel_size = 2, stride = 2)),
                  ('conv4', nn.Conv2d(in_channels=128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)),
                  ('bn4', nn.BatchNorm2d(256)),
                  ('elu4', nn.ELU()),
                  ('avgpool', nn.AvgPool2d(kernel_size=15)),
                  ('flatten', nn.Flatten())
                  ]))
        
        self.fc1 = nn.Linear(256,64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64,16)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(16,1)
        self.sig = nn.Sigmoid()

    def forward(self, img):
        out = self.cnn(img.to(device))
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.sig(out)

        return out

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        batch, label, _ , _  = train_batch
        pred = self(batch)
        
        #loss function:
        loss = torch.tensor(0., requires_grad=True).to(device)
        for idx, samplepred in enumerate(pred):
            if label[idx]==1:
              loss = loss - torch.log(samplepred)
            elif label[idx]==0:  
              loss = loss - torch.log(1. - samplepred)
        self.log('train_loss', loss)

        #accuracy:
        pred_ir = (pred > 0.5)
        acc = torch.sum(pred_ir.transpose(0,1) == label)/len(pred_ir)
        self.log('train_accuracy', acc, prog_bar=True)

        return loss

    def validation_step(self, val_batch, batch_idx):
        sample , label, _, _   = val_batch
        pred = self(sample)    

        #loss function:
        loss = torch.tensor(0. , requires_grad=True).to(device)
        for idx, samplepred in enumerate(pred):
            if label[idx]==1:
              loss = loss - torch.log(samplepred)
            elif label[idx]==0: 
              loss = loss - torch.log(1. - samplepred)
        self.log('val_loss', loss, prog_bar=True)

        #accuracy:
        pred_ir = (pred > 0.5)
        acc = torch.sum(pred_ir.transpose(0,1) == label)/len(pred_ir)
        self.log('val_accuracy', acc,prog_bar=True)


In [None]:
class RGB_Model(pl.LightningModule):
    def __init__(self):
        super(RGB_Model, self).__init__()

        self.cnn = nn.Sequential(collections.OrderedDict([
                  ('conv1', nn.Conv2d(in_channels=3, out_channels = 16, kernel_size = 3, stride = 1, padding = 1)),
                  ('bn1', nn.BatchNorm2d(16)),
                  ('elu1', nn.ELU()),
                  ('maxpool1', nn.MaxPool2d(kernel_size = 2, stride = 2)),
                  ('conv2', nn.Conv2d(in_channels=16, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)),
                  ('bn2', nn.BatchNorm2d(64)),
                  ('elu2', nn.ELU()),
                  ('maxpool2', nn.MaxPool2d(kernel_size = 2, stride = 2)),
                  ('conv3', nn.Conv2d(in_channels=64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)),
                  ('bn3', nn.BatchNorm2d(128)),
                  ('elu3', nn.ELU()),
                  ('maxpool3', nn.MaxPool2d(kernel_size = 2, stride = 2)),
                  ('conv4', nn.Conv2d(in_channels=128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)),
                  ('bn4', nn.BatchNorm2d(256)),
                  ('elu4', nn.ELU()),
                  ('avgpool', nn.AvgPool2d(kernel_size=15)),
                  ('flatten', nn.Flatten())
                  ]))
        
        self.fc1 = nn.Linear(256,64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64,16)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(16,1)
        self.sig = nn.Sigmoid()

        self.dropout = nn.Dropout(p=0)

    def forward(self, img):
        out = self.cnn(img.to(device))
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.sig(out)

        return out

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        batch, label, _ , _  = train_batch
        pred = self(batch)

        #loss function:
        loss = torch.tensor(0., requires_grad=True).to(device)
        for idx, samplepred in enumerate(pred):
            if label[idx]==1:
              loss = loss - torch.log(samplepred)
            elif label[idx]==0: 
              loss = loss - torch.log(1. - samplepred)
        self.log('train_loss', loss)

        #accuracy:
        pred_rgb = (pred > 0.5)
        acc = torch.sum(pred_rgb.transpose(0,1) == label)/len(pred_rgb)
        self.log('train_accuracy', acc, prog_bar=True)

        return loss
      
    def validation_step(self, val_batch, batch_idx):
        sample, label, _ , _  = val_batch
        pred = self(sample)    

        #loss function:
        loss = torch.tensor(0. , requires_grad=True).to(device)
        for idx, samplepred in enumerate(pred):
            if label[idx]==1:
              loss = loss - torch.log(samplepred)
            elif label[idx]==0:  
              loss = loss - torch.log(1. - samplepred)
        self.log('val_loss', loss, prog_bar=True)

        #accuracy:
        pred_rgb = (pred > 0.5)
        acc = torch.sum(pred_rgb.transpose(0,1) == label)/len(pred_rgb)
        self.log('val_accuracy', acc, prog_bar=True)


In [None]:
class IR_Model_pretrained(pl.LightningModule):
    def __init__(self):
        super(IR_Model_pretrained, self).__init__()

        #choose pretrained model:
        #self.backbone = torchvision.models.vgg19_bn(pretrained=True)
        self.backbone = torchvision.models.resnet50(pretrained=True)

        #fix initial layers of pretrained model:
        for p in list(self.backbone.children())[:-3]:
            p.requires_grad = False

        # get the structure until the Fully Connected Layer
        modules = list(self.backbone.children())[:-1]
        self.backbone = nn.Sequential(*modules)

        #create new network:
        self.fc1 = nn.Linear(2048,512)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(512,128)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128,16)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(16,1)
        self.sig = nn.Sigmoid()

        self.dropout = nn.Dropout(p=0.2)


    def forward(self, img):
        #Convert grayscale IR image to RGB image:
        img = img.repeat(1, 3, 1, 1)

        out = self.backbone(img.to(device))
        out = self.fc1(out.view(out.size(0), -1))
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.relu3(out)
        out = self.fc4(out)
        out = self.sig(out)

        return out

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        batch, label, _ , _  = train_batch

        #Convert grayscale IR image to RGB image:
        batch = batch.repeat(1, 3, 1, 1)

        out = self.backbone(batch)
        out = self.dropout(self.fc1(out.view(out.size(0), -1)))
        out = self.relu1(out)
        out = self.dropout(self.fc2(out))
        out = self.relu2(out)
        out = self.dropout(self.fc3(out))
        out = self.relu3(out)
        out = self.dropout(self.fc4(out))
        pred = self.sig(out)

        #loss function:
        loss = torch.tensor(0., requires_grad=True).to(device)
        for idx, samplepred in enumerate(pred):
            if label[idx]==1:
              loss = loss - torch.log(samplepred)
            elif label[idx]==0:  
              loss = loss - torch.log(1. - samplepred)
        self.log('train_loss', loss)

        #accuracy:
        pred_ir = (pred > 0.5)
        acc = torch.sum(pred_ir.transpose(0,1) == label)/len(pred_ir)
        self.log('train_accuracy', acc, prog_bar=True)

        return loss

    def validation_step(self, val_batch, batch_idx):
        sample, label, _ , _   = val_batch
        pred = self(sample)    

        #loss function:
        loss = torch.tensor(0. , requires_grad=True).to(device)
        for idx, samplepred in enumerate(pred):
            if label[idx]==1:
              loss = loss - torch.log(samplepred)
            elif label[idx]==0: 
              loss = loss - torch.log(1. - samplepred)
        self.log('val_loss', loss, prog_bar=True)

        #accuracy:
        pred_ir = (pred > 0.5)
        acc = torch.sum(pred_ir.transpose(0,1) == label)/len(pred_ir)
        self.log('val_accuracy', acc, prog_bar=True)

In [None]:
class RGB_Model_pretrained(pl.LightningModule):
    def __init__(self):
        super(RGB_Model_pretrained, self).__init__()

        #choose pretrained model:
        #self.backbone = torchvision.models.vgg19_bn(pretrained=True)
        self.backbone = torchvision.models.resnet50(pretrained=True)

        #fix initial layers of pretrained model:
        for p in list(self.backbone.children())[:-3]:
            p.requires_grad = False

        # get the structure until the Fully Connected Layer
        modules = list(self.backbone.children())[:-1]
        self.backbone = nn.Sequential(*modules)

        #create new network:
        self.fc1 = nn.Linear(2048,512)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(512,128)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128,16)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(16,1)
        self.sig = nn.Sigmoid()

        self.dropout = nn.Dropout(p=0.2)


    def forward(self, img):
        out = self.backbone(img.to(device))
        out = self.fc1(out.view(out.size(0), -1))
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.relu3(out)
        out = self.fc4(out)
        out = self.sig(out)

        return out

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        batch, label, _ , _ = train_batch
        out = self.backbone(batch)
        out = self.dropout(self.fc1(out.view(out.size(0), -1)))
        out = self.relu1(out)
        out = self.dropout(self.fc2(out))
        out = self.relu2(out)
        out = self.dropout(self.fc3(out))
        out = self.relu3(out)
        out = self.dropout(self.fc4(out))
        pred = self.sig(out)

        #loss function:
        loss = torch.tensor(0., requires_grad=True).to(device)
        for idx, samplepred in enumerate(pred):
            if label[idx]==1:
              loss = loss - torch.log(samplepred)
            elif label[idx]==0:  
              loss = loss - torch.log(1. - samplepred)
        self.log('train_loss', loss)

        #accuracy:
        pred_ir = (pred > 0.5)
        acc = torch.sum(pred_ir.transpose(0,1) == label)/len(pred_ir)
        self.log('train_accuracy', acc, prog_bar=True)

        return loss

    def validation_step(self, val_batch, batch_idx):
        sample, label, _ , _  = val_batch
        pred = self(sample)    

        #loss function:
        loss = torch.tensor(0. , requires_grad=True).to(device)
        for idx, samplepred in enumerate(pred):
            if label[idx]==1:
              loss = loss - torch.log(samplepred)
            elif label[idx]==0: 
              loss = loss - torch.log(1. - samplepred)
        self.log('val_loss', loss, prog_bar=True)

        #accuracy:
        pred_ir = (pred > 0.5)
        acc = torch.sum(pred_ir.transpose(0,1) == label)/len(pred_ir)
        self.log('val_accuracy', acc, prog_bar=True)


### Create model and train


In [None]:
#IR Model (choose selfmade or pretrained):

# model_ir = IR_Model()
model_ir = IR_Model_pretrained()

model_ir = model_ir.to(device)



#RGB Model (choose selfmade or pretrained):

# model_rgb = RGB_Model()
model_rgb = RGB_Model_pretrained()

model_rgb = model_rgb.to(device)

In [None]:
#Train IR Model:

# Saves checkpoints to 'my/path/' at every epoch
checkpoint_callback = ModelCheckpoint(dirpath=model_dir + 'ir/')

trainer_ir = pl.Trainer(gpus=-1, precision=16, progress_bar_refresh_rate = 20, max_epochs=150, profiler = True, callbacks=[checkpoint_callback])
trainer_ir.fit(model_ir, train_dataloader_ir, val_dataloader_ir)


# Save epoch and val_loss in name: saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(
     monitor='val_loss',
     dirpath=model_dir + 'ir/',
     filename='sample-ir-{epoch:02d}-{val_loss:.2f}'
)


In [None]:
#Train RGB Model:

# Saves checkpoints to 'my/path/' at every epoch
checkpoint_callback = ModelCheckpoint(dirpath=model_dir + 'rgb/')

trainer_rgb = pl.Trainer(gpus=-1, precision=16, progress_bar_refresh_rate = 20, max_epochs=150, profiler = True, callbacks=[checkpoint_callback])
trainer_rgb.fit(model_rgb, train_dataloader_rgb, val_dataloader_rgb)

# Save epoch and val_loss in name: saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(
     monitor='val_loss',
     dirpath=model_dir + 'rgb/',
     filename='sample-rgb-{epoch:02d}-{val_loss:.2f}'
)


In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

### Load Checkpoint

In [None]:
#Load Checkpoint Model

#loading checkpoint model (from folder 'checkpoint_models'): flag_statedict = 0
#loading state-dict model (from folder 'models'): flag_statedict = 1

flag_statedict_ir = 0
flag_statedict_rgb = 0


#IR Model:
if flag_statedict_ir:
  #For manually saved state-dict models:
  model_path_ir = model_dir + 'ir/model_ir.pth'
  model_ir = IR_Model_pretrained()    #IR_Model() or IR_Model_pretrained()
  model_ir.load_state_dict(torch.load(model_path_ir))
  model_ir = model_ir.to(device)

else:
  #For pytorch lightning checkpoint models: 
  model_path_ir = model_dir + 'ir/ir_pretrained.ckpt' 
  model_ir = IR_Model_pretrained.load_from_checkpoint(checkpoint_path=model_path_ir).to(device)     #use this line if the checkpoint model is pretrained
  #model_ir = IR_Model.load_from_checkpoint(checkpoint_path=model_path_ir).to(device)               #use this line if the checkpoint model is not the pretrained one.

#RGB Model:
if flag_statedict_rgb:
  #For manually saved state-dict models:
  model_path_rgb = model_dir + 'rgb/model_rgb.pth'
  model_rgb = RGB_Model_pretrained()    #RGB_Model() or RGB_Model_pretrained()
  model_rgb.load_state_dict(torch.load(model_path_rgb))
  model_rgb = model_rgb.to(device)

else:
  #For pytorch lightning checkpoint models:
  model_path_rgb = model_dir + 'rgb/rgb_pretrained.ckpt' 
  model_rgb = RGB_Model_pretrained.load_from_checkpoint(checkpoint_path=model_path_rgb).to(device)  #use this line if the checkpoint model is pretrained
  #model_rgb = RGB_Model.load_from_checkpoint(checkpoint_path=model_path_rgb).to(device)            #use this line if the checkpoint model is not the pretrained one.



### Test Models

In [None]:
#Test function:
def test_model(model, sample, label):

    label = [int(i) for i in label.detach().numpy()]     # like this, it is easier to compare ground truth and prediction when printing both
    model = model.to(device)

    ## Start time:
    t0 = time.clock()

    pred0 = model(sample)
    pred = (pred0 > 0.5)
    pred = [int(i) for i in pred]
    acc = np.sum(np.equal(pred, label))/len(pred)

    ## End time:
    t  = time.clock() - t0

    print('Prediction  : ' + str(pred))
    print('Ground Truth: ' + str(label))
    print('Accuracy    : ' + str(acc))
    print('Time elapsed: ' + str(t))
    print(' ')

    # Calculate Confusion Matrix
    target = torch.tensor(label)
    preds = torch.tensor(pred)
    confmat = pl.metrics.ConfusionMatrix(num_classes=2, normalize='true')
    confusion_matrix = confmat(preds, target)

    # Plot Confusion Matrix
    df_cm = pd.DataFrame((np.array(confusion_matrix)), range(2), range(2))
    plt.figure(figsize=(10,7))
    sn.set(font_scale=1.4) 
    sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}) # font size
    print('Confusion Matrix:')
    plt.xlabel('Prediction')
    plt.ylabel('Ground Truth')
    plt.show()

    return pred

#Test function for combined prediction:
def test_combined(model_ir, model_rgb, sample_ir, sample_rgb, label_ir, label_rgb):
    ##  IR and RGB imagery has same order

    label_ir = [int(i) for i in label_ir.detach().numpy()]
    label_rgb = [int(i) for i in label_rgb.detach().numpy()]
    label_combined = (label_rgb and label_ir) # apple only sound when both vectors predict

    ## Start time:
    t0 = time.clock()

    #IR prediction:
    pred_ir0 = model_ir(sample_ir)
    pred_ir = (pred_ir0 > 0.5)
    pred_ir = [int(i) for i in pred_ir]

    #RGB prediction:
    pred_rgb0 = model_rgb(sample_rgb)
    pred_rgb = (pred_rgb0 > 0.5)
    pred_rgb = [int(i) for i in pred_rgb]

    ## Combined prediction:

    # Compare confidence of both predictions and label as '1' if one prediction is above the threshold
    threshold = 0.5
    pred_combined = []

    # apple is sound --> '1'
    for i in range(len(pred_ir0)):
      if pred_ir[i] == 1 and pred_rgb[i] == 1:
        pred_combined.append(1)
      elif pred_ir[i] == 0 and pred_rgb[i] == 0:
        pred_combined.append(0)
      elif pred_ir0[i] < 1 - threshold or pred_rgb0[i] < 1 - threshold:
        pred_combined.append(0)
      else:
        pred_combined.append(1)

    acc_combined = np.sum(np.equal(pred_combined, label_combined))/len(pred_combined)
    
    ## End time:
    t  = time.clock() - t0

    print('Prediction  : ' + str(pred_combined))
    print('Ground Truth: ' + str(label_combined))
    print('Accuracy    : ' + str(acc_combined))
    print('Time elapsed: ' + str(t))
    print(' ')

    # Calculate Confusion Matrix
    target = torch.tensor(label_combined)
    preds = torch.tensor(pred_combined)
    confmat = pl.metrics.ConfusionMatrix(num_classes=2, normalize='true')
    confusion_matrix = confmat(preds, target)

    # Plot Confusion Matrix
    df_cm = pd.DataFrame((np.array(confusion_matrix)), range(2), range(2))
    plt.figure(figsize=(10,7))
    sn.set(font_scale=1.4) 
    sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}) # font size
    print('Confusion Matrix:')
    plt.xlabel('Prediction')
    plt.ylabel('Ground Truth')
    plt.show()

    return pred_combined

In [None]:
## Get test samples:

#infrared sample:
sample_ir, label_ir, _ , _ = next(iter(test_dataloader_ir))

#rgb samples:
sample_rgb, label_rgb, _ , _ = next(iter(test_dataloader_rgb))

In [None]:
## Test:

#Test IR Model:
print('TEST: Model IR')
pred_ir = test_model(model_ir, sample_ir, label_ir)

print('#########################################################################')

#Test RGB Model:
print('TEST: Model RGB')
pred_rgb = test_model(model_rgb, sample_rgb, label_rgb)

print('#########################################################################')

#Test Combined:
print('TEST: Combined IR + RGB')
pred_combined = test_combined(model_ir, model_rgb, sample_ir, sample_rgb, label_ir, label_rgb)


### Evaluate Models on whole apple

In [None]:
# Evaluation model for a single apple due to combining IR & RGB metrics

pred_eval = pred_combined   #choose pred_ir, pred_rgb, or pred_combined

## Apple 2:
# Count number of predicted defects
number_defects = len(pred_eval[:28]) - np.sum(pred_eval[:28])
defect_threshold = 5

if (number_defects < defect_threshold):
  apple_class = "sound"
else:
  apple_class = "bad"

print("Apple 2 was classified as ", apple_class, " because ", number_defects, " defects were detected!")

## Apple 30:
# Count number of predicted defects
number_defects = len(pred_eval[28:]) - np.sum(pred_eval[28:])
defect_threshold = 5

if (number_defects < defect_threshold):
  apple_class = "sound"
else:
  apple_class = "bad"

print("Apple 30 was classified as ", apple_class, " because ", number_defects, " defects were detected!")

Apple 2 was classified as  bad  because  15  defects were detected!
Apple 30 was classified as  bad  because  12  defects were detected!


### Save Model (State Dict)

In [None]:
## Manually save model (state-dict):

#Save IR Model:
torch.save(model_ir.state_dict(), os.path.join(model_dir, 'ir/', 'model_ir_new.pth'))

#Save RGB Model:
torch.save(model_rgb.state_dict(), os.path.join(model_dir, 'rgb/', 'model_rgb_new.pth'))
