In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import torchvision.models as models
from einops import rearrange

device = 'cuda'

In [2]:
from render_tools import renderer, sample_quadratic_bezier_curve
from vgg import VGG19

In [11]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.key_matrix = nn.Linear(1024, 256, bias=False)
        self.query_matrix = nn.Linear(1024, 256, bias=False)
        self.value_matrix = nn.Linear(1024, 256, bias=False)
        self.multihead_attn = nn.MultiheadAttention(embed_dim=256, num_heads=1)
        self.out_matrix = nn.Linear(256, 1024, bias=False)
    
    def forward(self, content, style):
        b, c, h, w = content.shape
                        
        content = rearrange(content, 'b c h w -> b c (h w)')
        style = rearrange(style, 'b c h w -> b c (h w)')
        
        key = self.key_matrix(content)
        query = self.query_matrix(style)
        value = self.value_matrix(content)
        
        attn_output, attn_output_weights = self.multihead_attn(query, key, value)
        out = self.out_matrix(attn_output)
        out = rearrange(out, 'b c (h w) -> b c h w', h=h, w=w)
        
        return out
        

In [20]:
class Conv_Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False))
        
    def forward(self, x):
        return self.net(x)

In [21]:
class Up_Conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                         nn.BatchNorm2d(out_channels),
                         nn.ReLU(inplace=False))
        
    def forward(self, x):
        return self.net(x)

In [78]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.block1 = Conv_Block(512, 256)
        self.up1 = Up_Conv(256, 256)
        
        self.block2 = Conv_Block(256, 128)
        self.up2 = Up_Conv(128, 128)
        
        self.block3 = Conv_Block(128, 64)
        self.up3 = Up_Conv(64, 64)
        
        self.block4 = Conv_Block(64, 32)
        self.up4 = Up_Conv(32, 32)
        
        self.final_block = Conv_Block(32, 12)
        
        
    def forward(self, x):
        x = self.block1(x)
        x = self.up1(x)   # 32 -> 64
        
        x = self.block2(x)
        x = self.up2(x)   # 64 -> 128
        
        x = self.block3(x)
        #x = self.up3(x)   # 128 -> 256
        
        x = self.block4(x)
        #x = self.up4(x)   # 32 -> 64
        
        out = self.final_block(x)
        
        return out
        
        
        

In [69]:
def image_loader(image_name, img_size, device):
    loader = lambda imsize: transforms.Compose([
    transforms.Resize((imsize,imsize)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std = [0.229, 0.224, 0.225])]) 
    
    image = Image.open(image_name)
    image = loader(img_size)(image).unsqueeze(0)
    image = image.to(device, torch.float)
    return image

In [91]:
class BrushStrokeRenderer(nn.Module):
    def __init__(self):
        super().__init__()
        self.samples_per_curve = 10
        self.canvas_height = 512
        self.canvas_width =  512
        self.strokes_per_pixel = 20
        self.canvas_color = .5
        
    def forward(self, curve_s, curve_e, curve_c, location, color, width):
        curve_points = sample_quadratic_bezier_curve(s=curve_s + location,
                                                     e=curve_e + location,
                                                     c=curve_c + location,
                                                     num_points=self.samples_per_curve)
        
        canvas = renderer(curve_points, location, color, width,
                                 self.canvas_height, self.canvas_width, self.strokes_per_pixel, self.canvas_color)
        
        return canvas

In [92]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        encoder = models.vgg16(pretrained=True).features

        self.block_32x32 = nn.Sequential(*encoder[:27])
        self.transformer = Transformer()
        
        self.decoder = Decoder()
        
        self.renderer = BrushStrokeRenderer()
        
    def forward(self, content_img, style_img):
        feats_content = self.block_32x32(content_img)
        feats_style = self.block_32x32(style_img)
        
        out = self.transformer(feats_content, feats_style)
        
        out = self.decoder(out)
        
        # curve s [N x 2]
        # curve e [N x 2]
        # curve c [N x 2]
        # color [N x 3]
        # location [N x 2]
        # width [N x 1]
        # total 12 channels
        #return out
        
        out = out.squeeze(0)
        curve_s = rearrange(out[0:2], 'c h w -> (h w) c') 
        curve_e =  rearrange(out[2:4], 'c h w -> (h w) c')
        curve_c =  rearrange(out[4:6], 'c h w -> (h w) c')
        location =  rearrange(out[6:8], 'c h w -> (h w) c')
        color =  rearrange(out[8:11], 'c h w -> (h w) c')
        width = rearrange(out[11:12], 'c h w -> (h w) c')
        
#         print(curve_s.shape, curve_e.shape, curve_c.shape, location.shape, color.shape, width.shape)
        
        
        canvas = self.renderer(curve_s, curve_e, curve_c, location, color, width)
        
        return canvas.unsqueeze(0)


In [87]:
content_img_file = '/scratch/umeleti/code/style/pytorch_brushstroke/brushstroke-parameterized-style-transfer/images/golden-gate-bridge.jpg'
style_img_file = '/scratch/umeleti/code/style/pytorch_brushstroke/brushstroke-parameterized-style-transfer/images/starry_night.jpg'
# vgg_weight_file = '/scratch/umeleti/code/style/pytorch_brushstroke/brushstroke-parameterized-style-transfer/vgg_weights/vgg19_weights_normalized.h5'


In [88]:
content_img = image_loader(content_img_file, 512, device)
style_img = image_loader(style_img_file, 512, device)

In [93]:
model = Model().to(device)

In [94]:
out = model(content_img, style_img)

torch.Size([16384, 2]) torch.Size([16384, 2]) torch.Size([16384, 2]) torch.Size([16384, 2]) torch.Size([16384, 3]) torch.Size([16384, 1])


In [95]:
plt.imshow(out..detach().cpu().numpy())

torch.Size([1, 512, 512, 3])

In [85]:
rearrange(out_[11:12], 'c h w -> (h w) c').shape

torch.Size([16384, 1])

In [53]:
curve_s = out.squeeze(0)[10:11,...]

In [54]:
curve_s.shape

torch.Size([1, 128, 128])

In [44]:
rearrange(curve_s, 'c h w -> (h w) c').shape

torch.Size([16384, 2])

In [75]:
out_[10:11].shape

torch.Size([1, 128, 128])

In [76]:
out_.shape

torch.Size([11, 128, 128])

In [7]:
# def give_vgg16_stages(self):

#     block1 = nn.Sequential(*self.encoder[:6])  # outshape 128, 256, 256
#     block2 = nn.Sequential(*self.encoder[6:13])  # 256, 128, 128
#     block3 = nn.Sequential(*self.encoder[13:20])  # 512, 64, 64
#     block4 = nn.Sequential(*self.encoder[20:27])  # 512, 32, 32
#     block5 = nn.Sequential(*self.encoder[27:34]) # 512, 16, 16

#     return [128, 256, 512, 512, 512], (block1, block2, block3, block4, block5)

In [None]:
# canvas_height = 512 
# canvas_width = 512
# strokes_per_pixel = 20
# canvas_color = .5
# curve_points = torch.randint(0, canvas_height, (10000, 20, 2))
# location = torch.randint(0, canvas_height, (10000,2))
# color =  torch.randint(0, 1, (10000,3))
# width = torch.randint((10000,1))

In [4]:
# # from tensorflow
# num_strokes = 1000
# canvas_height = 512
# canvas_width = 512

# width_scale=.1
# sec_scale = 1.1
# samples_per_curve = 10

# color = np.random.rand(num_strokes, 3)

# # Brushstroke widths
# width = np.random.rand(num_strokes, 1) * width_scale

# # Brushstroke locations
# location = np.stack([np.random.rand(num_strokes) * canvas_height, np.random.rand(num_strokes) * canvas_width], axis=-1)

# # Start point for the Bezier curves
# s = np.stack([np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_height,
#               np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_width], axis=-1)

# # End point for the Bezier curves
# e = np.stack([np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_height,
#               np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_width], axis=-1)

# # Control point for the Bezier curves
# c = np.stack([np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_height,
#               np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_width], axis=-1)

# # Normalize control points
# sec_center = (s + e + c) / 3.0
# s, e, c = [x - sec_center for x in [s, e, c]]
# s, e, c = [x * sec_scale for x in [s, e, c]]

# curve_s = torch.nn.Parameter(torch.from_numpy(np.array(s, 'float32')), requires_grad=True)
# curve_e = torch.nn.Parameter(torch.from_numpy(np.array(e, 'float32')), requires_grad=True)
# curve_c = torch.nn.Parameter(torch.from_numpy(np.array(c, 'float32')), requires_grad=True)
# color = torch.nn.Parameter(torch.from_numpy(np.array(color, 'float32')), requires_grad=True)
# location = torch.nn.Parameter(torch.from_numpy(np.array(location, 'float32')), requires_grad=True)
# width = torch.nn.Parameter(torch.from_numpy(np.array(width, 'float32')), requires_grad=True)

# curve_points = sample_quadratic_bezier_curve(s=curve_s + location,
#                                             e=curve_e + location,
#                                             c=curve_c + location,
#                                             num_points=samples_per_curve)

# strokes_per_pixel=20
# canvas_color = 0.5
# canvas = renderer(curve_points, location, color, width,
#                                  canvas_height, canvas_width, strokes_per_pixel, canvas_color)

# plt.imshow(canvas.detach().cpu().numpy())