In [7]:
# internal imports 
from MIL import build_model 
from utils.generic_utils import print_network 

#external imports 
import os 

import torch

# Load Best-performing model for Masses (FPN-AbMIL)

In [31]:
class Args:
    def __init__(self):
        self.dataset = 'ViNDr'
        self.label = 'Mass'
        self.n_class = 1
        self.train = False
        # Data settings
        self.img_size = [1520, 912]
        self.patch_size = 512 
        # Mammo-CLIP settings
        self.clip_chk_pt_path = 'Mammo-CLIP_Checkpoints/b2-model-best-epoch-10.tar'
        self.arch = 'upmc_breast_clip_det_b5_period_n_ft'
        self.model_type ="Classifier"
        self.feat_dim = 352
        self.feature_extraction = 'offline' # uses pre-trained features from Mammo-clip. If desired to use input images, set to 'online'
        # FPN-MIL model parameters
        self.mil_type = 'pyramidal_mil'
        self.nested_model = False
        self.multi_scale_model = 'fpn'
        self.fpn_dim = 256 
        self.upsample_method = 'nearest'
        self.norm_fpn = False
        self.drop_classhead = 0.0
        self.map_prob_func = 'softmax'
        self.type_mil_encoder = 'mlp'
        self.fcl_encoder_dim = 256 
        self.fcl_dropout = 0.25
        self.pooling_type = 'gated-attention'
        self.deep_supervision = True 
        self.type_scale_aggregator = 'gated-attention'
        self.fcl_attention_dim = 128 
        self.drop_attention_pool = 0.25
        self.scales = [16, 32, 128]
        
# Create an instance of the Args class
args = Args()

In [32]:
best_model_masses = build_model(args)

In [33]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load('FPN-MIL_Checkpoints/best_FPN-MIL_mass.pth', map_location='cpu')
best_model_masses.load_state_dict(checkpoint['model'], strict=False)

best_model_masses = best_model_masses.to(device)
print_network(best_model_masses)

PyramidalMILmodel(
  (inst_encoder): FeaturePyramidNetwork(
    (inner_blocks): ModuleDict(
      (inner_block_0): Sequential(
        (0): Conv2d(120, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Identity()
      )
      (inner_block_1): Sequential(
        (0): Conv2d(352, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Identity()
      )
    )
    (layer_blocks): ModuleDict(
      (layer_block_0): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Identity()
      )
      (layer_block_1): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Identity()
      )
    )
  )
  (side_inst_aggregator): ModuleDict(
    (encoders): ModuleDict(
      (encoder_16): ModuleList(
        (0): Sequential(
          (0): Linear(in_features=256, out_features=256, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.25, inplace=False)
        )
      )
      (encoder_32): Mod

# Load Best-performing model for Calcifications (FPN-SetTrans)

In [35]:
class Args:
    def __init__(self):
        self.dataset = 'ViNDr'
        self.label = 'Suspicious_Calcification'
        self.n_class = 1
        self.train = False
        # Data settings
        self.img_size = [1520, 912]
        self.patch_size = 512 
        # Mammo-CLIP settings
        self.clip_chk_pt_path = 'Mammo-CLIP_Checkpoints/b2-model-best-epoch-10.tar'
        self.arch = 'upmc_breast_clip_det_b5_period_n_ft'
        self.model_type ="Classifier"
        self.feat_dim = 352
        self.feature_extraction = 'offline' # uses pre-trained features from Mammo-clip. If desired to use input images, set to 'online'
        # FPN-MIL model parameters
        self.mil_type = 'pyramidal_mil'
        self.nested_model = False
        self.multi_scale_model = 'fpn'
        self.fpn_dim = 256 
        self.upsample_method = 'nearest'
        self.norm_fpn = False
        self.drop_classhead = 0.0
        self.map_prob_func = 'softmax'
        self.type_mil_encoder = 'isab'
        self.fcl_encoder_dim = 256 
        self.isab_num_heads = 4
        self.num_encoder_blocks = 2 
        self.pooling_type = 'pma'
        self.pma_num_heads = 1
        self.drop_mha = 0.0
        self.trans_layer_norm = True
        self.deep_supervision = True 
        self.type_scale_aggregator = 'gated-attention'
        self.fcl_attention_dim = 128 
        self.drop_attention_pool = 0.25
        self.scales = [16, 32, 128]
        
# Create an instance of the Args class
args = Args()

In [36]:
best_model_calcifications = build_model(args)

In [37]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load('FPN-MIL_Checkpoints/best_FPN-MIL_calcifications.pth', map_location='cpu')
best_model_calcifications.load_state_dict(checkpoint['model'], strict=False)

best_model_calcifications = best_model_calcifications.to(device)
print_network(best_model_calcifications)

PyramidalMILmodel(
  (inst_encoder): FeaturePyramidNetwork(
    (inner_blocks): ModuleDict(
      (inner_block_0): Sequential(
        (0): Conv2d(120, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Identity()
      )
      (inner_block_1): Sequential(
        (0): Conv2d(352, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Identity()
      )
    )
    (layer_blocks): ModuleDict(
      (layer_block_0): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Identity()
      )
      (layer_block_1): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Identity()
      )
    )
  )
  (side_inst_aggregator): ModuleDict(
    (encoders): ModuleDict(
      (encoder_16): ModuleList(
        (0-1): 2 x InducedSetAttentionBlock(256, d_hidden = 256, num_induced_points=38, heads=4, layer_norm=True, activation=softmax)
      )
      (encoder_32): ModuleList(
        (0-1): 2 x InducedSe