In [25]:
# get file size of this notebook

import os
print(f"Notebook file size: {os.path.getsize('vlm_embed_train_data.ipynb') / 1024:.2f} KB")

Notebook file size: 565.48 KB


In [None]:
import numpy as np

data = np.load('../data/vlm_embed/iiif_no_text_embedding_matrix.npy', allow_pickle=True)

## Common diagnostics when you’re training a **variational autoencoder (VAE)**

---

### **Top-left: `latent/kl_var`**
- This shows the **KL divergence term** of your VAE’s loss, often averaged across dimensions.
- It measures how close your approximate posterior \( q_\phi(z|x) \) is to the prior \( p(z) \) (usually \( \mathcal{N}(0,I) \)).
- **Interpretation:**  
  - Starts low (the encoder is ignoring the latent variables).  
  - Rises as the encoder begins to use the latent space.  
  - A good balance is important: too low → posterior collapse, too high → poor reconstructions.
- The upward trend then plateauing and slightly decreasing suggests the model is learning to balance KL and reconstruction terms.

---

### **Top-right: `latent/logvar`**
- This is the distribution of the **log-variance outputs** from the encoder.  
- Each encoder latent dimension predicts a mean (`mu`) and log-variance (`logvar`), defining the Gaussian distribution from which you sample \( z \).  
- **Interpretation:**  
  - Spikes around zero mean log-variance ≈ 0 → variance ≈ 1 (close to prior).  
  - Spread distributions show how much uncertainty the encoder is predicting for each latent dimension.

---

### **Bottom-left: `latent/mu`**
- This is the distribution of the **means** predicted by the encoder for each latent variable.  
- Ideally, they should cluster around zero if the KL regularization is working well (since prior is \( N(0, I) \)).  
- If they drift too far away from zero, the KL term will increase to pull them back.

---

### **Bottom-right: `latent/z_sample`**
- These are actual **samples of latent variables \( z \)**, drawn as \( z = \mu + \sigma \cdot \epsilon \) where \( \epsilon \sim \mathcal{N}(0, I) \).  
- This shows how the *effective latent space* looks during training.  
- Should roughly follow a Gaussian distribution centered at zero, though skew or multimodality can appear if the VAE is strongly encoding structured information.

---

✅ **Putting it all together:**  
- `latent/mu` + `latent/logvar` tell you what your encoder is outputting.  
- `latent/z_sample` shows what you actually pass to the decoder.  
- `latent/kl_var` tells you how far this is from the prior and whether your KL term is being respected.  

---

## How you actually use the latent variable \(z\) once the VAE is trained.  

---

### 🔹 Recap: What the encoder produces
For each input \(x\), the encoder outputs:
- \( \mu(x) \) → mean of the approximate posterior  
- \( \log \sigma^2(x) \) → log-variance  

So the latent posterior is:
\[
q_\phi(z|x) = \mathcal{N}(z; \mu(x), \sigma^2(x) I)
\]

And the reparameterization trick samples:
\[
z = \mu(x) + \sigma(x) \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)
\]

---

### 🔹 Two ways of using \(z\) at inference
1. **Posterior mean (deterministic)**  
   - Just take \( z = \mu(x) \).  
   - This ignores variance and gives the “most likely” latent code.  
   - Useful if you want **stable reconstructions** (less noise).  
   - Downside: you aren’t really sampling, so you might lose generative diversity.

2. **Stochastic sample (sample)**  
   - Actually draw \( z \) using the full distribution \( \mu(x), \sigma(x) \).  
   - This adds noise according to the learned uncertainty.  
   - More faithful to the probabilistic nature of the VAE.  
   - Useful for **generation**, data augmentation, or exploring diversity.  

---

### 🔹 When to use which
- **Reconstruction tasks** → use posterior mean (\(z = \mu(x)\)) for sharper, more stable outputs.  
- **Generative tasks / sampling** → use stochastic \(z\) so you capture the randomness and diversity.  
- **Evaluation (e.g. ELBO)** → always stochastic, because that’s how the model was trained.  

# Embedding Data

In [4]:
import numpy as np

data_infer = '../data/vlm_embed/iiif_no_text_embedding_matrix.npy'


X = np.load(data_infer, mmap_mode='r')
# X = np.load(data_infer).astype(np.float32)
# X_subset = X[np.random.choice(X.shape[0], 100000, replace=False)]


# VAE Latents

In [5]:
from wc_simd.vlm_embed_vae import VAE3DWrapper

sample = False
model_checkpoint = '../runs/vlm_embed_vae3d_hires_1/vae3d.pt'
wrapper = VAE3DWrapper(model_checkpoint)
Z = wrapper.to3d(
    X, use_mu=(not sample),
    batch_size=64 * 1024)

# construct output filename from model and input
output_filename = f'../data/vlm_embed/iiif_no_text_embedding_matrix_{model_checkpoint.split("/")[-2]}.npy'
np.save(output_filename, Z)
del Z, wrapper  # free memory

In [6]:
import numpy as np

# data_file = output_filename
data_file = "../data/vlm_embed/iiif_no_text_embedding_matrix_vlm_embed_vae3d_hires_1.npy"
# data_file = "../data/vlm_embed/iiif_no_text_embedding_matrix_vlm_embed_vae3d_light_8.npy"
# data_file = "../data/vlm_embed/iiif_no_text_embedding_matrix_vlm_embed_ae3d_light_2.npy"

data = np.load(data_file, allow_pickle=True)

# Subset

In [7]:

subset_idx = np.random.choice(data.shape[0], 10000, replace=False)

In [9]:
# Save subset_idx for later use
np.save(data_file.replace('.npy', '_subset_idx.npy'), subset_idx)

In [10]:
# Load subset_idx if needed
subset_idx = np.load(data_file.replace('.npy', '_subset_idx.npy'))

In [11]:
subset = data[subset_idx]
del data

# Images

In [3]:
import pandas as pd

df_image_indices = pd.read_parquet(
    '../data/vlm_embed/iiif_no_text_embedding_index.parquet')
df_image_indices = df_image_indices.set_index('row_index')
df_image_indices['image_id'].to_list()

['https://iiif.wellcomecollection.org/image/b19314164_0270.jp2/full/1338,1519/0/default.jpg',
 'https://iiif.wellcomecollection.org/image/b18597452_0033.jp2/full/1338,1636/0/default.jpg',
 'https://iiif.wellcomecollection.org/image/b33537525_0232.jp2/full/1338,2047/0/default.jpg',
 'https://iiif.wellcomecollection.org/image/b32274683_0049.jp2/full/1338,1671/0/default.jpg',
 'https://iiif.wellcomecollection.org/image/b28705804_0043.jp2/full/1338,2217/0/default.jpg',
 'https://iiif.wellcomecollection.org/image/b30634672_0189.jp2/full/1338,1697/0/default.jpg',
 'https://iiif.wellcomecollection.org/image/b33585489_0249.jp2/full/1338,2037/0/default.jpg',
 'https://iiif.wellcomecollection.org/image/b28705804_0366.jp2/full/1338,2102/0/default.jpg',
 'https://iiif.wellcomecollection.org/image/b28704940_0519.jp2/full/1338,1525/0/default.jpg',
 'https://iiif.wellcomecollection.org/image/b19191728_0174.jp2/full/1338,1629/0/default.jpg',
 'https://iiif.wellcomecollection.org/image/b19371810_0133.j

In [12]:
# select the image indices for the subset
subset_images = df_image_indices.loc[subset_idx].values
subset_images = np.reshape(subset_images, (-1))
subset_images

array(['https://iiif.wellcomecollection.org/image/b14755117_0076.jp2/full/1338,1353/0/default.jpg',
       'https://iiif.wellcomecollection.org/image/b18758897_0179.jp2/full/1338,1856/0/default.jpg',
       'https://iiif.wellcomecollection.org/image/b33249519_0001.jp2/full/1338,935/0/default.jpg',
       ...,
       'https://iiif.wellcomecollection.org/image/b18758897_0204.jp2/full/1338,1713/0/default.jpg',
       'https://iiif.wellcomecollection.org/image/b18318162_0115.jp2/full/1338,1898/0/default.jpg',
       'https://iiif.wellcomecollection.org/image/b2493625x_RET_3_3_2_16_0570.jp2/full/1024,1017/0/default.jpg'],
      dtype=object)

In [13]:
import umap
import plotly.express as px

centered = subset - subset.mean(axis=0)
U, S, Vt = np.linalg.svd(centered, full_matrices=False)
whitened = centered @ Vt.T / S

# UMAP
umap_model = umap.UMAP(n_neighbors=15, n_components=2, metric='euclidean')
umap_2d_embeddings = umap_model.fit_transform(whitened)
umap_2d_embeddings

array([[ 9.394396  ,  7.2479796 ],
       [ 1.9158536 ,  8.637921  ],
       [-0.25186992, 10.690093  ],
       ...,
       [ 3.639392  ,  7.611991  ],
       [ 0.04486572,  5.220537  ],
       [-4.1493754 ,  7.1644235 ]], dtype=float32)

# Plot with Images

In [16]:
from io import BytesIO
from PIL import Image, ImageOps
import base64
import requests
import requests_cache


from concurrent.futures import ThreadPoolExecutor, as_completed

# Optional: on-disk HTTP cache (pip install requests-cache) to avoid
# re-downloading
try:
    session = requests_cache.CachedSession(
        'img_cache',
        backend='sqlite',
        expire_after=7 * 24 * 3600,  # 7 days
        allowable_methods=('GET',),
        allowable_codes=(200,),
    )
except Exception:
    session = requests.Session()


def embeddable_jpeg(url, max_side=512, timeout=10):
    """Fetch JPEG at URL, scale longest side to max_side (keep aspect), return data:image/jpeg;base64,..."""
    r = session.get(url, timeout=timeout, stream=True)
    r.raise_for_status()

    # Decode + correct EXIF orientation
    img = Image.open(r.raw)
    img = ImageOps.exif_transpose(img).convert("RGB")

    # Scale proportionally: longest side = max_side
    w, h = img.size
    scale = max_side / max(w, h) if max(w, h) > max_side else 1.0
    if scale != 1.0:
        img = img.resize((int(w * scale), int(h * scale)),
                         Image.Resampling.LANCZOS)

    # Encode JPEG efficiently (smaller payloads = faster base64 + render)
    buf = BytesIO()
    img.save(buf, format="JPEG", quality=70, optimize=True, progressive=True)
    return "data:image/jpeg;base64," + \
        base64.b64encode(buf.getvalue()).decode()


def batch_embeddable_jpeg(urls, max_side=512, max_workers=16, timeout=10):
    """Deduplicate, parallelize, and preserve order."""
    # Deduplicate to avoid repeated downloads
    unique_urls = list(dict.fromkeys(urls))
    results_map = {}

    def _wrap(u):
        try:
            return u, embeddable_jpeg(u, max_side=max_side, timeout=timeout)
        except Exception:
            return u, None  # or a tiny placeholder

    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        futs = [ex.submit(_wrap, u) for u in unique_urls]
        for fut in as_completed(futs):
            u, val = fut.result()
            results_map[u] = val

    # Re-expand to original order
    return [results_map.get(u) for u in urls]

# Normalize whitened x, y, z to [0, 1] for RGB


def normalize(arr):
    min_val = arr.min()
    max_val = arr.max()
    return (arr - min_val) / (max_val - min_val + 1e-8)

# Convert to hex color
def rgb_to_hex(r, g, b):
    return ['#%02x%02x%02x' % (int(255 * x), int(255 * y), int(255 * z))
            for x, y, z in zip(r, g, b)]

In [17]:
images_df = pd.DataFrame(umap_2d_embeddings, columns=('x', 'y'))
images_df['image'] = batch_embeddable_jpeg(
    subset_images, max_side=512, max_workers=16)
images_df['idx'] = subset_idx

In [18]:
r = normalize(whitened[:, 0])
g = normalize(whitened[:, 1])
b = normalize(whitened[:, 2])

images_df['color'] = rgb_to_hex(r, g, b)

In [19]:
# Write parquet
images_df.to_parquet('../data/vlm_embed/iiif_no_text_umap_2d_vlm_embed_vae3d_hires_1_umap_2d_plot.parquet')

In [20]:
import pandas as pd
# Load the data
images_df = pd.read_parquet(
    '../data/vlm_embed/iiif_no_text_umap_2d_vlm_embed_vae3d_hires_1_umap_2d_plot.parquet')

In [21]:
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from bokeh.palettes import Spectral10

output_notebook()

In [22]:
datasource = ColumnDataSource(images_df)

plot_figure = figure(
    title='UMAP projection of the IIIF non-ocr dataset',
    width=1920,
    height=1080,
    tools=('pan, wheel_zoom, reset')
)

plot_figure.add_tools(HoverTool(tooltips="""
<div>
    <div>
        <img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
    </div>
    <div>
        <span style='font-size: 16px; color: #224499'>idx:</span>
        <span style='font-size: 18px'>@idx</span>
    </div>
</div>
"""))


plot_figure.scatter(
    'x',
    'y',
    source=datasource,
    line_alpha=0.5,
    fill_alpha=0.5,
    size=5,
    color='color'
)
# show(plot_figure)

In [23]:
import os
from bokeh.embed import file_html
from bokeh.resources import INLINE

html = file_html(plot_figure, resources=INLINE, title="UMAP projection")
out_path = "iiif_no_text_umap_2d_vlm_embed_vae3d_hires_1_umap_2d_plot.html"
with open(out_path, "w", encoding="utf-8") as f:
    f.write(html)

print(f"Wrote {out_path} (INLINE resources).")
# Optional size check
print("File size (MB):", os.path.getsize(out_path) / (1024 * 1024))

Wrote iiif_no_text_umap_2d_vlm_embed_vae3d_hires_1_umap_2d_plot.html (INLINE resources).
File size (MB): 377.00646018981934


# Whitened Data

In [14]:
import plotly.express as px

centered = subset - subset.mean(axis=0)
U, S, Vt = np.linalg.svd(centered, full_matrices=False)
whitened = centered @ Vt.T / S

fig = px.scatter_3d(
    x=whitened[:, 0],
    y=whitened[:, 1],
    z=whitened[:, 2],
    opacity=0.4)
fig.update_traces(marker=dict(size=2))
fig.update_layout(width=1024,
                  height=768, scene=dict(aspectmode='cube'))
fig.show()

# UMAP (White)

In [15]:
import umap
import plotly.express as px

centered = subset - subset.mean(axis=0)
U, S, Vt = np.linalg.svd(centered, full_matrices=False)
whitened = centered @ Vt.T / S

# UMAP
umap_model = umap.UMAP(n_neighbors=15, n_components=3, metric='euclidean')
umap_embeddings = umap_model.fit_transform(whitened)

fig = px.scatter_3d(
    x=umap_embeddings[:, 0],
    y=umap_embeddings[:, 1],
    z=umap_embeddings[:, 2],
    opacity=0.4)
fig.update_traces(marker=dict(size=2))
fig.update_layout(width=1024,
                  height=768, scene=dict(aspectmode='cube'))
fig.show()

# Sphericalised (White)

In [None]:
import plotly.express as px


centered = subset - subset.mean(axis=0)
U, S, Vt = np.linalg.svd(centered, full_matrices=False)
whitened = centered @ Vt.T / S
whitened /= np.linalg.norm(whitened, axis=1, keepdims=True)
fig = px.scatter_3d(
    x=whitened[:, 0],
    y=whitened[:, 1],
    z=whitened[:, 2],
    opacity=0.4)
fig.update_traces(marker=dict(size=2))
fig.update_layout(width=1024,
                  height=768, scene=dict(aspectmode='cube'))
fig.show()

# Raw

In [4]:
import numpy as np
import plotly.express as px

# Expect subset to be defined already (sampled points)
dim_indices = (0, 1, 2)

x, y, z = (subset[:, i] for i in dim_indices)

# Outlier filtering parameters
lower_q = 0.0
upper_q = 1.0

x_low, x_high = np.quantile(x, [lower_q, upper_q])
y_low, y_high = np.quantile(y, [lower_q, upper_q])
z_low, z_high = np.quantile(z, [lower_q, upper_q])

mask = (
    (x >= x_low) & (x <= x_high) &
    (y >= y_low) & (y <= y_high) &
    (z >= z_low) & (z <= z_high)
)

removed_count = (~mask).sum()
print(f"Removed {removed_count} outlier points out of {subset.shape[0]} ({removed_count / subset.shape[0]:.2%})")

if removed_count > 0:
    outliers = np.column_stack([x[~mask], y[~mask], z[~mask]])
    max_print = 10
    print("Sample outlier points (x,y,z):")
    print(outliers[:max_print])

# Keep only inliers
x_f, y_f, z_f = x[mask], y[mask], z[mask]

# Axis stats BEFORE unification (for debugging / understanding clustering)
for name, arr in zip(['x','y','z'], [x_f, y_f, z_f]):
    print(f"{name}: min={arr.min():.6f} max={arr.max():.6f} span={(arr.max()-arr.min()):.6f} mean={arr.mean():.6f} std={arr.std():.6f}")

# Detect degenerate (near-constant) axes
EPS = 1e-12
spans = np.array([x_f.max()-x_f.min(), y_f.max()-y_f.min(), z_f.max()-z_f.min()])
if np.any(spans < EPS):
    print("WARNING: One or more axes have (near) zero span. Data may be constant or precision-collapsed.")

# Optional standardization to reveal variation if values are very tiny
apply_standardize_if_tiny = True
tiny_threshold = 1e-8
if apply_standardize_if_tiny and np.all(spans < tiny_threshold):
    print("All spans are tiny; applying z-score standardization to x,y,z for visualization only.")
    def safe_std(a):
        s = a.std()
        return s if s > 0 else 1.0
    x_v = (x_f - x_f.mean()) / safe_std(x_f)
    y_v = (y_f - y_f.mean()) / safe_std(y_f)
    z_v = (z_f - z_f.mean()) / safe_std(z_f)
    standardized = True
else:
    x_v, y_v, z_v = x_f, y_f, z_f
    standardized = False
print(f"Using standardized coordinates: {standardized}")

# Scaling strategy options
# 'symmetric': center each axis at its own midpoint but use the MAX half-span (after any standardization)
# 'independent': each axis keeps its own min/max
# 'global': single global min/max across axes (can collapse if data spans are identical tiny intervals)
scaling_mode = 'symmetric'  # change to 'symmetric' if collapse persists

if scaling_mode == 'independent':
    xr, yr, zr = ( [x_v.min(), x_v.max()], [y_v.min(), y_v.max()], [z_v.min(), z_v.max()] )
elif scaling_mode == 'global':
    vals = np.concatenate([x_v, y_v, z_v])
    gmin, gmax = float(vals.min()), float(vals.max())
    if abs(gmax - gmin) < EPS:
        # Expand artificially so Plotly can render a volume
        pad = 1.0 if standardized else 1e-3
        print(f"Global span ~0 (gmin={gmin:.6g}, gmax={gmax:.6g}); padding by +/-{pad}")
        gmin -= pad
        gmax += pad
    xr = yr = zr = [gmin, gmax]
else:  # symmetric
    mins = np.array([x_v.min(), y_v.min(), z_v.min()])
    maxs = np.array([x_v.max(), y_v.max(), z_v.max()])
    centers = (mins + maxs) / 2.0
    half_spans = (maxs - mins) / 2.0
    H = half_spans.max()
    if H < EPS:
        H = 1.0 if standardized else 1e-3
        print(f"Symmetric mode: half-span ~0, padding H={H}")
    xr = [centers[0]-H, centers[0]+H]
    yr = [centers[1]-H, centers[1]+H]
    zr = [centers[2]-H, centers[2]+H]

print(f"Using scaling mode: {scaling_mode}")
print(f"Ranges -> x:{xr} y:{yr} z:{zr}")

fig = px.scatter_3d(x=x_v, y=y_v, z=z_v, opacity=0.5)
fig.update_traces(marker=dict(size=2))
fig.update_layout(
    width=1024,
    height=768,
    scene=dict(
        xaxis=dict(range=xr),
        yaxis=dict(range=yr),
        zaxis=dict(range=zr),
        aspectmode='cube'
    )
)
fig.show()

filtered_indices = np.nonzero(mask)[0]

Removed 0 outlier points out of 10000 (0.00%)
x: min=-0.035323 max=-0.034646 span=0.000677 mean=-0.035014 std=0.000083
y: min=0.110261 max=0.110981 span=0.000719 mean=0.110711 std=0.000075
z: min=-0.132297 max=-0.131648 span=0.000649 mean=-0.132017 std=0.000081
Using standardized coordinates: False
Using scaling mode: symmetric
Ranges -> x:[-0.035344444, -0.034625202] y:[0.11026126, 0.11098051] z:[-0.13233203, -0.13161278]


# UMAP Hi-dim

In [13]:
hd_subset = X[subset_idx]

In [14]:
hd_subset.shape

(10000, 1536)

In [16]:
import umap
import plotly.express as px

hd_centered = hd_subset - hd_subset.mean(axis=0)
U, S, Vt = np.linalg.svd(hd_centered, full_matrices=False)
hd_whitened = hd_centered @ Vt.T / S

# UMAP
hd_umap_model = umap.UMAP(n_neighbors=15, n_components=2, metric='euclidean')
hd_umap_2d_embeddings = hd_umap_model.fit_transform(hd_whitened)
hd_umap_2d_embeddings

array([[10.613586 ,  4.572473 ],
       [10.722154 ,  4.7413826],
       [11.199226 ,  5.150993 ],
       ...,
       [10.678477 ,  4.746382 ],
       [ 7.2929125,  4.8433213],
       [10.973627 ,  5.034195 ]], dtype=float32)

In [None]:
hd_images_df = pd.DataFrame(hd_umap_2d_embeddings, columns=('x', 'y'))
hd_images_df['image'] = batch_embeddable_jpeg(
    subset_images, max_side=512, max_workers=16)
hd_images_df['idx'] = subset_idx

In [None]:
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from bokeh.palettes import Spectral10

output_notebook()

In [None]:
datasource = ColumnDataSource(hd_images_df)

plot_figure = figure(
    title='UMAP projection of the IIIF non-ocr dataset',
    width=1920,
    height=1080,
    tools=('pan, wheel_zoom, reset')
)

plot_figure.add_tools(HoverTool(tooltips="""
<div>
    <div>
        <img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
    </div>
    <div>
        <span style='font-size: 16px; color: #224499'>idx:</span>
        <span style='font-size: 18px'>@idx</span>
    </div>
</div>
"""))


plot_figure.scatter(
    'x',
    'y',
    source=datasource,
    line_alpha=0.5,
    fill_alpha=0.5,
    size=5,
    # color='color'
)
# show(plot_figure)

In [None]:
import os
from bokeh.embed import file_html
from bokeh.resources import INLINE

html = file_html(plot_figure, resources=INLINE, title="UMAP projection")
out_path = "iiif_no_text_umap_2d_vlm_embed_umap_2d_plot.html"
with open(out_path, "w", encoding="utf-8") as f:
    f.write(html)

print(f"Wrote {out_path} (INLINE resources).")
# Optional size check
print("File size (MB):", os.path.getsize(out_path) / (1024 * 1024))