# Playing with MNIST

This notebook expects you to have previously trained the MNIST model and saved the resulting file. 

## Canvas Installation: Two Workflows

### 1. Jupyter (locally)

The recommended way is to clone the repo, which contains `canvas.py`. Make sure you have [pycairo](https://anaconda.org/conda-forge/pycairo) installed:

```bash
conda activate dmlap
conda install -c conda-forge pycairo
```

### 2. Google Colab

When using Google Colab you will need to use `pip` and install additional libraries (based on [this](https://github.com/pygobject/pycairo/issues/39#issuecomment-391830334)):

```bash
# WARNING!!!! Do NOT do this if you are running jupyter/python locally!!!
!apt-get install libcairo2-dev libjpeg-dev libgif-dev
!pip install pycairo
```

#### 2.1 Working with the repo in your drive

Mount your drive and change to the correct directory:

```python
from google.colab import drive
drive.mount('/content/drive')

# change directory using the os module
import os
os.chdir('drive/My Drive/')
os.listdir()             # shows the contents of the current dir, you can use chdir again after that
# os.mkdir("DMLCP-2023") # creating a directory
# os.chdir("DMLCP-2023") # moving to this directory
# os.getcwd()            # printing the current directory
```

See [this notebook](https://colab.research.google.com/notebooks/io.ipynb), and [Working With Files](https://realpython.com/working-with-files-in-python/) on Real Python.

#### 2.2 Working on it as a standalone notebook

Get the`canvas` module:

```python
!curl -O https://raw.githubusercontent.com/jchwenger/DMLCP/main/python/canvas.py
```

Download and unzip the necessary images with:

```python
!curl -O https://raw.githubusercontent.com/jchwenger/DMLCP/main/python/images/3.png
!curl -O https://raw.githubusercontent.com/jchwenger/DMLCP/main/python/images/4.png
!mkdir images
!mv 3.png 4.png images
```

In [None]:
import canvas
import pathlib
from PIL import Image

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import torchvision as tv
from torchvision.transforms import v2

# Get cpu, gpu or mps device for training
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

## Load a trained network

In [None]:
NUM_CLASSES = 10
INPUT_SHAPE = [1,28,28]

MODELS_DIR = pathlib.Path("models")
MODELS_DIR.mkdir(exist_ok=True)

MODEL_NAME = "dense_mnist"
MNIST_DIR = MODELS_DIR / MODEL_NAME

GENERATED_DIR = pathlib.Path("generated")
GENERATED_DIR.mkdir(exist_ok=True)

MNIST_GEN_DIR = GENERATED_DIR / f"{MODEL_NAME}_images"
MNIST_GEN_DIR.mkdir(exist_ok=True)

model = torch.jit.load(MNIST_DIR / f"{MODEL_NAME}_scripted.pt", map_location=device)

### Load weights only

If you saved using `torch.save` instead of `torch.jit.save`, you need to redefine your model first, then load the weights into it:

```python
# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten() # [1, 28, 28] -> [1, 28*28]
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(INPUT_SHAPE[1] * INPUT_SHAPE[2], 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, NUM_CLASSES)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load(MNIST_DIR / f"{MODEL_NAME}.pt", weights_only=True))
```

The `jit` only method is ideal for using model (inference), **however**, if you want to finetune your model after reloading it, prefer the full method above (class definition + loading weights).

## Classify an image of a number

In [None]:
img = Image.open('images/3.png') # try also images/4.png

transforms = v2.Compose([  
    tv.transforms.Grayscale(num_output_channels=1),
    tv.transforms.Resize(size=(28,28), antialias=True),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

input = transforms(img)
input = input.to(device)

print(f"Input shape: {input.shape}")

def predict(model, input): 
    model.eval()
    with torch.no_grad():
        probs = nn.Softmax(dim=-1)(model(input)).cpu().numpy()
        return np.argmax(probs[0])
        
predicted = predict(model, input)
canvas.show_image(img, title=f'Predicted number: {predicted}', cmap='gray')

## Two directions

1. **disrupt**: try and find cases where the network fails to predict the images properly
2. **generate**: come up with your own images and try to classify them! Combining the two, you can try to generate images that the network fails to classify!

### Note: Dense vs ConvNet

If you tried to train a ConvNet, you will notice that it tends to be more stable in its prediction!

## 1. Disrupt

Here we provide you with a canvas object that generates images with a number. You can see that a Dense net not always succeeds (and the ConvNet does).

In [None]:
# Generate a random number between 0 and 9 (the max is excluded)
number = np.random.randint(0, 10) 
c = canvas.Canvas(28, 28)
c.background(0)
c.fill(255)
c.text_size(26)
c.text([c.width/2, c.height/2 + 9], str(number), center=True)
x = c.get_image_grayscale()

# little things:
# convert to float32, and convert 
print(x.shape, x.dtype)
x = torch.tensor(x, dtype=torch.float32).view(INPUT_SHAPE).to(device)
print(x.shape, x.dtype)

predicted = predict(model, x)
c.show(title=f'Predicted number: {predicted}', size=(512, 512))

Disruption, first idea: how about we invert the colours? We do that by adding: `1.0 - c.get_image_grayscale()` (our pixel values lie between 0 and 1.

In [None]:
number = np.random.randint(0, 10)
c = canvas.Canvas(28, 28)
c.background(255)
c.fill(0)
c.text_size(26)
c.text([c.width/2, c.height/2 + 9], str(number), center=True)

# test: rotation?
# c.translate(c.width/2, c.height/2 + 7)
# c.rotate(torch.rand(1).item() * 2 * math.pi) # random rotation from 0 to 2 pi
# c.text([0, 0], str(number), center=True)

x = 1.0 - c.get_image_grayscale() # Inverted (note: this array has already values in [0,1], no need to divide by 255)

x = torch.tensor(x, dtype=torch.float32).view(INPUT_SHAPE).to(device)

predicted = predict(model, x)
c.show(title=f'Predicted number: {predicted}', size=(512, 512))

### Ideas for exploration

- Creatively disrupt the image, keeping it recognizable to a human, but causing the model to produce an incorrect prediction. You could add random dots, or patches, for instance. Or simply create an array of random numbers of the same size as the image and add it to the image.
- Try to do this in steps, e.g. incrementally adding modifications to the image and observing when and how it stops being recongized by the model.
- Briefly discuss the steps you are taking, taking advantage of the hybrid markdown/code format of the notebook.

Make sure to display the images you are creaating!

You may want to work with the `Canvas` object directly, using some tools demonstrated in the relevant notebook, in which case you should keep in mind that you are only producing grayscale images and that the images have size 28x28.

Otherwise you might as well work by preparing images externally (e.g. by hand, or using p5js) and then loading these as we have seen earlier for the image of a four. If you take this approach, make sure you start from an image that is consistently recognizable to a human as a given number and correctly classified by the model as that same number.

## 2. Generate

Here is a simple example that looks like a `0`, and usually gets classified as one.

In [None]:
c = canvas.Canvas(28, 28)
c.background(0)

c.no_stroke()
for t in np.linspace(1, 0.2, 5):
    c.fill(255*t)
    c.circle([c.width/2, c.height/2], 10*t)

x = c.get_image_grayscale()

x = torch.tensor(x, dtype=torch.float32).view(INPUT_SHAPE).to(device)

predicted = predict(model, x)
c.show(title=f'Predicted number: {predicted}', size=(512, 512))

This most interesting when not using the text function any more, but rather using the drawing abilities of canvas.

Try different numbers!

**Also**, try shapes that *really do not look like numbers* to us, and see what happens.

As before, a ConvNet will probably perform better than a plain Dense net.

### Note

If you trained a net on FashionMNIST, you can do the same thing but with pieces of clothing! (The images must always be b&w, 28*28!).

## 4. Optional: fine-tune images!

This requires you to install `imageio`:

```python
# or pip install imageio
conda install -c conda-forge imageio
```

In [None]:
import base64
import mimetypes
import imageio as iio
from datetime import datetime
import matplotlib.pyplot as plt
from IPython.display import HTML

# Function to save image as a frame
def save_image(tensor_img, iteration):
    img = tensor_img.squeeze().detach().cpu().numpy()
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    
    # Save each image to a file
    file_path = CURRENT_RUN / f'frame_{iteration}.png'
    plt.savefig(file_path, bbox_inches='tight', pad_inches=0)
    plt.close()

    # Return image path to later convert to a gif
    return file_path

This exploits a key idea in generative deep learning: using the same technique of computing the influence of each parameters on our loss, but this time the **pixels** of the image are the 'parameters' that we modify (whilst the model parameters remain fixed!

This *definitely* doesn't work as smoothly as I would want it to (some classes don't produce very recognisable results). Maybe a ConvNet would work better? Or some small detail in there might lead to improvements, experiments required!

In [None]:
# Chosen class (we'll optimize for '0' which corresponds to class 0)
CHOSEN_CLASS = 0

# Initialize a random image of size (28, 28) with values between 0 and 1
noise = True
if noise == True:
    # using uniform noise
    image = torch.rand(1, 1, 28, 28, device=device, requires_grad=True)
else:
    # using gaussian noise
    image = torch.normal(mean=.5, std=.1, size=(1, 1, 28, 28), device=device, requires_grad=True)
    image.data = image.data.clamp(0, 1) # Ensure values are within the [0, 1] 

# Define optimizer (we'll use gradient ascent, so we'll update image's pixel values)
optimizer = torch.optim.Adam([image], lr=0.001)

# Number of iterations (this can be tuned)
iters = 5000

# When do we save the intermediate result
SAVE_EVERY = 10
PRINT_EVERY = 200
SHOW_EVERY = 1000

# List to store frames for the gif (I save them instead)
# frames = []

now = datetime.now().strftime("%m-%d-%Y_%Hh%Mm%Ss")
CURRENT_RUN = MNIST_GEN_DIR / now
CURRENT_RUN.mkdir(exist_ok=True)

# Training loop for gradient ascent
for i in range(iters):

    # 1: prediction
    output = model(image)

    # 2: loss
    # (negative on our class, we want to *maximize* the pixels that activate the class)
    loss = - output[0, CHOSEN_CLASS]
    
    # # (positive on all the rest, *exclude* other classes from prediction)
    # # (trick: torch.arange(10) != CHOSEN_CLASS is an array of booleans used as indices
    # loss += output[0, torch.arange(10) != CHOSEN_CLASS].sum()

    # 3: 'backward' | Backpropagation *on the image*!
    loss.backward()

    # 4: 'step'
    optimizer.step()

    # 5: 'zero grad' (otherwise the gradients remain there)
    optimizer.zero_grad()

    # # Inject some noise into our data
    # image.data += torch.rand(1, 1, 28, 28, device=device) * .02

    # # Standardize the data (mean 0, std 1)
    # image.data = (image.data - torch.mean(image))/ torch.std(image)

    # Normalize the data between 0 and 1
    image.data = (image.data - image.data.min()) / (image.data.max() - image.data.min())    

    # Clamp the pixel values between 0 and 1 to keep it a valid image
    image.data.clamp_(0, 1)

    # Plot probs every `SHOW_EVERY` iterations
    if i % SHOW_EVERY == 0:
        with torch.no_grad():
            probs = F.softmax(output, dim = -1).squeeze().detach().cpu()
            plt.figure(figsize=(2,2))
            plt.bar(range(10), probs)
            plt.xticks(range(10))
            plt.show()        
        
    # Print every `PRINT_EVERY` iterations
    if (i+1) % PRINT_EVERY == 0 or i == iters - 1:
        print(f"Iteration {i+1:>{len(str(iters))}}, Loss: {loss.item():.5f}")
        
    # Save the intermediate images every `SAVE_EVERY` iterations
    if i % SAVE_EVERY == 0:
        image_path = save_image(image, i)
        # frames.append(image_path)

In [None]:
# annoying business sorting text files numerically (rather than alphabetically)
# https://stackoverflow.com/a/4836734
import re
def natural_sort(l):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', str(key))]
    return sorted(l, key=alphanum_key)

# Create a GIF from the frames using imageio
MNIST_GEN_GIF = MNIST_GEN_DIR / f"{now}.class_{CHOSEN_CLASS}.gif"
with iio.get_writer(MNIST_GEN_GIF, mode="I", loop=0, duration=0.1) as writer:
    for f in natural_sort(CURRENT_RUN.glob("frame_*.png")):
        image = iio.v3.imread(f)
        writer.append_data(image)
        
print(f"GIF saved as {MNIST_GEN_GIF}")

# adapted from here: https://github.com/tensorflow/docs/blob/master/tools/tensorflow_docs/vis/embed.py

def embed_data(mime, data):
    """Embeds data as an html tag with a data-url."""
    b64 = base64.b64encode(data).decode()
    if mime.startswith('image'):
        tag = f'<img src="data:{mime};base64,{b64}"/>'
    elif mime.startswith('video'):
        tag = textwrap.dedent(f"""
            <video width="640" height="480" controls>
              <source src="data:{mime};base64,{b64}" type="video/mp4">
              Your browser does not support the video tag.
            </video>
            """)
    else:
        raise ValueError('Images and Video only.')
    return HTML(tag)

def embed_file(path):
    """Embeds a file in the notebook as an html tag with a data-url."""
    path = pathlib.Path(path)
    mime, unused_encoding = mimetypes.guess_type(str(path))
    data = path.read_bytes()
    return embed_data(mime, data)

embed_file(MNIST_GEN_GIF)

Something that could be nice to do would be to transform this code so that instead of working with only one image, one would use a batch of 10 images, and optimise the loss for each according to its class, and plot a grid of all 9 images in one go!

Also, for a savagely awesome example of this process, check out [the end of this notebook](https://github.com/johnowhitaker/aiaiart/blob/master/AIAIART_1.ipynb) ([YT Video](https://youtu.be/p814BapRq2U?si=wD-wtcQqB77EjSVY&t=2821)).