In [1]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

import torch 
import torch.nn as nn
import torch.nn.functional as F

from IPython import display

import time
from math import *
from PIL import Image
from torchvision import transforms

import glob
import cv2
import pygame

pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class DownBlock(nn.Module):
    def __init__(self, NI, NO):
        super().__init__()
        self.l1 = nn.Conv2d(NI, NO, 3, padding=1, padding_mode='circular')
        self.l2 = nn.Conv2d(NO, NO, 3, padding=1, padding_mode='circular')
        self.pool = nn.MaxPool2d(2)
        self.b = nn.BatchNorm2d(NO)
        
    def forward(self, x):
        z = F.elu(self.l1(x))
        z = self.b(F.elu(self.l2(z)))
        return self.pool(z)

class UpBlock(nn.Module):
    def __init__(self, NI, NO):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        self.l1 = nn.Conv2d(NI, NO, 3, padding=1, padding_mode='circular')
        self.l2 = nn.Conv2d(NO, NO, 3, padding=1, padding_mode='circular')
        self.b = nn.BatchNorm2d(NI)
        
    def forward(self, x):
        z = self.b(self.upsample(x))
        z = F.elu(self.l1(z))
        return F.elu(self.l2(z))
    
class FaceSegNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.down1 = DownBlock(3,16)
        self.down2 = DownBlock(16,32)
        self.down3 = DownBlock(32,64)
        self.down4 = DownBlock(64,128)
        self.down5 = DownBlock(128,256)
        
        self.linter1 = nn.Conv2d(256,256,3,padding=1, padding_mode='circular')
        self.linter2 = nn.Conv2d(256,256,3,padding=1, padding_mode='circular')
        
        self.bn = nn.BatchNorm2d(256)
        
        self.up5 = UpBlock(256,128)
        self.up4 = UpBlock(128,64)
        self.up3 = UpBlock(64,32)
        self.up2 = UpBlock(32,16)
        self.up1 = UpBlock(16,16)
        self.lfinal = nn.Conv2d(16, 8, 1)
        
        self.optim = torch.optim.Adam(self.parameters(), lr = 1e-4)
        
    def forward(self, x):
        z1 = self.down1(x)
        z2 = self.down2(z1) # 64, RES/4
        z3 = self.down3(z2) # 128, RES/8
        z4 = self.down4(z3) # 256, RES/16
        z5 = self.down5(z4) # 512, RES/32
        
        z5 = F.elu(self.linter1(z5))
        z5 = self.bn(F.elu(self.linter2(z5)))
        
        z6 = z4 + self.up5(z5) # 256, RES/16
        z7 = z3 + self.up4(z6) # 128, RES/8
        z8 = z2 + self.up3(z7) # 64, RES/4
        z9 = z1 + self.up2(z8) # 32, RES/2
                
        z10 = F.log_softmax(self.lfinal(self.up1(z9)), dim=1)
                
        return z10

In [3]:
net = FaceSegNet().cuda()
net.load_state_dict(torch.load("segmentation.pth"))

<All keys matched successfully>

In [4]:
classnames = [ 'neck', 'skin', 'u_lip', 'l_lip', 'mouth', 'hair', 'r_ear', 'l_ear', 'l_eye', 'r_eye', 'eye_g', 'nose' ]
mapping = [ 1, 2, 3, 3, 3, 4, 5, 5, 6, 6, 6, 7 ]
annotation_dir = "/sata/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno/"
image_dir = "/sata/data/CelebAMask-HQ/CelebA-HQ-img/"

In [5]:
BS = 5

colors = np.random.rand(8,3)

tr_err = []

for epoch in range(10000):
    err = []
    
    for i in range(10000//BS):
        print(i)
        x = []
        y = []
        
        for j in range(BS):
            im = torch.cuda.FloatTensor((np.array(Image.open(image_dir+"%d.jpg" % (j+i*BS)))[::2,::2,:3]-128.0)/128.0)
            x.append(im.unsqueeze(0))

            seg = np.zeros((512,512))

            for k,cn in enumerate(classnames):
                try:
                    im = np.array(Image.open(annotation_dir + "%.5d_%s.png" % (j+i*BS, cn)))[::1,::1,0]>128
                    seg = im*mapping[k] + (1-im)*seg    
                except BaseException as e:
                    pass
            
            y.append(torch.cuda.LongTensor(seg))
            
        x = torch.cat(x,0).permute(0,3,1,2)
        y = torch.cat(y,0).view(BS*512*512,)
                
        net.optim.zero_grad()
        
        p = net.forward(x)
        
        pred = p.permute(0,2,3,1).contiguous().view(BS*512*512,8)
        idx = torch.arange(BS*512*512).long().cuda()
        
        loss = -torch.mean(pred[idx, y[idx]])
        loss.backward()
        
        net.optim.step()
        err.append(loss.cpu().detach().numpy())
    
    tr_err.append(np.mean(err))
    
    plt.clf()
    plt.subplot(1,3,1)
    plt.plot(tr_err)
    
    plt.subplot(1,3,2)
    p = torch.argmax(p[0],0).cpu().detach().numpy() # channel x XR x YR
    p = colors[p]
    
    plt.imshow(p)
    
    plt.subplot(1,3,3)
    y = y.view(BS,512,512)[0].cpu().detach().numpy()
    y = colors[y]
    plt.imshow(y)
    
    plt.gcf().set_size_inches((15,5))
    
    display.clear_output(wait=True)
    display.display(plt.gcf())
    time.sleep(0.01)

0




1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74


KeyboardInterrupt: 

In [7]:
torch.save(net.state_dict(), open("segmentation.pth","wb"))

In [5]:
class LangevinNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.down1 = DownBlock(6,32)
        self.down2 = DownBlock(32,64)
        self.down3 = DownBlock(64,128)
        #self.down4 = DownBlock(128,256)
        #self.down5 = DownBlock(256,256)
        
        self.linter1 = nn.Conv2d(128,128,3,padding=1, padding_mode='circular')
        self.linter2 = nn.Conv2d(128,128,3,padding=1, padding_mode='circular')
        
        self.bn = nn.BatchNorm2d(128)
        #self.up5 = UpBlock(256,256)
        #self.up4 = UpBlock(256,128)
        self.up3 = UpBlock(128,64)
        self.up2 = UpBlock(64,32)
        self.up1 = UpBlock(32,32)
        self.lfinal1 = nn.Conv2d(32, 3, 1)
        self.lfinal2 = nn.Conv2d(3, 3, 1)
        
        self.optim = torch.optim.Adam(self.parameters(), lr = 1e-4)
        
    def forward(self, x, seg):
        z1 = self.down1(torch.cat([x,seg.expand(x.shape[0],3,x.shape[2],x.shape[3])],1)) # 32, RES/2
        z2 = self.down2(z1) # 64, RES/4
        z3 = self.down3(z2) # 128, RES/8
        #z4 = self.down4(z3) # 256, RES/16
        #z5 = self.down5(z4) # 512, RES/32
        
        z5 = F.elu(self.linter1(z3))
        z5 = self.bn(F.elu(self.linter2(z3)))
        
        #z6 = z4 + self.up5(z5) # 256, RES/16
        #z7 = z3 + self.up4(z6) # 128, RES/8
        z8 = z2 + self.up3(z5) # 64, RES/4
        z9 = z1 + self.up2(z8) # 32, RES/2
        z10 = self.lfinal2(x + self.lfinal1(self.up1(z9)))
        
        return z10

In [6]:
# dx/dt = -x
# x(t+delta) = x(t) + delta*(-x)

# dx/dt = -x + eta
# x(t+delta) = x(t) + delta*(-x) + sqrt(delta)*Gaussian sample

def forwardLangevin(src, steps = 100, BETA = 8, M = 1, DT = 0.0025):
    with torch.no_grad():
        seq_x = []
        seq_v = []
        vel = torch.zeros_like(src)
        cur = src.clone()

        for step in range(steps):
            vel = vel - DT * cur * BETA - DT * M * BETA * sqrt(2/M) * vel + sqrt(2 * BETA * DT) * torch.randn_like(vel)
            cur = cur + DT * M * BETA * vel
            
            seq_x.append(cur.detach().clone())
            seq_v.append(vel.detach().clone())
        
        return torch.cat(seq_x,0), torch.cat(seq_v,0)
    
def generateFromNoise(steps = 100, BETA = 8, M = 1, DT = 0.0025):
    with torch.no_grad():
        x, v = forwardLangevin(torch.randn(1,3,256,256).cuda())
        x = x[-1:]
        
        for step in range(steps):
            v = net.forward(x, segmap)
            x = x - DT * M * BETA * v
    
    return x

In [7]:
ddpm = LangevinNet().cuda()
ddpm.load_state_dict(torch.load("cat.pth"))

<All keys matched successfully>

In [12]:
pygame.init()
screen = pygame.display.set_mode((512, 512), 0, 24)

In [27]:
# Render the thing
step = 0

ddpm.eval()
net.eval()

colors = 2*np.array([
    [0,0,0],
    [0.28, 0.19, 0.38],
    [0.73, 0, 1],
    [0, 1, 0.62],
    [0.73, 0, 1],
    [0, 0.11, 1],
    [1, 0.9, 0],
    [1, 0, 0]])-1

colors = torch.cuda.FloatTensor(colors)

cam = cv2.VideoCapture(0)

with torch.no_grad():
    x, v = forwardLangevin(torch.randn(1,3,512,512).cuda())
    x = x[-1:]
    while True:        
        ret_val, img = cam.read()
        img = torch.cuda.FloatTensor((img-128.0)/128.0).permute(2,0,1).unsqueeze(0)[:,:,::2,::2]
        img = F.pad(img, [0, 0, 8, 8, 0, 0, 0, 0])[:,:,:256,:256]

        seg = torch.exp(net.forward(img))
        seg = torch.sum(seg.unsqueeze(2)*colors.view(1,8,3,1,1),1)
        seg = F.upsample(seg, scale_factor=2)
        
        for i in range(10):
            v = ddpm.forward(x, seg)
            x = x - 0.005 * 1 * 8 * v + 0.5*sqrt(0.005) * torch.randn_like(x) - 0.005 * x

        im = np.clip(seg.cpu().detach().numpy()[0].transpose(2,1,0)*0.5+0.5,0,1)
        im = np.clip(x.cpu().detach().numpy()[0].transpose(2,1,0)*0.5+0.5,0,1)
        im = (255*im).astype(np.uint8)

        s_image = pygame.pixelcopy.make_surface(im)        
        screen.blit(s_image, (0, 0))

        pygame.display.update()   
        time.sleep(0.005)
        step += 1

KeyboardInterrupt: 

In [28]:
cam.release()

In [17]:
img.shape

torch.Size([1, 3, 272, 320])