## 1. Importing libraries

In [None]:
import os
from docarray import Document, DocumentArray

## 2. Configuration and setup

In [None]:
MAX_DOCS = 10
DATA_DIR = 'data/images'
DATA_PATH = f"{DATA_DIR}/*/*.jpg"
QUERY_IMAGE = "data/query.jpg"

In [None]:
da = DocumentArray(storage='weaviate', config={'name': 'Image', 'client': 'http://localhost:8080', 'n_dim': 1000})

da.summary()

## 3. Load Data

In [None]:
docs = da.from_files(DATA_PATH, size=MAX_DOCS, storage='weaviate', config={'name': 'Image', 'client': 'http://localhost:8080', 'n_dim': 1000})
print(f"{len(docs)} Documents in DocumentArray")

In [None]:
docs.plot_image_sprites() # Preview the images

## 4. Preprocess Data

In [None]:
# Convert to tensor, normalize so they're all similar enough
def preproc(d: Document):
    return (d.load_uri_to_image_tensor()  # load
             .set_image_tensor_shape((80, 60))  # ensure all images right size (dataset image size _should_ be (80, 60))
             .set_image_tensor_normalization()  # normalize color 
             .set_image_tensor_channel_axis(-1, 0))  # switch color axis for the PyTorch model later

In [None]:
# apply en masse
docs.apply(preproc)

## 5. Embed images

In [None]:
%pip install torchvision==0.11.2

In [None]:
# Use GPU if available
import torch
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [None]:
import torchvision
model = torchvision.models.resnet50(pretrained=True)  # load ResNet50

In [None]:
docs.embed(model, device=device)

## 6. Query dataset

In [None]:
query_doc = Document(uri=QUERY_IMAGE)
query_doc.display()


In [None]:
query_docs = DocumentArray([query_doc], storage='weaviate', config={'name': 'Image', 'client': 'http://localhost:8080', 'n_dim': 1000})

In [None]:
query_docs.apply(preproc)

In [None]:
query_docs.embed(model, device=device) # If running on non-gpu machine, change "cuda" to "cpu"

In [None]:
query_docs.match(docs, limit=3)

In [None]:
(DocumentArray(query_doc.matches, copy=True)
    .apply(lambda d: d.set_image_tensor_channel_axis(0, -1)
                      .set_image_tensor_inv_normalization())).plot_image_sprites()