In [1]:
# https://github.com/rishikksh20/CrossViT-pytorch/blob/master/crossvit.py

import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange


class CrossAttention(nn.Module):
    def __init__(self, input_dim=128, intermediate_dim=256, heads=4, dropout=0.):
        super().__init__()
        project_out = input_dim

        self.heads = heads
        self.scale = (input_dim / heads) ** -0.5

        self.key = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.value = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.query = nn.Linear(input_dim, intermediate_dim, bias=False)

        self.out = nn.Sequential(
            nn.Linear(intermediate_dim, project_out),
            nn.Dropout(dropout)
        )

        
    def forward(self, data):
        b, n, d, h = *data.shape, self.heads

        k = self.key(data)
        k = rearrange(k, 'b n (h d) -> b h n d', h=h)

        v = self.value(data)
        v = rearrange(v, 'b n (h d) -> b h n d', h=h)
        
        # get only cls token
        q = self.query(data[:, 0].unsqueeze(1))
        q = rearrange(q, 'b n (h d) -> b h n d', h=h)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attention = dots.softmax(dim=-1)

        output = einsum('b h i j, b h j d -> b h i d', attention, v)
        output = rearrange(output, 'b h n d -> b n (h d)')
        output = self.out(output)
        
        return output


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

    
class CrossAttentionLayer(nn.Module):
    def __init__(self, 
                 molecule_dim=128, molecule_intermediate_dim=256,
                 protein_dim=1024, protein_intermediate_dim=2048,
                 cross_attn_depth=1, cross_attn_heads=4, dropout=0.):
        super().__init__()


        self.cross_attn_layers = nn.ModuleList([])
        for _ in range(cross_attn_depth):
            self.cross_attn_layers.append(nn.ModuleList([
                nn.Linear(molecule_dim, protein_dim),
                nn.Linear(protein_dim, molecule_dim),
                PreNorm(protein_dim, CrossAttention(
                    protein_dim, protein_intermediate_dim, cross_attn_heads, dropout
                )),
                nn.Linear(protein_dim, molecule_dim),
                nn.Linear(molecule_dim, protein_dim),
                PreNorm(molecule_dim, CrossAttention(
                    molecule_dim, molecule_intermediate_dim, cross_attn_heads, dropout
                ))
            ]))

            
    def forward(self, molecule, protein):
        for f_sl, g_ls, cross_attn_s, f_ls, g_sl, cross_attn_l in self.cross_attn_layers:
            
            cls_molecule = molecule[:, 0]
            x_molecule = molecule[:, 1:]
            
            cls_protein = protein[:, 0]
            x_protein = protein[:, 1:]

            # Cross attention for protein sequence
            cal_q = f_ls(cls_protein.unsqueeze(1))
            cal_qkv = torch.cat((cal_q, x_molecule), dim=1)
            cal_out = cal_q + cross_attn_l(cal_qkv)
            cal_out = g_sl(cal_out)
            protein = torch.cat((cal_out, x_protein), dim=1)

            # Cross attention for molecule sequence
            cal_q = f_sl(cls_molecule.unsqueeze(1))
            cal_qkv = torch.cat((cal_q, x_protein), dim=1)
            cal_out = cal_q + cross_attn_s(cal_qkv)
            cal_out = g_ls(cal_out)
            molecule = torch.cat((cal_out, x_molecule), dim=1)

        return molecule, protein

In [5]:
cross_attention_layer = CrossAttentionLayer(cross_attn_depth=2)
cross_attention_layer

CrossAttentionLayer(
  (cross_attn_layers): ModuleList(
    (0): ModuleList(
      (0): Linear(in_features=128, out_features=1024, bias=True)
      (1): Linear(in_features=1024, out_features=128, bias=True)
      (2): PreNorm(
        (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (fn): CrossAttention(
          (key): Linear(in_features=1024, out_features=2048, bias=False)
          (value): Linear(in_features=1024, out_features=2048, bias=False)
          (query): Linear(in_features=1024, out_features=2048, bias=False)
          (out): Sequential(
            (0): Linear(in_features=2048, out_features=1024, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (3): Linear(in_features=1024, out_features=128, bias=True)
      (4): Linear(in_features=128, out_features=1024, bias=True)
      (5): PreNorm(
        (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (fn): CrossAttention(
          (key): Line

In [6]:
molecule_seq = torch.rand((1, 23, 128))
protein_seq = torch.rand((1, 512, 1024))

cross_attention_layer(molecule_seq, protein_seq)

(tensor([[[-0.0765, -0.0243,  0.0252,  ..., -0.0181, -0.1369, -0.0219],
          [ 0.0341,  0.0771,  0.2278,  ...,  0.6557,  0.9227,  0.6813],
          [ 0.6740,  0.3114,  0.1296,  ...,  0.5097,  0.7271,  0.6912],
          ...,
          [ 0.9594,  0.0214,  0.1743,  ...,  0.4185,  0.4760,  0.4756],
          [ 0.4065,  0.2358,  0.2313,  ...,  0.9924,  0.4578,  0.0556],
          [ 0.4851,  0.2601,  0.6943,  ...,  0.8129,  0.1236,  0.8042]]],
        grad_fn=<CatBackward0>),
 tensor([[[-0.0437, -0.0517, -0.1115,  ...,  0.0574, -0.0288, -0.0726],
          [ 0.4962,  0.3244,  0.2906,  ...,  0.1367,  0.6601,  0.1835],
          [ 0.5011,  0.4542,  0.6685,  ...,  0.2757,  0.8375,  0.4043],
          ...,
          [ 0.6170,  0.4904,  0.1649,  ...,  0.9705,  0.5077,  0.9278],
          [ 0.2185,  0.5124,  0.2504,  ...,  0.6507,  0.9039,  0.9518],
          [ 0.3343,  0.7399,  0.8789,  ...,  0.6956,  0.2932,  0.1575]]],
        grad_fn=<CatBackward0>))