In [45]:

def is_colab():
    try:
        import google
        return True
    except ImportError:
        return False
    
import cv2
from matplotlib import pyplot as plt

import numpy as np
import base64
from PIL import Image
import io
from IPython.display import HTML
import os

from torch import optim, nn, utils, Tensor
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L


if is_colab():
    try:
        import lightning
    except ImportError:
        ! pip install lightning
    from google.colab.patches import cv2_imshow as imshow
    from google.colab import output
else:
    from cv2 import imshow 
    

image = None

In [ ]:

def save_image(data_url):
    global image
    _, encoded = data_url.split(",", 1)
    data = base64.b64decode(encoded)
    image = Image.open(io.BytesIO(data))

    # For demonstration, just print the shape
    print("Your image is now saved in the 'image' global variable as an Image object.")


if is_colab():
    # Register the callback
    output.register_callback('notebook.save_image', save_image)


canvas_html = """
<canvas width=%d height=%d></canvas>
<button id="saveBtn">Save</button>
<script>
  var canvas = document.querySelector('canvas');
  var ctx = canvas.getContext('2d');
  ctx.fillStyle = "white";
  ctx.fillRect(0, 0, canvas.width, canvas.height);
  var pos = { x: 0, y: 0 };

  document.addEventListener('mousemove', draw);
  document.addEventListener('mousedown', setPosition);
  document.addEventListener('mouseenter', setPosition);

  function setPosition(e) {
    pos.x = e.clientX - canvas.offsetLeft;
    pos.y = e.clientY - canvas.offsetTop;
  }

  function draw(e) {
    if (e.buttons !== 1) return;

    ctx.beginPath();
    ctx.lineWidth = 20;
    ctx.lineCap = 'round';
    ctx.strokeStyle = 'black';

    ctx.moveTo(pos.x, pos.y);
    setPosition(e);
    ctx.lineTo(pos.x, pos.y);

    ctx.stroke();
  }

  document.getElementById('saveBtn').onclick = function() {
    var data = canvas.toDataURL('image/png');
    google.colab.kernel.invokeFunction('notebook.save_image', [data], {});
  };
</script>
"""

def get_image_drawing_from_user(width:int=400, height:int=400):
    if is_colab():
        display(HTML(canvas_html % (width, height)))
    else:
        print("Not in colab, can't collect user drawing input!")

In [26]:


# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


class MNISTClassifier(L.LightningModule):
    def __init__(self, n_layer:int = 5):
        super().__init__()
        self.classifier = None
        
        
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer



# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

In [54]:
from torch.nn import functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, kernel_size:int = 3, padding:int = 1, padding_mode:str = 'zeros'):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, padding_mode=padding_mode)
        self.dropout = nn.Dropout2d()
        
    def forward(self, x:Tensor) -> Tensor:
        x = self.conv(x)
        x = F.relu(x, inplace=False)
        x = self.dropout(x)
        x = F.max_pool2d(x, kernel_size=2, stride=1)
        return x

In [55]:
n_layer = 5
model = nn.Sequential(
            ConvBlock(1, 64),
            *[ConvBlock(64, 64) for _ in range(n_layer)],
        )




tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0457],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0457],
         [0.0000, 0.0000, 0.0000,  ..., 0.0765, 0.0765, 0.1270],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0615],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0615],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0246]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0411, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0034, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.

In [33]:
dataset[0][0].shape

torch.Size([1, 28, 28])

In [43]:
nn.Conv2d(1, 5, kernel_size=3, padding=1)(dataset[0][0]).shape

torch.Size([5, 28, 28])

In [9]:


# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = DataLoader(dataset)

In [20]:
def show_image(i:int):
    x,y = dataset[i]
    print(f"Class: {y}")
    Image.fromarray(cv2.resize(np.array(x.squeeze(0)) * 255, (400, 400))).show()


Class: 8


In [27]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

tensor(0.1459, grad_fn=<MseLossBackward0>)
tensor(0.1540, grad_fn=<MseLossBackward0>)
tensor(0.0947, grad_fn=<MseLossBackward0>)
tensor(0.0923, grad_fn=<MseLossBackward0>)
tensor(0.1109, grad_fn=<MseLossBackward0>)
tensor(0.1301, grad_fn=<MseLossBackward0>)
tensor(0.0886, grad_fn=<MseLossBackward0>)
tensor(0.1615, grad_fn=<MseLossBackward0>)
tensor(0.0523, grad_fn=<MseLossBackward0>)
tensor(0.0983, grad_fn=<MseLossBackward0>)
tensor(0.1235, grad_fn=<MseLossBackward0>)
tensor(0.0622, grad_fn=<MseLossBackward0>)
tensor(0.1513, grad_fn=<MseLossBackward0>)
tensor(0.1127, grad_fn=<MseLossBackward0>)
tensor(0.0495, grad_fn=<MseLossBackward0>)
tensor(0.1023, grad_fn=<MseLossBackward0>)
tensor(0.0939, grad_fn=<MseLossBackward0>)
tensor(0.0981, grad_fn=<MseLossBackward0>)
tensor(0.0513, grad_fn=<MseLossBackward0>)
tensor(0.0642, grad_fn=<MseLossBackward0>)
tensor(0.1364, grad_fn=<MseLossBackward0>)
tensor(0.1297, grad_fn=<MseLossBackward0>)
tensor(0.0532, grad_fn=<MseLossBackward0>)
tensor(0.06

`Trainer.fit` stopped: `max_epochs=10` reached.


tensor(0.0384, grad_fn=<MseLossBackward0>)
tensor(0.0245, grad_fn=<MseLossBackward0>)
tensor(0.0335, grad_fn=<MseLossBackward0>)
tensor(0.0293, grad_fn=<MseLossBackward0>)
tensor(0.0257, grad_fn=<MseLossBackward0>)
tensor(0.0159, grad_fn=<MseLossBackward0>)


In [49]:
import torch

# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ 
Predictions (4 image embeddings):
 tensor([[0.0816, 0.5224, 0.5690],
        [0.1280, 0.4803, 0.5752],
        [0.0669, 0.4184, 0.6746],
        [0.0463, 0.5535, 0.5125]], grad_fn=<AddmmBackward0>) 
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
