In [None]:
#| default_exp configuration

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export

from typing import Literal, Any
from typing_extensions import Self 
from pathlib import Path
from importlib.metadata import version
from uuid import uuid4
from glob import glob

from pydantic import (
    Field, BaseModel,
    model_validator, field_validator, ValidationInfo
)
from pydantic_settings import (
    BaseSettings,
    SettingsConfigDict,
    TomlConfigSettingsSource,
    CliApp,
    CliSuppress
)

import clip_plot # for version

In [None]:
#| export

class Paths(BaseModel):
    images: list[Path] = Field(default_factory=list, description="Folder or glob that expands to input images")
    tables: list[Path] | None = Field(None,
                               description="Path/folder/glob to table(s) with image_path, hidden_vector_path cols")
    metadata: list[Path] | None = Field(None,
                               description="Path/folder/glob to table(s) with image_path cols")
    table_id: str = Field(default_factory=lambda: str(uuid4()), description="Identifier for table output")
    output_dir: Path = Field((Path()/"clipplot_output").resolve(),
            description="Directory for output data files and viewer")
    table_format: Literal["parquet", "csv"] = Field("parquet",
                                                    description="Format for output table, `csv` or `parquet`")


    @field_validator("images", "tables", "metadata", mode="before")
    @classmethod
    def expand_paths(cls, value: str | list | Path | None, info: ValidationInfo) -> Any:
        if value is None:
            return None
        if isinstance(value, list) and len(value) == 1:
            value = value[0] # yikes unpack
        elif isinstance(value, list) and len(value) > 1:
            return value
        if "*" in str(value):
            return [Path(p) for p in glob(str(value), recursive=True)]
        elif Path(str(value)).is_dir():
            if info.field_name == "images":
                exts = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif'}
            else:
                exts = {'.json', '.csv', '.parquet'}
            return [p for p in Path(str(value)).rglob('*') if p.suffix.lower() in exts]
        else:
            return [Path(str(value))]

    @model_validator(mode='after')
    def check_table_vs_meta(self) -> Self:
        if self.metadata is not None and self.tables is not None:
            raise ValueError("'metadata' and 'tables' are mutually exclusive.")
        return self

In [None]:
#| export

class UmapSpec(BaseModel):
    n_neighbors: int | list[int] = Field([15], description="Number of neighbors in UMAP")
    min_dist: float | list[float] = Field([0.1], description="Minimum distance in UMAP")
    metric: CliSuppress[str] = Field("correlation", description="Metric argument for UMAP")
    reducer: Literal["umap", "localmap", "pacmap"] = Field("umap", description="Dimensionality reduction algorithm.")

    @field_validator("n_neighbors", "min_dist", mode="before")
    @classmethod
    def to_list(cls, value: str | list | int | float) -> list:
        if isinstance(value, list):
            return value
        elif isinstance(value, tuple) or isinstance(value, set):
            return list(value)
        else:
            return [value]

In [None]:
#| export

class ClusterSpec(BaseModel):
    n_clusters: CliSuppress[int] = Field(12)
    max_clusters: CliSuppress[int] = Field(10)
    min_cluster_size: CliSuppress[int] = Field(20)

In [None]:
#| export

class ImageLoaderOptions(BaseSettings):
    seed: CliSuppress[int | None] = Field(42, description="Seed for reproducible transforms")
    shuffle: CliSuppress[bool] = Field(False, description="Shuffle images before creating viewer")
    cell_size: CliSuppress[int] = Field(64, description="Size of cell in viewer atlas")
    lod_cell_height: CliSuppress[int] = Field(128)
    atlas_size: CliSuppress[int] = Field(4096, description="Size for atlases")

In [None]:
#| export

class ViewerOptions(BaseSettings):
    logo: CliSuppress[None | Path] = Field(None, description="Path to custom logo")
    tagline: CliSuppress[None | str] = Field(None, description="Custom tagline for viewer")


In [None]:
#| export

class Cfg(BaseSettings):
    thumbnail_size: int = Field(128, description="Size of images in main clip-plot view")
    model: str = Field("timm/convnext_tiny.dinov3_lvd1689m",
                            description="Model name on huggingface.co/models")
    umap_spec: UmapSpec = UmapSpec()
    cluster_spec: ClusterSpec = ClusterSpec()
    clipplot_version: str = Field(version(clip_plot.__name__), description="Version of clipplot")
    plot_id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for plot")
    paths: Paths = Paths()
    view_opts: ViewerOptions = ViewerOptions()
    image_opts: ImageLoaderOptions = ImageLoaderOptions()
    image_path_col: str = Field("image_path", description="Name of column with paths to images")
    vectors_col: str = Field("hidden_vectors", description="Name of column with hidden vectors i.e. embeddings")
    # image_column: str = Field("image", description="Name of column with images")
    # vectors_column: str = Field("hidden_vectors", description="Name of column with hidden vectors")

    model_config = SettingsConfigDict(
        env_prefix = "CLIPPLOT_",
        cli_parse_args = True,
        use_attribute_docstrings = True,
        cli_prog_name = "clipplot",
        cli_hide_none_type = True,
        cli_ignore_unknown_args=True,
        # pyproject_toml_table_header=(),
    )

In [None]:
#| hide

cfg = Cfg()
cfg.model_dump()

{'thumbnail_size': 128,
 'model': 'timm/convnext_tiny.dinov3_lvd1689m',
 'umap_spec': {'n_neighbors': 15,
  'min_dist': 0.1,
  'metric': 'correlation',
  'reducer': 'umap'},
 'cluster_spec': {'n_clusters': 12,
  'max_clusters': 10,
  'min_cluster_size': 20},
 'clipplot_version': '0.1.4',
 'plot_id': '0bfd3038-7429-467e-8ed8-e858cba52d63',
 'paths': {'images': [],
  'tables': None,
  'metadata': None,
  'table_id': 'a1685e7b-a571-47b5-9d98-c5fffbca31e7',
  'output_dir': PosixPath('/home/wsanger/git/clip-plot/nbs/clipplot_output'),
  'table_format': 'parquet'},
 'view_opts': {'logo': None, 'tagline': None},
 'image_opts': {'seed': 42,
  'shuffle': False,
  'cell_size': 64,
  'lod_cell_height': 128,
  'atlas_size': 4096},
 'image_path_col': 'image_path',
 'vectors_path_col': 'hidden_vectors_path'}

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()