In [1]:
%load_ext autoreload
%autoreload 2


import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from cellmates.data.sample import Sample
from cellmates.data.dataset import CellMatesDataset
from cellmates.data.stubs import (
    generate_dataset_for_n_cells_test, 
    generate_dataset_for_cell_type, 
    generate_dataset_for_distances
)


In [2]:
sample = Sample(
    cell_types=[1,2,3], 
    distances=torch.zeros((3,3)), 
    responder_cell_type=1, 
    is_dividing=False
)

# Test 1: Padding shouldn't affect the prediction of a sample

In [13]:
from cellmates.data.stubs import repeated_cell_sample
from cellmates.data import CellMatesDataset, collate_fn
from cellmates.model.transformer import CellMatesTransformer


def test_samples_are_independent_wrt_n_cells():

    tr = CellMatesTransformer(D=512, K=int(512/16), num_encoder_layers=0).eval()

    b1 = collate_fn([repeated_cell_sample(n) for n in [2, 10]])
    b2 = collate_fn([repeated_cell_sample(n) for n in [2, 5]])


    o1 = tr(
        cell_types_BL=b1['cell_types_BL'], 
        distances_BLL=b1['distances_BLL'],
        padding_mask_BL=b1['padding_mask_BL']
    )

    o2 = tr(
        cell_types_BL=b2['cell_types_BL'], 
        distances_BLL=b2['distances_BLL'],
        padding_mask_BL=b2['padding_mask_BL']
    )


    display(o1)
    display(o2)

    assert torch.allclose(o1[0], o2[0])

tensor([[0.6690],
        [3.3507]], grad_fn=<AddmmBackward0>)

tensor([[0.6690],
        [1.6744]], grad_fn=<AddmmBackward0>)

In [19]:
repeated_cell_sample(10).distances

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [20]:
SPECIAL_CELL_TYPE = -1

class AddClassificationCellSample:
    def __init__(self, sample: Sample) -> None:
        self.cell_types = self._cell_types(sample)
        self.distances = self._distances(sample)
        self.responder_cell_type = sample.responder_cell_type
        self.is_dividing = sample.is_dividing
        self.L = len(sample.cell_types)

    def _cell_types(self, sample: Sample) -> np.ndarray:
        return np.concatenate([[SPECIAL_CELL_TYPE], sample.cell_types])
    
    def _distances(self, sample: Sample) -> np.ndarray:
        L = sample.distances.shape[0]
        distances = np.zeros((L+1, L+1))
        distances[1:, 1:] = sample.distances
        return distances
        




In [38]:
from cellmates.model.transformer import bucketize_distances, N_DISTANCES

bucketize_distances(torch.tensor([-1])), bucketize_distances(torch.tensor([0])), bucketize_distances(torch.tensor([20]))

(tensor([15]), tensor([0]), tensor([2]))

In [39]:
N_DISTANCES

16

In [40]:
4/15

0.26666666666666666

In [46]:
def test_vanilla_einsum():

    B,L,H,K = 2,3,5,7   

    Q_BLHK = torch.randn((B,L,H,K))
    K_BLHK = torch.randn((B,L,H,K))

    Q_BHLK = Q_BLHK.permute(0,2,1,3)
    K_BHLK = K_BLHK.permute(0,2,1,3)

    slow_o = torch.zeros((B,H,L,L))
    for sample in range(B):
        for head in range(H):
            for q in range(L):
                for k in range(L):
                    slow_o[sample, head, q,k] = torch.dot(Q_BHLK[sample, head, q], K_BHLK[sample, head, k])

    fast_o = torch.einsum("BLHK,BXHK->BHLX", Q_BLHK, K_BLHK)

    assert torch.allclose(slow_o, fast_o)

In [56]:
from cellmates.data.breast import get_datasets

ds = get_datasets('F', 100, concatenated=True)

In [62]:
dl = DataLoader(ds, batch_size=1024, shuffle=True, collate_fn=collate_fn)



In [70]:
b1 = next(iter(dl))
b1['cell_types_BL'].shape

torch.Size([1024, 343])

In [85]:
B,H,L,K = 2,3,5,7

padding_BL = torch.tensor([
    [1,1,0,0,0],
    [1,1,1,1,1]]
)


tensor([[[1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1]],

        [[1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1]],

        [[1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1]]])

# HERE!!!

In [72]:
import pandas as pd

csv1 = pd.read_csv('../1.csv')
csv2 = pd.read_csv('../2.csv')

np.allclose(csv1, csv2, rtol=1e-4)

True

In [44]:
import torch.nn as nn
torch.manual_seed(42)

D = 3
L = 5
B = 2

lin = nn.Linear(D,D, bias=False)

display(list(lin.parameters())[0])

m1 = torch.zeros((B,D))
m1[0,0] = 1
m1[1,0] = 2

print(f'm1: {m1.shape}')
display(m1)

print(f'm1 @ W: {lin(m1).shape}')
display(lin(m1))


m1 = m1.unsqueeze(dim=1).repeat(1,L,1)
m1[0,0,0] = 2

print(f'm1: {m1.shape}')
display(m1)

print(f'm1 @ W: {lin(m1).shape}')
lin(m1)

Parameter containing:
tensor([[ 0.4414,  0.4792, -0.1353],
        [ 0.5304, -0.1265,  0.1165],
        [-0.2811,  0.3391,  0.5090]], requires_grad=True)

m1: torch.Size([2, 3])


tensor([[1., 0., 0.],
        [2., 0., 0.]])

m1 @ W: torch.Size([2, 3])


tensor([[ 0.4414,  0.5304, -0.2811],
        [ 0.8828,  1.0607, -0.5622]], grad_fn=<MmBackward0>)

m1: torch.Size([2, 5, 3])


tensor([[[2., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.]],

        [[2., 0., 0.],
         [2., 0., 0.],
         [2., 0., 0.],
         [2., 0., 0.],
         [2., 0., 0.]]])

m1 @ W: torch.Size([2, 5, 3])


tensor([[[ 0.8828,  1.0607, -0.5622],
         [ 0.4414,  0.5304, -0.2811],
         [ 0.4414,  0.5304, -0.2811],
         [ 0.4414,  0.5304, -0.2811],
         [ 0.4414,  0.5304, -0.2811]],

        [[ 0.8828,  1.0607, -0.5622],
         [ 0.8828,  1.0607, -0.5622],
         [ 0.8828,  1.0607, -0.5622],
         [ 0.8828,  1.0607, -0.5622],
         [ 0.8828,  1.0607, -0.5622]]], grad_fn=<UnsafeViewBackward0>)

In [11]:
o1, o2

(tensor([[0.1479],
         [0.6237]], grad_fn=<AddmmBackward0>),
 tensor([[0.1397],
         [0.5513]], grad_fn=<AddmmBackward0>))

# Test datasets:

In [3]:
generate_dataset_for_n_cells_test()
generate_dataset_for_cell_type()
generate_dataset_for_distances()

<cellmates.data.dataset.CellMatesDataset at 0x13d98a110>

In [4]:
area = (3.14*(140**2))


In [26]:
H,L,K = 2,3,4

E_HLL = torch.randn((H,L,L))

V1_LK = torch.randn((L,K))
V2_LK = torch.randn((L,K))

V_HLK = torch.stack([V1_LK,V2_LK])

# using einsum
V_LHK = V_HLK.permute(1,0,2)
Z_LHK = torch.einsum("HLX,XHK->LHK", E_HLL, V_LHK)

# using straightforward:
Z_HLK_sf = torch.matmul(E_HLL, V_HLK)
Z_HLK_sf = Z_HLK_sf.view(L,H,K)

In [27]:
Z_LHK

tensor([[[-3.4113, -0.0669, -0.8160, -3.3525],
         [ 3.7275,  0.8368,  1.2183,  1.2031]],

        [[ 2.6401, -0.3938, -0.3007,  1.2835],
         [-2.6774, -0.2836, -1.9936, -1.0862]],

        [[ 1.6987, -1.6065, -0.0619,  2.3682],
         [ 1.0618,  2.3300,  0.4523, -0.8643]]])

In [28]:
Z_HLK_sf

tensor([[[-3.4113, -0.0669, -0.8160, -3.3525],
         [ 2.6401, -0.3938, -0.3007,  1.2835]],

        [[ 1.6987, -1.6065, -0.0619,  2.3682],
         [ 3.7275,  0.8368,  1.2183,  1.2031]],

        [[-2.6774, -0.2836, -1.9936, -1.0862],
         [ 1.0618,  2.3300,  0.4523, -0.8643]]])

In [29]:
from cellmates.model.transformer import SpatialMultiHeadAttention, CellMatesTransformer

In [30]:
H,L,K = 2,3,4
D = H*K

smh = SpatialMultiHeadAttention(D,H,K,'cpu')
tf = CellMatesTransformer(D,H,K)

In [31]:
B = 16
x = torch.randn((B,L,D))


distance_idxs_BLL = torch.randint(low=0, high=10, size=(B,L,L))
distance_embeddings = tf.distance_embeddings

smh(x, distance_idxs_BLL, distance_embeddings).shape

torch.Size([16, 3, 8])

In [32]:
cell_types_BL = torch.randint(low=0, high=5, size=(B,L))
distance_idxs_BLL = torch.randint(low=0, high=10, size=(B,L,L))

tf(cell_types_BL, distance_idxs_BLL).shape

torch.Size([16, 1])

# Scratch:

In [11]:
tf = CellMatesTransformer(D,H,K)

In [50]:
V = torch.randn((L,D))
Wq = torch.randn((D,K))
A = torch.randn((L,L,K))

E = torch.zeros((L,L))

for i in range(L):
    for j in range(L):
        E[i,j] = (V[i,:] @ Wq) @ A[i,j]


E_einsum = torch.einsum('LK,LXK -> LX', V@Wq, A)


In [None]:
V = torch.randn((L,D))
Wq = torch.randn((D,K))
A = torch.randn((L,L,K))

E = torch.zeros((L,L))

for i in range(L):
    for j in range(L):
        E[i,j] = (V[i,:] @ Wq) @ A[i,j]


E_einsum = torch.einsum('LK,LXK -> LX', V@Wq, A)


In [52]:
E

tensor([[-3.8700,  3.8456, -0.8910],
        [ 1.3790, -6.7893, -5.1333],
        [ 2.3019,  1.3572, -0.9927]])

In [51]:
E_einsum

tensor([[-3.8700,  3.8456, -0.8910],
        [ 1.3790, -6.7893, -5.1333],
        [ 2.3019,  1.3572, -0.9927]])

In [15]:
Eqr_BHL = torch.randn((H,L))

Eqr_BHL.unsqueeze(-1).expand((H, L, L))

tensor([[[-0.4611, -0.4611, -0.4611],
         [-0.6132, -0.6132, -0.6132],
         [ 0.5478,  0.5478,  0.5478]],

        [[-0.9738, -0.9738, -0.9738],
         [-1.4068, -1.4068, -1.4068],
         [-0.5524, -0.5524, -0.5524]]])

In [27]:
Zqr_BLK = torch.randn((2,3,4))

Zqr_BLK.unsqueeze(1).expand(size=(2,2,3,4))[0]


tensor([[[ 1.9296,  1.5269,  0.1034,  0.4090],
         [ 1.0287, -0.4787,  0.0894, -0.8882],
         [-0.2348, -0.3270, -0.5398, -1.3754]],

        [[ 1.9296,  1.5269,  0.1034,  0.4090],
         [ 1.0287, -0.4787,  0.0894, -0.8882],
         [-0.2348, -0.3270, -0.5398, -1.3754]]])

In [47]:
A[i,j].shape, K

(torch.Size([4]), 4)

In [49]:
V[i,:].shape, D

(torch.Size([8]), 8)

In [44]:
idxs = torch.LongTensor([[1,2,4],[1,2,3]])

tf.distance_embeddings['K_qr'](idxs).shape

torch.Size([2, 3, 4])