![umap in atlas](https://docs.nomic.ai/img/umap-with-nomic-atlas.png)

# Visualizing MNIST Training Dynamics with Nomic Atlas

UMAP is available as a projection in [Nomic Atlas](https://atlas.nomic.ai), which creates interactive maps of your data with AI analysis, vector search APIs, and additional resources like topic label generation.

![mnist embeddings in atlas](https://assets.nomicatlas.com/mnist-training-embeddings-umap-short.gif)

Nomic Atlas automatically generates embeddings for your data and allows you to explore large datasets in a web browser. Atlas provides:

* In-browser analysis of your UMAP data with the [Atlas Analyst](https://docs.nomic.ai/atlas/data-maps/atlas-analyst)
* Vector search over your UMAP data using the [Nomic API](https://docs.nomic.ai/atlas/data-maps/guides/vector-search-over-your-data)
* Interactive features like zooming, recoloring, searching, and filtering in the [Nomic Atlas data map](https://docs.nomic.ai/atlas/data-maps/controls)
* Scalability for millions of data points
* Rich information display on hover
* Shareable UMAPs via URL links to your embeddings and data maps in Atlas

This example demonstrates how to use [Atlas](https://docs.nomic.ai/atlas/embeddings-and-retrieval/guides/using-umap-with-atlas) to visualize the training dynamics of your neural network using embeddings and UMAP.


## Setup

1. Get the python package with `pip instll nomic`

2. Get A Nomic API key [here](https://atlas.nomic.ai/cli-login)

3. Run `nomic login nk-...` in a terminal window or

```python
import nomic
nomic.login('nk-...')
```

at the top of your code

We set up some imports, hyperparameters, and a helper function.

In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import time
from PIL import Image
import base64
import io

NUM_EPOCHS = 20
LEARNING_RATE = 3e-6
BATCH_SIZE = 128
NUM_VIS_SAMPLES = 2000
EMBEDDING_DIM = 128
ATLAS_DATASET_NAME = "mnist_training_embeddings"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\n")

def tensor_to_html(tensor):
    """Helper function to convert image tensors to HTML for rendering in Nomic Atlas"""
    # Denormalize the image
    img = torch.clamp(tensor.clone().detach().cpu().squeeze(0) * 0.3081 + 0.1307, 0, 1)
    img_pil = Image.fromarray((img.numpy() * 255).astype('uint8'), mode='L')
    buffered = io.BytesIO()
    img_pil.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f'<img src="data:image/png;base64,{img_str}" width="28" height="28">'

Using device: cpu



## Download Data

We setup a CNN image classifier for MNIST data:

In [None]:
class MNIST_CNN(nn.Module):
    def __init__(self, embedding_dim=128):
        super(MNIST_CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, embedding_dim)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(embedding_dim, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        embeddings = self.relu3(self.fc1(x))
        output = self.fc2(embeddings)
        return output, embeddings

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

persistent_workers_flag = True if device.type not in ['mps', 'cpu'] else False
num_workers_val = 2 if persistent_workers_flag else 0
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers_val, persistent_workers=persistent_workers_flag if num_workers_val > 0 else False)
vis_indices = list(range(NUM_VIS_SAMPLES))
vis_subset = Subset(test_dataset, vis_indices)
test_loader_for_vis = DataLoader(vis_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers_val, persistent_workers=persistent_workers_flag if num_workers_val > 0 else False)
print(f"Training on {len(train_dataset)} samples, visualizing {NUM_VIS_SAMPLES} test samples per epoch.\n")

Training on 60000 samples, visualizing 2000 test samples per epoch.



## Collect Embeddings During Training

We save embeddings from the last layer at each iteration to track the change in the model's output distribution over the course of training. This is what Atlas is uniquely well suited to visualize.

In [None]:
model = MNIST_CNN(embedding_dim=EMBEDDING_DIM).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
all_embeddings_list = []
all_metadata_list = []
all_images_html = []
overall_start_time = time.time()
for epoch in range(NUM_EPOCHS):
    epoch_start_time = time.time()
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        outputs, _ = model(data)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if (batch_idx + 1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Batch [{batch_idx+1}/{len(train_loader)}], Avg Loss: {running_loss / 200:.4f}')
            running_loss = 0.0
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} training finished in {time.time() - epoch_start_time:.2f}s.\n")
    model.eval()
    vis_samples_collected_this_epoch = 0
    image_offset_in_vis_subset = 0 
    with torch.no_grad():
        for data, target in test_loader_for_vis:
            data, target = data.to(device), target.to(device)
            _, embeddings_batch = model(data)
            for i in range(embeddings_batch.size(0)):
                original_idx_in_subset = image_offset_in_vis_subset + i 
                if original_idx_in_subset >= NUM_VIS_SAMPLES:
                    continue
                all_embeddings_list.append(embeddings_batch[i].cpu().numpy())                
                img_html = tensor_to_html(data[i])
                all_images_html.append(img_html)
                all_metadata_list.append({
                    'id': f'vis_img_{original_idx_in_subset}_epoch_{epoch}',
                    'epoch': epoch,
                    'label': f'Digit: {target[i].item()}',
                    'vis_sample_idx': original_idx_in_subset,
                    'image_html': img_html
                })
                vis_samples_collected_this_epoch += 1
            image_offset_in_vis_subset += embeddings_batch.size(0)
            if vis_samples_collected_this_epoch >= NUM_VIS_SAMPLES: 
                break
    print(f"Collected {vis_samples_collected_this_epoch} embeddings for visualization in epoch {epoch+1}.\n")
total_script_time = time.time() - overall_start_time
print(f"Total training and embedding extraction time: {total_script_time:.2f}s\n")


Epoch [1/20], Batch [200/469], Avg Loss: 2.2656
Epoch [1/20], Batch [400/469], Avg Loss: 2.1646
Epoch 1/20 training finished in 15.08s.

Collected 2000 embeddings for visualization in epoch 1.

Epoch [2/20], Batch [200/469], Avg Loss: 1.9691
Epoch [2/20], Batch [400/469], Avg Loss: 1.7807
Epoch 2/20 training finished in 14.78s.

Collected 2000 embeddings for visualization in epoch 2.

Epoch [3/20], Batch [200/469], Avg Loss: 1.5193
Epoch [3/20], Batch [400/469], Avg Loss: 1.3360
Epoch 3/20 training finished in 14.75s.

Collected 2000 embeddings for visualization in epoch 3.

Epoch [4/20], Batch [200/469], Avg Loss: 1.1200
Epoch [4/20], Batch [400/469], Avg Loss: 0.9892
Epoch 4/20 training finished in 14.69s.

Collected 2000 embeddings for visualization in epoch 4.

Epoch [5/20], Batch [200/469], Avg Loss: 0.8479
Epoch [5/20], Batch [400/469], Avg Loss: 0.7668
Epoch 5/20 training finished in 14.80s.

Collected 2000 embeddings for visualization in epoch 5.

Epoch [6/20], Batch [200/469],

## Create Atlas Dataset and Upload Data

In [None]:
from nomic import AtlasDataset

dataset = AtlasDataset("mnist-training-embeddings")
dataset.add_data(data=all_metadata_list, embeddings=np.array(all_embeddings_list))

[32m2025-05-11 16:50:19.819[0m | [1mINFO    [0m | [36mnomic.dataset[0m:[36m_create_project[0m:[36m867[0m - [1mOrganization name: `nomic`[0m
[32m2025-05-11 16:50:20.486[0m | [1mINFO    [0m | [36mnomic.dataset[0m:[36m_create_project[0m:[36m895[0m - [1mCreating dataset `mnist-training-embeddings`[0m
100%|██████████| 8/8 [00:29<00:00,  3.71s/it]
[32m2025-05-11 16:50:50.784[0m | [1mINFO    [0m | [36mnomic.dataset[0m:[36m_add_data[0m:[36m1702[0m - [1mUpload succeeded.[0m


## Create Data Map

In [35]:
dataset.create_index(projection='umap', topic_model=False) 

[32m2025-05-11 16:50:52.419[0m | [1mINFO    [0m | [36mnomic.dataset[0m:[36mcreate_index[0m:[36m1289[0m - [1mCreated map `0196c11d-6611-46f0-38ab-b0e9778ad1fb` in dataset `nomic/mnist-training-embeddings`: https://atlas.nomic.ai/data/nomic/mnist-training-embeddings[0m


Your map will be available in your [Atlas Dashboard](https://atlas.nomic.ai/data).