In [1]:
import sys
sys.path.append('../Pretraining/')

import torch
import torch.nn.functional as F

from vision_transformer256 import vit_small
from models.model_hierarchical_mil import HIPT_LGP_FC

### Example Input
Input: $[M \times L \times D]$ Tensor, where:
- M: Number of (non-overlapping) $[4096 \times 4096]$ Image regions in a WSI (On Average: 38)
- L: Number of (non-overlapping) $[256 \times 256]$ Image Patches in a $[4096 \times 4096]$ Image Region (Defaullt: 256)
- D: Embedding Dimension (Default: 384)

### 1. Example Forward Pass (with Pre-Extracted $x_{256}$ Features)

In [2]:
x = torch.randn(38,256,384)
self = HIPT_LGP_FC()
self.forward(x)

# of Patches: 196




(tensor([[ 0.0744,  0.3583, -0.0675, -0.0432]], grad_fn=<AddmmBackward>),
 tensor([[0.2448, 0.3252, 0.2124, 0.2176]], grad_fn=<SoftmaxBackward>),
 tensor([[1]]),
 None,
 None)

### 2. Forward Pass Shape Walkthrough (with Pre-Extracted $x_{256}$ Features)

In [3]:
x_256 = torch.randn(38,256,384)
print("1. Input Tensor:", x_256.shape)
print()
x_256 = x_256.unfold(1, 16, 16).transpose(1,2)
print("2. Re-Arranging 1D-(Seq Length of # [256x256] tokens in [4096x4096] Region) Axis to be a 2D-Grid:", x_256.shape)
print()

h_4096 = self.local_vit(x_256)
print("3. Seq length of [4096x4096] Tokens in the WSI", h_4096.shape)
print()

h_4096 = self.global_phi(h_4096)
h_4096 = self.global_transformer(h_4096.unsqueeze(1)).squeeze(1)
A_4096, h_4096 = self.global_attn_pool(h_4096)  
A_4096 = torch.transpose(A_4096, 1, 0)
A_4096 = F.softmax(A_4096, dim=1) 
h_path = torch.mm(A_4096, h_4096)
h_WSI = self.global_rho(h_path)
print("4. ViT-4K + Global Attention Pooling to get WSI-Level Embedding:", h_WSI.shape)

1. Input Tensor: torch.Size([38, 256, 384])

2. Re-Arranging 1D-(Seq Length of # [256x256] tokens in [4096x4096] Region) Axis to be a 2D-Grid: torch.Size([38, 384, 16, 16])

3. Seq length of [4096x4096] Tokens in the WSI torch.Size([38, 192])

4. ViT-4K + Global Attention Pooling to get WSI-Level Embedding: torch.Size([1, 192])
