Skip to content
This repository has been archived by the owner on Mar 19, 2021. It is now read-only.
/ mnist-ten Public archive

Proof that I can do machine learning 🌈

License

Notifications You must be signed in to change notification settings

malyvsen/mnist-ten

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

28 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MNIST with 10 labeled examples

Proof that I can do machine learning 🌈

The task

Use only one labeled example for each digit from the MNIST training set - the rest can be used unlabeled. Make a working classifier!

To check it out

Without installing anything, you can take a look at the data test, validation, training, and baseline implementation notebooks.

But if you do want to go deeper, read on!

With Poetry

Poetry makes things easier, but it's not strictly necessary.

poetry install # install dependencies in a new venv
poetry shell # spawn a new shell within the venv

Without Poetry

python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
ipython kernel install --name "mnist-ten" --user

Example code

import torch
from mnist_ten.data import test_loader
from mnist_ten.models import classifier, weights_path

classifier.load_state_dict(torch.load(weights_path))
classifier(next(iter(test_loader))) # classify one batch

What I did

I used the unlabeled data to train a classifier on an auxiliary task: predicting how an image has been rotated/flipped. To be successful at this, the classifier needs to figure out what the different digits look like, and form a sensible early vision pipeline.

I then took the first few of that classifier's layers and trained a few new layers on top of them, this time to solve the main task - classify digits. The insights from the auxiliary task were re-used to help solve the main one, yielding an accuracy of 52%.

I also checked to see how well a simple nearest-neighbor classifier performs - indeed not bad! It also got an accuracy of 52%.

The time it took

  • 2 hours - basic data loading & baseline implementation
  • 5 hours - programming model architecture, training & validation
  • 2 hours - experimenting with data augmentation
  • 3 hours - training models (without interaction)

What didn't work

  • Training the entire model, instead of just the head, on the main task
  • Alternating between training on the main and auxiliary tasks
  • A few different architectures & hyperparameters
  • Data augmentation - it's there, but doesn't seem to be helping much

Possible improvements

  • Try many different architectures (especially simpler ones! I probably over-engineered the model)
  • Try a mix between training the entire model and training just the head
  • Experiment with learning rate scheduling and optimizers
  • Use the label-smoothing technique from this paper

About

Proof that I can do machine learning 🌈

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages