In [1]:
import torch
import torch.nn as nn

In [2]:
#CTransPath
#https://github.com/Xiyue-Wang/TransPath.git
import timm
from timm.models.layers.helpers import to_2tuple

class ConvStem(nn.Module):

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
        super().__init__()

        assert patch_size == 4
        assert embed_dim % 8 == 0

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten


        stem = []
        input_dim, output_dim = 3, embed_dim // 8
        for l in range(2):
            stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
            stem.append(nn.BatchNorm2d(output_dim))
            stem.append(nn.ReLU(inplace=True))
            input_dim = output_dim
            output_dim *= 2
        stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
        self.proj = nn.Sequential(*stem)

        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x

def ctranspath():
    model = timm.create_model("swin_tiny_patch4_window7_224", embed_layer=ConvStem, pretrained=False)
    return model

td = torch.load("model_weights/CtransPath.pth")
model_ctranspath = ctranspath()
model_ctranspath.head = nn.Identity()
model_ctranspath.load_state_dict(td['model'], strict=True)
print(model_ctranspath(torch.zeros(1,3,224,224)).shape)

torch.Size([1, 768])


In [3]:
#Prov-GigaPath
#https://github.com/prov-gigapath/prov-gigapath.git
import timm

model_gigapath = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=False)
model_gigapath.load_state_dict(torch.load("model_weights/Prov_Gigapath.pt"))
print(model_gigapath(torch.zeros(1,3,224,224)).shape)

torch.Size([1, 1536])


In [4]:
#PathDino
#https://github.com/KimiaLabMayo/PathDino.git
from PathDino import get_pathDino_model

model_pathdino = get_pathDino_model(weights_path="model_weights/PathDino512.pth")
print(model_pathdino(torch.zeros(1,3,512,512)).shape)

torch.Size([1, 384])


In [5]:
#UNI
#https://github.com/mahmoodlab/UNI.git
import timm

model_uni = timm.create_model(
    "vit_large_patch16_224",
    img_size=224,
    patch_size=16,
    init_values=1e-5,
    num_classes=0,
    dynamic_img_size=True,
)
model_uni.load_state_dict(torch.load("model_weights/uni.bin", map_location="cpu"), strict=True)
print(model_uni(torch.zeros(1,3,224,224)).shape)

torch.Size([1, 1024])


In [6]:
#Virchow
#https://huggingface.co/paige-ai/Virchow
import timm
from timm.layers import SwiGLUPacked

model_virchow = timm.create_model("hf-hub:paige-ai/Virchow", pretrained=False, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU)
model_virchow.load_state_dict(torch.load("model_weights/virchow.pth"))
print(model_virchow(torch.zeros(1,3,224,224)).shape)
print(model_virchow(torch.zeros(1,3,224,224))[:,0].shape)

torch.Size([1, 257, 1280])
torch.Size([1, 1280])


In [7]:
#BEPH
#https://github.com/Zhcyoung/BEPH.git
from mmselfsup.apis import init_model
from mmengine.config import Config

cfg = Config.fromfile("beitv2_vit.py")
model_beph = init_model(cfg, "model_weights/BEPH_backbone.pth", device='cpu').backbone
print(model_beph(torch.zeros(1,3,224,224))[0].shape)

Loads checkpoint by local backend from path: ./BEPH_backbone.pth
The model and loaded state dict do not match exactly

unexpected key in source state_dict: backbone.mask_token, backbone.rel_pos_bias.relative_position_bias_table, backbone.rel_pos_bias.relative_position_index

missing keys in source state_dict: backbone.layers.0.attn.relative_position_bias_table, backbone.layers.0.attn.relative_position_index, backbone.layers.1.attn.relative_position_bias_table, backbone.layers.1.attn.relative_position_index, backbone.layers.2.attn.relative_position_bias_table, backbone.layers.2.attn.relative_position_index, backbone.layers.3.attn.relative_position_bias_table, backbone.layers.3.attn.relative_position_index, backbone.layers.4.attn.relative_position_bias_table, backbone.layers.4.attn.relative_position_index, backbone.layers.5.attn.relative_position_bias_table, backbone.layers.5.attn.relative_position_index, backbone.layers.6.attn.relative_position_bias_table, backbone.layers.6.attn.relativ

In [8]:
#Hibou
#https://github.com/HistAI/hibou.git

from hibou import build_model

model_hibou = build_model(weights_path="model_weights/hibou-b.pth")
print(model_hibou(torch.zeros(1,3,224,224)).shape)

<All keys matched successfully>
torch.Size([1, 768])


In [9]:
#HIPT
#https://github.com/mahmoodlab/HIPT.git
import os
import vision_transformer as vits

def get_vit256(pretrained_weights, arch='vit_small', device=torch.device('cuda:0')):
    r"""
    Builds ViT-256 Model.
    
    Args:
    - pretrained_weights (str): Path to ViT-256 Model Checkpoint.
    - arch (str): Which model architecture.
    - device (torch): Torch device to save model.
    
    Returns:
    - model256 (torch.nn): Initialized model.
    """
    
    checkpoint_key = 'teacher'
    device = torch.device("cpu")
    model256 = vits.__dict__[arch](patch_size=16, num_classes=0)
    for p in model256.parameters():
        p.requires_grad = False
    model256.eval()
    model256.to(device)

    if os.path.isfile(pretrained_weights):
        state_dict = torch.load(pretrained_weights, map_location="cpu")
        if checkpoint_key is not None and checkpoint_key in state_dict:
            print(f"Take key {checkpoint_key} in provided checkpoint dict")
            state_dict = state_dict[checkpoint_key]
        # remove `module.` prefix
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        # remove `backbone.` prefix induced by multicrop wrapper
        state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
        msg = model256.load_state_dict(state_dict, strict=False)
        print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))s
        
    return model256

model_hipt = get_vit256(pretrained_weights="model_weights/vit256_small_dino.pth")
print(model_hipt(torch.zeros(1,3,256,256)).shape)

torch.Size([1, 384])


In [10]:
#CONCH
#https://github.com/mahmoodlab/CONCH.git
from conch.open_clip_custom import create_model_from_pretrained

model_conch, _ = create_model_from_pretrained('conch_ViT-B-16', "hf_hub:MahmoodLab/conch", hf_auth_token="_")
print(model_conch.encode_image(torch.zeros(1,3,448,448), proj_contrast=False, normalize=False).shape)

torch.Size([1, 512])


In [11]:
#Pathoduet
#https://github.com/openmedlab/PathoDuet.git
from vits import VisionTransformerMoCo

model_pathoduet = VisionTransformerMoCo(pretext_token=True, global_pool='avg')
model_pathoduet.load_state_dict(torch.load('model_weights/checkpoint_HE.pth', map_location="cpu"), strict=False)
model_pathoduet.head = nn.Identity()
print(model_pathoduet(torch.zeros(1,3,224,224)).shape)

torch.Size([1, 768])


In [12]:
#Ciga. et al.
#https://github.com/ozanciga/self-supervised-histopathology.git
import torchvision

def load_model_weights(model, weights):
    model_dict = model.state_dict()
    weights = {k: v for k, v in weights.items() if k in model_dict}
    if weights == {}:
        print('No weight could be loaded..')
    model_dict.update(weights)
    model.load_state_dict(model_dict)

    return model


model_ciga = torchvision.models.__dict__['resnet18'](pretrained=False)
state_dict = torch.load('model_weights/tenpercent_resnet18.ckpt', map_location='cpu')['state_dict']
for key in list(state_dict.keys()):
    state_dict[key.replace('model.', '').replace('resnet.', '')] = state_dict.pop(key)

model_ciga = load_model_weights(model_ciga, state_dict)
model_ciga.fc = nn.Identity()
print(model_ciga(torch.zeros(1,3,224,224)).shape)

torch.Size([1, 512])


In [13]:
#Phikon
#https://github.com/owkin/HistoSSLscaling.git
from rl_benchmarks.models import iBOTViT

model_phikon = iBOTViT(architecture="vit_base_pancan", encoder="teacher", weights_path= "model_weights/ibot_vit_base_pancan.pth")

print(model_phikon(torch.zeros(1,3,224,224)).shape)

[32m2025-02-03 14:32:07.519[0m | [1mINFO    [0m | [36mrl_benchmarks.models.feature_extractors.ibot_vit[0m:[36m__init__[0m:[36m78[0m - [1mPretrained weights found at model_weights/ibot_vit_base_pancan.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v', 'head.last_layer2.weight_g', 'head.last_layer2.weight_v'])[0m


torch.Size([1, 768])
