In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader

# Load the 800-epoch MoCo v2 checkpoint
checkpoint_url = "https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar"
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")

# Define a standard ResNet-50 backbone
resnet50 = models.resnet50()
encoder = nn.Sequential(*list(resnet50.children())[:-1])  # Remove final FC layer

# Load MoCo weights into the model (encoder_q is the online encoder)
state_dict = checkpoint['state_dict']
new_state_dict = {}

for k, v in state_dict.items():
    if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
        new_k = k.replace('module.encoder_q.', '')
        new_state_dict[new_k] = v

msg = encoder.load_state_dict(new_state_dict, strict=False)
print("Loaded keys:", msg)

# Wrap in a model that outputs flattened features
class MoCoResNetBackbone(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder  # Output: (B, 2048, 1, 1)

    def forward(self, x):
        x = self.encoder(x)
        return torch.flatten(x, 1)  # Output: (B, 2048)

# Instantiate final model
moco_backbone = MoCoResNetBackbone(encoder)
moco_backbone.eval()

# Example input
x = torch.randn(2, 3, 224, 224)
with torch.no_grad():
    features = moco_backbone(x)
print("Fedature shape:", features.shape)  # Should be [2, 2048]

Loaded keys: _IncompatibleKeys(missing_keys=['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '4.0.conv1.weight', '4.0.bn1.weight', '4.0.bn1.bias', '4.0.bn1.running_mean', '4.0.bn1.running_var', '4.0.conv2.weight', '4.0.bn2.weight', '4.0.bn2.bias', '4.0.bn2.running_mean', '4.0.bn2.running_var', '4.0.conv3.weight', '4.0.bn3.weight', '4.0.bn3.bias', '4.0.bn3.running_mean', '4.0.bn3.running_var', '4.0.downsample.0.weight', '4.0.downsample.1.weight', '4.0.downsample.1.bias', '4.0.downsample.1.running_mean', '4.0.downsample.1.running_var', '4.1.conv1.weight', '4.1.bn1.weight', '4.1.bn1.bias', '4.1.bn1.running_mean', '4.1.bn1.running_var', '4.1.conv2.weight', '4.1.bn2.weight', '4.1.bn2.bias', '4.1.bn2.running_mean', '4.1.bn2.running_var', '4.1.conv3.weight', '4.1.bn3.weight', '4.1.bn3.bias', '4.1.bn3.running_mean', '4.1.bn3.running_var', '4.2.conv1.weight', '4.2.bn1.weight', '4.2.bn1.bias', '4.2.bn1.running_mean', '4.2.bn1.running_var', '4.2.conv2.weight', '4.2.bn2.weight

In [None]:
from CellDataset import CellDataset, moco_transform
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

batch_size = 64

dataset = CellDataset(transform=moco_transform)
moco_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
predictions=np.zeros((len(dataset),2048))

moco_backbone.eval().to(device)

for i, (view1, _) in tqdm(enumerate(moco_dataloader), total=len(moco_dataloader), ncols=100):
    print(view1.to(device).shape)

predictions

cuda:0


  0%|                                                                     | 0/15876 [00:00<?, ?it/s]