In [1]:
!ls '/kaggle/input/flower-classification-dataset/Flower Classification Dataset/train/'

1    12  18  23  29  34  4   45  50  56  61  67  72  78  83  89  94
10   13  19  24  3   35  40  46  51  57  62  68  73  79  84  9	 95
100  14  2   25  30  36  41  47  52  58  63  69  74  8	 85  90  96
101  15  20  26  31  37  42  48  53  59  64  7	 75  80  86  91  97
102  16  21  27  32  38  43  49  54  6	 65  70  76  81  87  92  98
11   17  22  28  33  39  44  5	 55  60  66  71  77  82  88  93  99


In [2]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np


DEVICE = torch.device('cpu')
OUTPUT_SIZE = 2048

model = models.resnext50_32x4d(weights=models.ResNeXt50_32X4D_Weights.IMAGENET1K_V2)

extraction_layer = model._modules.get('avgpool')
model.to(DEVICE)
model.eval()

scaler = transforms.Resize((224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

def get_vec(arg, model, extraction_layer):
    image = normalize(to_tensor(scaler(arg))).unsqueeze(0).to(DEVICE)
    result = torch.zeros(1, OUTPUT_SIZE, 1, 1)
    def copy_data(m, i, o):
        result.copy_(o.data)
    hooked = extraction_layer.register_forward_hook(copy_data)
    with torch.no_grad():
        model(image)
    hooked.remove()
    return result

Downloading: "https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth" to /root/.cache/torch/hub/checkpoints/resnext50_32x4d-1a0047aa.pth
100%|██████████| 95.8M/95.8M [00:01<00:00, 65.1MB/s]


In [3]:
import arrow
import base64
import pandas as pd
from glob import iglob
from io import BytesIO
from os.path import basename
from PIL import Image

THUMBNAIL_SIZE = (64, 64)
TRAIN = '/kaggle/input/flower-classification-dataset/Flower Classification Dataset/train/'


def embed(model, filename: str):
    with Image.open(fp=filename, mode='r') as image:
        return get_vec(arg=image.convert('RGB'), model=model, extraction_layer=extraction_layer).numpy().reshape(OUTPUT_SIZE,)


# https://stackoverflow.com/a/952952
def flatten(arg):
    return [x for xs in arg for x in xs]

def png(filename: str) -> str:
    with Image.open(fp=filename, mode='r') as image:
        buffer = BytesIO()
        # our images are pretty big; let's shrink the hover images to thumbnail size
        image.resize(size=THUMBNAIL_SIZE).convert('RGB').save(buffer, format='png')
        return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()

def label(filename: str) -> str:
    pieces = filename.split('.')
    filename = '.'.join(pieces[:-1]) + '.txt'
    pieces = filename.split('/')
    pieces = [piece if piece != 'images' else 'labels' for piece in pieces]
    filename = '/'.join(pieces)
    with open(file=filename, mode='r') as fp:
        return fp.readlines()[0].split(' ')[0]
    

def get_picture_from_glob(arg: str, tag: str,) -> list:
    time_get = arrow.now()
    result = [pd.Series(data=[tag, basename(input_file), embed(model=model, filename=input_file), png(filename=input_file)],
                        index=['tag', 'name', 'value', 'png'])
        for index, input_file in enumerate(list(iglob(pathname=arg))) if (input_file.endswith('.jpg')) and index < 500]
    print('encoded {} rows of {}  in {}'.format(len(result), tag, arrow.now() - time_get))
    return result

time_start = arrow.now()
train_dict = {basename(folder) : folder + '/*.*' for folder in iglob(TRAIN + '/*')}
train_df = pd.DataFrame(data=flatten(arg=[get_picture_from_glob(arg=value, tag=key) for key, value in train_dict.items()]))
print('done in {}'.format(arrow.now() - time_start))

encoded 22 rows of 7  in 0:00:02.618534
encoded 32 rows of 47  in 0:00:03.513001
encoded 41 rows of 17  in 0:00:04.514302
encoded 81 rows of 81  in 0:00:09.448197
encoded 24 rows of 19  in 0:00:02.663498
encoded 29 rows of 22  in 0:00:03.143883
encoded 30 rows of 2  in 0:00:03.310213
encoded 17 rows of 35  in 0:00:01.937620
encoded 28 rows of 92  in 0:00:03.159455
encoded 53 rows of 50  in 0:00:05.921131
encoded 49 rows of 23  in 0:00:05.748875
encoded 25 rows of 87  in 0:00:02.946764
encoded 26 rows of 10  in 0:00:02.937495
encoded 27 rows of 5  in 0:00:02.935617
encoded 29 rows of 61  in 0:00:03.148363
encoded 39 rows of 36  in 0:00:04.258922
encoded 27 rows of 20  in 0:00:02.924278
encoded 17 rows of 45  in 0:00:01.845456
encoded 49 rows of 60  in 0:00:05.379466
encoded 25 rows of 27  in 0:00:02.748596
encoded 29 rows of 64  in 0:00:03.767951
encoded 66 rows of 41  in 0:00:07.237292
encoded 96 rows of 89  in 0:00:10.449535
encoded 18 rows of 39  in 0:00:01.951224
encoded 22 rows of 