In [1]:
!pip install --quiet img2vec_pytorch
print('pip installed img2vec')

pip installed img2vec


In [2]:
from warnings import filterwarnings
filterwarnings(action='ignore', category=FutureWarning) # quiet a plotly issue
filterwarnings(action='ignore', category=UserWarning) # quiet an img2vec issue

In [3]:
from img2vec_pytorch import Img2Vec
from PIL import Image
from arrow import now
from glob import glob
import pandas as pd
from os.path import basename

img2vec = Img2Vec(cuda=False, model='resnet-18', layer='default', layer_output_size=512)

# https://stackoverflow.com/a/952952
def flatten(arg):
    return [x for xs in arg for x in xs]

# we're going to just read a few pictures while we're building
def get_from_glob(arg: str, tag: str, stop: int) -> list:
    time_get = now()
    result = []
    for index, input_file in enumerate(glob(pathname=arg)):
        if index < stop:
            name = basename(input_file)
            try:
                with Image.open(fp=input_file, mode='r') as image:
                    vector = img2vec.get_vec(image, tensor=True).numpy().reshape(512,)
                    result.append(pd.Series(data=[tag, name, vector], index=['tag', 'name', 'value']))
            except RuntimeError:
                # we only have a few failures so we're just going to discard them
                print('runtime failure: {}'.format(tag, name))
                pass
    print('encoded {} data in {}'.format(tag, now() - time_get))
    return result

STOP = 500

time_start = now()
train = {' '.join(basename(folder).split('_')[1:]) : folder + '/*.jpg' 
         for folder in glob('/kaggle/input/aruzz22-5k-an-image-dataset-of-rice-varieties/1_TRAIN/*')}
train_data = [get_from_glob(arg=value, tag=key, stop=STOP) for key, value in train.items()]
df = pd.DataFrame(data=flatten(arg=train_data))
    
print('done in {}'.format(now() - time_start))

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 101MB/s]


encoded Bashmoti data in 0:00:31.026763
encoded Lal Binni data in 0:00:31.264525
encoded BR29 data in 0:00:30.502142
encoded Katarivog data in 0:00:30.901413
encoded Jirashail data in 0:00:30.781849
encoded Lal Aush data in 0:00:30.860261
encoded Shampakatari data in 0:00:30.663795
encoded Amon data in 0:00:30.886748
encoded Shorna5 data in 0:00:30.950074
encoded Subol Lota data in 0:00:31.144161
encoded Katari Polao data in 0:00:30.859164
encoded Najirshail data in 0:00:30.919078
encoded BR28 data in 0:00:30.704100
encoded Lal Biroi data in 0:00:31.558565
encoded Red Cargo data in 0:00:31.469350
encoded Paijam data in 0:00:30.787222
encoded Chinigura Polao data in 0:00:30.876516
encoded Gutisharna data in 0:00:30.977527
encoded Bashful data in 0:00:31.149255
encoded Ganjiya data in 0:00:31.186404
done in 0:10:20.693003


In [4]:
from plotly.express import histogram
histogram(data_frame=df, x='tag')

In [5]:
from arrow import now
from umap import UMAP

time_start = now()
umap = UMAP(random_state=2024, verbose=True, n_jobs=1, low_memory=False, n_epochs=500)
plot_df = pd.concat(objs=[df, pd.DataFrame(data=umap.fit_transform(X=df['value'].apply(pd.Series)), columns=['x', 'y'])], axis=1)
print('done with UMAP in {}'.format(now() - time_start))

2024-02-22 01:10:14.920783: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-22 01:10:14.920925: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-22 01:10:15.116321: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


UMAP(low_memory=False, n_epochs=500, n_jobs=1, random_state=2024, verbose=True)
Thu Feb 22 01:10:30 2024 Construct fuzzy simplicial set
Thu Feb 22 01:10:30 2024 Finding Nearest Neighbors
Thu Feb 22 01:10:30 2024 Building RP forest with 10 trees
Thu Feb 22 01:10:37 2024 NN descent for 13 iterations
	 1  /  13
	 2  /  13
	 3  /  13
	 4  /  13
	 5  /  13
	Stopping threshold met -- exiting after 5 iterations
Thu Feb 22 01:10:57 2024 Finished Nearest Neighbor Search
Thu Feb 22 01:11:00 2024 Construct embedding


Epochs completed:   0%|            0/500 [00:00]

	completed  0  /  500 epochs
	completed  50  /  500 epochs
	completed  100  /  500 epochs
	completed  150  /  500 epochs
	completed  200  /  500 epochs
	completed  250  /  500 epochs
	completed  300  /  500 epochs
	completed  350  /  500 epochs
	completed  400  /  500 epochs
	completed  450  /  500 epochs
Thu Feb 22 01:11:19 2024 Finished embedding
done with UMAP in 0:00:50.009892


In [6]:
from plotly.express import scatter
scatter(data_frame=plot_df, x='x', y='y', color='tag', hover_name='name', height=900)

This is encouraging; we see a fair amount of clustering in this sample. Let's load the test data and build a model.

In [7]:
from arrow import now

# we don't have test data so let's use the validation set as our test data

test = {' '.join(basename(folder).split('_')[1:]) : folder + '/*.jpg' 
         for folder in glob('/kaggle/input/aruzz22-5k-an-image-dataset-of-rice-varieties/2_VALID/*')}

time_start = now()
test_data = [get_from_glob(arg=value, tag=key, stop=STOP) for key, value in test.items()]
test_df = pd.DataFrame(data=flatten(arg=test_data))
print('done in {}'.format(now() - time_start))

encoded Bashmoti data in 0:00:15.238382
encoded Lal Binni data in 0:00:15.031433
encoded BR29 data in 0:00:14.418623
encoded Katarivog data in 0:00:14.293381
encoded Jirashail data in 0:00:15.215069
encoded Lal Aush data in 0:00:14.572277
encoded Shampakatari data in 0:00:15.122455
encoded Amon data in 0:00:14.392582
encoded Shorna5 data in 0:00:15.162115
encoded Subol Lota data in 0:00:14.233722
encoded Katari Polao data in 0:00:14.963873
encoded Najirshail data in 0:00:14.554357
encoded BR28 data in 0:00:15.118734
encoded  data in 0:00:00.000588
encoded Lal Biroi data in 0:00:14.386492
encoded Red Cargo data in 0:00:14.773231
encoded Paijam data in 0:00:14.425408
encoded Chinigura Polao data in 0:00:14.309576
encoded Gutisharna data in 0:00:14.713568
encoded Bashful data in 0:00:14.352463
encoded Ganjiya data in 0:00:14.981330
done in 0:04:54.686504


In [8]:
from sklearn.metrics import f1_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
from arrow import now

best_k = 1
best = 0
# let's step through a range of cluster sizes to find the one that will give us the best accuracy
for n_neighbors in range(2, 15):
    current = KNeighborsClassifier(n_neighbors=n_neighbors)
    current.fit(X=df['value'].apply(pd.Series), y=df['tag'])
    score = f1_score(average='weighted', labels=test_df['tag'].unique().tolist(), y_true=test_df['tag'], y_pred=current.predict(X=test_df['value'].apply(pd.Series)))
    if score > best:
        best = score
        best_k = n_neighbors
    print('neighbors: {} score: {:5.4f}'.format(n_neighbors, score))
        
time_start = now()
print('building best-k model with k = {}'.format(best_k))
knn = KNeighborsClassifier(n_neighbors=best_k)
knn.fit(X=df['value'].apply(pd.Series), y=df['tag'])
print(classification_report(labels=test_df['tag'].unique().tolist(), y_true=test_df['tag'], y_pred=knn.predict(X=test_df['value'].apply(pd.Series))))
print('model time: {}'.format(now() - time_start))


neighbors: 2 score: 0.8104
neighbors: 3 score: 0.8311
neighbors: 4 score: 0.8284
neighbors: 5 score: 0.8280
neighbors: 6 score: 0.8297
neighbors: 7 score: 0.8267
neighbors: 8 score: 0.8245
neighbors: 9 score: 0.8251
neighbors: 10 score: 0.8239
neighbors: 11 score: 0.8241
neighbors: 12 score: 0.8253
neighbors: 13 score: 0.8227
neighbors: 14 score: 0.8193
building best-k model with k = 3
                 precision    recall  f1-score   support

       Bashmoti       0.96      0.86      0.90       225
      Lal Binni       1.00      0.98      0.99       225
           BR29       0.67      0.65      0.66       225
      Katarivog       0.77      0.78      0.77       225
      Jirashail       0.71      0.67      0.69       225
       Lal Aush       0.92      0.96      0.94       225
   Shampakatari       0.77      0.56      0.65       225
           Amon       0.93      0.93      0.93       225
        Shorna5       0.96      0.91      0.94       225
     Subol Lota       0.97      0.66    