In [1]:
%matplotlib inline

<!--<badge>--><a href="https://colab.research.google.com/github/softmatterlab/DeepTrack-2.0/blob/develop/examples/tutorials/analyzing_video_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a><!--</badge>-->

# MNIST classification using PyTorch

In this tutorial, we will use DeepTrack to classify MNIST digits using a PyTorch model.

## 1 - Importing DeepTrack

First, we import DeepTrack and a few other useful libraries.

In [2]:
import deeptrack as dt
import deeptrack.torchmodels as dtm
import numpy as np 
import matplotlib.pyplot as plt

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## 3 - Loading the MNIST dataset

We download and load the MNIST dataset using DeepTrack

In [3]:
mnist_train_dataset = dt.Dataset("mnist", split="train[:80%]")
mnist_val_dataset = dt.Dataset("mnist", split="train[80%:]")
mnist_test_dataset = dt.Dataset("mnist", split="test")

# Create a data pipeline that picks which dataset to use from train, val, or test
mnist_image_pipeline = dt.Select(
    on_train=mnist_train_dataset.image,
    on_val=mnist_val_dataset.image,
    on_test=mnist_test_dataset.image,
)

mnist_gt_pipeline = dt.Select(
    on_train=mnist_train_dataset.label,
    on_val=mnist_val_dataset.label,
    on_test=mnist_test_dataset.label,
)

## 4 - Preparing the data

We prepare the data for training by normalizing the images and converting the labels to one-hot vectors.

In [4]:
normalization = dt.Divide(255)
normalized_image_pipeline = mnist_image_pipeline >> normalization

In [5]:
def to_one_hot(x):
    return np.eye(10)[x]

one_hot_gt_pipeline = mnist_gt_pipeline >> to_one_hot

In [16]:
training_pipeline = normalized_image_pipeline & one_hot_gt_pipeline

training_pipeline.update()

## 5 - Defining the model

We create a PyTorch-based image classifier from DeepTrack.

In [6]:
model = dtm.ImageClassifier(
    input_shape=(1, 28, 28), # torch is channel first
    num_classes=10
)

# 6 - Training the model

We train the model using the training data.

In [7]:
# The data module creates and manages data generators
data = dtm.DataModule(
    ,
    batch_size=32,
    training_size=1000,
    validation_size=1000,
    test_size=1000
)

In [9]:
# The trainer manages the training loop
trainer = dtm.Trainer(
    max_epochs=100,
    accelerator="cpu",
    log_every_n_steps=4,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(model, data)


  | Name         | Type             | Params
--------------------------------------------------
0 | loss         | CrossEntropyLoss | 0     
1 | conv_layers  | ModuleList       | 23.3 K
2 | dense_layers | ModuleList       | 91.7 K
--------------------------------------------------
114 K     Trainable params
0         Non-trainable params
114 K     Total params
0.460     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
# The trainer can also be used to test the model
trainer.test(model, data)