# Notkun gervigreindar fyrir teikningu þrívíddarmynda

Nathan Holmes-King

In [43]:
from einops import rearrange
from einops import repeat
from functools import wraps
import math
import numpy as np
import pandas as pd
import pywikibot
import random
import sklearn as sk
from sklearn.model_selection import train_test_split
from stl import mesh
import time
#from timm.models.layers import DropPath
import torch
from torch import einsum
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#from torch_cluster import fps
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
torch.set_default_device(device)

## Inngangsorð
Við ætlum að þjálfa gervigreindarlíkan til að teikna þrívíddarmyndir.

Tilvísanir:
- https://arxiv.org/pdf/2301.11445
- https://github.com/1zb/3DShape2VecSet

## Gögn
Þessi gögn eru STL-skrár frá Wikimedia Commons. Það eru fimm flokkar:
- líkamshlutar
- byggingar
- rúmfræði
- geimfarartæki
- styttur

### Sækja gögn

In [4]:
flokkar = ['body parts', 'buildings', 'geometric shapes', 'objects in space', 'sculptures']
skrar = {}
catnum = {}

In [5]:
commons = pywikibot.Site('commons', 'commons')
cn = 0
for a in flokkar:
    print(a)
    cat = pywikibot.Category(commons, 'STL files of ' + a)
    catnum[a] = cn
    cn += 1
    n = 0
    for p in cat.members(member_type=['file']):
        if n % 10 == 0:
            print(n)
        mynd = pywikibot.FilePage(p)
        try:
            tempf = open('/Users/002-nathan/Desktop/Envalys/STLdata/' + a + '_' + p.title()[5:], 'r')
            tempf.close()
        except FileNotFoundError:
            mynd.download(filename='/Users/002-nathan/Desktop/Envalys/STLdata/' + a + '_' + p.title()[5:])
        try:
            skrar[a].append(p.title()[5:])
        except KeyError:
            skrar[a] = [p.title()[5:]]
        n += 1
        if n >= 100:
            break

body parts
0
10
20
30
40
50
60
70
80
buildings
0
10
20
geometric shapes
0
10
20
30
40
objects in space
0
10
20
30
40
50
sculptures
0
10
20
30
40
50


### Setja upp gögn fyrir notkun
Við búum til greypingu ("embedding") fyrir punktana.

In [47]:
pc_pre = []
queries_pre = []
for cat in skrar:
    print(cat)
    byrjun = time.time()
    for fi in skrar[cat]:
        # Load data
        gogn = mesh.Mesh.from_file('/Users/002-nathan/Desktop/Envalys/STLdata/' + cat + '_' + fi)
        inp = []
        for i in range(2048):
            inp.append(gogn.v0[random.randint(0, len(gogn.v0) - 1)])
        pc_pre.append(torch.tensor([inp]).to(device))
        inp = []
        for i in range(512):
            inp.append(gogn.v0[random.randint(0, len(gogn.v0) - 1)])
        queries_pre.append(torch.tensor([inp]).to(device))
    print(time.time() - byrjun)
    print('----')

body parts
2.312717914581299
----
buildings
1.2082140445709229
----
geometric shapes
1.6727559566497803
----
objects in space
exception (False, 'No lines found, impossible to read')
2.0670459270477295
----
sculptures
6.107502698898315
----


## Líkan

In [53]:
def dist(p1, p2):
    return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2 + (p1[2] - p2[2]) ** 2) ** (1/2)

In [60]:
def fps(x, batch, ratio):
    """
    Self-written function, based off work done at Tengdu. Ignores batches entirely.
    """
    first = random.randrange(x.shape[0])
    hubChoice = [first]
    # Find distance between each user and first hub
    distFromHub = []
    for i in range(x.shape[0]):
        distFromHub.append(dist(x[i], x[first]))
    # Find distance from every user to every other user
    distb = {}
    for a in range(x.shape[0]):
        distb[a] = {}
    for a in range(x.shape[0]):
        for b in range(x.shape[0]):
            gd = dist(x[a], x[b])
            distb[a][b] = gd
            distb[b][a] = gd
    # Main loop
    while len(hubChoice) / x.shape[0] < ratio:
        # Calculate weights for each hub
        hubWeight = []
        for a in range(x.shape[0]):
            d_a = distFromHub[a]
            if d_a == 0:
                continue
            w = []
            # Distance to other users
            for b in range(x.shape[0]):
                if a == b:
                    continue
                d_b = distb[a][b]
                if d_b == 0:
                    continue
                if d_b > d_a:
                    w.append(0)
                elif d_b < 0.5:
                    w.append(math.log(2*d_a)-1)
                else:
                    w.append(min(max(0, math.log(d_a/d_b)-1),
                                 math.log(2*d_a)-1))
            wt = sum(w)
            hubWeight.append((a, wt))
        # Find hub with highest weight
        hubWeight.sort(key=lambda x: x[1], reverse=True)
        hubChoice.append(hubWeight[0][0])
        for a in range(x.shape[0]):
            newd = dist(x[a], x[hubWeight[0][0]])
            if newd < distFromHub[a]:
                distFromHub[a] = newd
    return hubChoice

In [69]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, _cache = True, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim = None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)

        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context = normed_context)

        return self.fn(x, **kwargs)

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, drop_path_rate = 0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )

        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, x):
        return self.drop_path(self.net(x))

class Attention(nn.Module):
    def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, query_dim)

        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, x, context = None, mask = None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h = h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim = -1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.drop_path(self.to_out(out))


class PointEmbed(nn.Module):
    def __init__(self, hidden_dim=48, dim=128):
        super().__init__()

        assert hidden_dim % 6 == 0

        self.embedding_dim = hidden_dim
        e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
        e = torch.stack([
            torch.cat([e, torch.zeros(self.embedding_dim // 6),
                        torch.zeros(self.embedding_dim // 6)]),
            torch.cat([torch.zeros(self.embedding_dim // 6), e,
                        torch.zeros(self.embedding_dim // 6)]),
            torch.cat([torch.zeros(self.embedding_dim // 6),
                        torch.zeros(self.embedding_dim // 6), e]),
        ])
        self.register_buffer('basis', e)  # 3 x 16

        self.mlp = nn.Linear(self.embedding_dim+3, dim)

    @staticmethod
    def embed(input, basis):
        projections = torch.einsum(
            'bnd,de->bne', input, basis)
        embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
        return embeddings
    
    def forward(self, input):
        # input: B x N x 3
        embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
        return embed


class DiagonalGaussianDistribution(object):
    def __init__(self, mean, logvar, deterministic=False):
        self.mean = mean
        self.logvar = logvar
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(device=self.mean.device)

    def sample(self):
        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.mean.device)
        return x

    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.])
        else:
            if other is None:
                return 0.5 * torch.mean(torch.pow(self.mean, 2)
                                       + self.var - 1.0 - self.logvar,
                                       dim=[1, 2])
            else:
                return 0.5 * torch.mean(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
                    dim=[1, 2, 3])

    def nll(self, sample, dims=[1,2,3]):
        if self.deterministic:
            return torch.Tensor([0.])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims)

    def mode(self):
        return self.mean

class AutoEncoder(nn.Module):
    def __init__(
        self,
        *,
        depth=24,
        dim=512,
        queries_dim=512,
        output_dim = 1,
        num_inputs = 2048,
        num_latents = 512,
        heads = 8,
        dim_head = 64,
        weight_tie_layers = False,
        decoder_ff = False
    ):
        super().__init__()

        self.depth = depth

        self.num_inputs = num_inputs
        self.num_latents = num_latents

        self.cross_attend_blocks = nn.ModuleList([
            PreNorm(dim, Attention(dim, dim, heads = 1, dim_head = dim), context_dim = dim),
            PreNorm(dim, FeedForward(dim))
        ])

        self.point_embed = PointEmbed(dim=dim)

        get_latent_attn = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, drop_path_rate=0.1))
        get_latent_ff = lambda: PreNorm(dim, FeedForward(dim, drop_path_rate=0.1))
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))

        self.layers = nn.ModuleList([])
        cache_args = {'_cache': weight_tie_layers}

        for i in range(depth):
            self.layers.append(nn.ModuleList([
                get_latent_attn(**cache_args),
                get_latent_ff(**cache_args)
            ]))

        self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, dim, heads = 1, dim_head = dim), context_dim = dim)
        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None

        self.to_outputs = nn.Linear(queries_dim, output_dim) if exists(output_dim) else nn.Identity()

    def encode(self, pc):
        # pc: B x N x 3
        B, N, D = pc.shape
        assert N == self.num_inputs
        
        ###### fps
        flattened = pc.view(B*N, D)

        batch = torch.arange(B).to(pc.device)
        batch = torch.repeat_interleave(batch, N)

        pos = flattened

        ratio = 1.0 * self.num_latents / self.num_inputs

        #idx = fps(pos, batch, ratio=ratio)

        #sampled_pc = pos[idx]
        sampled_pc = pos[:self.num_latents]  # fps() takes too much time to run
        sampled_pc = sampled_pc.view(B, -1, 3)
        ######

        sampled_pc_embeddings = self.point_embed(sampled_pc)

        pc_embeddings = self.point_embed(pc)

        cross_attn, cross_ff = self.cross_attend_blocks

        x = cross_attn(sampled_pc_embeddings, context = pc_embeddings, mask = None) + sampled_pc_embeddings
        x = cross_ff(x) + x
        
        return x


    def decode(self, x, queries):

        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x

        # cross attend from decoder queries to latents
        queries_embeddings = self.point_embed(queries)
        latents = self.decoder_cross_attn(queries_embeddings, context = x)

        # optional decoder feedforward
        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)
        
        return self.to_outputs(latents)

    def forward(self, pc, queries):
        x = self.encode(pc)

        o = self.decode(x, queries).squeeze(-1)

        return {'logits': o}

In [70]:
model = AutoEncoder()
fw = model(pc_pre[0], queries_pre[0])
print(fw)

{'logits': tensor([[-0.1341, -0.1468, -0.1333, -0.1404, -0.1480, -0.1428, -0.1339, -0.1460,
         -0.1353, -0.1349, -0.1458, -0.1483, -0.1472, -0.1344, -0.1360, -0.1343,
         -0.1484, -0.1458, -0.1457, -0.1464, -0.1455, -0.1460, -0.1355, -0.1374,
         -0.1458, -0.1470, -0.1428, -0.1386, -0.1343, -0.1368, -0.1473, -0.1410,
         -0.1338, -0.1405, -0.1448, -0.1476, -0.1445, -0.1482, -0.1453, -0.1469,
         -0.1386, -0.1368, -0.1284, -0.1338, -0.1376, -0.1479, -0.1392, -0.1432,
         -0.1452, -0.1366, -0.1440, -0.1341, -0.1341, -0.1435, -0.1446, -0.1352,
         -0.1471, -0.1337, -0.1322, -0.1430, -0.1333, -0.1484, -0.1489, -0.1284,
         -0.1462, -0.1395, -0.1442, -0.1394, -0.1340, -0.1455, -0.1467, -0.1442,
         -0.1420, -0.1394, -0.1415, -0.1466, -0.1448, -0.1489, -0.1346, -0.1355,
         -0.1409, -0.1476, -0.1452, -0.1408, -0.1473, -0.1467, -0.1404, -0.1478,
         -0.1395, -0.1406, -0.1470, -0.1416, -0.1351, -0.1440, -0.1373, -0.1335,
         -0.1479,