In [13]:
DATASET_URL      = "https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz"
DATASET_PATH     = "../dataset"
DATABASE_PATH    = "../database/flower.db.sqlite"
SAVED_MODEL_PATH = "./saved"


In [2]:
# Huggingface trained model: https://huggingface.co/dima806/oxford_flowers_image_detection
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("image-classification", model="dima806/oxford_flowers_image_detection")

# Load model directly from internet
from transformers import AutoModelForImageClassification
model = AutoModelForImageClassification.from_pretrained("dima806/oxford_flowers_image_detection")


  from .autonotebook import tqdm as notebook_tqdm





In [10]:
# Download & extract dataset. From: https://www.tensorflow.org/datasets/catalog/oxford_flowers102
import urllib.request
import tarfile
import os

__save_path = os.path.join(DATASET_PATH, "102flowers.tgz")
urllib.request.urlretrieve(DATASET_URL, __save_path)

file = tarfile.open(__save_path, mode="r|gz")
file.extractall(DATASET_PATH)
file.close()

In [11]:
# Show all label of this model
print(model.config.id2label)

{0: 'bolero deep blue', 1: 'toad lily', 2: 'bougainvillea', 3: 'blanket flower', 4: 'prince of wales feathers', 5: 'english marigold', 6: 'common dandelion', 7: 'mallow', 8: 'barbeton daisy', 9: 'desert-rose', 10: 'anthurium', 11: 'cyclamen', 12: 'marigold', 13: 'spring crocus', 14: 'petunia', 15: 'foxglove', 16: 'primula', 17: 'cape flower', 18: "colt's foot", 19: 'osteospermum', 20: 'buttercup', 21: 'balloon flower', 22: 'fire lily', 23: 'bromelia', 24: 'artichoke', 25: 'daffodil', 26: 'pink-yellow dahlia', 27: 'geranium', 28: 'peruvian lily', 29: 'king protea', 30: 'silverbush', 31: 'alpine sea holly', 32: 'hibiscus', 33: 'giant white arum lily', 34: 'canna lily', 35: 'sunflower', 36: 'sweet pea', 37: 'mexican aster', 38: 'californian poppy', 39: 'pincushion flower', 40: 'black-eyed susan', 41: 'blackberry lily', 42: 'gaura', 43: 'love in the mist', 44: 'spear thistle', 45: 'orange dahlia', 46: 'wallflower', 47: 'tiger lily', 48: 'stemless gentian', 49: 'morning glory', 50: 'frangip

In [14]:
# Predict/ Extract feature from the whole datasets
from PIL import Image
from pathlib import Path

data     = {}
pathlist = Path(DATASET_PATH).glob('**/*.jpg')

for path in pathlist:
    path_str = str(path)
    data_it = pipe(Image.open(path_str))
    data[path_str] = data_it

In [24]:
# Setup DB
import sqlite3
import os

db_dir = os.path.split(DATABASE_PATH)[0]
if not os.path.exists(db_dir):
    os.mkdir(db_dir)

# conn = sqlite3.connect(':memory:')
conn = sqlite3.connect(DATABASE_PATH)
c = conn.cursor()

# Create table
c.execute('''CREATE TABLE IF NOT EXISTS flowers_img(pid INTEGER PRIMARY KEY AUTOINCREMENT, filename TEXT)''')
c.execute('''CREATE VIRTUAL TABLE IF NOT EXISTS flowers_vector USING fts4(pid, tokens, score)''')

<sqlite3.Cursor at 0x6ae39340>

In [25]:
# Process predict data and save to DB
import os

def process_label(in_label: str):
    return in_label.replace("flower", "").strip()

def save_img_to_db(filename, predicts):
    c.execute("INSERT INTO flowers_img(filename) VALUES (?)", [filename])
    conn.commit()

    for predict in predicts:
        if predict["score"] < 0.1:
            continue
        c.execute("INSERT INTO flowers_vector(pid, tokens, score) VALUES (?, ?, ?)",
                  [c.lastrowid, process_label(predict["label"]), predict["score"]])
        conn.commit()

for path in data:
    filename = os.path.split(path)[1]
    save_img_to_db(filename, data[path])

In [26]:
# Close database
conn.close()

In [22]:
# Save pipeline
pipe.save_pretrained(SAVED_MODEL_PATH)
