In [3]:
import torch
from torchvision.models import vit_b_16, vit_l_16
from torchvision import datasets, transforms
import torch.nn as nn
import math
import torch.nn.functional as F
import copy
from einops import rearrange, repeat
from x_transformers import Encoder, Decoder

import numpy as np
import pytorch_lightning as pl
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    LearningRateMonitor,
    ModelSummary,
)
from pytorch_lightning.loggers import WandbLogger
from PIL import Image


In [4]:
class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=64):
        super().__init__()
        if isinstance(img_size, int):
            img_size = img_size, img_size
        if isinstance(patch_size, int):
            patch_size = patch_size, patch_size
        #calculate the number of patches
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])

        #convolutional layer to convert the image into patches
        self.conv = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )
        

    def forward(self, x):
        x = self.conv(x)
        #flatten the patches
        x = rearrange(x, 'b e h w -> b (h w) e')
        return x


# class Predictor(nn.Module):
#     def __init__(self, embed_dim, num_heads, depth):
#         super().__init__()
        
#         self.predictor = Decoder(dim = embed_dim, depth = depth, heads = num_heads)
#     def forward(self, context_encoding, target_masks):
#         x = torch.cat((context_encoding, target_masks), dim = 1)
#         x = self.predictor(x)
#         #return last len(target_masks) tokens
#         l = x.shape[1]
#         return x[:, l - target_masks.shape[1]:, :]


class Predictor(nn.Module):
    """
    Recurrent predictor network to predict future representations
    """
    def __init__(self, embed_dim=1024, action_dim=2):
        super(Predictor, self).__init__()
        self.rnn = nn.GRUCell(representation_dim + action_dim, representation_dim)
    
    def forward(self, prev_rep, action):
        # Concatenate previous representation and action
        input_combined = torch.cat([prev_rep, action], dim=1)
        return self.rnn(input_combined, prev_rep)


In [3]:
class JEPA(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, pred_depth, num_heads, post_emb_norm=False, M = 4, mode="train", layer_dropout=0.):
        super().__init__()
        self.M = M
        self.mode = mode
        self.layer_dropout = layer_dropout

        #define the patch embedding and positional embedding
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        self.patch_dim  = (self.patch_embed.patch_shape[0], self.patch_embed.patch_shape[1])
        self.num_tokens = self.patch_embed.patch_shape[0] * self.patch_embed.patch_shape[1]
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim))

        #define the cls and mask tokens
        self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        nn.init.trunc_normal_(self.mask_token, 0.02)

        print(self.mask_token.size())

        #define the encoder and decoder, as well as the layer normalization and dropout
        self.post_emb_norm = nn.LayerNorm(embed_dim) if post_emb_norm else nn.Identity()
        self.norm = nn.LayerNorm(embed_dim)
        self.teacher_encoder = Encoder(
            dim=embed_dim,
            heads=num_heads,
            depth=enc_depth, 
            layer_dropout=self.layer_dropout,
        )  
        # self.student_encoder = copy.deepcopy(self.teacher_encoder).cuda()
        self.student_encoder = copy.deepcopy(self.teacher_encoder)
        self.predictor = Predictor(embed_dim)

    @torch.no_grad() 
    def get_target_block(self, target_encoder, x, patch_dim, aspect_ratio, scale, M):  
        #get the target block
        target_encoder = target_encoder.eval()
        x = target_encoder(x)
        x = self.norm(x)
        #get the patch dimensions
        patch_h, patch_w = patch_dim
        #get the number of patches
        num_patches = patch_h * patch_w
        #get the number of patches in the target block
        num_patches_block = int(patch_h * patch_w * scale)
        #get the height and width of the target block with aspect ratio
        block_h = int(torch.sqrt(torch.tensor(num_patches_block / aspect_ratio)))
        block_w = int(aspect_ratio * block_h)
        #get the patches in the target block
        target_block = torch.zeros((M, x.shape[0], block_h*block_w, x.shape[2]))
        target_patches = []
        all_patches = []
        for z in range(M):
            #get the starting patch
            start_patch_h = torch.randint(0, patch_h - block_h+1, (1,)).item()
            start_patch_w = torch.randint(0, patch_w - block_w+1, (1,)).item()
            start_patch = start_patch_h * patch_w + start_patch_w

            patches = []
            #get the patches in the target block
            for i in range(block_h):
                for j in range(block_w):
                    patches.append(start_patch + i * patch_w + j)
                    if start_patch + i * patch_w + j not in all_patches:
                        all_patches.append(start_patch + i * patch_w + j)
                    
            #get the target block
            target_patches.append(patches)
            target_block[z] = x[:, patches, :]
        # return target_block.cuda(), target_patches, all_patches
        return target_block, target_patches, all_patches

    def get_context_block(self, x, patch_dim, aspect_ratio, scale, target_patches):
        patch_h, patch_w = patch_dim
        #get the number of patches in the target block
        num_patches_block = int(patch_h * patch_w * scale)
        #get the height and width of the target block with aspect ratio
        block_h = int(torch.sqrt(torch.tensor(num_patches_block / aspect_ratio)))
        block_w = int(aspect_ratio * block_h)
        #get the starting patch
        start_patch_h = torch.randint(0, patch_h - block_h+1, (1,)).item()
        start_patch_w = torch.randint(0, patch_w - block_w+1, (1,)).item()
        start_patch = start_patch_h * patch_w + start_patch_w
        #get the patches in the context_block
        patches = []
        for i in range(block_h):
            for j in range(block_w):
                if start_patch + i * patch_w + j not in target_patches: #remove the target patches
                    patches.append(start_patch + i * patch_w + j)
        return x[:, patches, :]
    

    def forward(self, x, y, target_aspect_ratio=1, target_scale=1, context_aspect_ratio=1, context_scale=1):
        #get the patch embeddings
        x = self.patch_embed(x)
        y = self.patch_embed(y)
        b, n, e = x.shape
        #add the positional embeddings
        x = x + self.pos_embedding
        y = y + self.pos_embedding
        #normalize the embeddings
        x = self.post_emb_norm(x)
        y = self.post_emb_norm(y)
        #if mode is test, we get return full embedding:
        if self.mode == 'test':
            return self.student_encoder(x)
        # #get target embeddings
        target_blocks, target_patches, all_patches = self.get_target_block(self.teacher_encoder, y, self.patch_dim, target_aspect_ratio, target_scale, self.M)
        m, b, n, e = target_blocks.shape
        #get context embedding

        context_block = self.get_context_block(x, self.patch_dim, context_aspect_ratio, context_scale, all_patches)
        context_encoding = self.student_encoder(context_block)
        context_encoding = self.norm(context_encoding)


        # prediction_blocks = torch.zeros((m, b, n, e)).cuda()
        prediction_blocks = torch.zeros((m, b, n, e))
        #get the prediction blocks, predict each target block separately
        for i in range(m):
            target_masks = self.mask_token.repeat(b, n, 1)
            target_pos_embedding = self.pos_embedding[:, target_patches[i], :]
            target_masks = target_masks + target_pos_embedding
            prediction_blocks[i] = self.predictor(context_encoding, target_masks)

        return prediction_blocks, target_blocks




In [4]:
class IJEPA(pl.LightningModule):
    def __init__(
            self,
            img_size=224,
            patch_size=16,
            in_chans=3, 
            embed_dim=64,
            enc_heads=8,
            enc_depth=8,
            decoder_depth=6,
            lr=1e-6,
            weight_decay=0.05,
            target_aspect_ratio = (0.75,1.5),
            target_scale = (0.15, .2),
            context_aspect_ratio = 1,
            context_scale = (0.85,1.0),
            M = 4, #number of different target blocks
            m=0.996, #momentum
            m_start_end = (.996, 1.)

    ):
        super().__init__()
        self.save_hyperparameters()
        
        #define models
        self.model = JEPA(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 
                                enc_depth = enc_depth, num_heads=enc_heads, pred_depth=decoder_depth, M=M)

        #define hyperparameters
        self.M = M
        self.lr = lr
        self.weight_decay = weight_decay
        self.m = m
        self.target_aspect_ratio = target_aspect_ratio
        self.target_scale = target_scale
        self.context_aspect_ratio = context_aspect_ratio
        self.context_scale = context_scale
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.num_tokens = (img_size // patch_size) ** 2
        self.m_start_end = m_start_end

        #define loss
        self.criterion = nn.MSELoss()
    
    def forward(self, x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale):
        return self.model(x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale)
    
    '''Update momentum for teacher encoder'''
    def update_momentum(self, m):
        student_model = self.model.student_encoder.eval()
        teacher_model = self.model.teacher_encoder.eval()
        with torch.no_grad():
            for student_param, teacher_param in zip(student_model.parameters(), teacher_model.parameters()):
                teacher_param.data.mul_(other=m).add_(other=student_param.data, alpha=1 - m)


    def training_step(self, batch, batch_idx):
        x = batch
        #generate random target and context aspect ratio and scale
        target_aspect_ratio = np.random.uniform(self.target_aspect_ratio[0], self.target_aspect_ratio[1])
        target_scale = np.random.uniform(self.target_scale[0], self.target_scale[1])
        context_aspect_ratio = self.context_aspect_ratio
        context_scale = np.random.uniform(self.context_scale[0], self.context_scale[1])

        y_student, y_teacher = self(x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale)
        loss = self.criterion(y_student, y_teacher)
        self.log('train_loss', loss)
                    
        return loss
    
    def validation_step(self, batch, batch_idx):
        x = batch
        target_aspect_ratio = np.random.uniform(self.target_aspect_ratio[0], self.target_aspect_ratio[1])
        target_scale = np.random.uniform(self.target_scale[0], self.target_scale[1])
        context_aspect_ratio = self.context_aspect_ratio
        context_scale = np.random.uniform(self.context_scale[0], self.context_scale[1])

        y_student, y_teacher = self(x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale)
        loss = self.criterion(y_student, y_teacher)
        self.log('val_loss', loss)
        
        return loss
    
    def predict_step(self, batch, batch_idx, dataloader_idx):
        target_aspect_ratio = np.random.uniform(self.target_aspect_ratio[0], self.target_aspect_ratio[1])
        target_scale = np.random.uniform(self.target_scale[0], self.target_scale[1])
        context_aspect_ratio = self.context_aspect_ratio
        context_scale = 1
        self.model.mode = "test"

        return self(batch, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale) #just get teacher embedding

    def on_after_backward(self):
        self.update_momentum(self.m)
        self.m += (self.m_start_end[1] - self.m_start_end[0]) / self.trainer.estimated_stepping_batches


    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.lr,
            total_steps=self.trainer.estimated_stepping_batches,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }

In [5]:
from typing import NamedTuple, Optional
import torch
import numpy as np


class WallSample(NamedTuple):
    states: torch.Tensor
    locations: torch.Tensor
    actions: torch.Tensor


class WallDataset:
    def __init__(
        self,
        data_path,
        probing=False,
        device="cuda",
    ):
        self.device = device
        self.states = np.load(f"{data_path}/states.npy", mmap_mode="r")
        self.actions = np.load(f"{data_path}/actions.npy")

        if probing:
            self.locations = np.load(f"{data_path}/locations.npy")
        else:
            self.locations = None

    def __len__(self):
        return len(self.states)

    def __getitem__(self, i):
        states = torch.from_numpy(self.states[i]).float().to(self.device)
        actions = torch.from_numpy(self.actions[i]).float().to(self.device)

        if self.locations is not None:
            locations = torch.from_numpy(self.locations[i]).float().to(self.device)
        else:
            locations = torch.empty(0).to(self.device)

        return WallSample(states=states, locations=locations, actions=actions)


def create_wall_dataloader(
    data_path,
    probing=False,
    device="cuda",
    batch_size=64,
    train=True,
):
    ds = WallDataset(
        data_path=data_path,
        probing=probing,
        device=device,
    )

    loader = torch.utils.data.DataLoader(
        ds,
        batch_size,
        shuffle=train,
        drop_last=True,
        pin_memory=False,
    )

    return loader

In [4]:
from typing import NamedTuple, Optional
import torch
import numpy as np


class WallSample(NamedTuple):
    states: torch.Tensor
    locations: torch.Tensor
    actions: torch.Tensor


class WallDataset:
    def __init__(
        self,
        data_path,
        probing=False,
        device="cuda",
    ):
        self.device = device
        self.states = np.load(f"{data_path}/states.npy", mmap_mode="r")
        self.actions = np.load(f"{data_path}/actions.npy")

        if probing:
            self.locations = np.load(f"{data_path}/locations.npy")
        else:
            self.locations = None

    def __len__(self):
        return len(self.states)

    def __getitem__(self, i):
        states = torch.from_numpy(self.states[i]).float().to(self.device)
        actions = torch.from_numpy(self.actions[i]).float().to(self.device)

        if self.locations is not None:
            locations = torch.from_numpy(self.locations[i]).float().to(self.device)
        else:
            locations = torch.empty(0).to(self.device)

        return WallSample(states=states, locations=locations, actions=actions)


def create_wall_dataloader(
    data_path,
    probing=False,
    device="cuda",
    batch_size=64,
    train=True,
):
    ds = WallDataset(
        data_path=data_path,
        probing=probing,
        device=device,
    )

    loader = torch.utils.data.DataLoader(
        ds,
        batch_size,
        shuffle=train,
        drop_last=True,
        pin_memory=False,
    )

    return loader

In [5]:
dataloader = create_wall_dataloader('./DL24FA/train', device='cpu', batch_size=1)

In [6]:
for x in dataloader:
    print(x.states.shape)
    print(x.actions.shape)
    print(x.locations.shape)
    print()
    print()
    
    t1 = transforms.ToPILImage()
    for y in range(16):
        print(x.actions[0][y])
    
    break

torch.Size([1, 17, 2, 65, 65])
torch.Size([1, 16, 2])
torch.Size([1, 0])


tensor([ 0.8981, -0.1178])
tensor([ 0.6551, -0.4548])
tensor([ 0.9138, -0.0485])
tensor([0.5497, 0.3975])
tensor([0.6934, 0.4740])
tensor([0.7901, 0.2268])
tensor([ 1.3507, -0.4438])
tensor([ 1.1281, -0.1383])
tensor([1.0353, 0.4285])
tensor([ 1.0585, -0.4855])
tensor([1.0278, 0.3916])
tensor([ 0.8944, -0.4860])
tensor([0.9172, 0.4508])
tensor([0.5561, 0.2728])
tensor([ 0.7411, -0.0362])
tensor([0.8644, 0.0572])


  states = torch.from_numpy(self.states[i]).float().to(self.device)


## New Implementation

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights

class VitEncoder(nn.Module):
    """
    Vision Transformer (ViT) based encoder for custom input dimensions
    """
    def __init__(self, representation_dim=512, input_channels=2):
        super(VitEncoder, self).__init__()
        # Custom initial convolutional layer to handle 2-channel input
        self.input_adapter = nn.Conv2d(input_channels, 3, kernel_size=3, padding=1)
        
        # Load pre-trained ViT model
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        
        # Remove the classification head
        self.vit.heads = nn.Identity()
        
        # Add a projection head to get desired representation dimension
        self.projection_head = nn.Sequential(
            nn.Linear(self.vit.hidden_dim, 512),
            nn.ReLU(),
            nn.Linear(512, representation_dim)
        )
        
        # Freeze base ViT weights
        for param in self.vit.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        # Adapt input to 3 channels
        x = self.input_adapter(x)
        
        # Ensure input is compatible with ViT
        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        
        # Project the input through the convolutional layer
        x = self.vit.conv_proj(x)  # Shape: (batch_size, hidden_dim, height, width)
        
        # Flatten spatial dimensions into a sequence
        batch_size, hidden_dim, height, width = x.shape
        x = x.flatten(2).transpose(1, 2)  # Shape: (batch_size, seq_length, hidden_dim)
        
        # Add positional embeddings
        x = x + self.vit.encoder.pos_embedding[:, :x.size(1), :]
        
        # Pass through the encoder
        features = self.vit.encoder.layers(self.vit.encoder.dropout(x))
        features = self.vit.encoder.ln(features)
        
        # Take the [CLS] token embedding (first token)
        cls_embedding = features[:, 0]
        
        # Project to desired representation dimension
        return self.projection_head(cls_embedding)




class ActionAwarePredictor(nn.Module):
    """
    Multi-layer predictor that considers both previous representation and action
    """
    def __init__(self, representation_dim=512, action_dim=2):
        super(ActionAwarePredictor, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(representation_dim + action_dim, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, representation_dim)
        )
    
    def forward(self, prev_rep, action):
        # Concatenate previous representation and action
        input_combined = torch.cat([prev_rep, action], dim=1)
        return self.network(input_combined)

class JEPAWorldModel(nn.Module):
    """
    Joint Embedding Predictive Architecture World Model with ViT
    """
    def __init__(self, representation_dim=512, action_dim=2):
        super(JEPAWorldModel, self).__init__()
        self.encoder = VitEncoder(representation_dim)
        self.predictor = ActionAwarePredictor(representation_dim, action_dim)
        
        # Use same encoder for target encoder (similar to VicReg)
        self.target_encoder = VitEncoder(representation_dim)
        
        # Synchronize target encoder with main encoder
        self.update_target_encoder()
    
    def update_target_encoder(self, tau=0.995):
        """
        Exponential Moving Average (EMA) update of target encoder
        """
        for param_q, param_k in zip(self.encoder.parameters(), self.target_encoder.parameters()):
            param_k.data = param_k.data * tau + param_q.data * (1. - tau)
    
    def forward(self, observations, actions):
        # Encode observations
        encoded_states = [self.encoder(observations[:, 0])]
        predicted_states = []
        
        # Predict future representations
        for t in range(1, observations.shape[1]):
            prev_state = encoded_states[-1]
            curr_action = actions[:, t-1]
            
            # Predict next state
            predicted_state = self.predictor(prev_state, curr_action)
            predicted_states.append(predicted_state)
            
            # Encode current observation with target encoder
            with torch.no_grad():
                curr_encoded_state = self.target_encoder(observations[:, t])
            encoded_states.append(curr_encoded_state)
        
        return predicted_states, encoded_states[1:]
    
    def compute_loss(self, predicted_states, target_states):
        """
        Multi-objective loss to prevent representation collapse
        """
        # 1. Prediction Loss: Minimize distance between predicted and target states
        pred_loss = F.mse_loss(torch.stack(predicted_states), torch.stack(target_states))
        
        # 2. Variance Loss: Encourage representations to have non-zero variance
        std_loss = self.variance_loss(predicted_states)
        
        # 3. Covariance Loss: Decorrelate representation dimensions
        cov_loss = self.covariance_loss(predicted_states)
        
        # Weighted combination of losses
        total_loss = pred_loss + 1e-4 * (std_loss + cov_loss)
        return total_loss
    
    def variance_loss(self, representations, min_std=0.1):
        """Encourage each feature to have non-zero variance"""
        repr_tensor = torch.stack(representations)
        std_loss = torch.max(
            torch.tensor(min_std), 
            torch.sqrt(repr_tensor.var(dim=0) + 1e-7)
        ).mean()
        return std_loss
    
    def covariance_loss(self, representations):
        """Decorrelate representation dimensions"""
        repr_tensor = torch.stack(representations)
        
        # Center the representations
        repr_tensor = repr_tensor - repr_tensor.mean(dim=0)
        
        # Flatten tensor (keep batch dimension intact)
        repr_tensor = repr_tensor.view(repr_tensor.shape[0], -1)
        
        # Compute covariance matrix
        cov_matrix = (repr_tensor.T @ repr_tensor) / (repr_tensor.shape[0] - 1)
        
        # Decorrelate dimensions (set diagonal to zero)
        cov_matrix.fill_diagonal_(0)
        
        # Compute loss
        cov_loss = (cov_matrix ** 2).sum()
        return cov_loss

class DataTransforms:
    """
    Image augmentations and preprocessing for JEPA training
    """
    @staticmethod
    def get_train_transforms():
        return transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ])

def train_jepa_model(model, dataloader, optimizer, device, epoch):
    """
    Training loop for JEPA world model
    """
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        batch_observations = batch.states.to(device)
        batch_actions = batch.actions.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        predicted_states, target_states = model(batch_observations, batch_actions)
        
        # Compute loss
        loss = model.compute_loss(predicted_states, target_states)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update target encoder (EMA)
        model.update_target_encoder()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
representation_dim = 1024
action_dim = 2

# Initialize model
jepa_model = JEPAWorldModel(
    representation_dim=representation_dim, 
    action_dim=action_dim
).to(device)

# Optimizer
optimizer = torch.optim.Adam(jepa_model.parameters(), lr=1e-4)

# TODO: Implement actual data loading from /scratch/DL24FA/train
# dataloader = ...
dataloader = create_wall_dataloader('./DL24FA/train', device='cpu', batch_size=8)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    avg_loss = train_jepa_model(jepa_model, dataloader, optimizer, device, epoch)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
    
    # Optional: Learning rate scheduling, model checkpointing
    # scheduler.step()

# Save model
torch.save(jepa_model.state_dict(), "jepa_vit_world_model.pth")

NameError: name 'JEPAWorldModel' is not defined