In [1]:
!pip install --quiet img2vec_pytorch
print('pip install img2vec complete')

pip install img2vec complete


In [2]:
import base64
import pandas as pd
import os
import csv

from arrow import now
from glob import glob
from img2vec_pytorch import Img2Vec
from io import BytesIO
from PIL import Image

# we're going to use the updated dataset
GLOB = '/kaggle/input/tomato-leaf-diseases-detection-computer-vision/'
SIZE = 512
STOP = 100000

def embed(model, filename: str):
    with Image.open(fp=filename, mode='r') as image:
        return model.get_vec(image, tensor=True).numpy().reshape(SIZE,)


# https://stackoverflow.com/a/952952
def flatten(arg):
    return [x for xs in arg for x in xs]

def png(filename: str) -> str:
    with Image.open(fp=filename, mode='r') as image:
        buffer = BytesIO()
        # we need to scale the images down to fit them all into the scatter plot with hover images
        # we have 100Mb of space and we can either shrink the images or sample the points
        size = (128, 128)
        image.resize(size=size).save(buffer, format='png')
        return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()

def get_picture_from_glob(arg: str, stop: int) -> list:
    time_get = now()
    result = [pd.Series(data=[os.path.basename(input_file), embed(model=model, filename=input_file), png(filename=input_file), ],
                        index=['name', 'value', 'image'])
        for index, input_file in enumerate(glob(pathname=arg)) if index < stop]
    print('encoded {} rows in {}'.format(len(result), now() - time_get))
    return result

def get_labels(arg: str) -> pd.DataFrame:
    labels = []
    for index, input_file in enumerate(glob(arg)):
        with open(file=input_file, mode='r') as input_fp:
            reader = csv.reader(input_fp, delimiter=' ')
            for row in reader:
                label = row[0]
        labels.append(pd.Series(data=[os.path.basename(input_file).replace('.txt', '.jpg'), label], index=['name', 'label']))
    return pd.DataFrame(data=labels)

    
time_start = now()
model = Img2Vec(cuda=False, model='resnet-18')
train_df = pd.DataFrame(data=get_picture_from_glob(arg=GLOB + 'train/images/*.jpg', stop=STOP)).merge(
    right=get_labels(arg=GLOB + 'train/labels/*.txt'), on='name', how='inner')

train_df.head()
print('done in {}'.format(now() - time_start))

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 93.7MB/s]


encoded 645 rows in 0:00:58.245227
done in 0:01:02.079397


In [3]:
# now get the validation data and the test data
valid_df = pd.DataFrame(data=get_picture_from_glob(arg=GLOB + 'valid/images/*.jpg', stop=STOP)).merge(right=get_labels(arg=GLOB + 'valid/labels/*.txt'), on='name', how='inner')
test_df = pd.DataFrame(data=get_picture_from_glob(arg=GLOB + 'test/images/*.jpg', stop=STOP)).merge(right=get_labels(arg=GLOB + 'test/labels/*.txt'), on='name', how='inner')


encoded 61 rows in 0:00:05.339045
encoded 31 rows in 0:00:02.708057
