# viz

> vizualization routines


**NOTE:** Lazy imports throughout

In [None]:
#| default_exp viz

In [None]:
#| hide
from nbdev.showdoc import *


In [None]:
#| export
import torch
import numpy as np
import wandb
import gc

## UMAP

In [None]:
#| export
def cpu_umap_project(embeddings, n_components=3, n_neighbors=15, min_dist=0.1, random_state=42):
    "Project embeddings to n_components dimensions via UMAP (on CPU)"
    import umap
    if isinstance(embeddings, torch.Tensor): embeddings = embeddings.detach().cpu().numpy()
    reducer = umap.UMAP(n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist, random_state=random_state)
    return reducer.fit_transform(embeddings)

In [None]:
#| export
def cuml_umap_project(embeddings, n_components=3, n_neighbors=15, min_dist=0.1, random_state=42):
    "Project embeddings to n_components dimensions via cuML UMAP (GPU)"
    from cuml import UMAP
    import cupy as cp
    if isinstance(embeddings, torch.Tensor): embeddings = cp.from_dlpack(embeddings.detach())
    reducer = UMAP(n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist, random_state=random_state)
    coords = reducer.fit_transform(embeddings)
    del reducer
    return cp.asnumpy(coords)  # back to numpy for plotly

In [None]:
#| export
def umap_project(embeddings, **kwargs): 
    "Calls one of two preceding UMAP routines based on device availability."
    try:
        coords = cuml_umap_project(embeddings, **kwargs)
    except torch.cuda.OutOfMemoryError:
        torch.cuda.empty_cache()
        coords = cpu_umap_project(embeddings, **kwargs)
    return coords

## PCA

In [None]:
#| export
def cuml_pca_project(embeddings, n_components=3):
    "Project embeddings to n_components dimensions via cuML PCA (GPU)"
    from cuml import PCA
    import cupy as cp
    if isinstance(embeddings, torch.Tensor): embeddings = cp.from_dlpack(embeddings.detach())
    coords = PCA(n_components=n_components).fit_transform(embeddings)
    return cp.asnumpy(coords)

In [None]:
#| export
def cpu_pca_project(embeddings, n_components=3):
    "Project embeddings to n_components dimensions via sklearn PCA (CPU)"
    from sklearn.decomposition import PCA
    if isinstance(embeddings, torch.Tensor): embeddings = embeddings.detach().cpu().numpy()
    return PCA(n_components=n_components).fit_transform(embeddings)

In [None]:
#| export
def pca_project(embeddings, **kwargs):
    "Calls GPU or CPU PCA based on availability"
    try:
        return cuml_pca_project(embeddings, **kwargs)
    except:
        return cpu_pca_project(embeddings, **kwargs)

## 3D Plotly Scatterplots

In [None]:
#| export
def plot_embeddings_3d(coords, num_tokens, color_by='pairs', file_idx=None, title='Embeddings', debug=False):
    "3D scatter plot of embeddings. color_by: 'none', 'file', or 'pair'"
    import plotly.graph_objects as go
    n = len(coords)
    if debug: print(" plot_embeddings_3d: n =",n)
    
    if color_by == 'none':     colors = ['blue'] * n
    elif color_by == 'file':   colors = file_idx.tolist() if file_idx is not None else ['blue'] * n
    elif color_by == 'pairs':
        n_pairs = n // 2
        pair_colors = [f'rgb({np.random.randint(0,256)},{np.random.randint(0,256)},{np.random.randint(0,256)})' for _ in range(n_pairs)]
        colors = [pair_colors[i // 2] for i in range(n)]  # pairs are adjacent in index-space 
    else: raise ValueError(f"Unknown color_by: {color_by}")

    hover_text = [f"file_id: {int(fid)}" for fid in file_idx] if file_idx is not None else None
    if color_by == 'pairs':
        hover_text = [f"pair {i//2}" for i in range(n)] if hover_text is None else [f"{s}, pair {i//2}" for i, s in enumerate(hover_text)]

    
    fig = go.Figure(data=[go.Scatter3d(
        x=coords[:,0], y=coords[:,1], z=coords[:,2],
        mode='markers', 
        marker=dict(size=4, color=colors, colorscale='Viridis' if color_by != 'pairs' else None, opacity=0.8),
        hovertext=hover_text, hoverinfo='x+y+z+text' if hover_text else 'x+y+z'
    )])
    title = title + f', n={n}'
    fig.update_layout(title=title, margin=dict(l=0, r=0, b=0, t=30))
    return fig


## Main Routine 

Calls the preceding routines

In [None]:
#| export
def _make_emb_viz(zs, num_tokens, epoch=-1, title='Embeddings', do_umap=True, file_idx=None):
    "visualize embeddings, projected"
    umap_fig = None
    if do_umap:
        coords = umap_project(zs)
        umap_fig = plot_embeddings_3d(coords, num_tokens, title=title+f' (UMAP), epoch {epoch}', file_idx=file_idx)
    if torch.cuda.is_available(): torch.cuda.synchronize() # cleanup before PCA or else you get CUDA errors
    gc.collect()
    coords = pca_project(zs)
    pca_fig = plot_embeddings_3d(coords, num_tokens, title=title+f' (PCA), epoch {epoch}', file_idx=file_idx)
    if wandb.run is not None: 
        if do_umap:
            wandb.log({f"{title} UMAP": wandb.Html(umap_fig.to_html()), f"{title} PCA": wandb.Html(pca_fig.to_html())}, step=epoch)
        else:
            wandb.log({f"{title} PCA": wandb.Html(pca_fig.to_html())}, step=epoch)
    if torch.cuda.is_available(): torch.cuda.synchronize() # cleanup again
    gc.collect()
    return pca_fig, umap_fig


In [None]:
#| export
def _subsample(data, indices, max_points):
    "Subsample data and indices together, in pairs"
    perm1 = torch.randperm(len(data)//2)[:max_points//2]
    perm2 = perm1 + len(data)//2
    perm = torch.cat([perm1,perm2])
    return data[perm], indices[perm] if indices is not None else None

In [None]:
#| export
def make_emb_viz(zs,  
                num_tokens, epoch=-1, 
                model=None, 
                title='Embeddings', 
                max_points=5000, 
                pmask=None, 
                file_idx=None, 
                do_umap=True):
    "this is the main routine, showing different groups of embeddings"
    device = zs.device
    if model is not None: model.to('cpu')
    torch.cuda.empty_cache()
    
    if file_idx is not None and file_idx.shape[0] < zs.shape[0]:
        #file_idx = file_idx.repeat(2).repeat_interleave(num_tokens).to(device)
        file_idx = file_idx.repeat_interleave(zs.shape[0]//file_idx.shape[0]).to(device)

    # CLS tokens
    cls_tokens = zs[::num_tokens]
    cls_file_idx = file_idx[::num_tokens] if file_idx is not None else None
    cls_pca_fig, cls_umap_fig = _make_emb_viz(cls_tokens, num_tokens, epoch=epoch, title='CLS Tokens '+title, file_idx=cls_file_idx, do_umap=do_umap)
    
    # Patches (non-CLS)
    patch_mask = torch.arange(len(zs)) % num_tokens != 0
    patch_only = zs[patch_mask]
    patch_file_idx = file_idx[patch_mask] if file_idx is not None else None
    
    if pmask is not None:
        patch_pmask = pmask[:, 1:].flatten().bool()
        print(f"Non-empty patches: {patch_pmask.sum()}/{len(patch_pmask)} ({patch_pmask.float().mean()*100:.1f}%)")
        
        # Non-empty patches
        valid_patches, valid_file_idx = patch_only[patch_pmask], (patch_file_idx[patch_pmask] if patch_file_idx is not None else None)
        rnd_patches, rnd_file_idx = _subsample(valid_patches, valid_file_idx, max_points)
        patch_pca_fig, patch_umap_fig = _make_emb_viz(rnd_patches, num_tokens, epoch=epoch, title='RND Patches '+title, file_idx=rnd_file_idx, do_umap=do_umap)
        
        # Empty patches
        empty_patches, empty_file_idx = patch_only[~patch_pmask], (patch_file_idx[~patch_pmask] if patch_file_idx is not None else None)
        rnd_empty, rnd_empty_idx = _subsample(empty_patches, empty_file_idx, max_points)
        empty_pca_fig = _make_emb_viz(rnd_empty, num_tokens, epoch=epoch, title='RND Empty Patches '+title, do_umap=False, file_idx=rnd_empty_idx)
    
    if model is not None: model.to(device)
    figs = {'cls_pca_fig':cls_pca_fig, 'cls_umap_fig':cls_umap_fig, 'patch_pca_fig':patch_pca_fig, 'patch_umap_fig':patch_umap_fig, 'empty_pca_fig': empty_pca_fig}
    return figs

Testing visualization:

In [None]:
#| eval: false
import plotly.io as pio
pio.renderers.default = 'notebook'

bs, num_tokens, dim = 32, 65, 256
z1 = torch.randn([bs, num_tokens, dim])
file_idx = torch.arange(bs)
z2 = z1 +  0.1*torch.randn([bs, num_tokens, dim]) # z2 is slightly shifted from z1

#zs = torch.cat([z1, z2], dim=0).view(-1, dim)  # flatten to [64*65, 256]
zs = torch.stack([z1, z2], dim=1).reshape(-1, z1.shape[-1])

file_idx = file_idx.repeat_interleave(zs.shape[0]//file_idx.shape[0])  

pmask = torch.ones([2*bs, num_tokens])  # all ones
pmask[:, 30:] = 0  # mark roughly half the patches as empty

figs = make_emb_viz(zs,  num_tokens, title='testing', pmask=pmask, file_idx=file_idx, do_umap=False) 
figs['patch_pca_fig'].show()

NameError: name 'torch' is not defined

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()