In [1]:
import sys
import os
sys.path.insert(0, '..')

In [2]:
os.environ["UNIVERSE_DIR"] = ("/home/lennon/workspace/hdd/ecad_database/features/65k/universe_test")

In [3]:
# from identificationWorker.utils import file_utils
import os
from tqdm import tqdm
import deepdish as dd
import glob


def get_unique_by_order(data):
    from collections import OrderedDict

    w = list(OrderedDict.fromkeys(data))
    return w


class DataLoader:
    """
    A class for loading and managing data for the identification engine.

    Args:
        universe_dir (str): The directory path where the universe data is stored.

    """

    def __init__(self, universe_dir):
        self.universe_mapping = []
        self.universe_features = []
        self.universe_context_idx = []
        self.ready = False
        self.universe_dir = universe_dir
        self.timestamp_dataloader = None
        self.universe_selection = []

    def get_status(self) -> bool:
        return self.ready

    def set_status(self, st: bool):
        self.ready = st

    def get_features_path(self, selection: list[str] = None) -> list[str]:
        """
        Retrieves the paths of the .h5 feature files.

        Args:
            selection (list[str], optional): A list of directory names to filter the search. Defaults to None.

        Returns:
            list[str]: A list of paths to the .h5 feature files.

        Raises:
            ValueError: If no .h5 files are found in the specified directory/directories.
        """
        h5_list = []
        if selection is not None:
            for dir_name in selection:
                h5_files = glob.glob(os.path.join(self.universe_dir, dir_name, "*.h5"))

                if not h5_files:
                    raise ValueError(f"No .h5 files found in directory {dir_name}")
                h5_list.extend(h5_files)
        else:
            h5_list = glob.glob(os.path.join(self.universe_dir, "**/*.h5"))
            if not h5_list:
                raise ValueError(
                    f"No .h5 files found in the universe directory {self.universe_dir}"
                )

        return h5_list

    def load_h5_features(self, file: str) -> dict:
        try:
            file_data = dd.io.load(file)

            if "segments" in file_data:
                return file_data["segments"][0]
            else:
                return file_data
        except Exception:
            print(f"Error: {file} not found or corrupt data.")
            return None

    def get_h5data_dict(self, work_dict):
    
        file = os.path.join(self.universe_dir, work_dict["obra"], work_dict["fonograma"] + ".h5")

        features = self.load_h5_features(file)

        return features
    
    def get_h5data(self, work_id, track_id):    
        file = os.path.join(self.universe_dir, work_id, track_id + ".h5")

        features = self.load_h5_features(file)

        return features
    
    def batch_load_h5data(self, list_files):
        pass

    def load_universe(self, selection: list[str] = None) -> bool:
        """
        Loads the universe data by iterating over the provided selection of files.

        Args:
            selection (list[str], optional): A list of file paths to load. Defaults to None.

        Returns:
            bool: True if the universe data is successfully loaded, False otherwise.

        Raises:
            ValueError: If no universe files data is found.
        """
        universe_features_path: list[str] = self.get_features_path(selection)

        if not universe_features_path:
            raise ValueError("No universe files data found.")
        else:
            for file in universe_features_path:
                parts = file.split("/")
                work_id = parts[-2]
                track_id = parts[-1].split(".")[0]
                self.universe_mapping.append({"obra": work_id, "fonograma": track_id})
                self.universe_selection.append(f"{work_id}")
        universe_features_path = []

        return True

    def set_universe(self, selection: list[str] = None) -> None:
        """
        Prepares the data loader by loading the universe data and setting the readiness status.

        Args:
            selection (list[str], optional): List of universe selections. Defaults to None.

        Returns:
            bool: True if the universe features and mapping are loaded successfully, False otherwise.
        """
        self.ready = False
        self.universe_mapping = []
        self.universe_features = []
        self.universe_context_idx = []
        self.universe_selection = []
        if self.load_universe(selection):
            if len(self.universe_mapping) > 0:
                self.ready = True
                # self.universe_selection = selection

    def get_universe(self):
        return self.universe_mapping

    def get_universe_idx(self):
        return range(len(self.universe_mapping))

    def get_universe_mapping(self):
        return self.universe_mapping

    def _search_restricted_universe(self, universe_context=None):
        self.universe_context_idx = []

        for idx, u in enumerate(self.universe_mapping):
            if u["obra"] in universe_context:
                self.universe_context_idx.append(idx)

        return self.universe_context_idx

    def get_restricted_universe(
        self, universe_context_by_id=None, universe_context_by_idx=None
    ):
        if universe_context_by_id is None and universe_context_by_idx is None:
            return (
                self.get_universe_idx(),
                self.universe_mapping,
            )

        self.universe_context_idx = []
        # universe_restricted = []
        universe_restricted_mapping = []

        if universe_context_by_id is not None:
            for idx, u in enumerate(self.universe_mapping):
                if u["obra"] in universe_context_by_id:
                    self.universe_context_idx.append(idx)
                    # universe_restricted.append(self.universe_features[idx])
                    universe_restricted_mapping.append(u)

        else:
            for idx in universe_context_by_idx:
                self.universe_context_idx.append(idx)
                # universe_restricted.append(self.universe_features[idx])
                universe_restricted_mapping.append(self.universe_mapping[idx])

        return (
            self.universe_context_idx,
            # universe_restricted,
            universe_restricted_mapping,
        )

    def get_restricted_universe_idx(self, universe_context=None):
        if universe_context is not None:
            self._search_restricted_universe(universe_context=universe_context)

        if len(self.universe_context_idx) == 0:
            return self.get_universe_idx()

        return self.universe_context_idx

    def get_work_from_idx(self, list_idx):
        worksList = []
        for idx in list_idx:
            worksList.append(self.universe_mapping[idx]["obra"])

        return get_unique_by_order(worksList)

    def get_fonograma_from_idx(self, list_idx):
        fonogramaList = []
        for idx in list_idx:
            fonogramaList.append(self.universe_mapping[idx]["fonograma"])

        return fonogramaList

In [4]:
UNIVERSE_DIR = os.getenv("UNIVERSE_DIR")
dl = DataLoader(UNIVERSE_DIR)

In [5]:
dl.load_universe()

True

In [6]:
features_path = dl.get_features_path()
features_list = [dl.load_h5_features(path) for path in features_path]

In [7]:
vectors_cqt = [feature["cqtnet"]["features"] for feature in features_list]
payloads_cqt = dl.get_universe_mapping()

In [8]:
from db.qdrant import Qdrant
qdrant = Qdrant()

In [9]:
from typing import List, TypedDict, overload
import numpy as np
from db.vector_db import Item, VectorDB

class Filter(TypedDict):
    obra: str
    fonograma: str


class CQTNetRepository:
    def __init__(self, vector_db: VectorDB):
        self.vector_db = vector_db
        self.collection_name = "cqtnet_test" # TODO: Mudar para o nome correto
        self.dim = 300

    def try_create_collection(self):
        self.vector_db.try_create_collection(self.collection_name, self.dim)

    def add(self, cqtnets: np.ndarray, payloads=None):
        for i, cqtnet in enumerate(cqtnets):
            vectors = [vector for vector in cqtnet]
            payload = payloads[i] if payloads else None
            if payload:
                payload = [payload for _ in vectors]
            self.vector_db.add(self.collection_name, vectors, payload)

    def search(self, queries: np.ndarray, top_k=5, filter: dict = None):
        for cqtnet in queries:
            vectors = [vector for vector in cqtnet]
            yield self.vector_db.search(
                self.collection_name, vectors, top_k=top_k, filter=filter
            )
    
    @overload
    def get(self, id: str) -> List[Item]: ...

    @overload
    def get(self, filter: Filter, top_k=5) -> List[Item]: ...
    
    def get(self, id_or_filter: str = None, top_k=5):
        return self.vector_db.get(self.collection_name, id_or_filter, top_k)

    def delete(self, id: str):
        self.vector_db.delete(self.collection_name, id)

    def delete_collection(self):
        self.vector_db.delete_collection(self.collection_name)


In [10]:
cqtnet_repository = CQTNetRepository(qdrant)

In [11]:
cqtnet_repository.try_create_collection()

In [12]:
cqtnet_repository.add(vectors_cqt[20:35], payloads=payloads_cqt[20:35])

In [13]:
results = cqtnet_repository.search(vectors_cqt[23:26])

In [14]:
for result in results:
    print(result)

[[{id: 4c4c34b1-1fa2-4876-81e6-20fd8d07e60b, score: 0.99999994, vector: None, payload: {'obra': '1673194', 'fonograma': '1727267'}}, {id: b59b8d28-509b-48ee-b934-9ca916aa476c, score: 0.9517004, vector: None, payload: {'obra': '1673194', 'fonograma': '1727267'}}, {id: 55364e19-70e7-4dfe-952a-837bcbb5b782, score: 0.8960926, vector: None, payload: {'obra': '1673194', 'fonograma': '1727267'}}, {id: 3dc4483b-2a60-4775-a55b-cd1f825694c6, score: 0.84135884, vector: None, payload: {'obra': '1673194', 'fonograma': '1727267'}}, {id: ecee8186-0408-4189-b3b5-0366466d8cb6, score: 0.8377396, vector: None, payload: {'obra': '1673194', 'fonograma': '1727267'}}], [{id: b59b8d28-509b-48ee-b934-9ca916aa476c, score: 1.0, vector: None, payload: {'obra': '1673194', 'fonograma': '1727267'}}, {id: 4c4c34b1-1fa2-4876-81e6-20fd8d07e60b, score: 0.9517004, vector: None, payload: {'obra': '1673194', 'fonograma': '1727267'}}, {id: 55364e19-70e7-4dfe-952a-837bcbb5b782, score: 0.9444559, vector: None, payload: {'obra

In [15]:
id = result[0][0].id

In [16]:
cqtnet_repository.get(id)

[{id: 9812297c-0464-40b3-9a7b-ab3b823135de, score: None, vector: [-9.17224100e-02 -9.69308400e-03  4.33622400e-02 -7.77479730e-03
   4.89704570e-02 -3.35638100e-02 -1.09387140e-01 -1.53065920e-04
   5.70040640e-02 -9.26600900e-03 -4.02580870e-02 -4.71739000e-03
  -9.63377650e-02 -5.09774400e-02  4.36813830e-02  4.13660670e-02
   7.56352400e-02 -2.65072200e-02  8.31943200e-02 -2.18100080e-03
  -1.34259640e-02 -6.78108700e-03  9.29658300e-02  7.39799300e-02
  -7.70657000e-02 -4.17729470e-02 -7.51776570e-03 -1.21131100e-01
   7.99212100e-02 -4.04765870e-02 -2.53900160e-02  3.02027800e-02
   2.40766220e-02 -9.33104500e-03  6.99430150e-03  2.55661180e-02
   2.84533850e-02  1.03225420e-02 -3.91824580e-02 -1.94073300e-02
   5.43707720e-02 -4.35805660e-02  1.59331730e-02  3.63925400e-02
   7.36118900e-02 -4.60129980e-02  5.41084370e-02  1.22749210e-02
   1.02709670e-02  7.28259700e-02 -2.71089280e-02 -1.39628570e-02
  -2.69793510e-02  2.28989660e-02 -8.90322300e-02  1.81170440e-02
   5.0239350

In [17]:
payload:dict = payloads_cqt[22]
cqtnet_repository.get(payload)

[{id: 00d83a95-95e8-4cd1-9da6-cd5472a998b3, score: None, vector: [-0.0150219   0.07021485  0.08314963  0.01428978  0.01516576  0.02844776
   0.05345831  0.02578808  0.0498589  -0.14448293  0.0137181   0.0653075
  -0.00776858 -0.08801954 -0.0922954  -0.00430782 -0.01737774  0.04016379
  -0.04978291 -0.04494949  0.00597721  0.04220469 -0.04387295  0.0636512
  -0.05103396 -0.08027597 -0.01315975 -0.03198038 -0.01486111 -0.07701579
   0.07442997  0.02023323  0.05519397  0.05580831  0.0005034   0.01877271
   0.01383358 -0.04276371  0.0059994  -0.02038506 -0.01179786 -0.00391527
   0.02032257 -0.06717927 -0.00491828  0.0399592   0.02526927  0.05434679
  -0.02861935 -0.0005604   0.00843222 -0.08073065 -0.00351661  0.01639787
  -0.08886606 -0.02292683  0.06591406 -0.06820139  0.02366526 -0.1551224
   0.02772722 -0.01929058 -0.02713663  0.03105416  0.03066763  0.05215802
   0.00413685  0.06692016  0.04534233  0.00712938  0.0533129  -0.00250684
   0.04959437 -0.05574548 -0.00133383 -0.05168218  

In [18]:
cqtnet_repository.delete_collection()

## TODO: FeaturesDbLoader

In [None]:
class FeaturesDbLoader:
    def __init__(
        self,
        data_loader: DataLoader,
        cqtnet_repository: CQTNetRepository,
        coverhunter_repository: ...,
    ): ...