# viz

> vizualization routines


**NOTE:** Lazy imports throughout

In [None]:
#| default_exp viz

In [None]:
#| hide
from nbdev.showdoc import *


In [None]:
#| export
import torch
import numpy as np
import wandb
import gc

## UMAP

In [None]:
#| export
def cpu_umap_project(embeddings, n_components=3, n_neighbors=15, min_dist=0.1, random_state=42):
    "Project embeddings to n_components dimensions via UMAP (on CPU)"
    import umap
    if isinstance(embeddings, torch.Tensor): embeddings = embeddings.detach().cpu().numpy()
    reducer = umap.UMAP(n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist, random_state=random_state)
    return reducer.fit_transform(embeddings)

In [None]:
#| export
def cuml_umap_project(embeddings, n_components=3, n_neighbors=15, min_dist=0.1, random_state=42):
    "Project embeddings to n_components dimensions via cuML UMAP (GPU)"
    from cuml import UMAP
    import cupy as cp
    if isinstance(embeddings, torch.Tensor): embeddings = cp.from_dlpack(embeddings.detach())
    reducer = UMAP(n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist, random_state=random_state)
    coords = reducer.fit_transform(embeddings)
    del reducer
    return cp.asnumpy(coords)  # back to numpy for plotly

In [None]:
#| export
def umap_project(embeddings, **kwargs): 
    "Calls one of two preceding UMAP routines based on device availability."
    try:
        coords = cuml_umap_project(embeddings, **kwargs)
    except torch.cuda.OutOfMemoryError:
        torch.cuda.empty_cache()
        coords = cpu_umap_project(embeddings, **kwargs)
    return coords

## PCA

In [None]:
#| export
def cuml_pca_project(embeddings, n_components=3):
    "Project embeddings to n_components dimensions via cuML PCA (GPU)"
    from cuml import PCA
    import cupy as cp
    if isinstance(embeddings, torch.Tensor): embeddings = cp.from_dlpack(embeddings.detach())
    coords = PCA(n_components=n_components).fit_transform(embeddings)
    return cp.asnumpy(coords)

In [None]:
#| export
def cpu_pca_project(embeddings, n_components=3):
    "Project embeddings to n_components dimensions via sklearn PCA (CPU)"
    from sklearn.decomposition import PCA
    if isinstance(embeddings, torch.Tensor): embeddings = embeddings.detach().cpu().numpy()
    return PCA(n_components=n_components).fit_transform(embeddings)

In [None]:
#| export
def pca_project(embeddings, **kwargs):
    "Calls GPU or CPU PCA based on availability"
    try:
        return cuml_pca_project(embeddings, **kwargs)
    except:
        return cpu_pca_project(embeddings, **kwargs)

## 3D Plotly Scatterplots

In [None]:
#| export
def plot_embeddings_3d(coords, num_tokens, color_by='pairs', file_idx=None, title='Embeddings', debug=False):
    "3D scatter plot of embeddings. color_by: 'none', 'file', or 'pair'"
    import plotly.graph_objects as go
    n = len(coords)
    if debug: print(" plot_embeddings_3d: n =",n)
    
    if color_by == 'none':     colors = ['blue'] * n
    elif color_by == 'file':   colors = file_idx.tolist() if file_idx is not None else ['blue'] * n
    elif color_by == 'pairs':
        n_pairs = n // 2
        pair_colors = [f'rgb({np.random.randint(0,256)},{np.random.randint(0,256)},{np.random.randint(0,256)})' for _ in range(n_pairs)]
        colors = [pair_colors[i % n_pairs] for i in range(n)]
    else: raise ValueError(f"Unknown color_by: {color_by}")

    hover_text = [f"fileid: {int(fid)}" for fid in file_idx] if file_idx is not None else None
    
    fig = go.Figure(data=[go.Scatter3d(
        x=coords[:,0], y=coords[:,1], z=coords[:,2],
        mode='markers', 
        marker=dict(size=4, color=colors, colorscale='Viridis' if color_by != 'pairs' else None, opacity=0.8),
        hovertext=hover_text,
        hoverinfo='text' if hover_text else 'x+y+z'
    )])
    title = title + f', n={n}'
    fig.update_layout(title=title, margin=dict(l=0, r=0, b=0, t=30))
    return fig


## Main Routine 

Calls the preceding routines

In [None]:
#| export
def _make_emb_viz(zs, epoch, title='Embeddings', do_umap=True, file_idx=None):
    "visualize embeddings, projected"
    fig = None
    if do_umap:
        coords = umap_project(zs)
        fig = plot_embeddings_3d(coords, title=title+f' (UMAP), epoch {epoch}', file_idx=file_idx)
    torch.cuda.synchronize() # cleanup before PCA or else you get CUDA errors
    gc.collect()
    coords = pca_project(zs)
    fig2 = plot_embeddings_3d(coords, title=title+f' (PCA), epoch {epoch}', file_idx=file_idx)
    if do_umap:
        wandb.log({f"{title} UMAP": wandb.Html(fig.to_html()), f"{title} PCA": wandb.Html(fig2.to_html())}, step=epoch)
    else:
        wandb.log({f"{title} PCA": wandb.Html(fig2.to_html())}, step=epoch)
    torch.cuda.synchronize() # cleanup again
    gc.collect()


In [None]:
#| export
def _subsample(data, indices, max_points):
    "Subsample data and indices together"
    perm = torch.randperm(len(data))[:max_points]
    return data[perm], indices[perm] if indices is not None else None

In [None]:
#| export
def make_emb_viz(zs, model, num_tokens, epoch, title='Embeddings', max_points=8192, pmask=None, file_idx=None):
    "this is the main routine, showing different groups of embeddings"
    device = zs.device
    model.to('cpu')
    torch.cuda.empty_cache()
    
    # CLS tokens
    cls_tokens = zs[::num_tokens]
    cls_file_idx = file_idx[::num_tokens] if file_idx is not None else None
    _make_emb_viz(cls_tokens, num_tokens, epoch, title='CLS Tokens'+title, file_idx=cls_file_idx)
    
    # Patches (non-CLS)
    patch_mask = torch.arange(len(zs)) % num_tokens != 0
    patch_only = zs[patch_mask]
    patch_file_idx = file_idx[patch_mask] if file_idx is not None else None
    
    if pmask is not None:
        patch_pmask = pmask[:, 1:].flatten().bool()
        print(f"Non-empty patches: {patch_pmask.sum()}/{len(patch_pmask)} ({patch_pmask.float().mean()*100:.1f}%)")
        
        # Non-empty patches
        valid_patches, valid_file_idx = patch_only[patch_pmask], (patch_file_idx[patch_pmask] if patch_file_idx is not None else None)
        rnd_patches, rnd_file_idx = _subsample(valid_patches, valid_file_idx, max_points)
        _make_emb_viz(rnd_patches, num_tokens, epoch, title='RND Patches'+title, file_idx=rnd_file_idx)
        
        # Empty patches
        empty_patches, empty_file_idx = patch_only[~patch_pmask], (patch_file_idx[~patch_pmask] if patch_file_idx is not None else None)
        rnd_empty, rnd_empty_idx = _subsample(empty_patches, empty_file_idx, max_points)
        _make_emb_viz(rnd_empty, num_tokens, epoch, title='RND Empty Patches'+title, do_umap=False, file_idx=rnd_empty_idx)
    
    model.to(device)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()