In [1]:
!pip install --quiet img2vec_pytorch
print('pip installed img2vec')

from warnings import filterwarnings
filterwarnings(action='ignore', category=FutureWarning) # quiet a plotly issue
filterwarnings(action='ignore', category=UserWarning) # quiet an img2vec issue

pip installed img2vec


In [2]:
from img2vec_pytorch import Img2Vec
from PIL import Image
from arrow import now
from glob import glob
import pandas as pd
from os.path import basename

SIZE = 512

# https://stackoverflow.com/a/952952
def flatten(arg):
    return [x for xs in arg for x in xs]

# we're going to just read a few pictures while we're building
def get_from_glob(arg: str, tag: str, stop: int) -> list:
    time_get = now()
    result = []
    for index, input_file in enumerate(glob(pathname=arg)):
        if index < stop:
            name = basename(input_file)
            try:
                with Image.open(fp=input_file, mode='r') as image:
                    vector = img2vec.get_vec(image, tensor=True).numpy().reshape(SIZE,)
                    result.append(pd.Series(data=[tag, name, vector], index=['tag', 'name', 'value']))
            except RuntimeError:
                # we only have a few failures so we're just going to discard them
                print('runtime failure: {}'.format(tag, name))
                pass
    print('encoded {} data {} rows in {}'.format(tag, len(result), now() - time_get))
    return result

STOP = 500 # we will load all our data with this limit on the number of instances per class

img2vec = Img2Vec(cuda=False, model='resnet-18', layer='default', layer_output_size=SIZE)

time_start = now()
train = {basename(folder) : folder + '/*.jpg' for folder in glob('/kaggle/input/fruits-dataset-for-classification/*')}
train_data = [get_from_glob(arg=value, tag=key, stop=STOP) for key, value in train.items()]
df = pd.DataFrame(data=flatten(arg=train_data))
df['condition'] = df['tag'].apply(func=lambda x: x.split('_')[0])
df['fruit'] = df['tag'].apply(func=lambda x: x.split('_')[1])
print('done in {}'.format(now() - time_start))

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 81.8MB/s]


encoded rotten_peaches_done data 343 rows in 0:00:20.535942
encoded fresh_strawberries_done data 250 rows in 0:00:15.647590
encoded fresh_peaches_done data 250 rows in 0:00:14.998522
encoded rotten_pomegranates_done data 250 rows in 0:00:15.754468
encoded fresh_pomegranates_done data 311 rows in 0:00:19.658111
encoded rotten_strawberries_done data 251 rows in 0:00:15.590116
done in 0:01:42.357010


In [3]:
from arrow import now
from umap import UMAP

time_start = now()
umap = UMAP(random_state=2024, verbose=True, n_jobs=1, low_memory=False, n_epochs=500)
plot_df = pd.concat(objs=[df, pd.DataFrame(data=umap.fit_transform(X=df['value'].apply(pd.Series)), columns=['x', 'y'])], axis=1)
print('done with UMAP in {}'.format(now() - time_start))

2024-02-23 19:04:36.188316: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-23 19:04:36.188463: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-23 19:04:36.360371: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


UMAP(low_memory=False, n_epochs=500, n_jobs=1, random_state=2024, verbose=True)
Fri Feb 23 19:04:50 2024 Construct fuzzy simplicial set
Fri Feb 23 19:04:54 2024 Finding Nearest Neighbors
Fri Feb 23 19:04:59 2024 Finished Nearest Neighbor Search
Fri Feb 23 19:05:02 2024 Construct embedding


Epochs completed:   0%|            0/500 [00:00]

	completed  0  /  500 epochs
	completed  50  /  500 epochs
	completed  100  /  500 epochs
	completed  150  /  500 epochs
	completed  200  /  500 epochs
	completed  250  /  500 epochs
	completed  300  /  500 epochs
	completed  350  /  500 epochs
	completed  400  /  500 epochs
	completed  450  /  500 epochs
Fri Feb 23 19:05:08 2024 Finished embedding
done with UMAP in 0:00:17.877170


In [4]:
from plotly.express import scatter
scatter(data_frame=plot_df, x='x', y='y', color='condition', hover_name='name', height=450).show()
scatter(data_frame=plot_df, x='x', y='y', color='fruit', hover_name='name', height=450).show()
scatter(data_frame=plot_df, x='x', y='y', color='tag', hover_name='name', height=900).show()

This is encourating; we should have no trouble identifying straweberries, at least, and we're going to have some difficult cases. Let's build a simple model.

In [5]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

X_train, X_test, y_train, y_test = train_test_split(df['value'].apply(pd.Series), df['tag'], test_size=0.25, random_state=2024)
regression = LogisticRegression(max_iter=100000)
regression.fit(X_train, y_train)

print('accuracy: {:5.2f} pct'.format(100 * accuracy_score(y_test, regression.predict(X_test))))

accuracy: 91.79 pct


In [6]:
from sklearn.metrics import classification_report
print(classification_report(y_true = y_test, y_pred = regression.predict(X_test)))

                          precision    recall  f1-score   support

      fresh_peaches_done       0.92      0.96      0.94        49
 fresh_pomegranates_done       0.97      0.98      0.97        85
 fresh_strawberries_done       0.96      0.93      0.94        72
     rotten_peaches_done       0.88      0.87      0.87        83
rotten_pomegranates_done       0.87      0.83      0.85        66
rotten_strawberries_done       0.90      0.95      0.93        59

                accuracy                           0.92       414
               macro avg       0.92      0.92      0.92       414
            weighted avg       0.92      0.92      0.92       414



Our results do not improve with KNN or random forests (not shown); for some reason our random forest classifier has a terrible time finding rotten pomegranates.