In [4]:
%matplotlib inline

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from math import *
import time
from IPython import display

from PIL import Image

import pygame

from torchvision.datasets import Omniglot
from torchvision.transforms.functional import affine
DEVICE = "cuda"

In [5]:
class DownBlock(nn.Module):
    def __init__(self, NI, NO):
        super().__init__()
        self.l1 = nn.Conv2d(NI, NO, 3, padding=1, padding_mode='circular')
        self.l2 = nn.Conv2d(NO, NO, 3, padding=1, padding_mode='circular')
        self.pool = nn.MaxPool2d(2)
        self.b = nn.BatchNorm2d(NO)
        
    def forward(self, x):
        z = F.elu(self.l1(x))
        z = self.b(F.elu(self.l2(z)))
        return self.pool(z)

class UpBlock(nn.Module):
    def __init__(self, NI, NO):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        self.l1 = nn.Conv2d(NI, NO, 3, padding=1, padding_mode='circular')
        self.l2 = nn.Conv2d(NO, NO, 3, padding=1, padding_mode='circular')
        self.b = nn.BatchNorm2d(NI)
        
    def forward(self, x):
        z = self.b(self.upsample(x))
        z = F.elu(self.l1(z))
        return F.elu(self.l2(z))

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.down1 = DownBlock(1,16)
        self.down2 = DownBlock(16,32)
        self.down3 = DownBlock(32,64)
        self.down4 = DownBlock(64,128)
        
        self.linter1 = nn.Conv2d(128,128,5,padding=2, padding_mode='circular')
        
        self.bn = nn.BatchNorm2d(128)
        
        self.up4 = UpBlock(128,64)
        self.up3 = UpBlock(64,32)
        self.up2 = UpBlock(32,16)
        self.up1 = UpBlock(16,16)
        self.lfinal = nn.Conv2d(16, 6, 1)
        
        self.optim = torch.optim.Adam(self.parameters(), lr = 1e-4)
        
    def forward(self, x):
        z1 = self.down1(x)
        z2 = self.down2(z1) # 64, RES/4
        z3 = self.down3(z2) # 128, RES/8
        z4 = self.down4(z3) # 256, RES/16
        
        z5 = self.bn(F.elu(self.linter1(z4)))
        
        z7 = z3 + self.up4(z5) # 128, RES/8
        z8 = z2 + self.up3(z7) # 64, RES/4
        z9 = z1 + self.up2(z8) # 32, RES/2
                
        embedding = self.up1(z9)
        z10 = F.log_softmax(self.lfinal(embedding), dim=1)
        
        embedding = (x*embedding).sum(3).sum(2)/x.sum(3).sum(2)
        embedding = embedding/torch.sqrt(1e-8 + torch.sum(embedding**2,1).unsqueeze(1))
        return z10, embedding

In [6]:
net = Net().to(DEVICE)
net.load_state_dict(torch.load("symbol_seg.pth"))
net.eval()

Net(
  (down1): DownBlock(
    (l1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
    (l2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (down2): DownBlock(
    (l1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
    (l2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (down3): DownBlock(
    (l1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
    (l2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),

In [7]:
pygame.init()
RES = 64
MULT = 8
screen = pygame.display.set_mode((RES*MULT,RES*MULT), 0, 24)

In [8]:
canvas = np.zeros((64,64))

In [9]:
colors = torch.cuda.FloatTensor(np.array([[1,0,0], [0,1,0], [0,0,1], [1,0,1], [0,1,1], [0.5, 0.5, 0.5]]))

while True:
    x = torch.cuda.FloatTensor(canvas.transpose(1,0)).unsqueeze(0).unsqueeze(1)
    p,_ = net.forward(x)
    p = torch.exp(2*p)
    p = p/torch.sum(p,1).unsqueeze(1)
    p = torch.sum(p.unsqueeze(4)*colors.view((1,6,1,1,3)),1)[0]
    p = p*(x[0,0].unsqueeze(2))
    p = p.cpu().detach().numpy().transpose(1,0,2)
    rgb = np.clip(256*p,0,255).astype(np.uint8)
    
    #rgb = np.repeat(np.clip(256*canvas[:,:,np.newaxis],0,255).astype(np.uint8),3,axis=2)
    rgb = np.repeat(rgb,MULT,axis=0)
    rgb = np.repeat(rgb,MULT,axis=1)
    s_image = pygame.pixelcopy.make_surface(rgb)

    screen.blit(s_image, (0, 0))

    pygame.display.update()   
    time.sleep(0.005)
        
    for event in pygame.event.get():
        if event.type == pygame.KEYUP:
            if event.key == pygame.K_r:
                canvas = canvas * 0
                
        if event.type == pygame.MOUSEMOTION:
            if event.buttons[0]:
                t1 = time.time()
                mx = event.pos[0]//MULT
                my = event.pos[1]//MULT
                
                for dy in range(-2,3):
                    for dx in range(-2,3):
                        if dx*dx+dy*dy<=2:
                            x = dx+mx
                            y = dy+my
                            if x>=0 and y>=0 and x<64 and y<64:
                                canvas[x,y] = 1
            if event.buttons[2]:
                t1 = time.time()
                mx = event.pos[0]//MULT
                my = event.pos[1]//MULT
                
                for dy in range(-2,3):
                    for dx in range(-2,3):
                        if dx*dx+dy*dy<=2:
                            x = dx+mx
                            y = dy+my
                            if x>=0 and y>=0 and x<64 and y<64:
                                canvas[x,y] = 0
        



KeyboardInterrupt: 