In [None]:
#| default_exp to_emb_atlas

In [None]:
#| hide

%load_ext autoreload
%autoreload 2

In [None]:
#| export

import atexit
import zipfile
from pathlib import Path

import daft
import embedding_atlas.cli as emb_atlas_cli
import polars as pl

from clip_plot.utils import timestamp

#

1. concatenate all data into a dataframe
2. load images as base strings
3. run atlas creation
4. unzip upon exit

In [None]:
#| export

def load_images(df: daft.DataFrame, image_path_col: str) -> daft.DataFrame:
    df = (df
    .with_column("image", daft.col(image_path_col).url.download())
    .with_column("image_jpeg_bytes", daft.col("image").image.encode("JPEG"))
    .with_column("image_jpeg_base64", daft.col("image_jpeg_bytes").encode("base64").decode("utf-8"))
    )
    df = df.with_column("image",
                        daft.functions.concat(daft.lit("data:image/jpeg;base64,"),
                        df["image_jpeg_base64"]))
    keep = {"image", image_path_col, df.columns} - {"image_jpeg_bytes", "image_jpeg_base64"}
    return df.select(list(keep))

In [None]:
#| export

def unzip_atlas(atlas_zip: Path, delete_after: bool = True):
    """unzip at exit"""
    extract_target = atlas_zip.parent / f"{atlas_zip.stem}_atlas-viewer-bundle/"
    print(timestamp(), f"Extracting the viewer bundle to \n{extract_target.as_posix()}")
    with zipfile.ZipFile(atlas_zip, 'r') as zip_ref:
        zip_ref.extractall(extract_target)
    if delete_after:
        atlas_zip.unlink()
    return None

In [None]:
#| export

def run_emb_atlas(parquet_path: Path, atlas_zip: Path):
    atexit.register(unzip_atlas, atlas_zip)
    emb_atlas_cli.main.main(
        args=[parquet_path.as_posix(),
              "--x", "emb_x", "--y", "emb_y",
              "--export-application", atlas_zip,
            ],
        standalone_mode=False
        )
    return None

In [None]:
#| export

def create_emb_atlas(table: pl.DataFrame, image_path_col: str,
                     plot_id: str, output_path: Path) -> None:
    df = daft.from_arrow(table.to_arrow)
    df = load_images(df, image_path_col)
    parquet_path = output_path / f"viewer-input-{plot_id}.parquet"
    df.write_parquet(parquet_path)
    run_emb_atlas(parquet_path)

    return None

In [None]:
#| hide

import nbdev

nbdev.nbdev_export()