In [None]:
#| default_exp to_emb_atlas

In [None]:
#| hide

%load_ext autoreload
%autoreload 2

In [None]:
#| export

import atexit
import zipfile
from pathlib import Path
from tempfile import TemporaryDirectory as TmpDir
from typing import Literal

import daft
import embedding_atlas.cli as emb_atlas_cli
import numpy as np
import polars as pl
from daft.functions import to_struct
from PIL import Image

from clip_plot.images import resize_to_max_side
from clip_plot.utils import timestamp


#

1. concatenate all data into a dataframe
2. load images as base strings
3. run atlas creation
4. unzip upon exit

- path is relative to the output folder; can turn into a link with <a>
- need to check out huggingface example to see why image renders properly
- should we drop _row_index?
- embed_path is there in test, but it's just because it's in the input, nbd

In [None]:
#| export

@daft.func
def relative_image_path(img_path: str, viewer_dir: Path) -> str:
    return str(Path(img_path).relative_to(viewer_dir))

In [None]:
#| export

@daft.func
def resize_for_preview(image: np.ndarray, max_side: int = 128,
                       mode: Literal["RGB", "RGBA", "L"] = "RGB") -> Image.Image:
    im = Image.fromarray(image).convert(mode)
    return resize_to_max_side(im, max_side)

NameError: name 'daft' is not defined

In [None]:
#| export

def load_images(df: daft.DataFrame, image_path_col: str,
                viewer_dir: Path, mode: Literal["RGB", "RGBA", "L"],
                preview_size: int = 128) -> daft.DataFrame:
    originals_dir = viewer_dir / "jpg_originals"
    originals_dir.mkdir(exist_ok=True, parents=True)
    df = (df
    .with_column("image_bytes", daft.col(image_path_col).url.download(on_error="null"))
    .with_column("image_fullsize", daft.col("image_bytes").image.decode())
    .with_column("image_preview", resize_for_preview(daft.col("image_fullsize"), preview_size, mode))
    )
    if mode != "RGB": # loads as RGB by default
        df = (df
        .with_column("image_fullsize", daft.functions.convert_image(df["image_fullsize"], mode=mode))
        .with_column("image_preview", daft.functions.convert_image(df["image_preview"], mode=mode))
        )

    df = (df
    .with_column("image_name", daft.col(image_path_col).str.replace("\\", "/").str.split("/"))
    .with_column("image_name", daft.col("image_name").list.get(-1))
    )

    df = (df
    .with_column("destination", daft.functions.concat(
        daft.lit(str(originals_dir)+"/"), df["image_name"]
        ))
    .with_column("image_local_path", daft.functions.concat(
        daft.lit(originals_dir.name +"/"), df["image_name"]
        ))
    )

    df = (df
    .with_column("image_preview_bytes", daft.col("image_preview").image.encode("JPEG"))
    .with_column("image_fullsize_bytes", daft.col("image_fullsize").image.encode("JPEG"))
    # TODO: put None in for duplicate image names
    # and handle appropriate in upload (don't upload to orig folder, which is flat)
    .with_column("saved_path", daft.col("image_fullsize_bytes").url.upload(df["destination"]))
    )

    df = df.with_column("image_preview", to_struct(bytes=df["image_preview_bytes"], path=df[image_path_col])
    ).exclude("image_bytes", "image_preview_bytes", "image_fullsize_bytes", "image_name",
              "saved_path", "destination")

    front_cols = ["image_preview", "image_local_path", image_path_col]
    others = sorted(c for c in df.column_names if c not in front_cols)
    keep = front_cols + others
    print(f"Columns to keep: {keep}")
    return df.select(*keep)

In [None]:
#| export

def unzip_atlas(zip_path: Path, extract_target: Path, delete_after: bool = False):
    """unzip at exit"""
    print(timestamp(), f"Extracting the viewer bundle to \n{extract_target.as_posix()}")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_target)
    if delete_after:
        zip_path.unlink()
    return None

In [None]:
#| export

def run_emb_atlas(parquet_path: Path, zip_path: Path,
                  viewer_dir: Path, temp_dir: TmpDir):
    atexit.register(print,
                    f"{timestamp()} Finished creating atlas viewer at {viewer_dir}")
    atexit.register(temp_dir.cleanup)
    # last registered runs first
    atexit.register(unzip_atlas,
                    zip_path=zip_path, extract_target=viewer_dir
                    )
    emb_atlas_cli.main.main(
        args=[parquet_path.as_posix(),
            "--x", "emb_x", "--y", "emb_y",
            "--export-application", zip_path.as_posix(),
            ],
        standalone_mode=False
        )
    return None

NameError: name 'Path' is not defined

In [None]:
#| export

def create_emb_atlas(table: pl.DataFrame, image_path_col: str,
                     viewer_dir: Path, plot_id: str,
                     mode: Literal["RGB", "RGBA", "L"] = "RGB",
                     preview_size: int = 128) -> None:
    df = daft.from_arrow(table.to_arrow())
    df = load_images(df, image_path_col, viewer_dir, mode, preview_size)
    prep_dir = TmpDir()
    parquet_path = Path(prep_dir.name) / f"viewer-input-{plot_id}.parquet"
    zip_path = parquet_path.with_suffix(".zip")
    df.write_parquet(parquet_path, write_mode="overwrite")
    run_emb_atlas(parquet_path, zip_path=zip_path,
                    viewer_dir=viewer_dir, temp_dir=prep_dir)
    return None

NameError: name 'pl' is not defined

In [None]:
#| hide

import nbdev

nbdev.nbdev_export()