In [1]:
import torch as th
import torch.nn as nn
from torchvision import datasets, transforms

In [2]:
# get the minst dataset
def get_mnist_data():
    # load the data
    mnist_train = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
    mnist_test = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())

    # create the data loaders
    train_loader = th.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)
    test_loader = th.utils.data.DataLoader(mnist_test, batch_size=64, shuffle=False)

    return train_loader, test_loader

In [3]:
class MNISTEmbedding(nn.Module):
    
    # To a 2 dimensional space
    # Use a two-stack of convolutional layers
    def __init__(self, input_size=28, channels_hidden=32, mlp_hidden = 128):
        super(MNISTEmbedding, self).__init__()
        # [btach, 1, x, y]
        # [batch, 1, x, y] -> [batch, channels_hidden, x, y]
        self.conv1 = nn.Conv2d(1, channels_hidden, kernel_size=3, stride=1, padding=1)
        # [batch, channels_hidden, x, y] -> [batch, channels_hidden, x, y]
        self.conv2 = nn.Conv2d(channels_hidden, channels_hidden, kernel_size=3, stride=1, padding=1)
        # [batch, channels_hidden, x, y] -> [batch, channels_hidden * x * y]
        self.flatten = nn.Flatten()
        # [batch, channels_hidden * x * y] -> [batch, 2]
        self.mlp = nn.Sequential(
            nn.Linear(channels_hidden * (input_size ** 2), mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, 2)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.mlp(x)
        return x
    

In [4]:
class MNISTModel(nn.Module):
    
    def __init__(self, input_size=28, channels_hidden=32, mlp_hidden=128, embedding_to_result_hidden = 32):
        super(MNISTModel, self).__init__()
        self.embedding = MNISTEmbedding(input_size, channels_hidden, mlp_hidden)
        self.lc_in = nn.Linear(2, embedding_to_result_hidden)
        self.relu = nn.ReLU()
        self.lc_out = nn.Linear(embedding_to_result_hidden, 10)
    
    def forward(self, x):
        x = self.embedding(x)
        x = self.lc_in(x)
        x = self.relu(x)
        x = self.lc_out(x)
        return x

In [5]:
# train the model
device = "mps"
train_loader, test_loader = get_mnist_data()
model = MNISTModel().to(device=device, dtype=th.float32)
optimizer = th.optim.Adam(model.parameters(), lr=1e-4)

In [6]:
epochs = 20
logging_steps = 400

In [7]:
from tqdm.notebook import tqdm, trange

In [8]:
for epoch in trange(epochs):
    for i, (x, y) in enumerate(tqdm(train_loader)):
        x = x.to(device=device, dtype=th.float32)
        y = y.to(device=device, dtype=th.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = th.nn.functional.cross_entropy(y_pred, y.long())
        loss.backward()
        optimizer.step()

        if i % logging_steps == 0:
            print(f"Epoch {epoch}, step {i}, loss {loss.item()}")

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 0, step 0, loss 2.361536979675293
Epoch 0, step 400, loss 1.0303179025650024
Epoch 0, step 800, loss 0.6761588454246521


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 1, step 0, loss 0.2868022322654724
Epoch 1, step 400, loss 0.30871737003326416
Epoch 1, step 800, loss 0.25503405928611755


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 2, step 0, loss 0.19075971841812134
Epoch 2, step 400, loss 0.2534377872943878
Epoch 2, step 800, loss 0.28929153084754944


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 3, step 0, loss 0.21914570033550262
Epoch 3, step 400, loss 0.10021018981933594
Epoch 3, step 800, loss 0.33594009280204773


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 4, step 0, loss 0.14676645398139954
Epoch 4, step 400, loss 0.09709945321083069
Epoch 4, step 800, loss 0.19906559586524963


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 5, step 0, loss 0.22849851846694946
Epoch 5, step 400, loss 0.22098049521446228
Epoch 5, step 800, loss 0.3400810658931732


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 6, step 0, loss 0.19728060066699982
Epoch 6, step 400, loss 0.27697843313217163
Epoch 6, step 800, loss 0.19847756624221802


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 7, step 0, loss 0.1521684229373932
Epoch 7, step 400, loss 0.08726197481155396
Epoch 7, step 800, loss 0.14757946133613586


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 8, step 0, loss 0.04396019130945206
Epoch 8, step 400, loss 0.07300981879234314
Epoch 8, step 800, loss 0.2020092010498047


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 9, step 0, loss 0.3562160134315491
Epoch 9, step 400, loss 0.069264255464077
Epoch 9, step 800, loss 0.15198764204978943


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 10, step 0, loss 0.1177256628870964
Epoch 10, step 400, loss 0.15538866817951202
Epoch 10, step 800, loss 0.0703272819519043


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 11, step 0, loss 0.08311384916305542
Epoch 11, step 400, loss 0.2656365633010864
Epoch 11, step 800, loss 0.06454846262931824


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 12, step 0, loss 0.06674327701330185
Epoch 12, step 400, loss 0.017385583370923996
Epoch 12, step 800, loss 0.06393568962812424


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 13, step 0, loss 0.12351129949092865
Epoch 13, step 400, loss 0.07670784741640091
Epoch 13, step 800, loss 0.08993058651685715


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 14, step 0, loss 0.10186967253684998
Epoch 14, step 400, loss 0.08534817397594452
Epoch 14, step 800, loss 0.055442288517951965


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 15, step 0, loss 0.10225304216146469
Epoch 15, step 400, loss 0.08070546388626099
Epoch 15, step 800, loss 0.0058848499320447445


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 16, step 0, loss 0.06715400516986847
Epoch 16, step 400, loss 0.12058861553668976
Epoch 16, step 800, loss 0.030284132808446884


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 17, step 0, loss 0.05192433297634125
Epoch 17, step 400, loss 0.1863536387681961
Epoch 17, step 800, loss 0.024821851402521133


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 18, step 0, loss 0.03700615465641022
Epoch 18, step 400, loss 0.0830673947930336
Epoch 18, step 800, loss 0.01992711052298546


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 19, step 0, loss 0.007939216680824757
Epoch 19, step 400, loss 0.06376026570796967
Epoch 19, step 800, loss 0.15957710146903992


In [9]:
model = model.eval()

In [10]:
embedding = model.embedding

In [11]:
# Convert test data to embedding vectors
embeddings = []
labels = []
for x, y in tqdm(test_loader):
    x = x.to(device=device, dtype=th.float32)
    y = y.to(device=device, dtype=th.float32)
    with th.no_grad():
        e = embedding(x)
    # flatten the batch dimension
    # detach then extend
    embeddings.extend(e.detach().cpu().numpy().tolist())
    labels.extend(y.detach().cpu().numpy().tolist())

  0%|          | 0/157 [00:00<?, ?it/s]

In [12]:
labels = list(map(lambda x: int(x), labels))

In [13]:
import plotly.express as px
import pandas as pd

In [14]:
# Plot the embeddings with plotly
df = pd.DataFrame(embeddings, columns=["x", "y"])
df["label"] = list(map(str, labels))
# labels are discrete, so we can use category
fig = px.scatter(df, x="x", y="y", color="label", opacity=0.7, category_orders={"label": [str(i) for i in range(10)]})
# enlarge the size of the graph
fig.update_layout(width=800, height=600)
fig.show()