In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import autobot

import torch
from torch import nn
import numpy as np

In [3]:
batch_size = 4
num_view = 3
h, w = 256, 256

imgs = torch.rand((batch_size, num_view, 3, h, w))

In [4]:
emb_layer = autobot.MultiViewEmbedding(emb_dim=128, patch_size=(16, 16))
emb = emb_layer(imgs)
print(emb.shape)

torch.Size([4, 3, 256, 128])


In [44]:
class MultiViewAttention(nn.Module):
    '''
        MVA searches object in 3D space by computing similarity 
    '''
    def __init__(self,
        emb_dim, attn_dim, v_dim, num_heads,
        num_view, view_transformation=None
    ):
        super().__init__()

        self.emb_dim = emb_dim
        self.attn_dim = attn_dim
        self.v_dim = v_dim
        self.num_heads = num_heads

        self.query = nn.Linear(emb_dim, attn_dim)
        self.key = nn.Linear(emb_dim, attn_dim)
        self.value = nn.Linear(emb_dim, v_dim)
        self.linear = nn.Linear(v_dim*num_heads, v_dim*num_heads)

        self.per_view_mha = nn.ModuleList([
            autobot.MultiHeadAttention(emb_dim, attn_dim, v_dim, num_heads)
            for _ in range(num_view)
        ])
        
    def forward(self, v, vs):
        query = self.query(v)
        
        print('query:', v.shape, query.shape)
        print('-'*30)
        key = self.key(vs)
        value = self.value(vs)
        print('key, value:', vs.shape, key.shape, value.shape)

        print(query.unsqueeze(1).shape)
        logits = torch.matmul(query.unsqueeze(1), key.transpose(-1, -2))
        print(logits.shape)
        logits_ = logits.sum(1)
        print(logits.shape)
        attn_ = logits_.softmax(dim=-1)

        attn = logits.softmax(dim=-1)
        v = torch.matmul(attn, value)
        print('v', v.shape)

        o = torch.matmul(attn_, v.sum(1))
        print('o', o.shape)

        return self.linear(o)


mva_layer = MultiViewAttention(emb_dim=128, attn_dim=8, v_dim=16, num_heads=1, num_view=num_view)
i = 0
print(emb.shape)
mva = mva_layer(
    emb[:, i, :, :], 
    torch.cat([
        emb[:, :i, :, :],
        emb[:, i+1:, :, :]
    ], dim=1)  # from i-th view to other views that are not i-th
)
print(mva.shape)

torch.Size([4, 3, 256, 128])
query: torch.Size([4, 256, 128]) torch.Size([4, 256, 8])
------------------------------
key, value: torch.Size([4, 2, 256, 128]) torch.Size([4, 2, 256, 8]) torch.Size([4, 2, 256, 16])
torch.Size([4, 1, 256, 8])
torch.Size([4, 2, 256, 256])
torch.Size([4, 2, 256, 256])
v torch.Size([4, 2, 256, 16])
o torch.Size([4, 256, 16])
torch.Size([4, 256, 16])


In [2]:
K = [[ 282.363047,      0.,          166.21515189],
     [   0.,          280.10715905,  108.05494375],
     [   0.,            0.,            1.        ]]
K = np.array(K)
R = np.eye(3)
t = np.array([[0],[1.],[0]])
P = K.dot(np.hstack((R,t)))
C = np.array([0., 0., 1.])
p1 = np.array([215, 180, 1.])

In [3]:
P

array([[282.363047  ,   0.        , 166.21515189,   0.        ],
       [  0.        , 280.10715905, 108.05494375, 280.10715905],
       [  0.        ,   0.        ,   1.        ,   0.        ]])

In [4]:
np.hstack((R,t))

array([[1., 0., 0., 0.],
       [0., 1., 0., 1.],
       [0., 0., 1., 0.]])

In [2]:
from dataclasses import dataclass


@dataclass(slots=True)
class Object:
    s: str='abc'

def _process(row):
    return Object(str(row))

def _get_gen(l):
    for i in range(l):
        yield _process(i)

g = _get_gen(100)

print(g)

for obj in g:
    print(obj)

<generator object _get_gen at 0x7fdbf7c7e180>
Object(s='0')
Object(s='1')
Object(s='2')
Object(s='3')
Object(s='4')
Object(s='5')
Object(s='6')
Object(s='7')
Object(s='8')
Object(s='9')
Object(s='10')
Object(s='11')
Object(s='12')
Object(s='13')
Object(s='14')
Object(s='15')
Object(s='16')
Object(s='17')
Object(s='18')
Object(s='19')
Object(s='20')
Object(s='21')
Object(s='22')
Object(s='23')
Object(s='24')
Object(s='25')
Object(s='26')
Object(s='27')
Object(s='28')
Object(s='29')
Object(s='30')
Object(s='31')
Object(s='32')
Object(s='33')
Object(s='34')
Object(s='35')
Object(s='36')
Object(s='37')
Object(s='38')
Object(s='39')
Object(s='40')
Object(s='41')
Object(s='42')
Object(s='43')
Object(s='44')
Object(s='45')
Object(s='46')
Object(s='47')
Object(s='48')
Object(s='49')
Object(s='50')
Object(s='51')
Object(s='52')
Object(s='53')
Object(s='54')
Object(s='55')
Object(s='56')
Object(s='57')
Object(s='58')
Object(s='59')
Object(s='60')
Object(s='61')
Object(s='62')
Object(s='63')
Obje