In [2]:
import torch
from monai.networks.nets import resnet50
from monai.config import print_config

model = resnet50(
    pretrained=False,  # Load MedicalNet pretrained weights (optional)
    spatial_dims=3,   # For 3D medical images
    n_input_channels=1,  # Single channel input for grayscale medical images
    num_classes=8     # Number of output classes (classification task)
)

model.load_state_dict(torch.load('best_loss_model.pth'))

<All keys matched successfully>

In [13]:
print(list(model.children()))

[Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False), BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False), Sequential(
  (0): ResNetBottleneck(
    (conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (1): BatchNorm3d(256, eps=1

In [3]:
from copy import deepcopy
headless_m = deepcopy(model)
headless_m.fc = None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
headless_m.to(device)

cuda


ResNet(
  (conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act): ReLU(inplace=True)
  (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): ResNetBottleneck(
      (conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv3d(64, 25

In [16]:
# Modify the forward method to output the layer before the classification head
class ResNet50WithoutHead(torch.nn.Module):
    def __init__(self, original_model):
        super(ResNet50WithoutHead, self).__init__()
        # Copy the layers up to the adaptive avg pool (before the projection head)
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        #self.pooling = original_model.global_pool  # This is typically the global average pooling layer
    
    def forward(self, x):
        x = self.features(x)
        #x = self.pooling(x)
        return x

In [17]:
modified_model = ResNet50WithoutHead(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
modified_model.to(device)

cuda


ResNet50WithoutHead(
  (features): Sequential(
    (0): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): ResNetBottleneck(
        (conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (dow

In [4]:
import tifffile as tiff
import numpy as np

# Load the TIFF image
superpixels = torch.from_numpy(tiff.imread('../copick_10439/segm_16193.tif'))

In [5]:
from monai.transforms import (
    Compose, 
    EnsureChannelFirst, 
    ScaleIntensityRange, 
    EnsureType, 
    Resize
)

# Define the transformation pipeline for NumPy arrays
transforms = Compose([
    EnsureChannelFirst(channel_dim="no_channel"),  # Handle no channel case (grayscale image)
    ScaleIntensityRange(a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),  # Intensity scaling
    EnsureType(),  # Ensure it's a torch.Tensor
    Resize(spatial_size=(64, 64, 64))  # Resize to the specified 3D spatial size
])


In [10]:
from scipy import ndimage
from tqdm import tqdm

resnet_emdeddings = []
labels = []
for label,obj in tqdm(enumerate(ndimage.find_objects(superpixels), start=1)):
    if obj is None:
        continue
    
    inputs = transforms(superpixels[obj]).unsqueeze(0).to(device)
    outputs = modified_model(inputs)
    print(outputs.shape)
    break
    resnet_emdeddings.append(outputs.detach().cpu())
    labels.append(label)

0it [00:00, ?it/s]

torch.Size([1, 2048, 1, 1, 1])



