# I-JEPA

## Imports

In [None]:
# core train
import os
import copy
import sys
import numpy as np
import torch
import torch.nn.functional as F
import yaml

## Config

In [None]:
# Basic global config
_GLOBAL_SEED = 0
np.random.seed(_GLOBAL_SEED)
torch.manual_seed(_GLOBAL_SEED)
torch.backends.cudnn.benchmark = True

In [None]:
args = {
    "data": {
        "batch_size": 64,
        "crop_scale": [0.3, 1.0],
        "crop_size": 224,
        "image_folders": ["ssl-s2l1c/data/ssl4eo-s12/train/S2L1C",
                          "ssl-s2l2a/data/ssl4eo-s12/train/S2L2A"],
        "validation_folders": ["ssl-s2l1c-val/data/ssl4eo-s12/val/S2L1C",
                                 "ssl-s2l2a-val/data/ssl4eo-s12/val/S2L2A"],
        "num_workers": 2,
        "pin_mem": True,
        "root_path": "/kaggle/input",
        "use_horizontal_flip": False
    },
    "logging": {
        "folder": "/kaggle/working/logs",
        "write_tag": "jepa"
    },
    "mask": {
        "allow_overlap": False,
        "aspect_ratio": [0.75, 1.5],
        "enc_mask_scale": [0.85, 1.0],
        "min_keep": 10,
        "num_enc_masks": 1,
        "num_pred_masks": 4,
        "patch_size": 14,
        "pred_mask_scale": [0.15, 0.2]
    },
    "meta": {
        "copy_data": False,
        "load_checkpoint": False,
        "model_name": "vit_huge",
        "pred_depth": 12,
        "pred_emb_dim": 384,
        "read_checkpoint": None,
        "use_bfloat16": True
    },
    "optimization": {
        "ema": [0.996, 1.0],
        "epochs": 2,
        "final_lr": 1.0e-5,
        "final_weight_decay": 0.4,
        "ipe_scale": 1.0,
        "lr": 0.001,
        "start_lr": 0.0002,
        "warmup": 20,
        "weight_decay": 0.04
    }
}


In [None]:
resume_preempt = False
rank = 0

# -- META
use_bfloat16 = args['meta']['use_bfloat16']
model_name = args['meta']['model_name']
load_model = args['meta']['load_checkpoint'] or resume_preempt
r_file = args['meta']['read_checkpoint']
copy_data = args['meta']['copy_data']
pred_depth = args['meta']['pred_depth']
pred_emb_dim = args['meta']['pred_emb_dim']
if not torch.cuda.is_available():
    device = torch.device('cpu')
else:
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)

# -- DATA
use_horizontal_flip = args['data']['use_horizontal_flip']
# --
batch_size = args['data']['batch_size']
pin_mem = args['data']['pin_mem']
num_workers = args['data']['num_workers']
root_path = args['data']['root_path']
image_folders = args['data']['image_folders']
validation_folders = args['data']['validation_folders']
crop_size = args['data']['crop_size']
crop_scale = args['data']['crop_scale']
# --

# -- MASK
allow_overlap = args['mask']['allow_overlap']  # whether to allow overlap b/w context and target blocks
patch_size = args['mask']['patch_size']  # patch-size for model training
num_enc_masks = args['mask']['num_enc_masks']  # number of context blocks
min_keep = args['mask']['min_keep']  # min number of patches in context block
enc_mask_scale = args['mask']['enc_mask_scale']  # scale of context blocks
num_pred_masks = args['mask']['num_pred_masks']  # number of target blocks
pred_mask_scale = args['mask']['pred_mask_scale']  # scale of target blocks
aspect_ratio = args['mask']['aspect_ratio']  # aspect ratio of target blocks
# --

# -- OPTIMIZATION
ema = args['optimization']['ema']
ipe_scale = args['optimization']['ipe_scale']  # scheduler scale factor (def: 1.0)
wd = float(args['optimization']['weight_decay'])
final_wd = float(args['optimization']['final_weight_decay'])
num_epochs = args['optimization']['epochs']
warmup = args['optimization']['warmup']
start_lr = args['optimization']['start_lr']
lr = args['optimization']['lr']
final_lr = args['optimization']['final_lr']

# -- LOGGING
folder = args['logging']['folder']
tag = args['logging']['write_tag']

os.makedirs(folder, exist_ok=True)
dump = os.path.join(folder, 'params-ijepa.yaml')
with open(dump, 'w') as f:
    yaml.dump(args, f)

## Logging

In [None]:
import logging
log_timings = True
log_freq = 10
checkpoint_freq = 1
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()

In [None]:
def gpu_timer(closure, log_timings=True):
    """ Helper to time gpu-time to execute closure() """
    log_timings = log_timings and torch.cuda.is_available()

    elapsed_time = -1.
    if log_timings:
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()

    result = closure()

    if log_timings:
        end.record()
        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)

    return result, elapsed_time

In [None]:
class CSVLogger(object):

    def __init__(self, fname, *argv):
        self.fname = fname
        self.types = []
        # -- print headers
        with open(self.fname, '+a') as f:
            for i, v in enumerate(argv, 1):
                self.types.append(v[0])
                if i < len(argv):
                    print(v[1], end=',', file=f)
                else:
                    print(v[1], end='\n', file=f)

    def log(self, *argv):
        with open(self.fname, '+a') as f:
            for i, tv in enumerate(zip(self.types, argv), 1):
                end = ',' if i < len(argv) else '\n'
                print(tv[0] % tv[1], end=end, file=f)


class AverageMeter(object):
    """computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.max = float('-inf')
        self.min = float('inf')
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        try:
            self.max = max(val, self.max)
            self.min = min(val, self.min)
        except Exception:
            pass
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
# -- log/checkpointing paths
log_file = os.path.join(folder, f'{tag}_r{rank}.csv')
save_path = os.path.join(folder, f'{tag}' + '-ep{epoch}.pth.tar')
latest_path = os.path.join(folder, f'{tag}-latest.pth.tar')
load_path = None
if load_model:
    load_path = os.path.join(folder, r_file) if r_file is not None else latest_path

# csv logger
csv_logger = CSVLogger(log_file,
                       ('%d', 'epoch'),
                       ('%d', 'itr'),
                       ('%.5f', 'train_loss'),
                       ('%.5f', 'val_loss'),  # Added validation loss
                       ('%.5f', 'mask-A'),
                       ('%.5f', 'mask-B'),
                       ('%d', 'time (ms)'))

## Dataset creation and preprocessing

### Data transformation

#### Masking

In [None]:
import math
from multiprocessing import Value

In [None]:
# class MaskCollator(object):

#     def __init__(
#         self,
#         input_size=(224, 224),
#         patch_size=16,
#         enc_mask_scale=(0.2, 0.8),
#         pred_mask_scale=(0.2, 0.8),
#         aspect_ratio=(0.3, 3.0),
#         nenc=1,
#         npred=2,
#         min_keep=4,
#         allow_overlap=False
#     ):
#         super(MaskCollator, self).__init__()
#         if not isinstance(input_size, tuple):
#             input_size = (input_size, ) * 2
#         self.patch_size = patch_size
#         self.height, self.width = input_size[0] // patch_size, input_size[1] // patch_size
#         self.enc_mask_scale = enc_mask_scale
#         self.pred_mask_scale = pred_mask_scale
#         self.aspect_ratio = aspect_ratio
#         self.nenc = nenc
#         self.npred = npred
#         self.min_keep = min_keep  # minimum number of patches to keep
#         self.allow_overlap = allow_overlap  # whether to allow overlap b/w enc and pred masks
#         self._itr_counter = Value('i', -1)  # collator is shared across worker processes

#     def step(self):
#         i = self._itr_counter
#         with i.get_lock():
#             i.value += 1
#             v = i.value
#         return v

#     def _sample_block_size(self, generator, scale, aspect_ratio_scale):
#         _rand = torch.rand(1, generator=generator).item()
#         # -- Sample block scale
#         min_s, max_s = scale
#         mask_scale = min_s + _rand * (max_s - min_s)
#         max_keep = int(self.height * self.width * mask_scale)
#         # -- Sample block aspect-ratio
#         min_ar, max_ar = aspect_ratio_scale
#         aspect_ratio = min_ar + _rand * (max_ar - min_ar)
#         # -- Compute block height and width (given scale and aspect-ratio)
#         h = int(round(math.sqrt(max_keep * aspect_ratio)))
#         w = int(round(math.sqrt(max_keep / aspect_ratio)))
#         while h >= self.height:
#             h -= 1
#         while w >= self.width:
#             w -= 1

#         return (h, w)

#     def _sample_block_mask(self, b_size, acceptable_regions=None):
#         h, w = b_size

#         def constrain_mask(mask, tries=0):
#             """ Helper to restrict given mask to a set of acceptable regions """
#             N = max(int(len(acceptable_regions)-tries), 0)
#             for k in range(N):
#                 mask *= acceptable_regions[k]
#         # --
#         # -- Loop to sample masks until we find a valid one
#         tries = 0
#         timeout = og_timeout = 20
#         valid_mask = False
#         while not valid_mask:
#             # -- Sample block top-left corner
#             top = torch.randint(0, self.height - h, (1,))
#             left = torch.randint(0, self.width - w, (1,))
#             mask = torch.zeros((self.height, self.width), dtype=torch.int32)
#             mask[top:top+h, left:left+w] = 1
#             # -- Constrain mask to a set of acceptable regions
#             if acceptable_regions is not None:
#                 constrain_mask(mask, tries)
#             mask = torch.nonzero(mask.flatten())
#             # -- If mask too small try again
#             valid_mask = len(mask) > self.min_keep
#             if not valid_mask:
#                 timeout -= 1
#                 if timeout == 0:
#                     tries += 1
#                     timeout = og_timeout
#                     logger.warning(f'Mask generator says: "Valid mask not found, decreasing acceptable-regions [{tries}]"')
#         mask = mask.squeeze()
#         # --
#         mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
#         mask_complement[top:top+h, left:left+w] = 0
#         # --
#         return mask, mask_complement

#     def __call__(self, batch):
#         '''
#         Create encoder and predictor masks when collating imgs into a batch
#         # 1. sample enc block (size + location) using seed
#         # 2. sample pred block (size) using seed
#         # 3. sample several enc block locations for each image (w/o seed)
#         # 4. sample several pred block locations for each image (w/o seed)
#         # 5. return enc mask and pred mask
#         '''
#         B = len(batch)

#         collated_batch = torch.utils.data.default_collate(batch)

#         seed = self.step()
#         g = torch.Generator()
#         g.manual_seed(seed)
#         p_size = self._sample_block_size(
#             generator=g,
#             scale=self.pred_mask_scale,
#             aspect_ratio_scale=self.aspect_ratio)
#         e_size = self._sample_block_size(
#             generator=g,
#             scale=self.enc_mask_scale,
#             aspect_ratio_scale=(1., 1.))

#         collated_masks_pred, collated_masks_enc = [], []
#         min_keep_pred = self.height * self.width
#         min_keep_enc = self.height * self.width
#         for _ in range(B):

#             masks_p, masks_C = [], []
#             for _ in range(self.npred):
#                 mask, mask_C = self._sample_block_mask(p_size)
#                 masks_p.append(mask)
#                 masks_C.append(mask_C)
#                 min_keep_pred = min(min_keep_pred, len(mask))
#             collated_masks_pred.append(masks_p)

#             acceptable_regions = masks_C
#             try:
#                 if self.allow_overlap:
#                     acceptable_regions= None
#             except Exception as e:
#                 logger.warning(f'Encountered exception in mask-generator {e}')

#             masks_e = []
#             for _ in range(self.nenc):
#                 mask, _ = self._sample_block_mask(e_size, acceptable_regions=acceptable_regions)
#                 masks_e.append(mask)
#                 min_keep_enc = min(min_keep_enc, len(mask))
#             collated_masks_enc.append(masks_e)

#         collated_masks_pred = [[cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred]
#         collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
#         # --
#         collated_masks_enc = [[cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc]
#         collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)

#         return collated_batch, collated_masks_enc, collated_masks_pred

In [None]:
class TimeSeriesMaskCollator:
    def __init__(self, num_frames=300, frame_size=16, nenc=1, npred=1):
        self.num_frames = num_frames
        self.frame_size = frame_size
        self.height = frame_size  # fixed spatial height of a frame
        self.width = num_frames * frame_size  # time flows horizontally
        self.nenc = nenc
        self.npred = npred
        self._itr_counter = Value('i', -1)

    def step(self):
        i = self._itr_counter
        with i.get_lock():
            i.value += 1
            v = i.value
        return v

    def _sample_frame_mask(self, generator, exclude_frames=None):
        # Build list of available frame indices
        choices = torch.tensor(
            [i for i in range(self.num_frames) if (exclude_frames is None or i not in exclude_frames)],
            dtype=torch.long
        )
        # Sample one index using the generator
        idx = torch.randint(0, len(choices), (1,), generator=generator).item()
        frame_idx = choices[idx]

        # Calculate top-left corner in new layout (horizontal stacking)
        top = 0
        left = frame_idx * self.frame_size
        mask = torch.zeros((self.height, self.width), dtype=torch.int32)
        mask[top:top+self.frame_size, left:left+self.frame_size] = 1
        return mask, frame_idx

    def build_encoder_mask_from_pred(self, pred_masks):
        enc_mask = torch.ones((self.height, self.width), dtype=torch.int32)
        for pred_mask in pred_masks:
            enc_mask.view(-1)[pred_mask] = 0  # Zero out the masked regions
        return torch.nonzero(enc_mask.flatten()).squeeze()

    def __call__(self, batch):
        B = len(batch)
        collated_batch = torch.utils.data.default_collate(batch)

        seed = self.step()
        g = torch.Generator()
        g.manual_seed(seed)

        collated_masks_enc, collated_masks_pred = [], []

        for _ in range(B):
            pred_masks = []
            pred_frame_idxs = []
            for _ in range(self.npred):
                mask, idx = self._sample_frame_mask(generator=g, exclude_frames=pred_frame_idxs)
                pred_masks.append(torch.nonzero(mask.flatten()).squeeze())
                pred_frame_idxs.append(idx)

            collated_masks_pred.append(pred_masks)

            enc_masks = []
            for _ in range(self.nenc):
                enc_mask = self.build_encoder_mask_from_pred(pred_masks)
                enc_masks.append(enc_mask)

            collated_masks_enc.append(enc_masks)

        collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
        collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)

        return collated_batch, collated_masks_enc, collated_masks_pred


In [None]:
def apply_masks(x, masks):
    """
    :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
    :param masks: list of tensors containing indices of patches in [N] to keep
    """
    all_x = []
    for m in masks:
        mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
        all_x += [torch.gather(x, dim=1, index=mask_keep)]
    return torch.cat(all_x, dim=0)

In [None]:
# mask_collator = MaskCollator(
#     input_size=crop_size,
#     patch_size=patch_size,
#     pred_mask_scale=pred_mask_scale,
#     enc_mask_scale=enc_mask_scale,
#     aspect_ratio=aspect_ratio,
#     nenc=num_enc_masks,
#     npred=num_pred_masks,
#     allow_overlap=allow_overlap,
#     min_keep=min_keep)

In [None]:
mask_collator = TimeSeriesMaskCollator() # defaults to 300 frames of size 16x16

#### Image transforms

In [None]:
import torchvision.transforms as transforms

In [None]:
import numpy as np
import random
from skimage.transform import resize
import torch # Import PyTorch

class Compose:
    """
    Composes several transforms together.
    Args:
        transforms (list of callables): list of transforms to compose.
    """
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img_array):
        """
        Applies the composed transforms to the input array.
        The input can be a NumPy array. The output will be a torch.Tensor
        if torch.from_numpy is the last transform.
        Args:
            img_array (numpy.ndarray): Input image array (C, H, W).
        Returns:
            torch.Tensor: Transformed image tensor.
        """
        for t in self.transforms:
            img_array = t(img_array) # Note: img_array will become a torch.Tensor at the end
        return img_array

def random_resized_crop_np(img_array, size, scale=(0.08, 1.0), interpolation_order=1):
    """
    Crop the given NumPy array to random size and aspect ratio, then resize.
    Applies the same crop/resize to all channels.

    Args:
        img_array (numpy.ndarray): Input image array (C, H, W).
        size (int or tuple): Desired output size of the crop (H_out, W_out).
            If int, it's (size, size).
        scale (tuple): Range of size of the origin size cropped.
        interpolation_order (int): Order of interpolation for resizing (0-5).
                                   0: Nearest-neighbor, 1: Bi-linear (default), 3: Bi-cubic.
    Returns:
        numpy.ndarray: Cropped and resized image array (C, new_H, new_W).
    """
    #img_array = img_array.astype(np.float32)
    C, H, W = img_array.shape
    if isinstance(size, int):
        target_h, target_w = size, size
    else:
        target_h, target_w = size

    area = H * W
    target_area = random.uniform(scale[0], scale[1]) * area
    log_ratio = (np.log(3 / 4), np.log(4 / 3)) # Default aspect ratio range
    aspect_ratio = np.exp(random.uniform(log_ratio[0], log_ratio[1]))

    # Calculate crop dimensions
    for attempt in range(10):
        w = int(round(np.sqrt(target_area * aspect_ratio)))
        h = int(round(np.sqrt(target_area / aspect_ratio)))

        if 0 < w <= W and 0 < h <= H:
            i = random.randint(0, H - h)
            j = random.randint(0, W - w)
            break
    else: # Fallback to center crop if no valid random crop found
        i = (H - h) // 2 if h < H else 0
        j = (W - w) // 2 if w < W else 0
        h = min(h, H - i)
        w = min(w, W - j)

    cropped_channels = []
    for c in range(C):
        cropped_channel = img_array[c, i:i+h, j:j+w]
        resized_channel = resize(cropped_channel, (target_h, target_w),
                                 order=interpolation_order,
                                 mode='reflect',
                                 anti_aliasing=True)
        cropped_channels.append(resized_channel)

    return np.stack(cropped_channels, axis=0)

def random_horizontal_flip_np(img_array, p=0.5):
    """
    Horizontally flip the given NumPy array randomly with a given probability.
    Applies to all channels.
    Args:
        img_array (numpy.ndarray): Input image array (C, H, W).
        p (float): probability of the image being flipped. Default value is 0.5.
    Returns:
        numpy.ndarray: Flipped or original image array.
    """
    if random.random() < p:
        return np.flip(img_array, axis=2).copy()
    return img_array

def scale_to_01_np(img_array, max_int_value=32767.0):
    """
    Scales positive integer values in a NumPy array to the 0-1 range.
    Assumes input values are positive and fit within max_int_value.
    Args:
        img_array (numpy.ndarray): Input image array.
        max_int_value (float): The maximum possible integer value in the original array.
                               For int16, this is typically 32767.
    Returns:
        numpy.ndarray: Scaled image array.
    """
    return img_array / max_int_value


def ensure_13_channels(img_array):
    """
    Ensures the input array has exactly 13 channels by adding the pixel-wise mean
    of the first 12 channels as the 13th channel.

    Args:
        img_array (numpy.ndarray): Input image array (C, H, W)

    Returns:
        numpy.ndarray: Array with exactly 13 channels
    """
    C, H, W = img_array.shape
    img_array = img_array.astype(np.float32)

    # If already 13 channels, return as is
    if C == 13:
        return img_array

    # If 12 channels, add pixel-wise mean as 13th channel
    elif C == 12:
        mean_channel = np.mean(img_array, axis=0, keepdims=True)  # Shape: (1, H, W)
        return np.concatenate([img_array, mean_channel], axis=0)

    # For other numbers of channels (unexpected case)
    else:
        raise ValueError(f"Expected 12 or 13 channels, but got {C}")

# Build the transform list
transform_list = []

transform_list += [ensure_13_channels]

# Data preprocessing into smaller dimension + augmentation
transform_list += [lambda x: random_resized_crop_np(x, crop_size, scale=crop_scale)]
if use_horizontal_flip:
    transform_list += [random_horizontal_flip_np]

# Scaling to 0-1 range (applied to NumPy array)
transform_list += [lambda x: scale_to_01_np(x, max_int_value=32767.0)]

# Final conversion to PyTorch Tensor
# torch.from_numpy will also ensure float32 dtype for the tensor if the numpy array is float32
transform_list += [torch.from_numpy]


# Composition of transforms
transform = Compose(transform_list)

### Dataset & Dataloader

In [None]:
!pip install zarr

In [None]:
import zarr

In [None]:
class SSLLEODataset(torch.utils.data.Dataset):
    def __init__(self, data_paths: list, transform=None, normalize=False, stats=None):
        """Initialize SSLLEO dataset from partial downloads.
        
        Args:
            data_paths (list): Paths to the SSLEO dataset parts.
            transform: Transformations to apply to each sample.
            normalize (bool): Whether to apply channel-wise normalization.
            stats (dict, optional): Dict containing 'mean' and 'std' for normalization.
                                   If None and normalize=True, will use default values.
        """
        self.data_paths = data_paths
        self.transform = transform
        self.normalize = normalize
        self.stats = stats
        
        # Default stats if none provided but normalization requested
        if self.normalize and self.stats is None:
            self.stats = {
                'mean': torch.tensor([0.0] * 13),  # Default zeros for 13 channels
                'std': torch.tensor([1.0] * 13)    # Default ones for 13 channels
            }
        
        # traverse each data path and retrieve the number .zarr files for dataset length count
        self.sample_paths = []
        for path in data_paths:
            for file in os.listdir(path):
                if file.endswith('.zarr'):
                    self.sample_paths.append(os.path.join(path, file))
            
        self.data_len = len(self.sample_paths)*256  # each zarr file contains 256 samples (64*4)

    def __len__(self):
        return self.data_len

    def __getitem__(self, index):
        # unpack images from zarr files
        zarr_index = index // 256
        sample_index = index % 256
        zarr_path = self.sample_paths[zarr_index]
        zarr_data = zarr.open(zarr_path, mode='r')

        bands_data = zarr_data['bands'][:]  # bands data is of dimension (64, 4, 12/13, 264, 264)
        bands_data = np.array(bands_data) 

        # combining the first two dimensions into one dimension
        bands_data = bands_data.reshape(-1, *bands_data.shape[2:])

        # retrieving a single sample
        sample = bands_data[sample_index]
        
        # Apply transformations if specified
        if self.transform is not None:
            sample = self.transform(sample)
            
        # Apply normalization if specified (after other transforms but before returning)
        if self.normalize and isinstance(sample, torch.Tensor):
            # Ensure sample is a tensor and normalize channel-wise
            # Assuming sample shape is (C, H, W)
            sample = self.normalize_sample(sample)
            
        return sample
    
    def normalize_sample(self, sample):
        """Apply channel-wise normalization to the sample."""
        # Ensure the mean and std are on the same device as the sample
        mean = self.stats['mean'].to(sample.device)
        std = self.stats['std'].to(sample.device)
        
        # Expand dimensions to match the sample shape: (C) -> (C, 1, 1)
        mean = mean.view(-1, 1, 1)
        std = std.view(-1, 1, 1)
        
        # Normalize
        return (sample - mean) / (std + 1e-8)  # Add small epsilon for numerical stability

In [None]:
def calculate_dataset_stats(dataset, num_samples=None, batch_size=64):
    """
    Calculate mean, std, and max of the dataset across channels.
    
    Args:
        dataset: A dataset instance without normalization.
        num_samples: Number of samples to use. If None, uses all samples.
        batch_size: Batch size for processing.
        
    Returns:
        dict: Dictionary containing 'mean', 'std', and 'max' tensors.
    """
    import torch
    
    loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size,
        shuffle=True,
        num_workers=4
    )
    
    if num_samples is not None:
        total_batches = min(num_samples // batch_size + (1 if num_samples % batch_size else 0), 
                          len(loader))
    else:
        total_batches = len(loader)
    
    channels_sum = 0
    channels_squared_sum = 0
    channels_max = None  # Will initialize after seeing the first batch
    num_batches = 0
    num_samples_seen = 0
    
    print(f"Calculating dataset statistics across {total_batches} batches...")
    
    for i, data in enumerate(loader):

        if not isinstance(data, torch.Tensor):
            print(f"Warning: Batch {i} returned non-tensor data, skipping")
            continue
            
        data = data.view(data.size(0), data.size(1), -1)  # [B, C, H*W]

        print(data.shape)
        
        if channels_max is None:
            channels_max = torch.max(data, dim=2)[0].max(dim=0)[0]
        else:
            batch_max = torch.max(data, dim=2)[0].max(dim=0)[0]
            channels_max = torch.max(channels_max, batch_max)
        
        channels_sum += torch.sum(data, dim=[0, 2])
        channels_squared_sum += torch.sum(data**2, dim=[0, 2])
        
        batch_size_actual = data.size(0)
        num_samples_seen += batch_size_actual
        num_batches += 1
        
        if i % 10 == 0:
            print(f"Processed {i}/{total_batches} batches ({num_samples_seen} samples)")

        if i+1==total_batches:
            break
    
    pixels_per_sample = data.size(2)
    
    print(pixels_per_sample)
    print(num_samples_seen)

    mean = channels_sum / (num_samples_seen * pixels_per_sample)
    std = torch.sqrt(
        channels_squared_sum / (num_samples_seen * pixels_per_sample) - mean**2
    )
    
    return {'mean': mean, 'std': std, 'max': channels_max}


In [None]:
# # Step 1: Create dataset without normalization to calculate statistics
# dataset_for_stats = SSLLEODataset(
#     data_paths=[os.path.join(root_path, path) for path in image_folders],
#     transform=transform,  
#     normalize=False
# )

# # Step 2: Calculate dataset statistics (can be time-consuming)
# # You might want to save these stats after calculation
# stats = calculate_dataset_stats(dataset_for_stats, num_samples=64)  # Using subset for speed
# print(f"Channel means: {stats['mean']}")
# print(f"Channel stds: {stats['std']}")

In [None]:
stats = {"mean": torch.tensor([0.0673, 0.0660, 0.0688, 0.0720, 0.0799, 0.0970, 0.1048, 0.1051, 0.1102, 0.0845, 0.0675, 0.0926, 0.0832]), 
         "std": torch.tensor([0.0327, 0.0324, 0.0324, 0.0388, 0.0387, 0.0388, 0.0412, 0.0417, 0.0425, 0.0474, 0.0484, 0.0416, 0.0363])}

In [None]:
# Step 3: Create your actual training dataset with normalization
dataset = SSLLEODataset(
    data_paths=[os.path.join(root_path, path) for path in image_folders],
    transform=transform,
    normalize=True,
    stats=stats
)

# Save the stats for future use
torch.save(stats, os.path.join(folder, 'dataset_stats.pth'))

In [None]:
dataset[1]

In [None]:
data_loader = torch.utils.data.DataLoader(
    dataset,
    collate_fn=mask_collator,
    batch_size=batch_size,
    drop_last=True,
    pin_memory=pin_mem,
    num_workers=num_workers,
    persistent_workers=False)

ipe = len(data_loader)

In [None]:
# Create validation dataset using a portion of the data or separate validation data
val_dataset = SSLLEODataset(
    data_paths=[os.path.join(root_path, path) for path in validation_folders],  # Adjust path as needed
    transform=transform,
    normalize=True,
    stats=stats
)

# Create validation dataloader
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    collate_fn=mask_collator,
    batch_size=batch_size,
    drop_last=True,
    pin_memory=pin_mem,
    num_workers=num_workers // 2,  # Use fewer workers for validation
    persistent_workers=False
)

logger.info(f"Created validation dataloader with {len(val_loader)} batches")

## Train

In [None]:
import math
from functools import partial
import numpy as np
import torch
import torch.nn as nn

### Model backbone (vision transformer)

In [None]:
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def repeat_interleave_batch(x, B, repeat):
    N = len(x) // B
    x = torch.cat([
        torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0)
        for i in range(N)
    ], dim=0)
    return x

In [None]:
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid length
    return:
    pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega   # (D/2,)

    pos = pos.reshape(-1)   # (M,)
    out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=13, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)


class VisionTransformerPredictor(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        num_patches,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=6,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        **kwargs
    ):
        super().__init__()
        self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        # --
        self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim),
                                                requires_grad=False)
        predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1],
                                                      int(num_patches**.5),
                                                      cls_token=False)
        self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0))
        # --
        self.predictor_blocks = nn.ModuleList([
            Block(
                dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.predictor_norm = norm_layer(predictor_embed_dim)
        self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
        # ------
        self.init_std = init_std
        trunc_normal_(self.mask_token, std=self.init_std)
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.predictor_blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x, masks_x, masks):
        assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices'

        if not isinstance(masks_x, list):
            masks_x = [masks_x]

        if not isinstance(masks, list):
            masks = [masks]

        # -- Batch Size
        B = len(x) // len(masks_x)

        # -- map from encoder-dim to pedictor-dim
        x = self.predictor_embed(x)

        # -- add positional embedding to x tokens
        x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
        x += apply_masks(x_pos_embed, masks_x)

        _, N_ctxt, D = x.shape

        # -- concat mask tokens to x
        pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
        pos_embs = apply_masks(pos_embs, masks)
        pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
        # --
        pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
        # --
        pred_tokens += pos_embs
        x = x.repeat(len(masks), 1, 1)
        x = torch.cat([x, pred_tokens], dim=1)

        # -- fwd prop
        for blk in self.predictor_blocks:
            x = blk(x)
        x = self.predictor_norm(x)

        # -- return preds for mask tokens
        x = x[:, N_ctxt:]
        x = self.predictor_proj(x)

        return x


class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=13,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        **kwargs
    ):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads
        # --
        self.patch_embed = PatchEmbed(
            img_size=img_size[0],
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # --
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches**.5),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # --
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # ------
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x, masks=None):
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]

        # -- patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape

        # -- add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        x = x + pos_embed

        # -- mask x
        if masks is not None:
            x = apply_masks(x, masks)

        # -- fwd prop
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        return x

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)


def vit_predictor(**kwargs):
    model = VisionTransformerPredictor(
        mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs)
    return model


def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_small(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_base(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_large(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_huge(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_giant(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


VIT_EMBED_DIMS = {
    'vit_tiny': 192,
    'vit_small': 384,
    'vit_base': 768,
    'vit_large': 1024,
    'vit_huge': 1280,
    'vit_giant': 1408,
}

### Helper functions

In [None]:
def load_checkpoint(
    device,
    r_path,
    encoder,
    predictor,
    target_encoder,
    opt,
    scaler,
):
    try:
        checkpoint = torch.load(r_path, map_location=torch.device('cpu'))
        epoch = checkpoint['epoch']

        # -- loading encoder
        pretrained_dict = checkpoint['encoder']
        msg = encoder.load_state_dict(pretrained_dict)
        logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}')

        # -- loading predictor
        pretrained_dict = checkpoint['predictor']
        msg = predictor.load_state_dict(pretrained_dict)
        logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}')

        # -- loading target_encoder
        if target_encoder is not None:
            print(list(checkpoint.keys()))
            pretrained_dict = checkpoint['target_encoder']
            msg = target_encoder.load_state_dict(pretrained_dict)
            logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}')

        # -- loading optimizer
        opt.load_state_dict(checkpoint['opt'])
        if scaler is not None:
            scaler.load_state_dict(checkpoint['scaler'])
        logger.info(f'loaded optimizers from epoch {epoch}')
        logger.info(f'read-path: {r_path}')
        del checkpoint

    except Exception as e:
        logger.info(f'Encountered exception when loading checkpoint {e}')
        epoch = 0

    return encoder, predictor, target_encoder, opt, scaler, epoch


def init_model(
    device,
    patch_size=16,
    model_name='vit_base',
    crop_size=224,
    pred_depth=6,
    pred_emb_dim=384
):
    encoder = vit_small(
        img_size=[crop_size],
        patch_size=patch_size)
    predictor = vit_predictor(
        num_patches=encoder.patch_embed.num_patches,
        embed_dim=encoder.embed_dim,
        predictor_embed_dim=pred_emb_dim,
        depth=pred_depth,
        num_heads=encoder.num_heads)

    def init_weights(m):
        if isinstance(m, torch.nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, torch.nn.LayerNorm):
            torch.nn.init.constant_(m.bias, 0)
            torch.nn.init.constant_(m.weight, 1.0)

    for m in encoder.modules():
        init_weights(m)

    for m in predictor.modules():
        init_weights(m)

    encoder.to(device)
    predictor.to(device)
    logger.info(encoder)
    return encoder, predictor


def init_opt(
    encoder,
    predictor,
    iterations_per_epoch,
    start_lr,
    ref_lr,
    warmup,
    num_epochs,
    wd=1e-6,
    final_wd=1e-6,
    final_lr=0.0,
    use_bfloat16=False,
    ipe_scale=1.25
):
    param_groups = [
        {
            'params': (p for n, p in encoder.named_parameters()
                       if ('bias' not in n) and (len(p.shape) != 1))
        }, {
            'params': (p for n, p in predictor.named_parameters()
                       if ('bias' not in n) and (len(p.shape) != 1))
        }, {
            'params': (p for n, p in encoder.named_parameters()
                       if ('bias' in n) or (len(p.shape) == 1)),
            'WD_exclude': True,
            'weight_decay': 0
        }, {
            'params': (p for n, p in predictor.named_parameters()
                       if ('bias' in n) or (len(p.shape) == 1)),
            'WD_exclude': True,
            'weight_decay': 0
        }
    ]

    logger.info('Using AdamW')
    optimizer = torch.optim.AdamW(param_groups)
    scheduler = WarmupCosineSchedule(
        optimizer,
        warmup_steps=int(warmup*iterations_per_epoch),
        start_lr=start_lr,
        ref_lr=ref_lr,
        final_lr=final_lr,
        T_max=int(ipe_scale*num_epochs*iterations_per_epoch))
    wd_scheduler = CosineWDSchedule(
        optimizer,
        ref_wd=wd,
        final_wd=final_wd,
        T_max=int(ipe_scale*num_epochs*iterations_per_epoch))
    scaler = torch.cuda.amp.GradScaler() if use_bfloat16 else None
    return optimizer, scaler, scheduler, wd_scheduler

In [None]:
class WarmupCosineSchedule(object):

    def __init__(
        self,
        optimizer,
        warmup_steps,
        start_lr,
        ref_lr,
        T_max,
        last_epoch=-1,
        final_lr=0.
    ):
        self.optimizer = optimizer
        self.start_lr = start_lr
        self.ref_lr = ref_lr
        self.final_lr = final_lr
        self.warmup_steps = warmup_steps
        self.T_max = T_max - warmup_steps
        self._step = 0.

    def step(self):
        self._step += 1
        if self._step < self.warmup_steps:
            progress = float(self._step) / float(max(1, self.warmup_steps))
            new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
        else:
            # -- progress after warmup
            progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max))
            new_lr = max(self.final_lr,
                         self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress)))

        for group in self.optimizer.param_groups:
            group['lr'] = new_lr

        return new_lr


class CosineWDSchedule(object):

    def __init__(
        self,
        optimizer,
        ref_wd,
        T_max,
        final_wd=0.
    ):
        self.optimizer = optimizer
        self.ref_wd = ref_wd
        self.final_wd = final_wd
        self.T_max = T_max
        self._step = 0.

    def step(self):
        self._step += 1
        progress = self._step / self.T_max
        new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress))

        if self.final_wd <= self.ref_wd:
            new_wd = max(self.final_wd, new_wd)
        else:
            new_wd = min(self.final_wd, new_wd)

        for group in self.optimizer.param_groups:
            if ('WD_exclude' not in group) or not group['WD_exclude']:
                group['weight_decay'] = new_wd
        return new_wd

### Initializing model

In [None]:
# -- init model
encoder, predictor = init_model(
    device=device,
    patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim,
    model_name=model_name)
target_encoder = copy.deepcopy(encoder)

In [None]:
# -- init optimizer and scheduler
optimizer, scaler, scheduler, wd_scheduler = init_opt(
    encoder=encoder,
    predictor=predictor,
    wd=wd,
    final_wd=final_wd,
    start_lr=start_lr,
    ref_lr=lr,
    final_lr=final_lr,
    iterations_per_epoch=ipe,
    warmup=warmup,
    num_epochs=num_epochs,
    ipe_scale=ipe_scale,
    use_bfloat16=use_bfloat16)
    # encoder = DistributedDataParallel(encoder, static_graph=True)
    # predictor = DistributedDataParallel(predictor, static_graph=True)
    # target_encoder = DistributedDataParallel(target_encoder)
for p in target_encoder.parameters():
    p.requires_grad = False

# -- momentum schedule
momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale)
                      for i in range(int(ipe*num_epochs*ipe_scale)+1))

start_epoch = 0
# -- load training checkpoint
if load_model:
    encoder, predictor, target_encoder, optimizer, scaler, start_epoch = load_checkpoint(
        device=device,
        r_path=load_path,
        encoder=encoder,
        predictor=predictor,
        target_encoder=target_encoder,
        opt=optimizer,
        scaler=scaler)
    for _ in range(start_epoch*ipe):
        scheduler.step()
        wd_scheduler.step()
        next(momentum_scheduler)
        mask_collator.step()

def save_checkpoint(epoch):
    save_dict = {
        'encoder': encoder.state_dict(),
        'predictor': predictor.state_dict(),
        'target_encoder': target_encoder.state_dict(),
        'opt': optimizer.state_dict(),
        'scaler': None if scaler is None else scaler.state_dict(),
        'epoch': epoch,
        'loss': loss_meter.avg,
        'batch_size': batch_size,
        'lr': lr
    }
    if rank == 0:
        torch.save(save_dict, latest_path)
        if (epoch + 1) % checkpoint_freq == 0:
            torch.save(save_dict, save_path.format(epoch=f'{epoch + 1}'))

In [None]:
def validate(epoch):
    """Run validation on the validation dataset and return average loss."""
    logger.info('Running validation...')
    val_loss_meter = AverageMeter()
    val_maskA_meter = AverageMeter()
    val_maskB_meter = AverageMeter()
    val_time_meter = AverageMeter()
    
    # Set models to eval mode
    encoder.eval()
    predictor.eval()
    target_encoder.eval()
    
    with torch.no_grad():
        for itr, (udata, masks_enc, masks_pred) in enumerate(val_loader):
            # Load and process images
            imgs = udata.to(device, non_blocking=True)
            masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
            masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
            
            val_maskA_meter.update(len(masks_1[0][0]))
            val_maskB_meter.update(len(masks_2[0][0]))
            
            # Forward pass
            def val_step():
                with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16):
                    # Target encoding
                    h = target_encoder(imgs)
                    h = F.layer_norm(h, (h.size(-1),))
                    B = len(h)
                    h = apply_masks(h, masks_2)
                    h = repeat_interleave_batch(h, B, repeat=len(masks_1))
                    
                    # Context encoding and prediction
                    z = encoder(imgs, masks_1)
                    z = predictor(z, masks_1, masks_2)
                    
                    # Loss calculation
                    loss = F.smooth_l1_loss(z, h)
                    return float(loss)
            
            loss, etime = gpu_timer(val_step)
            val_loss_meter.update(loss)
            val_time_meter.update(etime)
            
            # Log progress occasionally
            if itr % (log_freq * 2) == 0:
                logger.info(f'Val: [{epoch + 1}, {itr}] loss: {val_loss_meter.avg:.3f} '
                           f'masks: {val_maskA_meter.avg:.1f} {val_maskB_meter.avg:.1f} '
                           f'({val_time_meter.avg:.1f} ms)')
    
    # Set models back to training mode
    encoder.train()
    predictor.train()
    
    return val_loss_meter.avg

### Train loop

In [None]:
# Add this configuration near your other parameters
val_frequency = 120  # Run validation every 100 training steps
best_val_loss = float('inf')

In [None]:
# -- TRAINING LOOP
global_step = 0
for epoch in range(start_epoch, num_epochs):
    logger.info('Epoch %d' % (epoch + 1))

    loss_meter = AverageMeter()
    maskA_meter = AverageMeter()
    maskB_meter = AverageMeter()
    time_meter = AverageMeter()

    for itr, (udata, masks_enc, masks_pred) in enumerate(data_loader):
        new_validation = False
        print(udata.shape)
        if global_step % val_frequency == 0:
            new_validation = True
            # Run validation and check if we need to save the model
            val_loss = validate(epoch)
            print(f"Validation loss: {val_loss}")
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                save_checkpoint(epoch)
        
        def load_imgs():
            # -- unsupervised imgs
            imgs = udata.to(device, non_blocking=True) # udata[0]
            masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
            masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
            return (imgs, masks_1, masks_2)
        imgs, masks_enc, masks_pred = load_imgs()
        maskA_meter.update(len(masks_enc[0][0]))
        maskB_meter.update(len(masks_pred[0][0]))

        def train_step():
            _new_lr = scheduler.step()
            _new_wd = wd_scheduler.step()
            # --

            def forward_target():
                with torch.no_grad():
                    h = target_encoder(imgs)
                    h = F.layer_norm(h, (h.size(-1),))  # normalize over feature-dim
                    B = len(h)
                    # -- create targets (masked regions of h)
                    h = apply_masks(h, masks_pred)
                    h = repeat_interleave_batch(h, B, repeat=len(masks_enc))
                    return h

            def forward_context():
                z = encoder(imgs, masks_enc)
                z = predictor(z, masks_enc, masks_pred)
                return z

            def loss_fn(z, h):
                loss = F.smooth_l1_loss(z, h)
                return loss

            # Step 1. Forward
            with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16):
                h = forward_target()
                z = forward_context()
                loss = loss_fn(z, h)

            #  Step 2. Backward & step
            if use_bfloat16:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            # grad_stats = grad_logger(encoder.named_parameters())
            optimizer.zero_grad()

            # Step 3. momentum update of target encoder
            with torch.no_grad():
                m = next(momentum_scheduler)
                for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()):
                    param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)

            return (float(loss), _new_lr, _new_wd, None)#grad_stats)
        (loss, _new_lr, _new_wd, grad_stats), etime = gpu_timer(train_step)
        loss_meter.update(loss)
        time_meter.update(etime)

        # -- Logging
        def log_stats():
            if new_validation:
                log_val_value = val_loss
            else:
                log_val_value = None
            csv_logger.log(epoch + 1, itr, loss, log_val_value, maskA_meter.val, maskB_meter.val, etime)
            if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss):
                logger.info('[%d, %5d] loss: %.3f '
                            'masks: %.1f %.1f '
                            '[wd: %.2e] [lr: %.2e] '
                            '[mem: %.2e] '
                            '(%.1f ms)'
                            % (epoch + 1, itr,
                               loss_meter.avg,
                               maskA_meter.avg,
                               maskB_meter.avg,
                               _new_wd,
                               _new_lr,
                               torch.cuda.max_memory_allocated() / 1024.**2,
                               time_meter.avg))
                print(f"loss: {loss_meter.avg}, maskA: {maskA_meter.avg}, maskB: {maskB_meter.avg}")

                if grad_stats is not None:
                    logger.info('[%d, %5d] grad_stats: [%.2e %.2e] (%.2e, %.2e)'
                                % (epoch + 1, itr,
                                   grad_stats.first_layer,
                                   grad_stats.last_layer,
                                   grad_stats.min,
                                   grad_stats.max))

        log_stats()

        assert not np.isnan(loss), 'loss is nan'

        global_step += 1

    # -- Save Checkpoint after every epoch
    logger.info('avg. loss %.3f' % loss_meter.avg)
    save_checkpoint(epoch+1)

In [None]:
def save_checkpoint(epoch, global_step=None, val_loss=None):
    save_dict = {
        'encoder': encoder.state_dict(),
        'predictor': predictor.state_dict(),
        'target_encoder': target_encoder.state_dict(),
        'opt': optimizer.state_dict(),
        'scaler': None if scaler is None else scaler.state_dict(),
        'epoch': epoch,
        'step': global_step,
        'train_loss': loss_meter.avg,
        'val_loss': val_loss,
        'batch_size': batch_size,
        'lr': lr
    }
    if rank == 0:
        torch.save(save_dict, latest_path)
        if (epoch + 1) % checkpoint_freq == 0:
            torch.save(save_dict, save_path.format(epoch=f'{epoch + 1}'))