In [1]:
%load_ext autoreload
%autoreload 2


import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from cellmates.data.sample import Sample
from cellmates.data.dataset import CellMatesDataset
from cellmates.data.stubs import (
    generate_dataset_for_n_cells_test, 
    generate_dataset_for_cell_type, 
    generate_dataset_for_distances
)


In [2]:
sample = Sample(
    cell_types=[1,2,3], 
    distances=torch.zeros((3,3)), 
    responder_cell_type=1, 
    is_dividing=False
)

# Test datasets:

In [3]:
generate_dataset_for_n_cells_test()
generate_dataset_for_cell_type()
generate_dataset_for_distances()

<cellmates.data.dataset.CellMatesDataset at 0x11a96f2c0>

In [4]:
area = (3.14*(140**2))


In [5]:
H,L,K = 2,3,4

E_HLL = torch.randn((H,L,L))

V1_LK = torch.randn((L,K))
V2_LK = torch.randn((L,K))

V_HLK = torch.stack([V1_LK,V2_LK])

# using einsum
V_LHK = V_HLK.permute(1,0,2)
Z_LHK = torch.einsum("HLX,XHK->LHK", E_HLL, V_LHK)

# using straightforward:
Z_HLK_sf = torch.matmul(E_HLL, V_HLK)
Z_HLK_sf = Z_HLK_sf.view(L,H,K)

In [6]:
Z_LHK

tensor([[[ 0.6873,  1.1187,  2.0891, -1.7101],
         [-0.3542,  0.6051,  0.1643, -1.6346]],

        [[ 4.1546, -1.5799, -4.3477, -0.7346],
         [-2.8730,  1.7009,  1.3387, -1.6222]],

        [[-2.5943, -1.3149, -0.3851,  2.2987],
         [ 2.1534, -1.0339, -1.0051,  0.3659]]])

In [7]:
Z_HLK_sf

tensor([[[ 0.6873,  1.1187,  2.0891, -1.7101],
         [ 4.1546, -1.5799, -4.3477, -0.7346]],

        [[-2.5943, -1.3149, -0.3851,  2.2987],
         [-0.3542,  0.6051,  0.1643, -1.6346]],

        [[-2.8730,  1.7009,  1.3387, -1.6222],
         [ 2.1534, -1.0339, -1.0051,  0.3659]]])

In [8]:
from cellmates.model.transformer import SpatialMultiHeadAttention, CellMatesTransformer

In [13]:
H,L,K = 2,3,4
D = H*K

smh = SpatialMultiHeadAttention(D,H,K,'cpu')
tf = CellMatesTransformer(D,H,K)

In [12]:
B = 16
x = torch.randn((B,L,D))


distance_idxs_BLL = torch.randint(low=0, high=10, size=(B,L,L))
distance_embeddings = tf.distance_embeddings

smh(x, distance_idxs_BLL, distance_embeddings).shape

torch.Size([16, 3, 8])

In [19]:
cell_types_BL = torch.randint(low=0, high=5, size=(B,L))
distance_idxs_BLL = torch.randint(low=0, high=10, size=(B,L,L))

tf(cell_types_BL, distance_idxs_BLL).shape

torch.Size([16, 1])

# Scratch:

In [11]:
tf = CellMatesTransformer(D,H,K)

In [50]:
V = torch.randn((L,D))
Wq = torch.randn((D,K))
A = torch.randn((L,L,K))

E = torch.zeros((L,L))

for i in range(L):
    for j in range(L):
        E[i,j] = (V[i,:] @ Wq) @ A[i,j]


E_einsum = torch.einsum('LK,LXK -> LX', V@Wq, A)


In [None]:
V = torch.randn((L,D))
Wq = torch.randn((D,K))
A = torch.randn((L,L,K))

E = torch.zeros((L,L))

for i in range(L):
    for j in range(L):
        E[i,j] = (V[i,:] @ Wq) @ A[i,j]


E_einsum = torch.einsum('LK,LXK -> LX', V@Wq, A)


In [52]:
E

tensor([[-3.8700,  3.8456, -0.8910],
        [ 1.3790, -6.7893, -5.1333],
        [ 2.3019,  1.3572, -0.9927]])

In [51]:
E_einsum

tensor([[-3.8700,  3.8456, -0.8910],
        [ 1.3790, -6.7893, -5.1333],
        [ 2.3019,  1.3572, -0.9927]])

In [15]:
Eqr_BHL = torch.randn((H,L))

Eqr_BHL.unsqueeze(-1).expand((H, L, L))

tensor([[[-0.4611, -0.4611, -0.4611],
         [-0.6132, -0.6132, -0.6132],
         [ 0.5478,  0.5478,  0.5478]],

        [[-0.9738, -0.9738, -0.9738],
         [-1.4068, -1.4068, -1.4068],
         [-0.5524, -0.5524, -0.5524]]])

In [27]:
Zqr_BLK = torch.randn((2,3,4))

Zqr_BLK.unsqueeze(1).expand(size=(2,2,3,4))[0]


tensor([[[ 1.9296,  1.5269,  0.1034,  0.4090],
         [ 1.0287, -0.4787,  0.0894, -0.8882],
         [-0.2348, -0.3270, -0.5398, -1.3754]],

        [[ 1.9296,  1.5269,  0.1034,  0.4090],
         [ 1.0287, -0.4787,  0.0894, -0.8882],
         [-0.2348, -0.3270, -0.5398, -1.3754]]])

In [47]:
A[i,j].shape, K

(torch.Size([4]), 4)

In [49]:
V[i,:].shape, D

(torch.Size([8]), 8)

In [44]:
idxs = torch.LongTensor([[1,2,4],[1,2,3]])

tf.distance_embeddings['K_qr'](idxs).shape

torch.Size([2, 3, 4])