In [None]:
!pip install oidv6

In [None]:
!pip install --upgrade oidv6

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!oidv6 downloader en --dataset drive/MyDrive/pics --type_data train --classes Banana Pizza Strawberry Pomegranate --limit 344

In [None]:
#Downloading file with class names list
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

In [None]:
image_dir = "/content/drive/MyDrive/pics/train"
#Defining classes and their thresholds, which will be used for model evaluation
number_of_classes = 3
classes = ["banana", "pizza", "strawberry"]
#other_classes value will be later set on images which do not belong to 3 positive classes
other_classes = "other"
thresholds = [0.6, 0.6, 0.6]

#A dictionary is made by finding a corresponding class index by providing the predefined 3 class names
#and later adding it with a predefined threshold
#In case an image later provided to ImageDataset class doesn't correspond to any of predefined classes,
#it gets assigned with 'other'
with open('imagenet_classes.txt', 'r') as f:
    class_names = [ln.strip() for ln in f]

class_mapping = {}
for i, class_name in enumerate(class_names):
    if class_name.lower() in [c.lower() for c in classes]:
        index = [c.lower() for c in classes].index(class_name.lower())
        class_mapping[class_name.lower()] = {'index': i, 'threshold': thresholds[index]}
class_mapping[other_classes] = {'index': -1, 'threshold': 0.5}

print(class_mapping)


{'strawberry': {'index': 949, 'threshold': 0.6}, 'banana': {'index': 954, 'threshold': 0.6}, 'pizza': {'index': 963, 'threshold': 0.6}, 'other': {'index': -1, 'threshold': 0.5}}


In [None]:
#@title Dataset
import numpy as np
import re
import glob
import torch
import torchvision
from PIL import Image
import PIL
from torchvision import transforms
import os

#Defining transformations which are later going to be applied on images
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            #torchvision.transforms.ColorJitter(hue=.05, saturation=.05),
            torchvision.transforms.RandomHorizontalFlip(),
            #torchvision.transforms.RandomRotation(20, resample=PIL.Image.BILINEAR),
            transforms.ToTensor(),

            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, transforms=transforms):
        self.image_dir = image_dir
        self.labels = []
        self.class_names = []
        self.transforms = transforms
        #other_pictures contains name of folder with images which do not belond to positive 3 classes
        self.other_pictures = ["pomegranate"]

        #Retrieving image files and their label paths, folder doesn't necessarily have to contain images of its name
        self.class1_files = glob.glob(self.image_dir + "/{}/*.jpg".format(classes[0].lower()))
        self.class1_label_paths = glob.glob(self.image_dir + "/{}/labels/*.txt".format(classes[0].lower()))
        self.class2_files = glob.glob(self.image_dir + "/{}/*.jpg".format(classes[1].lower()))
        self.class2_label_paths = glob.glob(self.image_dir + "/{}/labels/*.txt".format(classes[1].lower()))
        self.class3_files = glob.glob(self.image_dir + "/{}/*.jpg".format(classes[2].lower()))
        self.class3_label_paths = glob.glob(self.image_dir + "/{}/labels/*.txt".format(classes[2].lower()))
        self.other_files = glob.glob(self.image_dir + "/{}/*.jpg".format(self.other_pictures[0].lower()))
        self.other_label_paths = glob.glob(self.image_dir + "/{}/labels/*.txt".format(self.other_pictures[0].lower()))

        self.class1 = len(self.class1_files)
        #print(self.class1)
        self.class2 = len(self.class2_files)
        #print(self.class2)
        self.class3 = len(self.class3_files)
        #print(self.class3)
        self.other = len(self.other_files)
        #print(self.class3)
    
        self.files = self.class1_files + self.class2_files + self.class3_files + self.other_files
       # print(self.files)
        self.label_paths = self.class1_label_paths + self.class2_label_paths + self.class3_label_paths + self.other_label_paths
        #print(self.label_paths)

        
       #Each image has a corresponding label text file, therefore class names are extracted into a list 
        for label_file in self.label_paths:
            #img_name = os.path.splitext(os.path.basename(label_file))[0]
            with open(label_file, 'r') as f:
              label_parts = f.readline().strip().split(' ')
              label_name = label_parts[0]
              self.class_names.append(label_name)
        #print(self.class_names)

        #In order to keep labels as integers(indexes, corresponding in class file), 
        #class names are mapped to their index values in the dictionary
        for label in self.class_names:
            label_data = class_mapping.get(label.lower(), class_mapping[other_classes])
            self.labels.append(label_data['index'])
        #print(self.labels)

        #Image and label order is mixed
        self.order =  [x for x in np.random.permutation(len(self.labels))]
        self.files = [self.files[x] for x in self.order]
        self.labels = [self.labels[x] for x in self.order]
       #print(self.labels)

    def __len__(self):
         return (len(self.labels))

    def __getitem__(self, index):
        files = self.files[index]

        img = Image.open(files).convert("RGB")
        img = self.transforms(img)

        label = self.labels[index]

        return img, label
      
dataset = ImageDataset(image_dir=image_dir, transforms=transform)


In [None]:
len(dataset)

1227

In [None]:
from torchvision import models, transforms

In [None]:
torch.cuda.is_available()

True

In [None]:
#Loading an image classification model and setting it to evaluation mode
model = models.efficientnet_b7(pretrained=True)
model.eval()

In [None]:
#@title Dataloader
batch_size=100
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

total_images = 0
#Keeping track of predicted and ground truth indexes to later calculate statistics
predicted_list = []
ground_truth_list = []

with torch.no_grad():
    for inputs, labels in dataloader:
        outputs = model(inputs)
        #Extracting class index which the model has the highest confidence on
        _, predicted = torch.max(outputs, 1)
        #Extracting probabilities to later determine whether they meet the threshold
        probabilities = torch.softmax(outputs, dim=1)
        #print('Probs:', probabilities)

        for i in range(len(predicted)):
            #Ground truth is extracted from a label and added to the list
            ground_truth_index = labels[i].item()
            ground_truth_list.append(ground_truth_index)

            ground_truth_label = [k for k, v in class_mapping.items() if v['index'] == ground_truth_index][0]
            #print('Ground truth:', ground_truth_label)

            #Getting the prediction tensor as a number and looking if such class exists in class dictionary
            predicted_index = predicted[i].item()
            predicted_class_name = [k for k, v in class_mapping.items() if v['index'] == predicted_index]
            if len(predicted_class_name) > 0:
                predicted_class_name = predicted_class_name[0]
                #print(' prediction: ', predicted_class_name)
            else: #If class doesn't belong to dictionary, it's classified as 'other'
                predicted_class_name = other_classes
                #print('prediction: ', predicted_class_name)

            #Getting probability of the predicted class
            probability = probabilities[i][predicted_index].item()
            #print('Prob:', probability)

            #Getting a mapping of the predicted class' index+threshold 
            predicted_class_mapping = class_mapping.get(predicted_class_name.lower(), class_mapping.get('other'))
            #print('Predicted class map:', predicted_class_mapping)
       
            #If prediction was correct, result depends on threshold, otherwise - image is classified as 'other'
            if predicted_class_mapping['index'] == ground_truth_index:
                if probability >= predicted_class_mapping['threshold']:
                    final_class = predicted_class_mapping['index']
                else: #Probability isn't high enough - final is 'other'
                    final_class = class_mapping[other_classes]['index']
            else:
                final_class = class_mapping[other_classes]['index']
            predicted_list.append(final_class)
            final_class_name = [k for k, v in class_mapping.items() if v['index'] == final_class]
            #print('Final Predicted class:', final_class_name[0])

        total_images += labels.size(0)

#print(predicted_list)
#print(ground_truth_list)
print(total_images)


1227


In [None]:
print(len(predicted_list))
print(len(ground_truth_list))

1227
1227


In [None]:
#@title Statistics of each class
#TP TN FP FN statistics are saved separately for each class
tp = {classes[0].lower(): 0, classes[1].lower(): 0, classes[2].lower(): 0, other_classes.lower(): 0}
tn = {classes[0].lower(): 0, classes[1].lower(): 0, classes[2].lower(): 0, other_classes.lower(): 0}
fp = {classes[0].lower(): 0, classes[1].lower(): 0, classes[2].lower(): 0, other_classes.lower(): 0}
fn = {classes[0].lower(): 0, classes[1].lower(): 0, classes[2].lower(): 0, other_classes.lower(): 0}

#print(class_mapping.values())
index_list = [v['index'] for v in class_mapping.values() if v['index'] != -1]
print(index_list)

#Going through the predicted and grouth truth list to obtain TP TN FP FN for each class and later calculate metrics
for i in range(len(ground_truth_list)):
    current_ground_truth = ground_truth_list[i]
    current_predicted = predicted_list[i]
  
    gt_class_name = [k for k, v in class_mapping.items() if v['index'] == current_ground_truth]
    gt_class_name = gt_class_name[0]
    pr_class_name = [k for k, v in class_mapping.items() if v['index'] == current_predicted]
    pr_class_name = pr_class_name[0]

   # print(current_ground_truth)
   # print(gt_class_name)
   # print(current_predicted)
   # print(pr_class_name)
    
    if current_predicted == current_ground_truth and current_predicted != class_mapping[other_classes]['index']:
        tp[gt_class_name] += 1
        #print('tp')
    elif current_predicted == class_mapping[other_classes]['index']:
        if current_ground_truth not in index_list:
            tn[gt_class_name] += 1
            #print('tn')
        else:
            fn[gt_class_name] += 1
            #print('fn') 
    else:
        if current_ground_truth in index_list:
            fp[pr_class_name] += 1
            #print('fp')
    #print('--')
print(tp, tn, fp, fn)


[949, 954, 963]
{'banana': 197, 'pizza': 216, 'strawberry': 179, 'other': 0} {'banana': 0, 'pizza': 0, 'strawberry': 0, 'other': 225} {'banana': 0, 'pizza': 0, 'strawberry': 0, 'other': 0} {'banana': 127, 'pizza': 118, 'strawberry': 165, 'other': 0}


In [None]:
#@title Metrics of each class and overall metrics
#Calculating metrics for each class separately, excluding 'other' because it can have true negatives only

def calculate_metrics(tp, tn, fp, fn):

  metrics = {}
  metrics['accuracy'] = (tp + tn) / (tp + fp + tn + fn)
  metrics['recall'] = tp / (tp + fn)
  metrics['precision'] = tp / (tp + fp)
  metrics['f1'] = 2 * (metrics['precision'] * metrics['recall']) / (metrics['precision'] + metrics['recall'])

  return metrics

for i in range(0, number_of_classes):
  metrics = calculate_metrics(tp[classes[i].lower()], tn[classes[i].lower()], fp[classes[i].lower()], fn[classes[i].lower()])
  print(classes[i].lower(), ': ', metrics)

metrics_len = len(tp)

#Adding up separate class statistics
def add_up_statistics(t_p, t_n, f_p, f_n):

  statistics = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
  
  #add up tp
  for i in range(0, number_of_classes):
    statistics['tp'] = statistics['tp'] + t_p[classes[i].lower()]

  #tn only in 'other'
  statistics['tn'] = t_n[other_classes.lower()]

   #add up fp
  for i in range(0, number_of_classes):
    statistics['fp'] = statistics['fp'] + f_p[classes[i].lower()]

   #add up fn
  for i in range(0, number_of_classes):
    statistics['fn'] = statistics['fn'] + f_n[classes[i].lower()]

  return statistics

conjoined_statistics = add_up_statistics(tp, tn, fp, fn)
print(conjoined_statistics)

#Calculating metrics for all classes, including 'other'
def calculate_overall_metrics(conjoined_statistics):

  tp = conjoined_statistics['tp']
  tn = conjoined_statistics['tn']
  fp = conjoined_statistics['fp']
  fn = conjoined_statistics['fn']

  metrics = {}
  metrics['accuracy'] = (tp + tn) / (tp + fp + tn + fn)
  metrics['recall'] = tp / (tp + fn)
  metrics['precision'] = tp / (tp + fp)
  metrics['f1'] = 2 * (metrics['precision'] * metrics['recall']) / (metrics['precision'] + metrics['recall'])

  return metrics

overall_metrics = calculate_overall_metrics(conjoined_statistics)
print(overall_metrics)

banana :  {'accuracy': 0.6080246913580247, 'recall': 0.6080246913580247, 'precision': 1.0, 'f1': 0.7562380038387716}
pizza :  {'accuracy': 0.6467065868263473, 'recall': 0.6467065868263473, 'precision': 1.0, 'f1': 0.7854545454545454}
strawberry :  {'accuracy': 0.5203488372093024, 'recall': 0.5203488372093024, 'precision': 1.0, 'f1': 0.6845124282982792}
{'tp': 592, 'tn': 225, 'fp': 0, 'fn': 410}
{'accuracy': 0.6658516707416463, 'recall': 0.590818363273453, 'precision': 1.0, 'f1': 0.7427854454203261}
