In [15]:
# imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
import pandas as pd
import chroma

In [16]:
# Download Data for MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
train_kwargs = {"batch_size": 64}
test_kwargs = {"batch_size":1000}
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)

# We split the training pool into training and holdback for later sampling
train_size = int(0.5 * len(dataset1))
sample_from_size = len(dataset1) - train_size
train_dataset, sample_from_dataset = torch.utils.data.random_split(dataset1, [train_size, sample_from_size], generator=torch.Generator().manual_seed(42))

train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [17]:
# Setup our CNN to train on MNIST
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

device = torch.device("cpu")
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

def attach_forward_hook(model, array):
    return model.register_forward_hook(
        lambda model, input, output: array.append(output.data.detach().tolist())
    )

def infer(model, device, data_loader, resource_uris, label_classes, inference_classes):
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target, resource_uri in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            # why are we calculating loss here?
            test_loss += F.nll_loss(output, target, reduction="sum").item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            for resource_uri, label_class, inference_class in zip(resource_uri, target.data.detach().tolist(), pred.data.detach().flatten().tolist()):
                resource_uris.append(resource_uri)
                label_classes.append(str(label_class))
                inference_classes.append(str(inference_class))

    test_loss /= len(data_loader.dataset)

    print(
        "\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
        )
    )


# We modify the MNIST dataset to expose some information about the source data
# to allow us to uniquely identify an input in a way that we can recover it later
class CustomDataset(datasets.MNIST):
    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        resource_uri = f"{'train' if self.train else 't10k'}-images-idx3-ubyte-{index}"
        return img, target, resource_uri

In [18]:
# Train and test our model
epochs = 5
for epoch in range(1, epochs + 1):

    # Train
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10== 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )

    # Determine Loss on the test set
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction="sum").item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
        )
    )

    scheduler.step()

torch.save(model.state_dict(), "mnist_cnn.pt")



KeyboardInterrupt: 

In [None]:
# Load the pre-trained model
model = Net()
model.load_state_dict(torch.load("mnist_cnn.pt"))
model.eval()
model.to(device)

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [None]:
# Run Inference on all data and generate embeddings
inference_kwargs = {"batch_size": 1000}

train_embeddings = []
train_resource_uris = []
train_label_classes = []
train_inference_classes = []

sample_from_embeddings = []
sample_from_resource_uris = []
sample_from_label_classes = []
sample_from_inference_classes = []

train_mnist_data = CustomDataset("../data", train=True, transform=transform, download=True)
train_dataset, sample_from_dataset = torch.utils.data.random_split(train_mnist_data, [train_size, sample_from_size], generator=torch.Generator().manual_seed(42))

# from train
data_loader = torch.utils.data.DataLoader(train_dataset, **inference_kwargs)
hook = attach_forward_hook(model.fc2, train_embeddings)
infer(model, device, data_loader, train_resource_uris, train_label_classes, train_inference_classes)
hook.remove()

# from sample_from
data_loader = torch.utils.data.DataLoader(sample_from_dataset, **inference_kwargs)
attach_forward_hook(model.fc2, sample_from_embeddings)
infer(model, device, data_loader, sample_from_resource_uris, sample_from_label_classes, sample_from_inference_classes)

# remove one dimension from embeddings
train_embeddings = [item for sublist in train_embeddings for item in sublist]
sample_from_embeddings = [item for sublist in sample_from_embeddings for item in sublist]



Average loss: 0.0161, Accuracy: 29855/30000 (100%)


Average loss: 0.0505, Accuracy: 29606/30000 (99%)



In [62]:
from chroma.config import Settings
api = chroma.get_api(Settings(chroma_api_impl="rest",
                              chroma_server_host="localhost",
                              chroma_server_http_port="8000") )

print(api.heartbeat())

Running Chroma in client mode using REST to connect to remote server
1669936805800599795000


In [75]:
# Load data into Chroma
# api = chroma.get_api()


api.reset()
api.set_model_space("mnist")

api.add(
    embedding= train_embeddings,
    input_uri= train_resource_uris,
    dataset= "train",
    inference_class= train_inference_classes,
    label_class= train_label_classes,
    model_space= "mnist"
)
api.add(
    embedding= sample_from_embeddings,
    input_uri= sample_from_resource_uris,
    dataset= "test",
    inference_class= sample_from_inference_classes,
    label_class= sample_from_label_classes,
    model_space= "mnist"
)





True

In [64]:
print(api.count(model_space="mnist"))

60000


In [39]:
api.create_index(model_space="mnist")

True

In [76]:
print(api.fetch(limit=10, where={"model_space": "mnist"}))

  model_space                                  uuid  \
0       mnist  b42ae3f4-ac03-48b6-9aaf-6c446fc44981   
1       mnist  f9e5baf0-3d3c-4b0f-bcee-6fa64cb4e2fe   
2       mnist  d7413bff-d82b-406b-a605-c4500b5a43bd   
3       mnist  0589dff5-f03d-4207-86eb-ec1f029d638a   
4       mnist  5f6a34a3-f882-4ec1-836c-ab6eedf854c5   
5       mnist  04edec1a-7fc7-43c6-be17-cb5c565aed90   
6       mnist  de92d375-48fd-4cb0-be80-ef1292165fc9   
7       mnist  e181ef07-d805-4956-b0f4-a1d90df26ae3   
8       mnist  9a309430-b5c8-48ab-b7a5-6cd864a8fc5a   
9       mnist  d66f3cb2-54b9-4b16-8f54-67b0218a3ad8   

                                           embedding  \
0  [-16.076580047607422, -16.445514678955078, -16...   
1  [-10.801337242126465, -17.538055419921875, -14...   
2  [12.641288757324219, -20.95970344543457, -8.41...   
3  [-16.89393424987793, -18.083648681640625, -14....   
4  [-10.38875961303711, -2.6868276596069336, 8.27...   
5  [-13.484353065490723, 9.503548622131348, -10.3...   
6 

In [78]:
api.process(training_dataset_name="train", inference_dataset_name="test", model_space="mnist")

True

In [None]:
# Create an index and run ANN (commented out)
# api.create_index()
# results = api.get_nearest_neighbors(sample_from_embeddings[0], n_results=5)


In [79]:
# Get results back from Chroma
results = api.get_results(dataset_name="test", n_results=15000)
print(results)
# sample_from_crhoma_subset = [x for x in sample_from_dataset if x[2] in [y for y in results]]

                                   0
0       train-images-idx3-ubyte-7850
1      train-images-idx3-ubyte-19364
2      train-images-idx3-ubyte-18676
3      train-images-idx3-ubyte-38518
4       train-images-idx3-ubyte-1551
...                              ...
13127  train-images-idx3-ubyte-57165
13128  train-images-idx3-ubyte-45672
13129  train-images-idx3-ubyte-33034
13130   train-images-idx3-ubyte-9856
13131  train-images-idx3-ubyte-25398

[13132 rows x 1 columns]


In [None]:
# Randomly sample 15k results
random_sample_from_dataset = torch.utils.data.Subset(sample_from_dataset, torch.randperm(len(sample_from_dataset))[:15000])

In [None]:
# Train from scratch on the original cut and the sampled results

# Create a dataloader which is a combination of the original cut and the sampled results
train_sampled_dataset = torch.utils.data.ConcatDataset([train_dataset, sample_from_crhoma_subset])
train_sampled_loader = torch.utils.data.DataLoader(train_sampled_dataset, **train_kwargs)

sampled_model = Net()
optimizer = optim.Adadelta(sampled_model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

# Train and test our model
epochs = 5
for epoch in range(1, epochs + 1):

    # Train
    sampled_model.train()
    # emumerate through the dataloader
    for batch_idx, (data, target, _) in enumerate(train_sampled_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = sampled_model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10== 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_sampled_loader.dataset),
                    100.0 * batch_idx / len(train_sampled_loader),
                    loss.item(),
                )
            )

    # Determine Loss on the test set
    sampled_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = sampled_model(data)
            test_loss += F.nll_loss(output, target, reduction="sum").item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
        )
    )

    scheduler.step()

torch.save(sampled_model.state_dict(), "mnist_cnn_sampled.pt")



KeyboardInterrupt: 

In [None]:
# Train from scratch on the original cut and the sampled results

# Create a dataloader which is a combination of the original cut and the sampled results
train_random_dataset = torch.utils.data.ConcatDataset([train_dataset, random_sample_from_dataset])
train_random_loader = torch.utils.data.DataLoader(train_random_dataset, **train_kwargs)

random_model = Net()
optimizer = optim.Adadelta(random_model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

# Train and test our model
epochs = 5
for epoch in range(1, epochs + 1):

    # Train
    random_model.train()
    # emumerate through the dataloader
    for batch_idx, (data, target, _) in enumerate(train_random_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = random_model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10== 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_random_loader.dataset),
                    100.0 * batch_idx / len(train_random_loader),
                    loss.item(),
                )
            )

    # Determine Loss on the test set
    random_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = random_model(data)
            test_loss += F.nll_loss(output, target, reduction="sum").item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
        )
    )

    scheduler.step()

torch.save(random_model.state_dict(), "mnist_cnn_random.pt")


Test set: Average loss: 0.0474, Accuracy: 9844/10000 (98%)


Test set: Average loss: 0.0380, Accuracy: 9883/10000 (99%)


Test set: Average loss: 0.0372, Accuracy: 9875/10000 (99%)


Test set: Average loss: 0.0334, Accuracy: 9895/10000 (99%)


Test set: Average loss: 0.0315, Accuracy: 9899/10000 (99%)

