In [1]:
!pip install --quiet img2vec_pytorch
print('pip installed img2vec')

pip installed img2vec


In [2]:
from warnings import filterwarnings
filterwarnings(action='ignore', category=FutureWarning) # quiet a plotly issue
filterwarnings(action='ignore', category=UserWarning) # quiet an img2vec issue

Let's load our image data and encode it into vectors using ResNet-18. 

In [3]:
from img2vec_pytorch import Img2Vec
from PIL import Image
from arrow import now
from glob import glob
import pandas as pd
from os.path import basename


# https://stackoverflow.com/a/952952
def flatten(arg):
    return [x for xs in arg for x in xs]

# we're going to just read a few pictures while we're building
def get_from_glob(arg: str, tag: str, stop: int) -> list:
    time_get = now()
    result = []
    for index, input_file in enumerate(glob(pathname=arg)):
        if index < stop:
            name = basename(input_file)
            try:
                with Image.open(fp=input_file, mode='r') as image:
                    vector = img2vec.get_vec(image, tensor=True).numpy().reshape(512,)
                    result.append(pd.Series(data=[tag, name, vector], index=['tag', 'name', 'value']))
            except RuntimeError:
                # we only have a few failures so we're just going to discard them
                print('runtime failure: {}'.format(tag, name))
                pass
    print('encoded {} data {} rows in {}'.format(tag, len(result), now() - time_get))
    return result

STOP = 2500 # this will get us the full results but it may make the big scatter plot below sluggish

img2vec = Img2Vec(cuda=False, model='resnet-18', layer='default', layer_output_size=512)

time_start = now()
train = {' '.join(basename(folder).split('_')[1:]) : folder + '/*.jpg' 
         for folder in glob('/kaggle/input/aruzz22-5k-an-image-dataset-of-rice-varieties/1_TRAIN/*')}
train_data = [get_from_glob(arg=value, tag=key, stop=STOP) for key, value in train.items()]
df = pd.DataFrame(data=flatten(arg=train_data))
    
print('done in {}'.format(now() - time_start))

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 121MB/s]


encoded Bashmoti data 1125 rows in 0:01:12.745137
encoded Lal Binni data 1125 rows in 0:01:15.953323
encoded BR29 data 1125 rows in 0:01:10.257083
encoded Katarivog data 1125 rows in 0:01:09.896895
encoded Jirashail data 1125 rows in 0:01:12.810615
encoded Lal Aush data 1125 rows in 0:01:10.125005
encoded Shampakatari data 1125 rows in 0:01:09.915169
encoded Amon data 1125 rows in 0:01:09.659798
encoded Shorna5 data 1125 rows in 0:01:10.235863
encoded Subol Lota data 1125 rows in 0:01:10.478299
encoded Katari Polao data 1125 rows in 0:01:10.976463
encoded Najirshail data 1125 rows in 0:01:10.578382
encoded BR28 data 1125 rows in 0:01:10.145248
encoded Lal Biroi data 1125 rows in 0:01:12.244060
encoded Red Cargo data 1125 rows in 0:01:11.492684
encoded Paijam data 1125 rows in 0:01:10.540801
encoded Chinigura Polao data 1125 rows in 0:01:09.938640
encoded Gutisharna data 1125 rows in 0:01:10.329900
encoded Bashful data 1125 rows in 0:01:11.348497
encoded Ganjiya data 1125 rows in 0:01:1

In [4]:
from plotly.express import histogram
histogram(data_frame=df, x='tag')

Our classes are balanced by design, so although we have a lot of classes we don't have to worry about any of them being especially poorly represented in the training data.

Now let's visualize our image vectors with UMAP; this should tell us if our image vectors capture features that will let us train a model effectively. The more clustered our UMAP scatter plot looks the more optimistic we should be; the more random the more likely we will need to find another way to encode our image data.

In [5]:
from arrow import now
from plotly.express import scatter
from umap import UMAP

time_start = now()
umap = UMAP(random_state=2024, verbose=True, n_jobs=1, low_memory=False, n_epochs=500)
plot_df = pd.concat(objs=[df, pd.DataFrame(data=umap.fit_transform(X=df['value'].apply(pd.Series)), columns=['x', 'y'])], axis=1)
print('done with UMAP in {}'.format(now() - time_start))
scatter(data_frame=plot_df, x='x', y='y', color='tag', hover_name='name', height=900).show()

2024-02-22 01:59:08.733347: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-22 01:59:08.733494: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-22 01:59:08.900701: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


UMAP(low_memory=False, n_epochs=500, n_jobs=1, random_state=2024, verbose=True)
Thu Feb 22 01:59:25 2024 Construct fuzzy simplicial set
Thu Feb 22 01:59:25 2024 Finding Nearest Neighbors
Thu Feb 22 01:59:25 2024 Building RP forest with 13 trees
Thu Feb 22 01:59:31 2024 NN descent for 14 iterations
	 1  /  14
	 2  /  14
	 3  /  14
	 4  /  14
	 5  /  14
	Stopping threshold met -- exiting after 5 iterations
Thu Feb 22 01:59:51 2024 Finished Nearest Neighbor Search
Thu Feb 22 01:59:55 2024 Construct embedding


Epochs completed:   0%|            0/500 [00:00]

	completed  0  /  500 epochs
	completed  50  /  500 epochs
	completed  100  /  500 epochs
	completed  150  /  500 epochs
	completed  200  /  500 epochs
	completed  250  /  500 epochs
	completed  300  /  500 epochs
	completed  350  /  500 epochs
	completed  400  /  500 epochs
	completed  450  /  500 epochs
Thu Feb 22 02:00:38 2024 Finished embedding
done with UMAP in 0:01:16.063101


This is encouraging; we see a fair amount of clustering in this sample. Let's load the test data and build a model.

In [6]:
from arrow import now

# we don't have test data so let's use the validation set as our test data
test = {' '.join(basename(folder).split('_')[1:]) : folder + '/*.jpg' 
         for folder in glob('/kaggle/input/aruzz22-5k-an-image-dataset-of-rice-varieties/2_VALID/*')}

time_start = now()
test_data = [get_from_glob(arg=value, tag=key, stop=STOP) for key, value in test.items()]
test_df = pd.DataFrame(data=flatten(arg=test_data))
print('done in {}'.format(now() - time_start))

encoded Bashmoti data 225 rows in 0:00:13.904339
encoded Lal Binni data 225 rows in 0:00:15.116335
encoded BR29 data 225 rows in 0:00:14.264077
encoded Katarivog data 225 rows in 0:00:15.015021
encoded Jirashail data 225 rows in 0:00:14.305633
encoded Lal Aush data 225 rows in 0:00:15.167268
encoded Shampakatari data 225 rows in 0:00:14.100340
encoded Amon data 225 rows in 0:00:15.250243
encoded Shorna5 data 225 rows in 0:00:14.997546
encoded Subol Lota data 225 rows in 0:00:14.371533
encoded Katari Polao data 225 rows in 0:00:15.035759
encoded Najirshail data 225 rows in 0:00:14.501098
encoded BR28 data 225 rows in 0:00:15.521916
encoded  data 0 rows in 0:00:00.000717
encoded Lal Biroi data 225 rows in 0:00:14.615792
encoded Red Cargo data 225 rows in 0:00:15.310368
encoded Paijam data 225 rows in 0:00:14.208501
encoded Chinigura Polao data 225 rows in 0:00:14.555599
encoded Gutisharna data 225 rows in 0:00:13.774482
encoded Bashful data 225 rows in 0:00:14.700023
encoded Ganjiya data

In [7]:
from sklearn.metrics import f1_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
from arrow import now

best_k = 1
best = 0
# let's step through a range of cluster sizes to find the one that will give us the best accuracy
for n_neighbors in range(2, 15):
    current = KNeighborsClassifier(n_neighbors=n_neighbors)
    current.fit(X=df['value'].apply(pd.Series), y=df['tag'])
    score = f1_score(average='weighted', labels=test_df['tag'].unique().tolist(), y_true=test_df['tag'], y_pred=current.predict(X=test_df['value'].apply(pd.Series)))
    if score > best:
        best = score
        best_k = n_neighbors
    print('neighbors: {} score: {:5.4f}'.format(n_neighbors, score))
        
time_start = now()
print('building best-k model with k = {}'.format(best_k))
knn = KNeighborsClassifier(n_neighbors=best_k)
knn.fit(X=df['value'].apply(pd.Series), y=df['tag'])
print(classification_report(labels=test_df['tag'].unique().tolist(), y_true=test_df['tag'], y_pred=knn.predict(X=test_df['value'].apply(pd.Series))))
print('model time: {}'.format(now() - time_start))


neighbors: 2 score: 0.8987
neighbors: 3 score: 0.9120
neighbors: 4 score: 0.8954
neighbors: 5 score: 0.8927
neighbors: 6 score: 0.8872
neighbors: 7 score: 0.8872
neighbors: 8 score: 0.8853
neighbors: 9 score: 0.8800
neighbors: 10 score: 0.8763
neighbors: 11 score: 0.8719
neighbors: 12 score: 0.8737
neighbors: 13 score: 0.8697
neighbors: 14 score: 0.8699
building best-k model with k = 3
                 precision    recall  f1-score   support

       Bashmoti       0.99      0.94      0.96       225
      Lal Binni       1.00      0.99      0.99       225
           BR29       0.86      0.80      0.83       225
      Katarivog       0.91      0.88      0.89       225
      Jirashail       0.87      0.76      0.81       225
       Lal Aush       0.97      0.98      0.97       225
   Shampakatari       0.90      0.77      0.83       225
           Amon       0.96      0.96      0.96       225
        Shorna5       1.00      0.98      0.99       225
     Subol Lota       1.00      0.80    