In [None]:
import decimal
from time import perf_counter
import torch
import numpy as np
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F
from torchvision.models import alexnet, AlexNet_Weights, resnet101, ResNet101_Weights, densenet121, DenseNet121_Weights

In [None]:
!rm -rf OID/Dataset/train/
!git clone https://github.com/EscVM/OIDv4_ToolKit.git
!pip install -r ./OIDv4_ToolKit/requirements.txt
!python ./OIDv4_ToolKit/main.py downloader --classes Fruit --type_csv train --limit 50 --multiclasses 1 -y

fatal: destination path 'OIDv4_ToolKit' already exists and is not an empty directory.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
[92m
		   ___   _____  ______            _    _    
		 .'   `.|_   _||_   _ `.         | |  | |   
		/  .-.  \ | |    | | `. \ _   __ | |__| |_  
		| |   | | | |    | |  | |[ \ [  ]|____   _| 
		\  `-'  /_| |_  _| |_.' / \ \/ /     _| |_  
		 `.___.'|_____||______.'   \__/     |_____|
	[0m
[92m
             _____                    _                 _             
            (____ \                  | |               | |            
             _   \ \ ___  _ _ _ ____ | | ___   ____  _ | | ____  ____ 
            | |   | / _ \| | | |  _ \| |/ _ \ / _  |/ || |/ _  )/ ___)
            | |__/ / |_| | | | | | | | | |_| ( ( | ( (_| ( (/ /| |    
            |_____/ \___/ \____|_| |_|_|\___/ \_||_|\____|\____)_|    
                                                          
        [0m
    [INFO] | Downl

In [None]:
alexnet = alexnet(AlexNet_Weights.DEFAULT).eval()
resnet101 = resnet101(ResNet101_Weights.DEFAULT).eval()
densenet121 = densenet121(DenseNet121_Weights.DEFAULT).eval()

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


  0%|          | 0.00/30.8M [00:00<?, ?B/s]

In [None]:
from os import listdir
from os.path import isfile, join

MYPATH = 'OID/Dataset/train/'
lables = [f for f in listdir(MYPATH) if not isfile(join(MYPATH, f))]

data_images = []
for l in lables:
  for d in listdir(join(MYPATH, l)):
    if join(MYPATH, l, d).endswith('.jpg'):
      data_images.append(join(MYPATH, l, d))

print(data_images)

['OID/Dataset/train/Fruit/659cd10bca70b317.jpg', 'OID/Dataset/train/Fruit/9edeea515cf89b2a.jpg', 'OID/Dataset/train/Fruit/31fffb3611ffa66b.jpg', 'OID/Dataset/train/Fruit/cff2073b6a9dfd40.jpg', 'OID/Dataset/train/Fruit/3d81c7d8acab67e2.jpg', 'OID/Dataset/train/Fruit/a5b53097a76d4fcf.jpg', 'OID/Dataset/train/Fruit/a7f4b96f07487cb3.jpg', 'OID/Dataset/train/Fruit/baceff9660f9973d.jpg', 'OID/Dataset/train/Fruit/026cbd2b207902fa.jpg', 'OID/Dataset/train/Fruit/eb25a2dd759fca6b.jpg', 'OID/Dataset/train/Fruit/2ca536d7496edf51.jpg', 'OID/Dataset/train/Fruit/bc293aa5a210e5a9.jpg', 'OID/Dataset/train/Fruit/fe03f065bc1f0719.jpg', 'OID/Dataset/train/Fruit/e43d4689076623ec.jpg', 'OID/Dataset/train/Fruit/8c2c839ca590477b.jpg', 'OID/Dataset/train/Fruit/ad9e88ea8944c05b.jpg', 'OID/Dataset/train/Fruit/1fffb1e23bb0f2fc.jpg', 'OID/Dataset/train/Fruit/a1e258cc4cf0016d.jpg', 'OID/Dataset/train/Fruit/3c19f3b03239dd45.jpg', 'OID/Dataset/train/Fruit/dedef72f7c731f30.jpg', 'OID/Dataset/train/Fruit/e27c0930a26f4a

In [None]:
def predict(weights, model, img):
    preprocess = weights.DEFAULT.transforms()
    batch = preprocess(img).unsqueeze(0)
    pred = model(batch).squeeze(0).softmax(0)
    val, index = torch.topk(pred, 5)
    for i in range(5):
        category_name = weights.DEFAULT.meta["categories"][index[i]]
        print(f"{category_name}: {100 * val[i]:.1f}%")

In [None]:
from torchvision.io import read_image

for data in data_images[:10]:
        print()
        img = read_image(data)
        print(data)
        print("AlexNet prediction")
        start = perf_counter()
        predict(AlexNet_Weights, alexnet, img)
        end = perf_counter()
        print("AlexNet frame time: " + str(end - start))
        print()
        print("ResNet101 prediction")
        start = perf_counter()
        predict(ResNet101_Weights, resnet101, img)
        end = perf_counter()
        print("ResNet101 frame time: " + str(end - start))
        print()
        print("DenseNet121 prediction")
        start = perf_counter()
        predict(DenseNet121_Weights, densenet121, img)
        end = perf_counter()
        print("DenseNet121 frame time: " + str(end - start))
        print()      


OID/Dataset/train/Fruit/659cd10bca70b317.jpg
AlexNet prediction
banjo: 16.3%
cello: 15.8%
acoustic guitar: 10.6%
wooden spoon: 9.5%
ladle: 8.6%
AlexNet frame time: 0.08038382699942304

ResNet101 prediction
jack-o'-lantern: 82.6%
power drill: 0.2%
swab: 0.2%
vacuum: 0.1%
EntleBucher: 0.1%
ResNet101 frame time: 0.4043337719995179

DenseNet121 prediction
jack-o'-lantern: 93.2%
soccer ball: 1.8%
ocarina: 0.8%
croquet ball: 0.8%
pomegranate: 0.5%
DenseNet121 frame time: 0.2023897830003989


OID/Dataset/train/Fruit/9edeea515cf89b2a.jpg
AlexNet prediction
strawberry: 31.7%
hip: 14.2%
pomegranate: 13.6%
thimble: 11.2%
agaric: 6.9%
AlexNet frame time: 0.06484755000019504

ResNet101 prediction
strawberry: 21.0%
pomegranate: 13.8%
tray: 4.9%
ice cream: 3.9%
saltshaker: 2.2%
ResNet101 frame time: 0.44994420500006527

DenseNet121 prediction
strawberry: 84.1%
hip: 2.7%
trifle: 1.7%
matchstick: 1.6%
fig: 1.5%
DenseNet121 frame time: 0.2766754680005761


OID/Dataset/train/Fruit/31fffb3611ffa66b.jpg
A