# Custom Embeddings

Encord Active have three different types of embeddings.

1. _Image embeddings:_ are general for each image / frame in the dataset
2. _Classification embeddings:_ are associated to specific frame level classifications
3. _Object embeddings:_ are associated to specific objects like polygons of bounding boxes

If you like, you can "swap out" these embeddings with your own by following the steps in this notebook.

There are two sections in the notebook. One for the image embeddings and one for the objects.
If you have classifications in your project, you should run 

```
encord-active metric run "Image-level Annotation Quality"
```

This will take the image level embeddings that you provided and also associate them to the classification labels.

In [3]:
import pickle
from pathlib import Path
from typing import List

import torch
from encord_active.lib.common.iterator import DatasetIterator, Iterator
from encord_active.lib.embeddings.dimensionality_reduction import (
    generate_2d_embedding_data,
)
from encord_active.lib.embeddings.types import LabelEmbedding
from encord_active.lib.metrics.types import EmbeddingType
from encord_active.lib.project.project_file_structure import ProjectFileStructure
from PIL import Image
from torchvision.transforms import ToTensor


def load_my_model() -> torch.nn.Module:
    ...  # <- HERE: Edit here to return your model


def get_transform():
    return (
        ToTensor()
    )  # <- HERE: If you have any specific transforms to apply to PIL images.

## Examle of Image Embeddings

In [None]:
@torch.inference_mode()
def generate_cnn_image_embeddings(iterator: Iterator) -> List[LabelEmbedding]:
    model = load_my_model()
    transform = get_transform()

    collections: List[LabelEmbedding] = []
    for data_unit, image in iterator.iterate(desc="Embedding image data."):
        if image is None:
            continue

        image_pil = image.convert("RGB")
        image = transform(image_pil)

        # START Embedding
        embedding = model(image)  # <- HERE - your logic for embedding data.

        if embedding is None:
            continue

        embedding = embedding.flatten().detach().numpy()  # <- should be a [d,] array.
        # End Embedding

        entry = LabelEmbedding(
            url=data_unit["data_link"],
            label_row=iterator.label_hash,
            data_unit=data_unit["data_hash"],
            frame=iterator.frame,
            dataset_title=iterator.dataset_title,
            embedding=embedding,
            labelHash=None,
            lastEditedBy=None,
            featureHash=None,
            name=None,
            classification_answers=None,
        )
        collections.append(entry)

    return collections


project = Path("/path/to/your/project/root")  # <- HERE: Path to the Encord Project
pfs = ProjectFileStructure(project)

iterator = DatasetIterator(project)
embeddings = generate_cnn_image_embeddings(iterator)
out_file = prfs.get_embeddings_file(EmbeddingType.IMAGE)

with out_file.open("wb") as f:
    pickle.dump(embeddings, f)

generate_2d_embedding_data(EmbeddingType.IMAGE, project)

## Example of Object Embeddings

In [None]:
from encord_active.lib.common.utils import get_bbox_from_encord_label_object

@torch.inference_mode()
def generate_cnn_object_embeddings(iterator: Iterator) -> List[LabelEmbedding]:
    model = get_model()
    transform = get_transform()

    embeddings: List[LabelEmbedding] = []
    for data_unit, image in iterator.iterate(desc="Embedding object data."):
        if image is None:
            continue
        
        image_pil = image.convert("RGB")
        image = transform(image_pil)
        
        for obj in data_unit["labels"].get("objects", []):
            if obj["shape"] in [
                ObjectShape.POLYGON.value,
                ObjectShape.BOUNDING_BOX.value,
                ObjectShape.ROTATABLE_BOUNDING_BOX.value,
            ]:
                # Crops images tightly around object
                out = get_bbox_from_encord_label_object( 
                    obj,
                    image.shape[2],
                    image.shape[1],
                )

                if out is None:
                    continue
                
                x, y, w, h = out
                img_patch = image[:, y : y + h, x : x + w]
                
                # Compute embeddings
                embedding = model(img_patch)
                embedding = embedding.flatten().detach().numpy()  # <- should be a [d,] array.

                last_edited_by = obj["lastEditedBy"] if "lastEditedBy" in obj.keys() else obj["createdBy"]
                entry = LabelEmbedding(
                    url=data_unit["data_link"],
                    label_row=iterator.label_hash,
                    data_unit=data_unit["data_hash"],
                    frame=iterator.frame,
                    labelHash=obj["objectHash"],
                    lastEditedBy=last_edited_by,
                    featureHash=obj["featureHash"],
                    name=obj["name"],
                    dataset_title=iterator.dataset_title,
                    embedding=embedding,
                    classification_answers=None,
                )

                embeddings.append(entry)


    return embeddings

embeddings = generate_cnn_object_embeddings(iterator)
out_file = pfs.get_embeddings_file(EmbeddingType.OBJECT)

with out_file.open("wb") as f:
    pickle.dump(embeddings, f)

generate_2d_embedding_data(EmbeddingType.OBJECT, project)