# Chroma with MNIST

This notebook is an illustrative example of how to use Chroma with a simple iamge classifier, on the MNIST digits dataset.

## Setup Chroma

In [None]:
! rm -rf chroma
# Clone and Install Chroma as a python package
! git clone --branch anton/mnist-notebook-colab https://oauth2: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@github.com/chroma-core/chroma.git
! cd chroma/ && python -m pip install --upgrade pip && pip install .

# Setup the digits classifier

In [1]:
# Import what we need to create the model and dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
import pandas as pd

In [19]:
# We modify the MNIST Dataset class to expose some information about the source data
# to allow us to uniquely identify an input.
class CustomDataset(datasets.MNIST):
    def __getitem__(self, index):
        img, label = super().__getitem__(index) # Existing loader returns the img and the label
        resource_uri = f"{'train' if self.train else 't10k'}-images-idx3-ubyte-{index}"
        return img, label, resource_uri

In [20]:
# Normalizing transform for MNIST data
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

# Get the MNIST training data (wrapped in our custom dataset)
train_mnist_data = CustomDataset("../data", train=True, transform=transform, download=True)

# Split the training data into equal 'training' and 'unlabeled' sets
train_size = len(train_mnist_data) // 2
unlabeled_size = len(train_mnist_data) - train_size
train_dataset, unlabeled_dataset = torch.utils.data.random_split(train_mnist_data, [train_size, unlabeled_size], generator=torch.Generator().manual_seed(42))


In [23]:
# A simple feed-forward CNN clasifier, with two conv. layers and two fully connected layers. 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        output = x.argmax(dim=1, keepdim=True)
        return output

In [24]:
# Set the device
device = 'cpu'

# Load up the pretrained model 
model = Net()
model.load_state_dict(torch.load("mnist_cnn.pt"))
model.eval()
model.to(device)

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [25]:
# Create a forward hook class to extract and store embeddings from a supplied model layer
class EmbeddingHook:
  def __init__(self, module):
    self.hook = module.register_forward_hook(self.hook_fn)

  def hook_fn(self, module, input, output):
    self.embeddings = output.detach().tolist()

  def __del__(self):
    self.hook.remove()

# Attach the embedding hook to the last fully connected layer before softmax
embedding_hook = EmbeddingHook(model.fc2)

In [7]:
# Import Chroma to our environment
from chroma import chroma
# Set up Chroma
chroma_client = chroma.get_api()

Running Chroma using direct local API.
Using DuckDB in-memory for database. Data will be transient.


In [28]:
from tqdm.notebook import tqdm

# Send training data to Chroma
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64)

with torch.no_grad():
    for img, label, uri in tqdm(train_dataloader):
        inference_class = model(img)
        chroma_client.add_training(
            embedding=embedding_hook.embeddings,
            input_uri=list(uri),
            inference_class=inference_class.detach().flatten().tolist(),
            label_class=label.detach().tolist(),
        )

  0%|          | 0/469 [00:00<?, ?it/s]

In [29]:
# Send unlabeled data to Chroma
unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=64)

with torch.no_grad():
    for img, label, uri in tqdm(unlabeled_dataloader):
        inference_class = model(img)
        chroma_client.add_unlabeled(
            embedding=embedding_hook.embeddings,
            input_uri=list(uri),
            inference_class=inference_class.detach().flatten().tolist(),
        )

  0%|          | 0/469 [00:00<?, ?it/s]

In [31]:
chroma_client.process()

time to fetch 30000 embeddings:  0.2687249183654785
time to fetch 30000 embeddings:  0.1871638298034668
time to fetch 30000 embeddings:  0.11549115180969238


True

In [32]:
chroma_client.get_results()

['train-images-idx3-ubyte-58329',
 'train-images-idx3-ubyte-12741',
 'train-images-idx3-ubyte-32247',
 'train-images-idx3-ubyte-47355',
 'train-images-idx3-ubyte-31713',
 'train-images-idx3-ubyte-38586',
 'train-images-idx3-ubyte-40612',
 'train-images-idx3-ubyte-31579',
 'train-images-idx3-ubyte-37758',
 'train-images-idx3-ubyte-46316',
 'train-images-idx3-ubyte-43433',
 'train-images-idx3-ubyte-40893',
 'train-images-idx3-ubyte-3137',
 'train-images-idx3-ubyte-26380',
 'train-images-idx3-ubyte-1512',
 'train-images-idx3-ubyte-6102',
 'train-images-idx3-ubyte-19272',
 'train-images-idx3-ubyte-20861',
 'train-images-idx3-ubyte-37668',
 'train-images-idx3-ubyte-38408',
 'train-images-idx3-ubyte-25230',
 'train-images-idx3-ubyte-7199',
 'train-images-idx3-ubyte-24587',
 'train-images-idx3-ubyte-41852',
 'train-images-idx3-ubyte-1758',
 'train-images-idx3-ubyte-8415',
 'train-images-idx3-ubyte-15804',
 'train-images-idx3-ubyte-56774',
 'train-images-idx3-ubyte-35260',
 'train-images-idx3-