In [1]:
import sys
sys.path.append('models')
from models.model_coattn import *

In [2]:
###########################
### MCAT Implementation ###
###########################
class MCAT_Surv(nn.Module):
    def __init__(self, fusion='concat', omic_sizes=[100, 200, 300, 400, 500, 600], model_size_wsi: str='small', 
        model_size_omic: str='small', n_classes=4, dropout=0.25):
        r"""
        Multimodal Co-Attention Transformer (MCAT) Implementation.

        Args:
            fusion (str): Late fusion method (Choices: concat, bilinear, or None)
            omic_sizes (List): List of sizes of genomic embeddings
            model_size_wsi (str): Size of WSI encoder (Choices: small or large)
            model_size_omic (str): Size of Genomic encoder (Choices: small or large)
            dropout (float): Dropout rate
            n_classes (int): Output shape of NN
        """
        super(MCAT_Surv, self).__init__()
        self.fusion = fusion
        self.omic_sizes = omic_sizes
        self.n_classes = n_classes
        self.size_dict_WSI = {"small": [1024, 256, 256], "big": [1024, 512, 384]}
        self.size_dict_omic = {'small': [256, 256], 'big': [1024, 1024, 1024, 256]}
        #self.criterion = SupConLoss(temperature=0.7)
        
        ### FC Layer over WSI bag
        size = self.size_dict_WSI[model_size_wsi]
        fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
        fc.append(nn.Dropout(0.25))
        self.wsi_net = nn.Sequential(*fc)
        
        ### Constructing Genomic SNN
        hidden = self.size_dict_omic[model_size_omic]
        sig_networks = []
        for input_dim in omic_sizes:
            fc_omic = [SNN_Block(dim1=input_dim, dim2=hidden[0])]
            for i, _ in enumerate(hidden[1:]):
                fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
            sig_networks.append(nn.Sequential(*fc_omic))
        self.sig_networks = nn.ModuleList(sig_networks)

        ### Multihead Attention
        self.coattn = MultiheadAttention(embed_dim=256, num_heads=1)

        ### Path Transformer + Attention Head
        path_encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, dropout=dropout, activation='relu')
        self.path_transformer = nn.TransformerEncoder(path_encoder_layer, num_layers=2)
        self.path_attention_head = Attn_Net_Gated(L=size[2], D=size[2], dropout=dropout, n_classes=1)
        self.path_rho = nn.Sequential(*[nn.Linear(size[2], size[2]), nn.ReLU(), nn.Dropout(dropout)])
        
        ### Omic Transformer + Attention Head
        omic_encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, dropout=dropout, activation='relu')
        self.omic_transformer = nn.TransformerEncoder(omic_encoder_layer, num_layers=2)
        self.omic_attention_head = Attn_Net_Gated(L=size[2], D=size[2], dropout=dropout, n_classes=1)
        self.omic_rho = nn.Sequential(*[nn.Linear(size[2], size[2]), nn.ReLU(), nn.Dropout(dropout)])
        
        ### Fusion Layer
        if self.fusion == 'concat':
            self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()])
        elif self.fusion == 'bilinear':
            self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256)
        else:
            self.mm = None
        
        ### Classifier
        self.classifier = nn.Linear(size[2], n_classes)


    def forward(self, x_path, x_omic):
        ### Bag-Level Representation
        print("*** 1. Bag-Level Representation (FC Processing) ***")
        h_path_bag = self.wsi_net(x_path).unsqueeze(1) ### path embeddings are fed through a FC layer
        h_omic = [self.sig_networks[idx].forward(sig_feat) for idx, sig_feat in enumerate(x_omic)] ### each omic signature goes through it's own FC layer
        h_omic_bag = torch.stack(h_omic).unsqueeze(1) ### omic embeddings are stacked (to be used in co-attention)
        print("Instance-Level 256 x 256 Patch Embedings (H_bag before GCA):\n", h_path_bag.shape)
        print("Genomic Embeddings (G_bag before GCA):\n", h_omic_bag.shape)
        print()
        
        ### Genomic-Guided Co-Attention
        print("*** 2. Genomic-Guided Co-Attention ***")
        h_path_coattn, A_coattn = self.coattn(h_omic_bag, h_path_bag, h_path_bag)
        print("Genomic-Guided WSI-Level Embeddings (H_bag after GCA becomes H_coattn):\n", h_path_coattn.shape)
        print("Genomic Embeddings (G_bag after GCA stays same):\n", h_omic_bag.shape)
        print("Co-Attention Matrix:\n", A_coattn[0,0,:,:].shape)
        print('- Note that the # of embeddings in H_coattn goes from 15231 -> 6')
        print()

        ### Set-Based MIL Transformers
        print("*** 3. Set-based MIL Transformers ***")
        h_path_trans = self.path_transformer(h_path_coattn)
        h_omic_trans = self.omic_transformer(h_omic_bag)
        print("H_coattn after Transformers:\n", h_path_trans.shape)
        print("G_bag after Transformers:\n", h_omic_trans.shape)
        print('- Note that attention is permutation-equivariant, so dimensions are the same')        
        print()
        
        ### Global Attention Pooling
        print("*** 4. Global Attention Pooling ***")
        A_path, h_path = self.path_attention_head(h_path_trans.squeeze(1))
        A_path = torch.transpose(A_path, 1, 0)
        h_path = torch.mm(F.softmax(A_path, dim=1) , h_path)
        h_path = self.path_rho(h_path).squeeze()
        print("Final WSI-Level Representation (h^L):\n", h_path.shape)
        
        A_omic, h_omic = self.omic_attention_head(h_omic_trans.squeeze(1))
        A_omic = torch.transpose(A_omic, 1, 0)
        h_omic = torch.mm(F.softmax(A_omic, dim=1) , h_omic)
        h_omic = self.omic_rho(h_omic).squeeze()
        print("Final Genomic Representation (g^L):\n", h_omic.shape)
        print()
        
        ### Late Fusion
        print("*** 5. Late Fusion ***")
        if self.fusion == 'bilinear':
            h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze()
        elif self.fusion == 'concat':
            h = self.mm(torch.cat([h_path, h_omic], axis=0))
        print("Final shared representation (h_final):\n", h.shape)
        print()
        
        ### Survival Layer
        logits = self.classifier(h).unsqueeze(0)
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
        hazards = torch.sigmoid(logits)
        S = torch.cumprod(1 - hazards, dim=1)
        
        attention_scores = {'coattn': A_coattn, 'path': A_path, 'omic': A_omic}
        return hazards, S, Y_hat, attention_scores, None# F.normalize(h_path_coattn, dim=2), F.normalize(h_omic_bag, dim=2)


In [3]:
model = MCAT_Surv(omic_sizes=[100, 200, 300, 400, 500, 600])
x_path = torch.randn((15231, 1024)) # 15231 patches with 1024-dim embedding size
x_omic = [torch.randn(dim) for dim in [100, 200, 300, 400, 500, 600]]
model.forward(x_path, x_omic)

*** 1. Bag-Level Representation (FC Processing) ***
Instance-Level 256 x 256 Patch Embedings (H_bag before GCA):
 torch.Size([15231, 1, 256])
Genomic Embeddings (G_bag before GCA):
 torch.Size([6, 1, 256])

*** 2. Genomic-Guided Co-Attention ***
Genomic-Guided WSI-Level Embeddings (H_bag after GCA becomes H_coattn):
 torch.Size([6, 1, 256])
Genomic Embeddings (G_bag after GCA stays same):
 torch.Size([6, 1, 256])
Co-Attention Matrix:
 torch.Size([6, 15231])
- Note that the # of embeddings in H_coattn goes from 15231 -> 6

*** 3. Set-based MIL Transformers ***
H_coattn after Transformers:
 torch.Size([6, 1, 256])
G_bag after Transformers:
 torch.Size([6, 1, 256])
- Note that attention is permutation-equivariant, so dimensions are the same

*** 4. Global Attention Pooling ***
Final WSI-Level Representation (h^L):
 torch.Size([256])
Final Genomic Representation (g^L):
 torch.Size([256])

*** 5. Late Fusion ***
Final shared representation (h_final):
 torch.Size([256])



(tensor([[0.5072, 0.5105, 0.5107, 0.4955]], grad_fn=<SigmoidBackward>),
 tensor([[0.4928, 0.2412, 0.1180, 0.0595]], grad_fn=<CumprodBackward>),
 tensor([[2]]),
 {'coattn': tensor([[[[-5.2956e-02, -2.7924e-02,  6.8047e-03,  ..., -1.4772e-01,
             -2.3292e-01,  4.3942e-02],
            [ 8.3089e-02, -4.5703e-02,  2.0750e-01,  ...,  9.3415e-02,
             -1.7095e-01, -2.6080e-01],
            [ 2.6508e-01,  3.9206e-01,  3.2596e-01,  ...,  2.1387e-01,
              1.9855e-01,  2.3128e-01],
            [-1.3625e-01, -2.4742e-01, -2.6088e-01,  ..., -2.7371e-01,
             -1.0058e-01, -2.1232e-01],
            [ 2.7769e-01,  1.5632e-01,  3.0384e-01,  ...,  7.5612e-02,
              3.0659e-02,  2.5983e-04],
            [ 2.4861e-01,  6.0646e-02,  8.7220e-02,  ..., -6.3945e-02,
              2.3488e-02,  1.9480e-01]]]], grad_fn=<ViewBackward>),
  'path': tensor([[ 0.1641, -0.1381, -0.0578, -0.1118,  0.1523, -0.2283]],
         grad_fn=<TransposeBackward0>),
  'omic': tensor([[ 0