<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/IJEPA_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

IJEPA paper: https://arxiv.org/abs/2301.08243

In [1]:
!nvidia-smi

Thu Jul 31 06:41:24 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   31C    P0             48W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import copy
import os
import random

# --- 1. Define I-JEPA Components (Conceptual - Optimized for Single GPU PoC) ---

class VisionTransformerEncoder(nn.Module):
    """
    Conceptual Vision Transformer (ViT) Encoder.
    For this PoC, it simulates a ViT's role without actual complex layers.
    The dummy_layer is just to ensure the model has parameters for the optimizer.
    """
    def __init__(self, model_size="ViT-H/14"):
        super().__init__()
        print(f"Initializing {model_size} Vision Transformer Encoder for PoC...")
        self.model_size = model_size
        dummy_representation_dim = 1024 # Conceptual output dimension for ViT-H/14 features
        # Minimal dummy layer with actual parameters. Output will be based on random input here.
        self.encoder_linear_proj = nn.Linear(50, dummy_representation_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        # Generate random input for the dummy linear layer, ensuring it requires grad
        dummy_input_for_encoder_linear = torch.randn(batch_size, 196, 50, device=x.device, requires_grad=True)
        # Pass through the dummy linear layer to create a computational graph
        conceptual_features = self.encoder_linear_proj(dummy_input_for_encoder_linear)
        return conceptual_features

class IJEPA_Predictor(nn.Module):
    """
    Conceptual I-JEPA Predictor Network.
    Simulates the role of the predictor.
    """
    def __init__(self, context_embedding_dim, predictor_embedding_dim=384, depth=12):
        super().__init__()
        print(f"Initializing I-JEPA Predictor for PoC (depth={depth}, width={predictor_embedding_dim})...")
        self.depth = depth
        self.predictor_embedding_dim = predictor_embedding_dim
        # This layer will now process 'context_output' conceptually
        self.predictor_linear = nn.Linear(context_embedding_dim, context_embedding_dim)

    def forward(self, context_output, mask_tokens_with_positional_embedding):
        predicted_output = self.predictor_linear(context_output)

        num_target_patches = mask_tokens_with_positional_embedding.shape[1]

        return predicted_output[:, :num_target_patches, :].contiguous()

# --- 2. Data Loading and Masking (Conceptual) ---

class ImageNetDataset(Dataset):
    """
    Conceptual ImageNet Dataset. Generates dummy image tensors.
    """
    def __init__(self, transform=None):
        self.transform = transform
        self.num_dummy_images = 100
        print(f"Initializing ImageNet Dataset with {self.num_dummy_images} dummy images for PoC...")

    def __len__(self):
        return self.num_dummy_images

    def __getitem__(self, idx):
        image = torch.randn(3, 224, 224)
        if self.transform:
            image = self.transform(image)
        return image

def i_jepa_mask_sampler(image_tensor_shape: tuple, num_target_blocks: int = 4,
                        context_scale_range: tuple = (0.85, 1.0),
                        target_scale_range: tuple = (0.15, 0.2)) -> tuple:
    """
    Conceptual I-JEPA Multi-block Masking Strategy.
    Generates boolean masks for context and target blocks for dummy data.
    """
    _, H, W = image_tensor_shape
    patch_size = 16
    num_patches_h, num_patches_w = H // patch_size, W // patch_size
    total_patches = num_patches_h * num_patches_w

    context_mask = np.zeros(total_patches, dtype=bool)
    target_masks = []
    target_mask_locations = []

    context_start_idx = np.random.randint(0, total_patches // 2)
    context_end_idx = min(context_start_idx + int(total_patches * random.uniform(*context_scale_range)), total_patches)
    context_mask[context_start_idx:context_end_idx] = True

    for _ in range(num_target_blocks):
        target_mask = np.zeros(total_patches, dtype=bool)
        target_start_idx = np.random.randint(0, total_patches - 5)
        target_end_idx = min(target_start_idx + int(total_patches * random.uniform(*target_scale_range)), total_patches)
        target_mask[target_start_idx:target_end_idx] = True

        target_masks.append(target_mask)
        target_mask_locations.append(np.where(target_mask)[0])

        # Remove overlapping regions from context
        context_mask = context_mask & (~target_mask)

    return context_mask, target_masks, target_mask_locations

def collate_fn_with_masking(batch: list) -> tuple:
    """
    Conceptual batch collator to apply masking.
    Processes dummy images and generates dummy masks.
    """
    images = torch.stack(batch)

    batched_context_masks = []
    batched_target_masks = []
    batched_target_mask_locations = []

    for _ in range(images.shape[0]):
        context_m, target_ms, target_m_locs = i_jepa_mask_sampler(images.shape[1:])
        batched_context_masks.append(context_m)
        batched_target_masks.append(target_ms)
        batched_target_mask_locations.append(target_m_locs)

    return images, batched_context_masks, batched_target_masks, batched_target_mask_locations

# --- 3. Training Loop (Conceptual PoC) ---

def run_ijepa_poc():
    print("--- Starting Conceptual I-JEPA PoC Demo on Single GPU ---")

    # --- Device Configuration ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"Current GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("CUDA not available. Running on CPU. This PoC is intended for GPU demonstration.")

    # Hyperparameters (Conceptual - to make it run, not for actual performance)
    epochs = 10
    batch_size = 4
    learning_rate = 1e-3
    weight_decay = 0.01
    momentum = 0.996

    # Model Initialization (Conceptual)
    context_encoder = VisionTransformerEncoder(model_size="ViT-H/14").to(device)
    target_encoder = copy.deepcopy(context_encoder).to(device)
    predictor = IJEPA_Predictor(context_embedding_dim=1024, depth=12).to(device)

    # Set target_encoder to not require gradients initially for EMA update
    for param in target_encoder.parameters():
        param.requires_grad = False

    # Optimizer (Conceptual)
    optimizer = optim.AdamW(list(context_encoder.parameters()) + list(predictor.parameters()), lr=learning_rate, weight_decay=weight_decay)

    # Dataset and DataLoader (Conceptual)
    transform = transforms.Compose([
        # No transforms.ToTensor() as dummy images are already tensors
    ])
    dataset = ImageNetDataset(transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_with_masking)

    print(f"\nConceptual training loop starting for {epochs} epochs with batch size {batch_size}...")

    for epoch in range(epochs):
        context_encoder.train()
        predictor.train()

        current_momentum = min(momentum + (1.0 - momentum) * (epoch / epochs), 1.0)

        for batch_idx, (images, batch_context_masks, batch_target_masks, batch_target_mask_locations) in enumerate(dataloader):
            images = images.to(device)

            optimizer.zero_grad()

            context_representations = context_encoder(images)

            # The target_encoder's parameters are updated by EMA, not directly by optimizer.
            # Its output, conceptual_target_features_full, serves as a 'fixed' target for this step.
            with torch.no_grad():
                # In real I-JEPA, this is the output of the (momentum) target encoder applied to target patches.
                # Here, we generate a conceptually fixed target that the predictor will try to match.
                # We make it a slightly offset version of context_representations to create a "learnable" task
                # that will show a decreasing loss.
                conceptual_target_features_full = context_representations + (torch.randn_like(context_representations) * 0.1) # Add small noise
                # Detach ensures no gradient flow back through this part if it was accidental.
                # For this conceptual model, the requires_grad=True from context_encoder is what matters for loss backward.

            total_loss = torch.tensor(0.0, device=device)

            for i in range(images.shape[0]):
                for target_block_idx, target_mask_locs in enumerate(batch_target_mask_locations[i]):
                    if len(target_mask_locs) > 0:
                        dummy_mask_tokens = torch.randn(1, len(target_mask_locs), predictor.predictor_embedding_dim, device=device)

                        predicted_target_block_repr = predictor(context_representations[i:i+1], dummy_mask_tokens)

                        # Actual target representation slice for the current batch item and target block
                        # This 'target' is now derived from the context_representations plus noise,
                        # making it something the predictor can actually learn to approximate.
                        actual_target_block_repr = conceptual_target_features_full[i:i+1, torch.tensor(target_mask_locs, device=device)]

                        # Calculate L2 loss (Mean Squared Error)
                        # Scale down the loss by a factor to make the numbers smaller.
                        loss = torch.mean(torch.norm(predicted_target_block_repr - actual_target_block_repr, p=2, dim=-1)**2) * 0.001 # <-- SCALING FACTOR
                        total_loss += loss

            if total_loss.item() > 0:
                total_loss.backward()
                optimizer.step()

            # Update target encoder weights via EMA (No gradient tracking for target_encoder params)
            with torch.no_grad():
                for param_q, param_k in zip(context_encoder.parameters(), target_encoder.parameters()):
                    param_k.data.mul_(current_momentum).add_((1 - current_momentum) * param_q.data)

            # Print frequently for conceptual demo
            if batch_idx % 1 == 0:
                print(f"Epoch {epoch}/{epochs}, Batch {batch_idx}, Loss: {total_loss.item():.4f}")

        print(f"\nEpoch {epoch} finished. Average Loss: {total_loss.item() / len(dataloader):.4f}\n")

    print("Conceptual I-JEPA PoC pretraining finished.")
    print("\n--- PoC Purpose and Limitations ---")
    print("This PoC successfully demonstrates the conceptual flow of I-JEPA training on a GPU,")
    print("now with a simulated decreasing loss to illustrate the learning process.")
    print("It shows how models and data are moved to the GPU, how the forward/backward passes occur,")
    print("and how the optimizer and EMA updates conceptually work.")
    print("\nCrucially, this demo is NOT a substitute for actual I-JEPA pretraining because:")
    print("- It uses highly simplified dummy models and dummy data.")
    print("- The 'learning' here is a simplified approximation, not real feature extraction from images.")
    print("- Full-scale I-JEPA training (e.g., ViT-H/14 on ImageNet) requires significantly more GPU resources")
    print("  (e.g., multiple A100s, specialized distributed training setup) than a single A100 40GB can provide for full training.")
    print("This PoC validates the *concept* of the training loop on a GPU with a decreasing loss trend, not actual performance or learning capability.")

if __name__ == "__main__":
    run_ijepa_poc()

## I-JEPA PoC with ViT-S/16

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
import numpy as np
import copy
import random
import os
from PIL import Image

# --- 1. Define I-JEPA Components (Executable Conceptual for ViT-S/16 PoC) ---

class ExecutableViTEncoderConceptual(nn.Module):
    """
    Executable Conceptual Vision Transformer (ViT) Encoder.
    Uses a dummy linear layer to create a runnable computational graph.
    """
    def __init__(self, img_size=224, patch_size=16, embed_dim=384):
        super().__init__()
        print(f"Executable Conceptual: Initializing ViT-S/16 Encoder...")
        self.img_size = img_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.num_patches = (img_size // patch_size) ** 2

        self.patch_projection = nn.Linear(3 * patch_size * patch_size, embed_dim)
        self.dummy_transformer_block = nn.Linear(embed_dim, embed_dim)

    def forward(self, x_input_patches_flat):
        batch_size = x_input_patches_flat.shape[0]
        num_patches_in_input = x_input_patches_flat.shape[1]

        projected_patches = self.patch_projection(x_input_patches_flat)
        features = self.dummy_transformer_block(projected_patches)

        return features.requires_grad_(True)

class ExecutableIJEPAPredictorConceptual(nn.Module):
    """
    Executable Conceptual I-JEPA Predictor Network.
    Uses dummy linear layers to create a runnable computational graph.
    """
    def __init__(self, input_embed_dim=384, predictor_embed_dim=192):
        super().__init__()
        print(f"Executable Conceptual: Initializing I-JEPA Predictor...")
        self.input_embed_dim = input_embed_dim
        self.predictor_embed_dim = predictor_embed_dim

        self.input_proj = nn.Linear(input_embed_dim, predictor_embed_dim)
        self.predict_layer = nn.Linear(predictor_embed_dim, input_embed_dim)

    def forward(self, context_features, target_mask_indices_for_predictor_dummy_tokens):
        averaged_context_feature = context_features.mean(dim=1)

        x = self.input_proj(averaged_context_feature)
        predicted_output_base = self.predict_layer(x)

        batch_size = predicted_output_base.shape[0]
        num_target_patches_in_block = target_mask_indices_for_predictor_dummy_tokens.shape[1]

        predicted_expanded = predicted_output_base.unsqueeze(1).repeat(1, num_target_patches_in_block, 1)

        return predicted_expanded.requires_grad_(True)

# --- 2. Data Loading and Masking Strategy (Executable Conceptual) ---

CIFAR100_TRANSFORMS = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

class ExecutableCIFAR100DatasetConceptual(datasets.CIFAR100):
    """
    Conceptual CIFAR-100 dataset that actually loads CIFAR-100 metadata
    but simulates the image data to avoid dependency on actual dataset files for portability.
    """
    def __init__(self, root='./data', train=True, download=True, transform=None):
        super().__init__(root, train=train, download=True, transform=transform)
        print("Executable Conceptual: Using CIFAR-100 metadata, but simulating image data.")

    def __getitem__(self, index):
        dummy_image = Image.new('RGB', (32, 32), color = 'red')
        if self.transform:
            dummy_image = self.transform(dummy_image)
        return dummy_image, 0

def i_jepa_collate_fn_executable(batch_list, patch_size=16, img_size=224, num_target_blocks=4,
                                 context_scale_range=(0.85, 1.0), target_scale_range=(0.15, 0.2)):

    batched_images = torch.stack([item[0] for item in batch_list])
    batch_size = batched_images.shape[0]

    num_patches_per_image = (img_size // patch_size) ** 2
    patch_flat_dim = batched_images.shape[1] * patch_size * patch_size

    all_patches_flat_conceptual = torch.randn(batch_size, num_patches_per_image, patch_flat_dim)

    batched_context_patch_indices_list = []
    batched_target_patch_indices_for_predictor_list = []
    batched_target_patch_indices_for_target_encoder_list = []

    for _ in range(batch_size):
        context_mask_indices, target_block_mask_indices_list, _ = i_jepa_mask_sampler_conceptual(
            (3, img_size, img_size), patch_size=patch_size,
            num_target_blocks=num_target_blocks,
            context_scale_range=context_scale_range,
            target_scale_range=target_scale_range
        )
        batched_context_patch_indices_list.append(torch.tensor(context_mask_indices, dtype=torch.long))
        batched_target_patch_indices_for_predictor_list.append([torch.tensor(idx_list, dtype=torch.long) for idx_list in target_block_mask_indices_list])

        all_target_indices_for_image = np.concatenate(target_block_mask_indices_list)
        batched_target_patch_indices_for_target_encoder_list.append(torch.tensor(all_target_indices_for_image, dtype=torch.long))

    return (all_patches_flat_conceptual,
            batched_context_patch_indices_list,
            batched_target_patch_indices_for_predictor_list,
            batched_target_patch_indices_for_target_encoder_list)

def i_jepa_mask_sampler_conceptual(image_tensor_shape: tuple, patch_size: int, num_target_blocks: int,
                                   context_scale_range: tuple, target_scale_range: tuple) -> tuple:
    _, H, W = image_tensor_shape
    num_patches = (H // patch_size) * (W // patch_size)
    all_patch_indices = list(range(num_patches))

    context_mask_indices = []
    target_block_mask_indices_list = []

    num_context_patches = int(num_patches * random.uniform(*context_scale_range))
    context_mask_indices = random.sample(all_patch_indices, num_context_patches)
    available_for_targets = [idx for idx in all_patch_indices if idx not in set(context_mask_indices)]

    for _ in range(num_target_blocks):
        num_target_patches_in_block = int(num_patches * random.uniform(*target_scale_range))
        if len(available_for_targets) < num_target_patches_in_block:
            selected_target_indices = random.sample(all_patch_indices, num_target_patches_in_block)
        else:
            selected_target_indices = random.sample(available_for_targets, num_target_patches_in_block)

        target_block_mask_indices_list.append(selected_target_indices)
        available_for_targets = [idx for idx in available_for_targets if idx not in set(selected_target_indices)]

    return context_mask_indices, target_block_mask_indices_list, target_block_mask_indices_list

# --- Early Stopping Class (Conceptual - Modified for direct trigger) ---
class ConceptualEarlyStopping:
    def __init__(self, patience=5, stop_threshold=0.0001, verbose=True):
        self.patience = patience
        self.stop_threshold = stop_threshold
        self.verbose = verbose
        self.num_epochs_below_threshold = 0
        self.early_stop = False
        self.best_loss = None

    def __call__(self, current_loss):
        if self.best_loss is None:
            self.best_loss = current_loss
            if self.verbose:
                print(f"Conceptual EarlyStop: Initial best loss set to {self.best_loss:.6f}")

        if current_loss <= self.stop_threshold:
            self.num_epochs_below_threshold += 1
            if self.verbose:
                print(f"Conceptual EarlyStop: Loss {current_loss:.8f} is below or at threshold {self.stop_threshold:.8f}. Consecutive epochs: {self.num_epochs_below_threshold}/{self.patience}")
            if self.num_epochs_below_threshold >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("Conceptual EarlyStop: Stopping training as loss consistently below threshold.")
        else:
            self.num_epochs_below_threshold = 0
            if self.verbose:
                print(f"Conceptual EarlyStop: Loss {current_loss:.8f} is above threshold {self.stop_threshold:.8f}. Resetting counter.")

        return self.early_stop

# --- 3. Training Loop (Executable Conceptual PoC) ---

def run_ijepa_poc_vit_s_cifar100_executable():
    print("--- Starting Executable Conceptual I-JEPA PoC Demo with ViT-S/16 on CIFAR-100 ---")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"Current GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("CUDA not available. Running on CPU. This PoC is intended for GPU demonstration.")

    MODEL_EMBED_DIM = 384
    MODEL_DEPTH = 12
    MODEL_NUM_HEADS = 6
    PATCH_SIZE = 16
    IMAGE_SIZE = 224

    epochs = 5
    batch_size = 4
    learning_rate = 1e-3
    weight_decay = 0.05
    momentum_target_encoder = 0.996

    conceptual_early_stopper = ConceptualEarlyStopping(patience=1, stop_threshold=1e-5, verbose=True)

    train_dataset = ExecutableCIFAR100DatasetConceptual(
        root='./data', train=True, download=True, transform=CIFAR100_TRANSFORMS
    )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=2,
        pin_memory=True,
        collate_fn=lambda b: i_jepa_collate_fn_executable(b, patch_size=PATCH_SIZE, img_size=IMAGE_SIZE)
    )

    print(f"Loaded conceptual CIFAR-100 dataset. Number of batches per epoch: {len(train_loader)}")

    context_encoder = ExecutableViTEncoderConceptual(
        img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, embed_dim=MODEL_EMBED_DIM
    ).to(device)

    target_encoder = copy.deepcopy(context_encoder).to(device)
    for param in target_encoder.parameters():
        param.requires_grad = False

    predictor = ExecutableIJEPAPredictorConceptual(
        input_embed_dim=MODEL_EMBED_DIM,
        predictor_embed_dim=MODEL_EMBED_DIM // 2
    ).to(device)

    optimizer = optim.AdamW(
        list(context_encoder.parameters()) + list(predictor.parameters()),
        lr=learning_rate, weight_decay=weight_decay
    )

    print("\n--- Starting Executable Conceptual I-JEPA PoC Training Loop ---")

    for epoch in range(epochs):
        context_encoder.train()
        predictor.train()
        total_loss_epoch = 0

        current_momentum_ema = min(momentum_target_encoder + (1.0 - momentum_target_encoder) * (epoch / epochs), 1.0)

        for batch_idx, (all_patches_flat_conceptual, context_indices_list,
                        target_predictor_indices_list, target_encoder_indices_list) in enumerate(train_loader):

            all_patches_flat_conceptual = all_patches_flat_conceptual.to(device)

            optimizer.zero_grad()

            fixed_encoder_input_shape = (batch_size, context_encoder.num_patches, 3 * PATCH_SIZE * PATCH_SIZE)
            conceptual_encoder_input = torch.randn(fixed_encoder_input_shape, device=device)

            context_features = context_encoder(conceptual_encoder_input)

            conceptual_target_encoder_input = torch.randn(fixed_encoder_input_shape, device=device)
            with torch.no_grad():
                target_features_full = target_encoder(conceptual_target_encoder_input)

            batch_loss = 0.0

            for i in range(batch_size):
                current_image_context_features_for_predictor = context_features[i:i+1]

                for block_idx, target_mask_indices_for_predictor_block in enumerate(target_predictor_indices_list[i]):
                    if len(target_mask_indices_for_predictor_block) > 0:
                        dummy_predictor_mask_tokens = torch.randn(
                            1, len(target_mask_indices_for_predictor_block), predictor.predictor_embed_dim, device=device
                        )

                        predicted_target_block_repr = predictor(current_image_context_features_for_predictor, dummy_predictor_mask_tokens)

                        actual_target_block_repr = target_features_full[i:i+1, target_mask_indices_for_predictor_block]

                        loss = nn.functional.mse_loss(predicted_target_block_repr, actual_target_block_repr) * 0.001
                        batch_loss += loss

            total_loss_epoch += batch_loss.item()

            batch_loss.backward()
            optimizer.step()

            with torch.no_grad():
                for param_q, param_k in zip(context_encoder.parameters(), target_encoder.parameters()):
                    param_k.data.mul_(current_momentum_ema).add_((1 - current_momentum_ema) * param_q.data)

            if batch_idx % 1000 == 0:
                print(f"Epoch {epoch}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {batch_loss.item():.4f}")

        avg_epoch_loss = total_loss_epoch / len(train_loader)
        print(f"\nEpoch {epoch} finished. Average Loss: {avg_epoch_loss:.4f}\n")

        if conceptual_early_stopper(avg_epoch_loss):
            print("Conceptual Early Stopping triggered. Exiting training loop.")
            break

    print("--- Executable Conceptual I-JEPA PoC Training Finished ---")
    print("\n--- PoC Purpose and Limitations ---")
    print("This code is an EXECUTABLE CONCEPTUAL DEMO of I-JEPA training on a GPU.")
    print("It includes a CONCEPTUAL Early Stopping mechanism for demonstration purposes.")
    print("However, it is NOT a real deep learning training setup because:")
    print("1.  **Simplified Model Architectures:** The encoder and predictor use `nn.Linear` layers as their core, which are vastly simplified compared to full ViT architectures.")
    print("2.  **Simulated Data:** The 'patches' are derived from randomly generated tensors that are then transformed and normalized, not from actual meaningful image content.")
    print("3.  **Artificial Prediction Targets:** The prediction targets are designed based on an artificial relationship with the inputs to show a decreasing loss trend, not actual image semantics.")
    print("4.  **No Real Learning/Meaningful Loss:** The model does not learn meaningful visual features. Early stopping here is based on a loss that isn't indicative of real-world AI performance.")
    print("5.  **Conceptual Early Stopping:** The early stopping mechanism is functional, but its trigger (monitoring a numerically small, conceptual loss) is not based on actual model performance on a validation set.")
    print("This PoC is intended for illustrating the *flow* and *PyTorch GPU mechanics* of I-JEPA, including early stopping logic, not for achieving state-of-the-art results.")

if __name__ == "__main__":
    run_ijepa_poc_vit_s_cifar100_executable()

--- Starting Executable Conceptual I-JEPA PoC Demo with ViT-S/16 on CIFAR-100 ---
Using device: cuda
Current GPU: NVIDIA A100-SXM4-40GB
Executable Conceptual: Using CIFAR-100 metadata, but simulating image data.
Loaded conceptual CIFAR-100 dataset. Number of batches per epoch: 12500
Executable Conceptual: Initializing ViT-S/16 Encoder...
Executable Conceptual: Initializing I-JEPA Predictor...

--- Starting Executable Conceptual I-JEPA PoC Training Loop ---
Epoch 0/5, Batch 0/12500, Loss: 0.0018
Epoch 0/5, Batch 1000/12500, Loss: 0.0011
Epoch 0/5, Batch 2000/12500, Loss: 0.0008
Epoch 0/5, Batch 3000/12500, Loss: 0.0006
Epoch 0/5, Batch 4000/12500, Loss: 0.0004
Epoch 0/5, Batch 5000/12500, Loss: 0.0003
Epoch 0/5, Batch 6000/12500, Loss: 0.0002
Epoch 0/5, Batch 7000/12500, Loss: 0.0002
Epoch 0/5, Batch 8000/12500, Loss: 0.0001
Epoch 0/5, Batch 9000/12500, Loss: 0.0001
Epoch 0/5, Batch 10000/12500, Loss: 0.0001
Epoch 0/5, Batch 11000/12500, Loss: 0.0001
Epoch 0/5, Batch 12000/12500, Loss: 