<a href="https://colab.research.google.com/github/ollihansen90/Quickdraw_CNN/blob/main/Quickdraw_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        feature_maps_1 = 32
        feature_maps_2 = 64
        
        num_classes = 13
        
        # TODO: initialize the child-modules
        self.conv1 = torch.nn.Conv2d(1, feature_maps_1, 5)
        self.conv2 = torch.nn.Conv2d(feature_maps_1, feature_maps_2, 5)

        self.fully_connected = torch.nn.Linear(4*4*feature_maps_2, num_classes)

    def forward(self, x):
        # TODO: apply the modules and functions
        feature_map_1_pre_activations = self.conv1(x)
        feature_map_1_pre_pooling = F.max_pool2d(feature_map_1_pre_activations, 2)
        feature_map_1_post_pooling = F.relu(feature_map_1_pre_pooling)

        feature_map_2_pre_activations = self.conv2(feature_map_1_post_pooling)
        feature_map_2_pre_pooling = F.max_pool2d(feature_map_2_pre_activations, 2)
        feature_map_2_post_pooling = F.relu(feature_map_2_pre_pooling)
        
        # TODO: apply the final linear layer to get the class scores
        feature_map_2_flattened = feature_map_2_post_pooling.flatten(start_dim=1)
        #print(feature_map_2_flattened.shape)
        predictions_pre_softmax = self.fully_connected(feature_map_2_flattened)
        
        return predictions_pre_softmax

In [None]:
import requests
import io
import numpy as np
from torch.utils.data import Dataset, DataLoader

def MasterDset(n_imgs, n_test):
    linklist = [
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/butterfly.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/camera.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/cup.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/duck.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/mouse.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/owl.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/pineapple.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sea%20turtle.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/snowman.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sun.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/swan.npy",
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/umbrella.npy",
                ]
    data1 = np.empty_like(np.zeros([1,784]))
    data2 = np.empty_like(np.zeros([1,784]))
    for link in linklist:
        print("Loading",link)
        response = requests.get(link)
        response.raise_for_status()
        #temp = np.load(io.BytesIO(response.content))[:n_imgs]
        temp = np.random.permutation(np.load(io.BytesIO(response.content)))
        temp1 = temp[:n_imgs]
        temp2 = temp[n_imgs:n_imgs+n_test]
        #print(np.min(temp), np.max(temp))
        data1 = np.concatenate((data1, temp1), axis=0)
        data2 = np.concatenate((data2, temp2), axis=0)
        #print(data1.shape, data2.shape)
    data1 = data1[1:,None,:]/255
    data2 = data2[1:,None,:]/255
    data1[data1>0.5] = 1
    data1[data1<=0.5] = 0
    data2[data2>0.5] = 1
    data2[data2<=0.5] = 0
    return Dset(data1), Dset(data2)

class Dset(Dataset):
    def __init__(self, data):
        self.labelnames = [
                           "Schmetterling",
                           "Kamera",
                           "Becher",
                           "Ente",
                           "Maus",
                           "Eule",
                           "Ananas",
                           "Schildkröte",
                           "Schaf",
                           "Schneemann",
                           "Sonne",
                           "Schwan",
                           "Regenschirm",
                           ]
        self.data = data
        self.label = torch.outer(torch.arange(0, 13), torch.ones(data.shape[0]//13)).flatten()
        
    def __getitem__(self, index):
        return torch.from_numpy(self.data[index]).float().reshape(1,28,28), self.label[index]
    
    def __len__(self):
        return len(self.label)

DS_train, DS_test = MasterDset(n_imgs=20_000, n_test=1000)
print(len(DS_train), len(DS_test))

#print(data.shape)

In [20]:
from torch.utils.data import Dataset, DataLoader
dataloader_train = DataLoader(DS_train, batch_size=16, shuffle=True)

batch, label = next(iter(dataloader_train))

In [None]:
import matplotlib.pyplot as plt

print(label)
plt.figure(figsize=[15,15])
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(batch[i].squeeze().numpy())
    plt.title(DS_train.labelnames[int(label[i])])
plt.show()

In [None]:
from torch.utils.data import Dataset, DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
n_epochs = 3
batch_size = 128
lr = 1e-3
betas = (0.9, 0.999)

model = CNN().to(device)

dataloader_train = DataLoader(DS_train, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(DS_test, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)
criterion = nn.CrossEntropyLoss().to(device)

for epoch in range(n_epochs):
    model.train()
    for i, (batch, label) in enumerate(dataloader_train):
        batch = batch.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        output = model(batch)
        loss = criterion(output, label.long())
        loss.backward()
        optimizer.step()
        if i%500==0:
            print(i, loss.item(), (torch.sum(torch.argmax(output, dim=-1)==label)/len(batch)).item())
        #break
    
    model.eval()
    confusion_matrix = torch.zeros([13,13])
    acc = 0
    for batch, label in dataloader_test:
        batch = batch.to(device)
        label = label.to(device)
        output = model(batch)
        for i in range(len(label)):
            confusion_matrix[int(label[i]), int(torch.argmax(output[i]))] += 1
            #confusion_matrix[int(label[i])] += F.softmax(output[i], dim=-1).detach().cpu()
        acc += torch.sum(torch.argmax(output, dim=-1)==label).item()
    acc /= len(DS_test)
    print(f"Test-Accuracy epoch {epoch}: {acc}")
    plt.figure()
    plt.imshow(confusion_matrix.detach().cpu().numpy())
    #plt.set_xticklabels(DS_train.labelnames)
    plt.show()
        

In [None]:
import numpy as np
import torch
import sys

canvas_size = (280, 280)
from IPython.core.display import display, HTML

def make_prediction_canvas(canvas_size, prediction_func_name):
        canvas_html = '''
<div>
<p>Drawing canvas:</p>
<canvas id="canvas" width="''' + str(canvas_size[0]) + '''" height="''' + str(canvas_size[1]) + '''" style="border: 5px solid black"></canvas>
<button onclick="predict()">Predict</button>
<button onclick="clear_canvas()">Clear canvas</button>
<p id="predictionfield">Prediction:</p>
</div>
<script type="text/Javascript">
function prediction_callback(data){
    if (data.msg_type === 'execute_result') {
        document.getElementById("predictionfield").innerHTML = "Prediction: " + data.content.data['text/plain']
        /*$('#predictionfield').html("Prediction: " + data.content.data['text/plain'])*/
    } else {
        console.log(data)
    }
}
function predict(){
    var imgData = ctx.getImageData(0, 0, ctx.canvas.width, ctx.canvas.height);
    imgData = Array.prototype.slice.call(imgData.data).filter(function (data, idx) { return idx % 4 == 3; })
    var kernelAPI = undefined
    try {
      //check if defined
      if (IPython) {
        kernelAPI = "IPython"
      }
    } catch(err) {
    }
    try {
      //check if defined
      if (google) {
        kernelAPI = "google"
      }
    } catch(err) {
    }
    if (kernelAPI === "IPython") {
        var command = "''' + prediction_func_name + '''(" + JSON.stringify(imgData) + ")"
        document.getElementById("predictionfield").innerHTML = "Prediction: calculating..."
        /*$('#predictionfield').html("Prediction: calculating...")*/
        var kernel = IPython.notebook.kernel;
        kernel.execute(command, {iopub: {output: prediction_callback}}, {silent: false});
    } else if (kernelAPI === "google") {
        google.colab.kernel.invokeFunction("''' + prediction_func_name + '''", [imgData], {})
        .then(function(result) {
            prediction_callback({msg_type: 'execute_result', content: {data: result.data}})
        })
    } else {
        console.error('no kernel api found to invoke predictions!')
    }
}
canvas = document.getElementById('canvas')
ctx = canvas.getContext("2d")
var clickX = new Array();
var clickY = new Array();
var clickDrag = new Array();
var paint;
function clear_canvas() {    
    clickX = new Array();
    clickY = new Array();
    clickDrag = new Array();
    
    redraw();
}
function addClick(x, y, dragging)
{
  clickX.push(x);
  clickY.push(y);
  clickDrag.push(dragging);
}
var canvas = document.getElementById("canvas")
/*$('#canvas').mousedown(*/
canvas.addEventListener('mousedown', function(e){
  var boundingRect = canvas.getBoundingClientRect()
  var mouseX = e.pageX - boundingRect.left;
  var mouseY = e.pageY - boundingRect.top;
  
  paint = true;
  addClick(mouseX, mouseY);
  redraw();
});
/*$('#canvas').mousemove(*/
canvas.addEventListener('mousemove', function(e){
  if(paint){
    var boundingRect = canvas.getBoundingClientRect()
    addClick(e.pageX - boundingRect.left, e.pageY - boundingRect.top, true);
    redraw();
  }
});
/*$('#canvas').mouseup(*/
canvas.addEventListener('mouseup', function(e){
  paint = false;
});
/*$('#canvas').mouseleave(*/
canvas.addEventListener('mouseleave', function(e){
  paint = false;
});
function redraw(){
  ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height); // Clears the canvas
  
  ctx.strokeStyle = '#000000';//"#df4b26";
  ctx.lineJoin = "round";
  ctx.lineWidth = 10;
  for(var i=0; i < clickX.length; i++) {
    ctx.beginPath();
    if(clickDrag[i] && i){
      ctx.moveTo(clickX[i-1], clickY[i-1]);
     }else{
       ctx.moveTo(clickX[i]-1, clickY[i]);
     }
     ctx.lineTo(clickX[i], clickY[i]);
     ctx.closePath();
     ctx.stroke();
  }
}
</script>
'''
        display(HTML(canvas_html)) 

def predict_func(img_data):
    #transform the list into a 2d array
    img_data = np.asarray(img_data, dtype=np.float32).reshape(*canvas_size)
    #transform the numpy array into a torch tensor
    x = torch.tensor(img_data, dtype=torch.float).to(device)
    x /= 255
    #move it to the GPU is possible
    #scale it down from 280x280 to 28x28 pixels using bilinear interpolation
    x = torch.nn.functional.interpolate(x.unsqueeze_(dim=0).unsqueeze_(dim=1), size=(28, 28), mode='bilinear', align_corners=True).squeeze_(dim=1).squeeze_(dim=0)
    #flip dimensions to match dataset
    x = x.transpose(0, 1)
    #add a batch dimension with batch_size=1
    x.unsqueeze_(dim=0)
    #add a color dimension with color channel count=1
    x.unsqueeze_(dim=1)
    #print(torch.min(x), torch.max(x))
    #switch model to evaluation mode (only necessary for some modules like dropout and batch normalization, but better to always have it rather than forget it when needed)
    model.eval()
    #predict the label for the input
    global out
    with torch.no_grad():#we don't want to store information for gradient computation
        out = model(x)
    #get the most likely label
    #print(F.softmax(out, dim=-1))
    pred_label = out.argmax(1)
    #print(pred_label)
    pred_label = pred_label.item()
    #return the predicted class name to the HTML-framework (to be displayed below)
    return DS_test.labelnames[pred_label]

#create the prediction canvas and make it use the prediction function defined above
if 'google.colab' in sys.modules:
    from google.colab import output
    output.register_callback('predict_func', predict_func)
make_prediction_canvas(canvas_size, "predict_func")

In [None]:
variable_name = ""
plt.figure()
plt.barh(DS_train.labelnames, F.softmax(out, dim=-1).squeeze().detach().cpu().numpy())
plt.grid()
plt.show()