In [1]:
!git clone https://github.com/facebookresearch/msn.git
!cd msn
!pip install medmnist

Cloning into 'msn'...
remote: Enumerating objects: 69, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 69 (delta 4), reused 1 (delta 1), pack-reused 58 (from 1)[K
Receiving objects: 100% (69/69), 219.49 KiB | 9.14 MiB/s, done.
Resolving deltas: 100% (26/26), done.
Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Downloading fire-0.7.1-py3-none-any.whl (115 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fire, medmnist
Successfully installed fire-0.7.1 medmnist-3.0.2


In [None]:
%%writefile "/content/msn/configs/pretrain/msn_vits16.yaml"
criterion:
  ent_weight: 0.0               # weight for entropy term in the loss fxn
  final_sharpen: 0.25           # sharpening factor applied to final output probabilities
  me_max: true                  # mean entropy maximisation regulariser
  memax_weight: 1.0             # weight of me-max
  num_proto: 1024               # no. of prototype vectors (clusters)
  start_sharpen: 0.25           # initial sharpening (at the start of training)
  temperature: 0.1              # temp. param. for softmax
  batch_size: 64                # batch size used to compute loss
  use_ent: true                 # includes entropy term in loss
  use_sinkhorn: true            # uses Sinkhorn-Knopp algorithm for balanced cluster assignments
data:
  color_jitter_strength: 0.5    # intensity of color jitter augmentation
  pin_mem: true                 # uses pin memory for faster GPU training
  num_workers: 6                # no. of CPU processes to use for data loading
  image_folder: pathmnist       # folder containing image dataset
  label_smoothing: 0.0          # amount of label smoothing for classification
  patch_drop: 0.15              # prob. of dropping image patches
  rand_size: 224                # size of random/global crops
  focal_size: 96                # size of focal/local crops
  rand_views: 1                 # no. of random/global views per image
  focal_views: 10               # no. of focal/local views per image
  root_path: null               # base path
logging:
  folder: logs/                    # directory where logs/results will be stored
  write_tag: msn-pathmnist-train   # tag for this trial
meta:
  bottleneck: 1                   # no of bottleneck layers in projection head
  copy_data: false                # copies dataset locally before training
  drop_path_rate: 0.0             # prob. of dropping entire residual paths
  hidden_dim: 2048                # hidden layer dim. in projection head
  load_checkpoint: false          # loads model weight from saved cp
  model_name: deit_small          # model architecture used (Data-efficient Image Transformer)
  output_dim: 256                 # output embedding dim. (from projection head)
  read_checkpoint: null           # path to cp file to load
  use_bn: true                    # uses batch normalisation
  use_fp16: true                  # mixed precision for faster training and lower memory usage
  use_pred_head: true             # uses prediction head (false if only using encoder features)
optimization:
  clip_grad: 3.0                  # max. gradient norm for clipping
  epochs: 100                     # no. of training epochs
  final_lr: 1.0e-06               # learning rate (lr) after cosine decay schedule
  final_weight_decay: 0.4         # weight decay (wd) after schedule
  lr: 0.001                       # base lr
  start_lr: 0.0002                # starting lr (for warmup phase)
  warmup: 15                      # no. of warmup epochs
  weight_decay: 0.04              # initial wd for regularisation

Overwriting /content/msn/configs/pretrain/msn_vits16.yaml


In [None]:
%%writefile "/content/msn/configs/eval/lineval_msn_vits16.yaml"
meta:
  model_name: deit_small
  master_port: 8888
  load_checkpoint: false
  training: true
  copy_data: false
  device: cuda:0
data:
  root_path: null
  image_folder: pathmnist
  num_classes: 9
  train_subset_frac: 1             # train_subset_frac was changed based on label %
  train_subset_seed: 42
optimization:
  weight_decay: 0.0
  lr: 3.0
  epochs: 30
  num_blocks: 1
  normalize: true
logging:
  folder: logs_sub/
  write_tag: msn-100-sub-linear    # write_tag was changed based on label %
  pretrain_path: msn-pathmnist-train-latest.pth.tar

In [None]:
%%writefile "/content/msn/src/data_manager.py"
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the Creative Commons Attribution–NonCommercial 4.0 International License
#
# Modifications: Adapted for use in 'Self-supervised learning applied to unlabelled histopathology imaging data –
# a comparison between Masked Siamese Networks and Self-Distillation with No Labels', 2025.
#

import numpy as np
import os
import subprocess
import time

# ADAPTED FOR USE
from medmnist import PathMNIST
from torchvision import transforms
from torch.utils.data import Dataset, Subset
from torchvision.transforms.functional import to_pil_image

from logging import getLogger

from PIL import ImageFilter
from PIL import Image

import torch
import torchvision

_GLOBAL_SEED = 0        # setting global seed for reproducibility
logger = getLogger()    # setting up logger object


# ADAPTED FOR USE
def _labels_from_wrapper(ds):
    # works with PathMNISTWrapper before subsetting
    if hasattr(ds, "labels"):
        return np.asarray(ds.labels).squeeze()
    if hasattr(ds, "dataset") and hasattr(ds.dataset, "labels"):
        return np.asarray(ds.dataset.labels).squeeze()
    raise RuntimeError("Could not find labels on dataset/wrapper for subsetting.")

# ADAPTED FOR USE
def _balanced_indices(labels, per_class=100, seed=42):
    rng = np.random.default_rng(seed)
    labels = np.asarray(labels).squeeze()
    idxs = []
    for c in np.unique(labels):
        cls = np.where(labels == c)[0]
        take = min(per_class, len(cls))
        idxs.extend(rng.choice(cls, size=take, replace=False))
    rng.shuffle(idxs)
    return list(map(int, idxs))

# ADAPTED FOR USE
def init_data(          # defining function to prepare DataLoader for dataset
    transform,          # data preprocesing and augmentation
    batch_size,         # no. of batches
    pin_mem=True,       # uses pin memory for faster GPU training
    num_workers=6,      # no. of CPU processes to use for data loading
    world_size=1,       # no. of distributed processes
    rank=0,             # ID of current process (0 = main process)
    root_path=None,     # base path
    image_folder=None,  # dataset name
    training=True,      # loading training, val or test split?
    copy_data=False,    # copies dataset locally before training
    drop_last=True,     # drop last batch due to incompleteness
    # choosing explicit split when not training; and training label budget
    split=None,                         # "train" | "val" | "test"
    train_subset_frac=None,             # 0.1 for 10% labels
    train_subset_size=None,             # put number of samples
    train_subset_per_class=None,        # default 100 per class
    train_subset_seed=42,               # reproducibility
):
    if image_folder == "pathmnist":   # creating instance of PathMNIST dataset
      # choose split
      split_name = split if split is not None else ('train' if training else 'test')
      dataset = PathMNISTWrapper(split=split_name, transform=transform)

      # train label-budget (few-label linear eval); only when training on labels
      if training and any(x is not None for x in (train_subset_frac, train_subset_size, train_subset_per_class)):
          labels = _labels_from_wrapper(dataset)
          N = len(labels)
          rng = np.random.default_rng(int(train_subset_seed))
          if train_subset_per_class is not None:
              idx = _balanced_indices(labels, per_class=int(train_subset_per_class), seed=int(train_subset_seed))
          elif train_subset_size is not None:
              k = max(1, min(int(train_subset_size), N))
              idx = list(rng.choice(np.arange(N), size=k, replace=False))
          else:  # frac
              k = max(1, int(round(float(train_subset_frac) * N)))
              idx = list(rng.choice(np.arange(N), size=k, replace=False))
          dataset = Subset(dataset, idx)
          logger.info(f"pathmnist TRAIN subset enabled: using {len(idx)} labeled samples "
                      f"(mode={'per-class' if train_subset_per_class is not None else 'size' if train_subset_size is not None else 'frac'})")


    # creates distributed sampler that ensures that each process gets a
    # different subset of the data
    dist_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset=dataset,
        num_replicas=world_size,    # total no. of processes
        rank=rank)                  # rank (ID) of the current process
    # building DataLoader to feed data into the model
    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=dist_sampler,       # handles data partitioning across GPUs
        batch_size=batch_size,      # no. of samples per batch
        drop_last=drop_last,        # decided whether to drop last batch
        pin_memory=pin_mem,         # speeds up data transfer to CUDA
        num_workers=num_workers,    # no. of subprocesses for loading data
        )
    # logs that DataLoader has been set up successfully
    logger.info('pathmnist unsupervised data loader created')

    # returns DataLoader and its associated distributed sampler
    return (data_loader, dist_sampler)


def make_transforms(
    rand_size=224,      # random views (global crops) - 224x224
    focal_size=96,      # focal views (local crops) - 96x96
    rand_crop_scale=(0.3, 1.0),     # global crop is 30%-100% of orignal area
    focal_crop_scale=(0.05, 0.3),   # local crop is 5%-30% of original area
    color_jitter=1.0,
    rand_views=2,       # no. of global views per sample
    focal_views=10,     # no. of focal views per sample
):
    # prints message to logger that the transforms are being built
    logger.info('making pathmnist data transforms')

    def get_color_distortion(s=1.0):
        # s is the strength of color distortion.
        color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
        # applies jitter with 80% probability
        rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
        # converts to grayscale with 20% probability
        rnd_gray = transforms.RandomGrayscale(p=0.2)
        color_distort = transforms.Compose([
            rnd_color_jitter,
            rnd_gray])
        return color_distort

    rand_transform = transforms.Compose([
        # randomly crops and resizes image
        transforms.RandomResizedCrop(rand_size, scale=rand_crop_scale),
        # flips the image horizontally with 50% probability
        transforms.RandomHorizontalFlip(),
        # applies color jitter and grayscale
        get_color_distortion(s=color_jitter),
        # blurs image with 50% probability
        GaussianBlur(p=0.5),
        # converts PIL image to PyTorch tensor
        transforms.ToTensor(),
        # applies mean and std normalisation
        # ADAPTED FOR USE
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    focal_transform = transforms.Compose([
        # randomly crops and resizes image
        transforms.RandomResizedCrop(focal_size, scale=focal_crop_scale),
        # flips the image horizontally with 50% probability
        transforms.RandomHorizontalFlip(),
        # applies color jitter and grayscale
        get_color_distortion(s=color_jitter),
        # blurs image with 50% probability
        GaussianBlur(p=0.5),
        # converts PIL image to PyTorch tensor
        transforms.ToTensor(),
        # applies mean and std normalisation
        # ADAPTED FOR USE
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    # creates MultiViewTransform object that will
    # apply rand_transform for rand_views times
    # apply focal_transform for focal_views times
    transform = MultiViewTransform(
        rand_transform=rand_transform,
        focal_transform=focal_transform,
        rand_views=rand_views,
        focal_views=focal_views
    )
    return transform


# creates multiple augmented views (random and focal) of a single input image
class MultiViewTransform(object):

    def __init__(
        self,
        rand_transform=None,
        focal_transform=None,
        rand_views=1,         # no. of gloabl views to be generated per image
        focal_views=1,        # no. of local views to be generated per image
    ):
        self.rand_views = rand_views
        self.focal_views = focal_views
        self.rand_transform = rand_transform
        self.focal_transform = focal_transform

    def __call__(self, img):
        img_views = []        # initialising empty list to store the
                              # augmented views

        # -- generate random views + adds them to the list
        if self.rand_views > 0:
            img_views += [self.rand_transform(img) for i in range(self.rand_views)]

        # -- generate focal views + adds them to the list
        if self.focal_views > 0:
            img_views += [self.focal_transform(img) for i in range(self.focal_views)]

        # returns a list of all transformed views of the image
        return img_views


class GaussianBlur(object):
    def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
        self.prob = p       # probability of applying the blur
        self.radius_min = radius_min    # min blur radius
        self.radius_max = radius_max    # max blur radius

    def __call__(self, img):
        # returns 1 with prob p, returns 0 with prob 1-p
        if torch.bernoulli(torch.tensor(self.prob)) == 0:
            return img

        radius = self.radius_min + torch.rand(1) * (self.radius_max - self.radius_min)
        # applies the blur using the randomly sampled radius above, with range
        # [radius_min, radius_max]
        return img.filter(ImageFilter.GaussianBlur(radius=radius.item()))

# ADAPTED FOR USE
class PathMNISTWrapper(Dataset):
  def __init__(self, split='train', transform=None):
    self.dataset = PathMNIST(split=split, download=True)
    self.transform = transform

  def __len__(self):
    return len(self.dataset)    # returns no. of samples in pathmnist dataset

  def __getitem__(self, idx):
    img, label = self.dataset[idx]  # retrieves the image and label at index idx
    img = self.transform(img)
    # make label a scalar int
    y = int(label) if isinstance(label, (int, np.integer)) else int(label[0])
    return img, y

Overwriting /content/msn/src/data_manager.py


In [None]:
%%writefile "/content/msn/src/deit.py"
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License")
#

"""
Mostly copy-paste from timm library.
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import math
from functools import partial
import numpy as np

import torch
import torch.nn as nn

from src.utils import trunc_normal_


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)


class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=nn.LayerNorm,
                 conv_stem=False, conv_stem_channels=None, conv_stem_strides=None, **kwargs):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim

        if conv_stem:
            self.patch_embed = ConvEmbed(conv_stem_channels, conv_stem_strides,
                                         in_chans=in_chans, img_size=img_size)
        else:
            self.patch_embed = PatchEmbed(
                img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        # Classifier head
        self.fc = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        self.pred = None

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if isinstance(m, nn.Conv2d) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x, return_before_head=False, patch_drop=0.):
        if not isinstance(x, list):
            x = [x]
        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in x]),
            return_counts=True,
        )[1], 0)
        start_idx = 0
        for end_idx in idx_crops:
            _h = self.forward_features(torch.cat(x[start_idx:end_idx]), patch_drop)
            _z = self.forward_head(_h)
            if start_idx == 0:
                h, z = _h, _z
            else:
                h, z = torch.cat((h, _h)), torch.cat((z, _z))
            patch_drop = 0.
            start_idx = end_idx

        if return_before_head:
            return h, z
        return z

    def forward_head(self, x):
        if self.pred is not None:
            return self.pred(x)
        return x

    def forward_features(self, x, patch_drop):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        x = x + pos_embed
        x = self.pos_drop(x)

        if patch_drop > 0:
            patch_keep = 1. - patch_drop
            T_H = int(np.floor((x.shape[1]-1)*patch_keep))
            perm = 1 + torch.randperm(x.shape[1]-1)[:T_H]  # keep class token
            idx = torch.cat([torch.zeros(1, dtype=perm.dtype, device=perm.device), perm])
            x = x[:, idx, :]

        for blk in self.blocks:
            x = blk(x)
        if self.norm is not None:
            x = self.norm(x)
        x = x[:, 0]
        if self.fc is not None:
            x = self.fc(x)
        return x

    def forward_blocks(self, x, num_blocks=1, patch_drop=0.):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        x = x + pos_embed
        x = self.pos_drop(x)

        if patch_drop > 0:
            patch_keep = 1. - patch_drop
            T_H = int(np.floor((x.shape[1]-1)*patch_keep))
            perm = 1 + torch.randperm(x.shape[1]-1)[:T_H]  # keep class token
            idx = torch.cat([torch.zeros(1, dtype=perm.dtype, device=perm.device), perm])
            x = x[:, idx, :]

        cls_x = []
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)
            if (len(self.blocks) - i) <= num_blocks:
                cls_x.append(x[:, 0])

        return torch.cat(cls_x, dim=-1)

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)

    def forward_selfattention(self, x):
        B, nc, w, h = x.shape
        N = self.pos_embed.shape[1] - 1
        x = self.patch_embed(x)

        # interpolate patch embeddings
        dim = x.shape[-1]
        w0 = w // self.patch_embed.patch_size
        h0 = h // self.patch_embed.patch_size
        class_pos_embed = self.pos_embed[:, 0]
        patch_pos_embed = self.pos_embed[:, 1:]
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
            mode='bicubic',
        )
        # sometimes there is a floating point error in the interpolation and so
        # we need to pad the patch positional encoding.
        if w0 != patch_pos_embed.shape[-2]:
            helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device)
            patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2)
        if h0 != patch_pos_embed.shape[-1]:
            helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device)
            patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + pos_embed
        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            if i < len(self.blocks) - 1:
                x = blk(x)
            else:
                return blk(x, return_attention=True)

    def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        x = x + pos_embed
        x = self.pos_drop(x)

        # we will return the [CLS] tokens from the `n` last blocks
        output = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if len(self.blocks) - i <= n:
                output.append(self.norm(x)[:, 0])
        if return_patch_avgpool:
            x = self.norm(x)
            # In addition to the [CLS] tokens from the `n` last blocks, we also return
            # the patch tokens from the last block. This is useful for linear eval.
            output.append(torch.mean(x[:, 1:], dim=1))
        return torch.cat(output, dim=-1)


def deit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_small(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_small_p8(patch_size=8, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_small_p7(patch_size=7, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vitc_4gf(patch_size=16, **kwargs):
    channels = [48, 96, 192, 384, 384]
    strides = [2, 2, 2, 2, 1]
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=11, num_heads=6, mlp_ratio=3,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
        conv_stem=True, conv_stem_channels=channels, conv_stem_strides=strides,
        **kwargs)
    return model


def deit_small_convstem(patch_size=16, **kwargs):
    channels = [48, 96, 192, 384, 384]
    strides = [2, 2, 2, 2, 1]
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=11, num_heads=6, mlp_ratio=6,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
        conv_stem=True, conv_stem_channels=channels, conv_stem_strides=strides,
        **kwargs)
    return model


def deit_base_p8(patch_size=8, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_base_p7(patch_size=7, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_base_p4(patch_size=4, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_base(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_large_p7(patch_size=7, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_large_p8(patch_size=8, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_large(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_huge(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_huge_p8(patch_size=8, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_huge_p7(patch_size=7, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def deit_huge_p10(patch_size=10, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

Overwriting /content/msn/src/data_manager.py


In [None]:
%%writefile "/content/msn/src/losses.py"
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the Creative Commons Attribution–NonCommercial 4.0 International License
#

from logging import getLogger
import torch
import math
from src.utils import AllReduce


logger = getLogger()


def init_msn_loss(
    num_views=1,
    tau=0.1,
    me_max=True,
    return_preds=False
):
    """
    Make unsupervised MSN loss

    :num_views: number of anchor views
    :param tau: cosine similarity temperature
    :param me_max: whether to perform me-max regularization
    :param return_preds: whether to return anchor predictions
    """
    softmax = torch.nn.Softmax(dim=1)

    def sharpen(p, T):
        sharp_p = p**(1./T)
        sharp_p /= torch.sum(sharp_p, dim=1, keepdim=True)
        return sharp_p

    def snn(query, supports, support_labels, temp=tau):
        """ Soft Nearest Neighbours similarity classifier """
        query = torch.nn.functional.normalize(query)
        supports = torch.nn.functional.normalize(supports)
        return softmax(query @ supports.T / temp) @ support_labels

    def loss(
        anchor_views,
        target_views,
        prototypes,
        proto_labels,
        T=0.25,
        use_entropy=False,
        use_sinkhorn=False,
        sharpen=sharpen,
        snn=snn
    ):
        # Step 1: compute anchor predictions
        probs = snn(anchor_views, prototypes, proto_labels)

        # Step 2: compute targets for anchor predictions
        with torch.no_grad():
            targets = sharpen(snn(target_views, prototypes, proto_labels), T=T)
            if use_sinkhorn:
                targets = distributed_sinkhorn(targets)
            targets = torch.cat([targets for _ in range(num_views)], dim=0)

        # Step 3: compute cross-entropy loss H(targets, queries)
        loss = torch.mean(torch.sum(torch.log(probs**(-targets)), dim=1))

        # Step 4: compute me-max regularizer
        rloss = 0.
        if me_max:
            avg_probs = AllReduce.apply(torch.mean(probs, dim=0))
            rloss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))

        sloss = 0.
        if use_entropy:
            sloss = torch.mean(torch.sum(torch.log(probs**(-probs)), dim=1))

        # -- logging
        with torch.no_grad():
            num_ps = float(len(set(targets.argmax(dim=1).tolist())))
            max_t = targets.max(dim=1).values.mean()
            min_t = targets.min(dim=1).values.mean()
            log_dct = {'np': num_ps, 'max_t': max_t, 'min_t': min_t}

        if return_preds:
            return loss, rloss, sloss, log_dct, targets

        return loss, rloss, sloss, log_dct

    return loss


@torch.no_grad()
def distributed_sinkhorn(Q, num_itr=3, use_dist=True):
    _got_dist = use_dist and torch.distributed.is_available() \
        and torch.distributed.is_initialized() \
        and (torch.distributed.get_world_size() > 1)

    if _got_dist:
        world_size = torch.distributed.get_world_size()
    else:
        world_size = 1

    Q = Q.T
    B = Q.shape[1] * world_size  # number of samples to assign
    K = Q.shape[0]  # how many prototypes

    # make the matrix sums to 1
    sum_Q = torch.sum(Q)
    if _got_dist:
        torch.distributed.all_reduce(sum_Q)
    Q /= sum_Q

    for it in range(num_itr):
        # normalize each row: total weight per prototype must be 1/K
        sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
        if _got_dist:
            torch.distributed.all_reduce(sum_of_rows)
        Q /= sum_of_rows
        Q /= K

        # normalize each column: total weight per sample must be 1/B
        Q /= torch.sum(Q, dim=0, keepdim=True)
        Q /= B

    Q *= B  # the colomns must sum to 1 so that Q is an assignment
    return Q.T

In [None]:
%%writefile "/content/msn/src/msn_train.py"
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the Creative Commons Attribution–NonCommercial 4.0 International License
#

import os

# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS
try:
    # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE
    # --          SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE
    # --          THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE
    # --          TO EACH PROCESS
    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID']
except Exception:
    pass

import copy
import logging
import sys
from collections import OrderedDict

import numpy as np

import torch
import torch.multiprocessing as mp

import src.deit as deit
from src.utils import (
    AllReduceSum,
    trunc_normal_,
    gpu_timer,
    init_distributed,
    WarmupCosineSchedule,
    CosineWDSchedule,
    CSVLogger,
    grad_logger,
    AverageMeter
)
from src.losses import init_msn_loss
from src.data_manager import (
    init_data,
    make_transforms
)

from torch.nn.parallel import DistributedDataParallel

# --
log_timings = True
log_freq = 10
checkpoint_freq = 25
checkpoint_freq_itr = 2500
# --

_GLOBAL_SEED = 0
np.random.seed(_GLOBAL_SEED)
torch.manual_seed(_GLOBAL_SEED)
torch.backends.cudnn.benchmark = True

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()


def main(args):

    # ----------------------------------------------------------------------- #
    #  PASSED IN PARAMS FROM CONFIG FILE
    # ----------------------------------------------------------------------- #

    # -- META
    model_name = args['meta']['model_name']
    two_layer = False if 'two_layer' not in args['meta'] else args['meta']['two_layer']
    bottleneck = 1 if 'bottleneck' not in args['meta'] else args['meta']['bottleneck']
    output_dim = args['meta']['output_dim']
    hidden_dim = args['meta']['hidden_dim']
    load_model = args['meta']['load_checkpoint']
    r_file = args['meta']['read_checkpoint']
    copy_data = args['meta']['copy_data']
    use_pred_head = args['meta']['use_pred_head']
    use_bn = args['meta']['use_bn']
    drop_path_rate = args['meta']['drop_path_rate']
    if not torch.cuda.is_available():
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:0')
        torch.cuda.set_device(device)

    # -- CRITERTION
    memax_weight = 1 if 'memax_weight' not in args['criterion'] else args['criterion']['memax_weight']
    ent_weight = 1 if 'ent_weight' not in args['criterion'] else args['criterion']['ent_weight']
    freeze_proto = False if 'freeze_proto' not in args['criterion'] else args['criterion']['freeze_proto']
    use_ent = False if 'use_ent' not in args['criterion'] else args['criterion']['use_ent']
    reg = args['criterion']['me_max']
    use_sinkhorn = args['criterion']['use_sinkhorn']
    num_proto = args['criterion']['num_proto']
    # --
    batch_size = args['criterion']['batch_size']
    temperature = args['criterion']['temperature']
    _start_T = args['criterion']['start_sharpen']
    _final_T = args['criterion']['final_sharpen']

    # -- DATA
    label_smoothing = args['data']['label_smoothing']
    pin_mem = False if 'pin_mem' not in args['data'] else args['data']['pin_mem']
    num_workers = 1 if 'num_workers' not in args['data'] else args['data']['num_workers']
    color_jitter = args['data']['color_jitter_strength']
    root_path = args['data']['root_path']
    image_folder = args['data']['image_folder']
    patch_drop = args['data']['patch_drop']
    rand_size = args['data']['rand_size']
    rand_views = args['data']['rand_views']
    focal_views = args['data']['focal_views']
    focal_size = args['data']['focal_size']
    # --

    # -- OPTIMIZATION
    clip_grad = args['optimization']['clip_grad']
    wd = float(args['optimization']['weight_decay'])
    final_wd = float(args['optimization']['final_weight_decay'])
    num_epochs = args['optimization']['epochs']
    warmup = args['optimization']['warmup']
    start_lr = args['optimization']['start_lr']
    lr = args['optimization']['lr']
    final_lr = args['optimization']['final_lr']

    # -- LOGGING
    folder = args['logging']['folder']
    tag = args['logging']['write_tag']
    # ----------------------------------------------------------------------- #

    try:
        mp.set_start_method('spawn')
    except Exception:
        pass

    # -- init torch distributed backend
    world_size, rank = init_distributed()
    logger.info(f'Initialized (rank/world-size) {rank}/{world_size}')
    # if rank > 0:
    #     logger.setLevel(logging.ERROR)

    # -- proto details
    assert num_proto > 0, 'unsupervised pre-training requires specifying prototypes'

    # -- log/checkpointing paths
    log_file = os.path.join(folder, f'{tag}_r{rank}.csv')
    save_path = os.path.join(folder, f'{tag}' + '-ep{epoch}.pth.tar')
    latest_path = os.path.join(folder, f'{tag}-latest.pth.tar')
    load_path = None
    if load_model:
        load_path = os.path.join(folder, r_file) if r_file is not None else latest_path

    # -- make csv_logger
    csv_logger = CSVLogger(log_file,
                           ('%d', 'epoch'),
                           ('%d', 'itr'),
                           ('%.5f', 'msn'),
                           ('%.5f', 'me_max'),
                           ('%.5f', 'ent'),
                           ('%d', 'time (ms)'))

    # -- init model
    encoder = init_model(
        device=device,
        model_name=model_name,
        two_layer=two_layer,
        use_pred=use_pred_head,
        use_bn=use_bn,
        bottleneck=bottleneck,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        drop_path_rate=drop_path_rate)
    target_encoder = copy.deepcopy(encoder)
    if (world_size > 1):
        encoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(encoder)
        target_encoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(target_encoder)

    # -- init losses
    msn = init_msn_loss(
        num_views=focal_views+rand_views,
        tau=temperature,
        me_max=reg,
        return_preds=True)

    def one_hot(targets, num_classes, smoothing=label_smoothing):
        off_value = smoothing / num_classes
        on_value = 1. - smoothing + off_value
        targets = targets.long().view(-1, 1).to(device)
        return torch.full((len(targets), num_classes), off_value, device=device).scatter_(1, targets, on_value)

    # -- make data transforms
    transform = make_transforms(
        rand_size=rand_size,
        focal_size=focal_size,
        rand_views=rand_views+1,
        focal_views=focal_views,
        color_jitter=color_jitter)

    # -- init data-loaders/samplers
    (unsupervised_loader,
     unsupervised_sampler) = init_data(
         transform=transform,
         batch_size=batch_size,
         pin_mem=pin_mem,
         num_workers=num_workers,
         world_size=world_size,
         rank=rank,
         root_path=root_path,
         image_folder=image_folder,
         training=True,
         copy_data=copy_data)
    ipe = len(unsupervised_loader)
    logger.info(f'iterations per epoch: {ipe}')

    # -- make prototypes
    prototypes, proto_labels = None, None
    if num_proto > 0:
        with torch.no_grad():
            prototypes = torch.empty(num_proto, output_dim)
            _sqrt_k = (1./output_dim)**0.5
            torch.nn.init.uniform_(prototypes, -_sqrt_k, _sqrt_k)
            prototypes = torch.nn.parameter.Parameter(prototypes).to(device)

            # -- init prototype labels
            proto_labels = one_hot(torch.tensor([i for i in range(num_proto)]), num_proto)

        if not freeze_proto:
            prototypes.requires_grad = True
        logger.info(f'Created prototypes: {prototypes.shape}')
        logger.info(f'Requires grad: {prototypes.requires_grad}')

    # -- init optimizer and scheduler
    encoder, optimizer, scheduler, wd_scheduler = init_opt(
        encoder=encoder,
        prototypes=prototypes,
        wd=wd,
        final_wd=final_wd,
        start_lr=start_lr,
        ref_lr=lr,
        final_lr=final_lr,
        iterations_per_epoch=ipe,
        warmup=warmup,
        num_epochs=num_epochs)
    if world_size > 1:
        encoder = DistributedDataParallel(encoder)
        target_encoder = DistributedDataParallel(target_encoder)
        for p in target_encoder.parameters():
            p.requires_grad = False

    # -- momentum schedule
    _start_m, _final_m = 0.996, 1.0
    _increment = (_final_m - _start_m) / (ipe * num_epochs * 1.25)
    momentum_scheduler = (_start_m + (_increment*i) for i in range(int(ipe*num_epochs*1.25)+1))

    # -- sharpening schedule
    _increment_T = (_final_T - _start_T) / (ipe * num_epochs * 1.25)
    sharpen_scheduler = (_start_T + (_increment_T*i) for i in range(int(ipe*num_epochs*1.25)+1))

    start_epoch = 0
    # -- load training checkpoint
    if load_model:
        encoder, target_encoder, prototypes, optimizer, start_epoch = load_checkpoint(
            device=device,
            prototypes=prototypes,
            r_path=load_path,
            encoder=encoder,
            target_encoder=target_encoder,
            opt=optimizer)
        for _ in range(start_epoch*ipe):
            scheduler.step()
            wd_scheduler.step()
            next(momentum_scheduler)
            next(sharpen_scheduler)

    def save_checkpoint(epoch):

        if target_encoder is not None:
            target_encoder_state_dict = target_encoder.state_dict()
        else:
            target_encoder_state_dict = None

        save_dict = {
            'encoder': encoder.state_dict(),
            'opt': optimizer.state_dict(),
            'prototypes': prototypes.data,
            'target_encoder': target_encoder_state_dict,
            'epoch': epoch,
            'loss': loss_meter.avg,
            'batch_size': batch_size,
            'world_size': world_size,
            'lr': lr,
            'temperature': temperature
        }
        if rank == 0:
            torch.save(save_dict, latest_path)
            if (epoch + 1) % checkpoint_freq == 0 \
                    or (epoch + 1) % 10 == 0 and epoch < checkpoint_freq:
                torch.save(save_dict, save_path.format(epoch=f'{epoch + 1}'))

    # -- TRAINING LOOP
    for epoch in range(start_epoch, num_epochs):
        logger.info('Epoch %d' % (epoch + 1))

        # -- update distributed-data-loader epoch
        unsupervised_sampler.set_epoch(epoch)

        loss_meter = AverageMeter()
        ploss_meter = AverageMeter()
        rloss_meter = AverageMeter()
        eloss_meter = AverageMeter()
        np_meter = AverageMeter()
        maxp_meter = AverageMeter()
        time_meter = AverageMeter()
        data_meter = AverageMeter()

        for itr, (udata, _) in enumerate(unsupervised_loader):

            def load_imgs():
                # -- unsupervised imgs
                imgs = [u.to(device, non_blocking=True) for u in udata]
                return imgs
            imgs, dtime = gpu_timer(load_imgs)
            data_meter.update(dtime)

            def train_step():
                optimizer.zero_grad()

                # --
                # h: representations of 'imgs' before head
                # z: representations of 'imgs' after head
                # -- If use_pred_head=False, then encoder.pred (prediction
                #    head) is None, and _forward_head just returns the
                #    identity, z=h
                h, z = encoder(imgs[1:], return_before_head=True, patch_drop=patch_drop)
                with torch.no_grad():
                    h, _ = target_encoder(imgs[0], return_before_head=True)

                # Step 1. convert representations to fp32
                h, z = h.float(), z.float()

                # Step 2. determine anchor views/supports and their
                #         corresponding target views/supports
                # --
                anchor_views, target_views = z, h.detach()
                T = next(sharpen_scheduler)

                # Step 3. compute msn loss with me-max regularization
                (ploss, me_max, ent, logs, _) = msn(
                    T=T,
                    use_sinkhorn=use_sinkhorn,
                    use_entropy=use_ent,
                    anchor_views=anchor_views,
                    target_views=target_views,
                    proto_labels=proto_labels,
                    prototypes=prototypes)
                loss = ploss + memax_weight*me_max + ent_weight*ent

                _new_lr = scheduler.step()
                _new_wd = wd_scheduler.step()
                # --

                # Step 4. Optimization step
                loss.backward()
                with torch.no_grad():
                    prototypes.grad.data = AllReduceSum.apply(prototypes.grad.data)
                grad_stats = grad_logger(encoder.named_parameters())
                if clip_grad > 0:
                    torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip_grad)
                optimizer.step()

                # Step 5. momentum update of target encoder
                with torch.no_grad():
                    m = next(momentum_scheduler)
                    for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()):
                        param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)

                return (float(loss), float(ploss), float(me_max), float(ent),
                        logs, _new_lr, _new_wd, grad_stats)
            (loss, ploss, rloss, eloss,
             _logs, _new_lr, _new_wd, grad_stats), etime = gpu_timer(train_step)
            loss_meter.update(loss)
            ploss_meter.update(ploss)
            rloss_meter.update(rloss)
            eloss_meter.update(eloss)
            maxp_meter.update(_logs['max_t'])
            np_meter.update(_logs['np'])

            time_meter.update(etime)

            # -- Save Checkpoint
            if itr % checkpoint_freq_itr == 0:
                save_checkpoint(epoch)

            # -- Logging
            def log_stats():
                csv_logger.log(epoch + 1, itr, ploss, rloss, eloss, etime)
                if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss):
                    logger.info('[%d, %5d] loss: %.3f (%.3f %.3f %.3f) '
                                '(np: %.1f, max-t: %.3f) '
                                '[wd: %.2e] [lr: %.2e] '
                                '[mem: %.2e] '
                                '(%d ms; %d ms)'
                                % (epoch + 1, itr,
                                   loss_meter.avg,
                                   ploss_meter.avg,
                                   rloss_meter.avg,
                                   eloss_meter.avg,
                                   np_meter.avg,
                                   maxp_meter.avg,
                                   _new_wd,
                                   _new_lr,
                                   torch.cuda.max_memory_allocated() / 1024.**2,
                                   time_meter.avg,
                                   data_meter.avg))

                    if grad_stats is not None:
                        logger.info('[%d, %5d] grad_stats: [%.2e %.2e] (%.2e, %.2e)'
                                    % (epoch + 1, itr,
                                       grad_stats.first_layer,
                                       grad_stats.last_layer,
                                       grad_stats.min,
                                       grad_stats.max))
            log_stats()
            assert not np.isnan(loss), 'loss is nan'

        # -- Save Checkpoint after every epoch
        logger.info('avg. loss %.3f' % loss_meter.avg)
        save_checkpoint(epoch+1)


def load_checkpoint(
    device,
    r_path,
    prototypes,
    encoder,
    target_encoder,
    opt
):
    checkpoint = torch.load(r_path, map_location=torch.device('cpu'))
    epoch = checkpoint['epoch']

    # -- loading encoder
    pretrained_dict = checkpoint['encoder']
    if ('scaling_module.bias' not in pretrained_dict) and ('scaling_bias' in pretrained_dict):
        pretrained_dict['scaling_module.bias'] = pretrained_dict['scaling_bias']
        del pretrained_dict['scaling_bias']
    msg = encoder.load_state_dict(pretrained_dict)
    logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}')

    # -- loading target_encoder
    if target_encoder is not None:
        print(list(checkpoint.keys()))
        pretrained_dict = checkpoint['target_encoder']
        if ('scaling_module.bias' not in pretrained_dict) and ('scaling_bias' in pretrained_dict):
            pretrained_dict['scaling_module.bias'] = pretrained_dict['scaling_bias']
            del pretrained_dict['scaling_bias']
        msg = target_encoder.load_state_dict(pretrained_dict)
        logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}')

    # -- loading prototypes
    if (prototypes is not None) and ('prototypes' in checkpoint):
        with torch.no_grad():
            prototypes.data = checkpoint['prototypes'].to(device)
        logger.info(f'loaded prototypes from epoch {epoch}')

    # -- loading optimizer
    opt.load_state_dict(checkpoint['opt'])
    logger.info(f'loaded optimizers from epoch {epoch}')
    logger.info(f'read-path: {r_path}')
    del checkpoint
    return encoder, target_encoder, prototypes, opt, epoch


def init_model(
    device,
    model_name='resnet50',
    use_pred=False,
    use_bn=False,
    two_layer=False,
    bottleneck=1,
    hidden_dim=2048,
    output_dim=128,
    drop_path_rate=0.1,
):
    encoder = deit.__dict__[model_name](drop_path_rate=drop_path_rate)
    emb_dim = 192 if 'tiny' in model_name else 384 if 'small' in model_name else 768 if 'base' in model_name else 1024 if 'large' in model_name else 1280

    # -- projection head
    encoder.fc = None
    fc = OrderedDict([])
    fc['fc1'] = torch.nn.Linear(emb_dim, hidden_dim)
    if use_bn:
        fc['bn1'] = torch.nn.BatchNorm1d(hidden_dim)
    fc['gelu1'] = torch.nn.GELU()
    fc['fc2'] = torch.nn.Linear(hidden_dim, hidden_dim)
    if use_bn:
        fc['bn2'] = torch.nn.BatchNorm1d(hidden_dim)
    fc['gelu2'] = torch.nn.GELU()
    fc['fc3'] = torch.nn.Linear(hidden_dim, output_dim)
    encoder.fc = torch.nn.Sequential(fc)

    for m in encoder.modules():
        if isinstance(m, torch.nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, torch.nn.LayerNorm):
            torch.nn.init.constant_(m.bias, 0)
            torch.nn.init.constant_(m.weight, 1.0)

    encoder.to(device)
    logger.info(encoder)
    return encoder


def init_opt(
    encoder,
    iterations_per_epoch,
    start_lr,
    ref_lr,
    warmup,
    num_epochs,
    prototypes=None,
    wd=1e-6,
    final_wd=1e-6,
    final_lr=0.0
):
    param_groups = [
        {'params': (p for n, p in encoder.named_parameters()
                    if ('bias' not in n) and ('bn' not in n) and len(p.shape) != 1)},
        {'params': (p for n, p in encoder.named_parameters()
                    if ('bias' in n) or ('bn' in n) or (len(p.shape) == 1)),
         'WD_exclude': True,
         'weight_decay': 0}
    ]
    if prototypes is not None:
        param_groups.append({
            'params': [prototypes],
            'lr': ref_lr,
            'LARS_exclude': True,
            'WD_exclude': True,
            'weight_decay': 0
        })

    logger.info('Using AdamW')
    optimizer = torch.optim.AdamW(param_groups)
    scheduler = WarmupCosineSchedule(
        optimizer,
        warmup_steps=int(warmup*iterations_per_epoch),
        start_lr=start_lr,
        ref_lr=ref_lr,
        final_lr=final_lr,
        T_max=int(1.25*num_epochs*iterations_per_epoch))
    wd_scheduler = CosineWDSchedule(
        optimizer,
        ref_wd=wd,
        final_wd=final_wd,
        T_max=int(1.25*num_epochs*iterations_per_epoch))
    return encoder, optimizer, scheduler, wd_scheduler


if __name__ == "__main__":
    main()

In [None]:
%%writefile "/content/msn/src/sgd.py"
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the Creative Commons Attribution–NonCommercial 4.0 International License
#

import torch
from torch.optim import Optimizer


class SGD(Optimizer):

    def __init__(self, params, lr, momentum=0, weight_decay=0, nesterov=False):
        if lr < 0.0:
            raise ValueError(f'Invalid learning rate: {lr}')
        if weight_decay < 0.0:
            raise ValueError(f'Invalid weight_decay value: {weight_decay}')

        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay,
                        nesterov=nesterov)
        super(SGD, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue

                d_p = p.grad
                if weight_decay != 0:
                    d_p = d_p.add(p, alpha=weight_decay)
                d_p.mul_(-group['lr'])

                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = d_p.clone().detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(d_p)

                    if nesterov:
                        d_p.add_(buf, alpha=momentum)
                    else:
                        d_p = buf

                p.add_(d_p)

        return None

In [None]:
%%writefile "/content/msn/src/utils.py"
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the Creative Commons Attribution–NonCommercial 4.0 International License
#

import os
import math
import torch
import torch.distributed as dist

from logging import getLogger

logger = getLogger()


def gpu_timer(closure, log_timings=True):
    """ Helper to time gpu-time to execute closure() """
    elapsed_time = -1.
    if log_timings:
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()

    result = closure()

    if log_timings:
        end.record()
        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)

    return result, elapsed_time


def init_distributed(port=40111, rank_and_world_size=(None, None)):

    if dist.is_available() and dist.is_initialized():
        return dist.get_world_size(), dist.get_rank()

    rank, world_size = rank_and_world_size
    os.environ['MASTER_ADDR'] = 'localhost'

    if (rank is None) or (world_size is None):
        try:
            world_size = int(os.environ['SLURM_NTASKS'])
            rank = int(os.environ['SLURM_PROCID'])
            os.environ['MASTER_ADDR'] = os.environ['HOSTNAME']
        except Exception:
            logger.info('SLURM vars not set (distributed training not available)')
            world_size, rank = 1, 0
            return world_size, rank

    try:
        os.environ['MASTER_PORT'] = str(port)
        torch.distributed.init_process_group(
            backend='nccl',
            world_size=world_size,
            rank=rank)
    except Exception:
        world_size, rank = 1, 0
        logger.info('distributed training not available')

    return world_size, rank


class WarmupCosineSchedule(object):

    def __init__(
        self,
        optimizer,
        warmup_steps,
        start_lr,
        ref_lr,
        T_max,
        last_epoch=-1,
        final_lr=0.
    ):
        self.optimizer = optimizer
        self.start_lr = start_lr
        self.ref_lr = ref_lr
        self.final_lr = final_lr
        self.warmup_steps = warmup_steps
        self.T_max = T_max - warmup_steps
        self._step = 0.

    def step(self):
        self._step += 1
        if self._step < self.warmup_steps:
            progress = float(self._step) / float(max(1, self.warmup_steps))
            new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
        else:
            # -- progress after warmup
            progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max))
            new_lr = max(self.final_lr,
                         self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress)))

        for group in self.optimizer.param_groups:
            group['lr'] = new_lr

        return new_lr


class CosineWDSchedule(object):

    def __init__(
        self,
        optimizer,
        ref_wd,
        T_max,
        final_wd=0.
    ):
        self.optimizer = optimizer
        self.ref_wd = ref_wd
        self.final_wd = final_wd
        self.T_max = T_max
        self._step = 0.

    def step(self):
        self._step += 1
        progress = self._step / self.T_max
        new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress))

        if self.final_wd <= self.ref_wd:
            new_wd = max(self.final_wd, new_wd)
        else:
            new_wd = min(self.final_wd, new_wd)

        for group in self.optimizer.param_groups:
            if ('WD_exclude' not in group) or not group['WD_exclude']:
                group['weight_decay'] = new_wd
        return new_wd


class CSVLogger(object):

    def __init__(self, fname, *argv):
        self.fname = fname
        self.types = []
        # -- print headers
        with open(self.fname, '+a') as f:
            for i, v in enumerate(argv, 1):
                self.types.append(v[0])
                if i < len(argv):
                    print(v[1], end=',', file=f)
                else:
                    print(v[1], end='\n', file=f)

    def log(self, *argv):
        with open(self.fname, '+a') as f:
            for i, tv in enumerate(zip(self.types, argv), 1):
                end = ',' if i < len(argv) else '\n'
                print(tv[0] % tv[1], end=end, file=f)


class AverageMeter(object):
    """computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.max = float('-inf')
        self.min = float('inf')
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.max = max(val, self.max)
        self.min = min(val, self.min)
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class AllGather(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x):
        if (
            dist.is_available()
            and dist.is_initialized()
            and (dist.get_world_size() > 1)
        ):
            outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
            dist.all_gather(outputs, x)
            return torch.cat(outputs, 0)
        return x

    @staticmethod
    def backward(ctx, grads):
        if (
            dist.is_available()
            and dist.is_initialized()
            and (dist.get_world_size() > 1)
        ):
            s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank()
            e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1)
            grads = grads.contiguous()
            dist.all_reduce(grads)
            return grads[s:e]
        return grads


class AllReduceSum(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x):
        if (
            dist.is_available()
            and dist.is_initialized()
            and (dist.get_world_size() > 1)
        ):
            x = x.contiguous()
            dist.all_reduce(x)
        return x

    @staticmethod
    def backward(ctx, grads):
        return grads


class AllReduce(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x):
        if (
            dist.is_available()
            and dist.is_initialized()
            and (dist.get_world_size() > 1)
        ):
            x = x.contiguous() / dist.get_world_size()
            dist.all_reduce(x)
        return x

    @staticmethod
    def backward(ctx, grads):
        return grads


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def grad_logger(named_params):
    stats = AverageMeter()
    stats.first_layer = None
    stats.last_layer = None
    for n, p in named_params:
        if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1):
            grad_norm = float(torch.norm(p.grad.data))
            stats.update(grad_norm)
            if 'qkv' in n:
                stats.last_layer = grad_norm
                if stats.first_layer is None:
                    stats.first_layer = grad_norm
    if stats.first_layer is None or stats.last_layer is None:
        stats.first_layer = stats.last_layer = 0.
    return stats

In [None]:
%%writefile "/content/msn/linear_eval.py"
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the Creative Commons Attribution–NonCommercial 4.0 International License
#
# Modifications: Adapted for use in 'Self-supervised learning applied to unlabelled histopathology imaging data –
# a comparison between Masked Siamese Networks and Self-Distillation with No Labels', 2025.
#

import os

from torch.nn.modules.module import T

# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS
try:
    # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE
    # --          SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE
    # --          THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE
    # --          TO EACH PROCESS
    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID']
except Exception:
    pass

import logging
import sys

import argparse
import yaml

import numpy as np

import torch
import torchvision.transforms as transforms

import src.deit as deit
from src.utils import (
    AllReduce,
    init_distributed,
    WarmupCosineSchedule
)
from src.data_manager import init_data
from src.sgd import SGD
from torch.nn.parallel import DistributedDataParallel

# --
log_timings = True
log_freq = 10
checkpoint_freq = 50
# --

_GLOBAL_SEED = 0
np.random.seed(_GLOBAL_SEED)
torch.manual_seed(_GLOBAL_SEED)
torch.backends.cudnn.benchmark = True

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()


def main(args):

    # -- META
    model_name = args['meta']['model_name']
    port = args['meta']['master_port']
    load_checkpoint = args['meta']['load_checkpoint']
    training = args['meta']['training']
    copy_data = args['meta']['copy_data']
    device = torch.device(args['meta']['device'])
    if 'cuda' in args['meta']['device']:
        torch.cuda.set_device(device)

    # -- DATA
    root_path = args['data']['root_path']
    image_folder = args['data']['image_folder']
    num_classes = args['data']['num_classes']

    # -- OPTIMIZATION
    wd = float(args['optimization']['weight_decay'])
    ref_lr = args['optimization']['lr']
    num_epochs = args['optimization']['epochs']
    num_blocks = args['optimization']['num_blocks']
    l2_normalize = args['optimization']['normalize']

    # -- LOGGING
    folder = args['logging']['folder']
    tag = args['logging']['write_tag']
    r_file_enc = args['logging']['pretrain_path']

    # -- log/checkpointing paths
    r_enc_path = os.path.join(folder, r_file_enc)
    w_enc_path = os.path.join(folder, f'{tag}-lin-eval.pth.tar')

    # -- init distributed
    world_size, rank = init_distributed(port)
    logger.info(f'initialized rank/world-size: {rank}/{world_size}')

    # -- optimization/evaluation params
    scaler = torch.cuda.amp.GradScaler(enabled=True)
    if training:
        batch_size = 256
    else:
        batch_size = 128
        load_checkpoint = True
        num_epochs = 1

    # -- init loss
    criterion = torch.nn.CrossEntropyLoss()

    # ADAPTED FOR USE
    # -- make train data transforms and data loaders/samples
    transform = transforms.Compose([
        transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5),
            (0.5, 0.5, 0.5))])
    # ADAPTED FOR USE
    data_loader, dist_sampler = init_data(
        transform=transform,
        batch_size=batch_size,
        world_size=world_size,
        rank=rank,
        root_path=root_path,
        image_folder=image_folder,
        training=training,
        copy_data=copy_data,
        num_workers=6,
        drop_last=False,
        # label budget (read from your YAML under data:)
        train_subset_frac=args['data'].get('train_subset_frac', None),
        train_subset_size=args['data'].get('train_subset_size', None),
        train_subset_per_class=args['data'].get('train_subset_per_class', None),
        train_subset_seed=args['data'].get('train_subset_seed', 42)
        )

    ipe = len(data_loader)
    logger.info(f'initialized data-loader (ipe {ipe})')

    # ADAPTED FOR USE
    # -- make val data transforms and data loaders/samples
    val_transform = transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5),
            (0.5, 0.5, 0.5))])
    val_data_loader, val_dist_sampler = init_data(
        transform=val_transform,
        batch_size=batch_size,
        world_size=world_size,
        rank=rank,
        root_path=root_path,
        image_folder=image_folder,
        training=False,
        drop_last=False,
        copy_data=copy_data,
        split='test',
        num_workers=6
        )
    logger.info(f'initialized val data-loader (ipe {len(val_data_loader)})')

    # -- init model and optimizer
    encoder, linear_classifier, optimizer, scheduler = init_model(
        device=device,
        device_str=args['meta']['device'],
        num_classes=num_classes,
        num_blocks=num_blocks,
        normalize=l2_normalize,
        training=training,
        r_enc_path=r_enc_path,
        iterations_per_epoch=ipe,
        world_size=world_size,
        ref_lr=ref_lr,
        weight_decay=wd,
        num_epochs=num_epochs,
        model_name=model_name)
    logger.info(encoder)

    best_acc = None
    start_epoch = 0
    # -- load checkpoint
    if not training or load_checkpoint:
        encoder, linear_classifier, optimizer, scheduler, start_epoch, best_acc = load_from_path(
            r_path=w_enc_path,
            encoder=encoder,
            linear_classifier=linear_classifier,
            opt=optimizer,
            sched=scheduler,
            device_str=args['meta']['device'])
    if not training:
        logger.info('putting model in eval mode')
        encoder.eval()
        logger.info(sum(p.numel() for n, p in encoder.named_parameters()
                        if p.requires_grad and ('fc' not in n)))
        start_epoch = 0

    encoder.eval()

    for epoch in range(start_epoch, num_epochs):

        def train_step():
            # -- update distributed-data-loader epoch
            dist_sampler.set_epoch(epoch)
            top1_correct, top5_correct, total = 0, 0, 0
            for i, data in enumerate(data_loader):
                with torch.cuda.amp.autocast(enabled=True):
                    inputs, labels = data[0].to(device), data[1].to(device)
                    with torch.no_grad():
                        outputs = encoder.forward_blocks(inputs, num_blocks)
                outputs = linear_classifier(outputs)
                loss = criterion(outputs, labels)
                total += inputs.shape[0]
                top5_correct += float(outputs.topk(5, dim=1).indices.eq(labels.unsqueeze(1)).sum())
                top1_correct += float(outputs.max(dim=1).indices.eq(labels).sum())
                top1_acc = 100. * top1_correct / total
                top5_acc = 100. * top5_correct / total
                if training:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
                if i % log_freq == 0:
                    logger.info('[%d, %5d] %.3f%% %.3f%% (loss: %.3f)'
                                % (epoch + 1, i, top1_acc, top5_acc, loss))
            return 100. * top1_correct / total

        def val_step():
            top1_correct, total = 0, 0
            for i, data in enumerate(val_data_loader):
                with torch.cuda.amp.autocast(enabled=True):
                    inputs, labels = data[0].to(device), data[1].to(device)
                    outputs = encoder.forward_blocks(inputs, num_blocks)
                outputs = linear_classifier(outputs)
                total += inputs.shape[0]
                top1_correct += outputs.max(dim=1).indices.eq(labels).sum()
                top1_acc = 100. * top1_correct / total

            top1_acc = AllReduce.apply(top1_acc)
            logger.info('[%d, %5d] %.3f%%' % (epoch + 1, i, top1_acc))
            return top1_acc

        train_top1 = 0.
        train_top1 = train_step()
        with torch.no_grad():
            val_top1 = val_step()

        log_str = 'train:' if training else 'test:'
        logger.info('[%d] (%s %.3f%%) (val: %.3f%%)'
                    % (epoch + 1, log_str, train_top1, val_top1))

        # -- logging/checkpointing
        if training and (rank == 0) and ((best_acc is None) or (best_acc < val_top1)):
            best_acc = val_top1
            save_dict = {
                'target_encoder': encoder.state_dict(),
                'classifier': linear_classifier.state_dict(),
                'opt': optimizer.state_dict(),
                'epoch': epoch + 1,
                'world_size': world_size,
                'best_top1_acc': best_acc,
                'batch_size': batch_size,
                'lr': ref_lr,
            }
            torch.save(save_dict, w_enc_path)

    return train_top1, val_top1


class LinearClassifier(torch.nn.Module):

    def __init__(self, dim, num_labels=9, normalize=True):
        super(LinearClassifier, self).__init__()
        self.normalize = normalize
        self.norm = torch.nn.LayerNorm(dim)
        self.linear = torch.nn.Linear(dim, num_labels)
        self.linear.weight.data.normal_(mean=0.0, std=0.01)
        self.linear.bias.data.zero_()

    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten
        x = self.norm(x)
        if self.normalize:
            x = torch.nn.functional.normalize(x)
        return self.linear(x)


def load_pretrained(
    r_path,
    encoder,
    linear_classifier,
    device_str
):
    checkpoint = torch.load(r_path, map_location='cpu')
    pretrained_dict = {k.replace('module.', ''): v for k, v in checkpoint['target_encoder'].items()}
    for k, v in encoder.state_dict().items():
        if k not in pretrained_dict:
            logger.info(f'key "{k}" could not be found in loaded state dict')
        elif pretrained_dict[k].shape != v.shape:
            logger.info(f'key "{k}" is of different shape in model and loaded state dict')
            pretrained_dict[k] = v
    msg = encoder.load_state_dict(pretrained_dict, strict=False)
    logger.info(f'loaded pretrained model with msg: {msg}')
    logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]} '
                f'path: {r_path}')

    if linear_classifier is not None:
        pretrained_dict = {k.replace('module.', ''): v for k, v in checkpoint['classifier'].items()}
        for k, v in linear_classifier.state_dict().items():
            if k not in pretrained_dict:
                logger.info(f'key "{k}" could not be found in loaded state dict')
            elif pretrained_dict[k].shape != v.shape:
                logger.info(f'key "{k}" is of different shape in model and loaded state dict')
                pretrained_dict[k] = v
        msg = linear_classifier.load_state_dict(pretrained_dict, strict=False)
        logger.info(f'loaded pretrained model with msg: {msg}')
        logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]} '
                    f'path: {r_path}')

    del checkpoint
    return encoder, linear_classifier


def load_from_path(
    r_path,
    encoder,
    linear_classifier,
    opt,
    sched,
    device_str
):
    encoder, linear_classifier = load_pretrained(r_path, encoder, linear_classifier, device_str)
    checkpoint = torch.load(r_path, map_location=device_str)

    best_acc = None
    if 'best_top1_acc' in checkpoint:
        best_acc = checkpoint['best_top1_acc']

    epoch = checkpoint['epoch']
    if opt is not None:
        opt.load_state_dict(checkpoint['opt'])
        sched.load_state_dict(checkpoint['sched'])
        logger.info(f'loaded optimizers from epoch {epoch}')
    logger.info(f'read-path: {r_path}')
    del checkpoint
    return encoder, opt, sched, epoch, best_acc


def init_model(
    device,
    device_str,
    num_classes,
    num_blocks,
    training,
    r_enc_path,
    iterations_per_epoch,
    world_size,
    ref_lr,
    num_epochs,
    normalize,
    model_name='deit_small',
    warmup_epochs=0,
    weight_decay=0
):
    # -- init model
    encoder = deit.__dict__[model_name]()
    emb_dim = 192 if 'tiny' in model_name else 384 if 'small' in model_name else 768 if 'base' in model_name else 1024 if 'large' in model_name else 1280
    emb_dim *= num_blocks
    encoder.fc = None
    encoder.norm = None

    encoder.to(device)
    encoder, _ = load_pretrained(
        r_path=r_enc_path,
        encoder=encoder,
        linear_classifier=None,
        device_str=device_str)

    linear_classifier = LinearClassifier(emb_dim, num_classes, normalize).to(device)

    # -- init optimizer
    optimizer, scheduler = None, None
    param_groups = [
        {'params': (p for n, p in linear_classifier.named_parameters()
                    if ('bias' not in n) and ('bn' not in n) and len(p.shape) != 1)},
        {'params': (p for n, p in linear_classifier.named_parameters()
                    if ('bias' in n) or ('bn' in n) or (len(p.shape) == 1)),
         'weight_decay': 0}
    ]
    optimizer = SGD(
        param_groups,
        nesterov=True,
        weight_decay=weight_decay,
        momentum=0.9,
        lr=ref_lr)
    scheduler = WarmupCosineSchedule(
        optimizer,
        warmup_steps=warmup_epochs*iterations_per_epoch,
        start_lr=ref_lr,
        ref_lr=ref_lr,
        T_max=num_epochs*iterations_per_epoch)
    if world_size > 1:
        linear_classifier = DistributedDataParallel(linear_classifier)

    return encoder, linear_classifier, optimizer, scheduler


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("--fname", type=str, required=True)
  parser.add_argument("--devices", type=str, default="cuda:0")
  cli = parser.parse_args()

  with open(cli.fname, "r") as f:
    args = yaml.safe_load(f)
  args["meta"]["device"] = cli.devices

  main(args)

In [None]:
%%writefile "/content/msn/logistic_eval.py"
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the Creative Commons Attribution–NonCommercial 4.0 International License
#
# Modifications: Adapted for use in 'Self-supervised learning applied to unlabelled histopathology imaging data –
# a comparison between Masked Siamese Networks and Self-Distillation with No Labels', 2025.
#

import os
import argparse
import logging
import pprint

import numpy as np
import torch
import torchvision.transforms as transforms
import cyanure as cyan

from torch.utils.data import DataLoader
from torchvision.transforms import InterpolationMode

from src.data_manager import init_data
import src.deit as deit

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument('--lambd', type=float, default=0.00025, help='regularization')
parser.add_argument('--penalty', type=str, default='l2', choices=['l2', 'elastic-net'])
parser.add_argument('--mask', type=float, default=0.0, help='MSN masking used only when extracting TRAIN embs')
parser.add_argument('--preload', action='store_true', help='reuse precomputed embeddings if present')
parser.add_argument('--fname', type=str, required=True, help='short run tag to append to files')
parser.add_argument('--model-name', type=str, required=True, help='backbone (e.g., deit_small)')
parser.add_argument('--pretrained', type=str, required=True, help='folder containing pretrained checkpoint + where to save embs')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--normalize', type=bool, default=True, help='row-center + L2 if cyanure.preprocess is unavailable')
parser.add_argument('--root-path', type=str, default='/datasets/')
parser.add_argument('--image-folder', type=str, default='pathmnist', help='must be "pathmnist" for our data_manager')
parser.add_argument('--subset-path', type=str, default=None, help='(unused)')

# ADAPTED FOR USE
# eval split + train label-budget
parser.add_argument('--eval-split', type=str, default='test', choices=['val', 'test'])
parser.add_argument('--train-subset-frac', type=float, default=None)
parser.add_argument('--train-subset-size', type=int, default=None)
parser.add_argument('--train-subset-per-class', type=int, default=None)
parser.add_argument('--train-subset-seed', type=int, default=42)

_GLOBAL_SEED = 0
np.random.seed(_GLOBAL_SEED)
torch.manual_seed(_GLOBAL_SEED)
torch.backends.cudnn.benchmark = True
pp = pprint.PrettyPrinter(indent=4)


def main(args):
    device = torch.device(args.device)
    if 'cuda' in args.device:
        torch.cuda.set_device(device)

    # ADAPTED FOR USE
    # file tags so preload is safe across different subsets/splits
    subset_tag = 'pathmnist-train'
    if args.train_subset_per_class is not None:
        subset_tag += f'-pc{args.train_subset_per_class}-seed{args.train_subset_seed}'
    elif args.train_subset_size is not None:
        subset_tag += f'-n{args.train_subset_size}-seed{args.train_subset_seed}'
    elif args.train_subset_frac is not None:
        subset_tag += f'-frac{args.train_subset_frac}-seed{args.train_subset_seed}'

    # ADAPTED FOR USE
    train_embs_path = os.path.join(args.pretrained, f'train-features-{subset_tag}-{args.fname}.pth.tar')
    eval_embs_path  = os.path.join(args.pretrained, f'{args.eval_split}-features-{args.fname}.pth.tar')
    logger.info(f"Train embs file: {train_embs_path}")
    logger.info(f"Eval  embs file: {eval_embs_path}")

    pretrained_ckpt = os.path.join(args.pretrained, args.fname)

    # ADAPTED FOR USE
    # data transforms
    transform = transforms.Compose([
        transforms.Resize(size=256, interpolation=InterpolationMode.BICUBIC),
        transforms.CenterCrop(size=224),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # ADAPTED FOR USE
    # data loaders via data_manager
    # train: apply label budget via train_subset_*
    train_loader, _ = init_data(
        transform=transform,
        batch_size=16,
        pin_mem=True,
        num_workers=6,
        world_size=1,
        rank=0,
        root_path=args.root_path,
        image_folder=args.image_folder,
        training=True,
        copy_data=False,
        drop_last=False,
        split='train',
        train_subset_frac=args.train_subset_frac,
        train_subset_size=args.train_subset_size,
        train_subset_per_class=args.train_subset_per_class,
        train_subset_seed=args.train_subset_seed,
    )

    # eval
    eval_loader, _ = init_data(
        transform=transform,
        batch_size=16,
        pin_mem=True,
        num_workers=6,
        world_size=1,
        rank=0,
        root_path=args.root_path,
        image_folder=args.image_folder,
        training=False,
        copy_data=False,
        drop_last=False,
        split=args.eval_split,
    )

    # initialise the model
    encoder = init_model(device=device, pretrained=pretrained_ckpt, model_name=args.model_name)
    encoder.eval()

    # if train embeddings already computed, load file, otherwise, compute embeddings and save
    if args.preload and os.path.exists(train_embs_path):
        checkpoint = torch.load(train_embs_path, map_location='cpu')
        embs, labs = checkpoint['embs'], checkpoint['labs']
        logger.info(f'loaded TRAIN embs of shape {embs.shape}')
    else:
        embs, labs = make_embeddings(blocks=1, device=device, mask_frac=args.mask,
                                     data_loader=train_loader, encoder=encoder)
        os.makedirs(os.path.dirname(train_embs_path), exist_ok=True)
        torch.save({'embs': embs, 'labs': labs}, train_embs_path)
        logger.info(f'saved TRAIN embs of shape {embs.shape}')

    # ADAPTED FOR USE
    # normalise rows (cyanure if available, else manual)
    embs_np = embs.numpy().astype(np.float32)
    labs_np = labs.numpy().astype(np.int64)
    try:
        cyan.preprocess(embs_np, normalize=args.normalize, columns=False, centering=True)
    except AttributeError:
        if args.normalize:
            embs_np -= embs_np.mean(axis=1, keepdims=True)
            norms = np.linalg.norm(embs_np, axis=1, keepdims=True) + 1e-12
            embs_np /= norms

    # ADAPTED FOR USE
    # fit multinomial logistic regression
    try:
        classifier = cyan.MultiClassifier(loss='multiclass-logistic', penalty=args.penalty, fit_intercept=False)
        classifier.fit(
            embs_np, labs_np,
            it0=10,
            lambd=args.lambd/len(embs_np),
            lambd2=args.lambd/len(embs_np),
            nthreads=-1,
            tol=1e-3,
            solver='auto',
            seed=0,
            max_epochs=100
        )
        train_score = float(classifier.score(embs_np, labs_np))
    except AttributeError:
        from sklearn.linear_model import LogisticRegression
        C = max(1e-6, float(len(embs_np)) / float(args.lambd))
        if args.penalty == 'elastic-net':
            clf = LogisticRegression(
                penalty='elasticnet', l1_ratio=0.5, solver='saga',
                multi_class='multinomial', fit_intercept=False,
                max_iter=2000, tol=1e-3, C=C, n_jobs=-1
            )
        else:
            clf = LogisticRegression(
                penalty='l2', solver='saga',
                multi_class='multinomial', fit_intercept=False,
                max_iter=2000, tol=1e-3, C=C, n_jobs=-1
            )
        clf.fit(embs_np, labs_np)

        class _SkAdapter:
            def __init__(self, m): self.m = m
            def score(self, X, y): return float(self.m.score(X, y))
        classifier = _SkAdapter(clf)
        train_score = classifier.score(embs_np, labs_np)

    logger.info(f'TRAIN score: {train_score:.4f}')

    # if test embeddings already computed, load file, otherwise, compute embeddings and save
    if args.preload and os.path.exists(eval_embs_path):
        checkpoint = torch.load(eval_embs_path, map_location='cpu')
        eval_embs, eval_labs = checkpoint['embs'], checkpoint['labs']
        logger.info(f'loaded {args.eval_split.upper()} embs of shape {eval_embs.shape}')
    else:
        eval_embs, eval_labs = make_embeddings(blocks=1, device=device, mask_frac=0.0,
                                               data_loader=eval_loader, encoder=encoder)
        torch.save({'embs': eval_embs, 'labs': eval_labs}, eval_embs_path)
        logger.info(f'saved {args.eval_split.upper()} embs of shape {eval_embs.shape}')

    # ADAPTED FOR USE
    eval_embs_np = eval_embs.numpy().astype(np.float32)
    eval_labs_np = eval_labs.numpy().astype(np.int64)
    try:
        cyan.preprocess(eval_embs_np, normalize=args.normalize, columns=False, centering=True)
    except AttributeError:
        if args.normalize:
            eval_embs_np -= eval_embs_np.mean(axis=1, keepdims=True)
            norms = np.linalg.norm(eval_embs_np, axis=1, keepdims=True) + 1e-12
            eval_embs_np /= norms

    eval_score = classifier.score(eval_embs_np, eval_labs_np)
    logger.info(f'{args.eval_split.upper()} score: {eval_score:.4f}\n')
    return eval_score


def make_embeddings(blocks, device, mask_frac, data_loader, encoder, epochs=1):
    ipe = len(data_loader)
    z_mem, l_mem = [], []
    for _ in range(epochs):
        for itr, (imgs, labels) in enumerate(data_loader):
            imgs = imgs.to(device)
            with torch.no_grad():
                z = encoder.forward_blocks(imgs, blocks, mask_frac).cpu()
            labels = labels.cpu()
            z_mem.append(z)
            l_mem.append(labels)
            if itr % 50 == 0:
                logger.info(f'[{itr}/{ipe}]')
    z_mem = torch.cat(z_mem, 0)
    l_mem = torch.cat(l_mem, 0)
    logger.info(z_mem.shape); logger.info(l_mem.shape)
    return z_mem, l_mem


def load_pretrained(encoder, pretrained):
    checkpoint = torch.load(pretrained, map_location='cpu')
    pretrained_dict = {k.replace('module.', ''): v for k, v in checkpoint['target_encoder'].items()}
    for k, v in encoder.state_dict().items():
        if k not in pretrained_dict:
            logger.info(f'key "{k}" could not be found in loaded state dict')
        elif pretrained_dict[k].shape != v.shape:
            logger.info(f'key "{k}" is of different shape in model and loaded state dict')
            pretrained_dict[k] = v
    msg = encoder.load_state_dict(pretrained_dict, strict=False)
    logger.info(f'loaded pretrained model with msg: {msg}')
    try:
        logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]} path: {pretrained}')
    except Exception:
        pass
    del checkpoint
    return encoder


def init_model(device, pretrained, model_name):
    encoder = deit.__dict__[model_name]()
    encoder.fc = None
    encoder.to(device)
    encoder = load_pretrained(encoder=encoder, pretrained=pretrained)
    return encoder


if __name__ == '__main__':
    args = parser.parse_args()
    pp.pprint(vars(args))
    main(args)

In [None]:
%%writefile "/content/msn/main.py"
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the Creative Commons Attribution–NonCommercial 4.0 International License
#

import argparse

import torch.multiprocessing as mp

import pprint
import yaml

from src.msn_train import main as msn

from src.utils import init_distributed

parser = argparse.ArgumentParser()
parser.add_argument(
    '--fname', type=str,
    help='name of config file to load',
    default='configs.yaml')
parser.add_argument(
    '--devices', type=str, nargs='+', default=['cuda:0'],
    help='which devices to use on local machine')


def process_main(rank, fname, world_size, devices):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1])

    import logging
    logging.basicConfig()
    logger = logging.getLogger()
    if rank == 0:
        logger.setLevel(logging.INFO)
    else:
        logger.setLevel(logging.ERROR)

    logger.info(f'called-params {fname}')

    # -- load script params
    params = None
    with open(fname, 'r') as y_file:
        params = yaml.load(y_file, Loader=yaml.FullLoader)
        logger.info('loaded params...')
        pp = pprint.PrettyPrinter(indent=4)
        pp.pprint(params)

    dump = os.path.join(params['logging']['folder'], 'params-msn-train.yaml')
    with open(dump, 'w') as f:
        yaml.dump(params, f)

    world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
    logger.info(f'Running... (rank: {rank}/{world_size})')

    return msn(params)


if __name__ == '__main__':
    args = parser.parse_args()

    num_gpus = len(args.devices)
    mp.spawn(
        process_main,
        nprocs=num_gpus,
        args=(args.fname, num_gpus, args.devices))