In [None]:
# src.architecture.adViT.critic

import torch
import torch.nn as nn


class AdversarialVisionTransformer(nn.Module):

    def __init__(
            self,
            vit_encoder: nn.Module,
            z_dim: int | None = None,  # proposal dimension
            c_dim: int | None = None,  # context dimension
            hidden_dim: int = 256
    ):
        super().__init__()
        self.vit = vit_encoder
        self.z_dim = z_dim
        self.c_dim = c_dim

        embed_dim = self.vit.c_token.size(-1)

        # Total feature dimension (one long vector)
        in_dim = embed_dim
        if z_dim is not None:
            in_dim += z_dim
        if c_dim is not None:
            in_dim += c_dim

        self.mlp = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(
            self,
            I_in: torch.Tensor,        # (B, C_in, H, W)
            O_pred: torch.Tensor,      # (B, C_out, H, W) or (B, T, C_out, H, W)
            mask_in: torch.Tensor | None = None,
            mask_out: torch.Tensor | None = None,
            z: torch.Tensor | None = None,
            C: torch.Tensor | None = None
    ) -> torch.Tensor:
        ###############################
        #   B = batch size            #    
        #   P = num proposals         #
        #   C_in = input channels     #
        #   C_out = output channels   #
        #   H = height                #
        #   W = width                 #
        ###############################

        B, C_in, H, W = I_in.shape

        #################################
        #   MULTI-PROPOSAL BRANCH       #
        #################################

        if O_pred.dim() == 5:
            B, T, C_out, H, W = O_pred.shape

            # Expand inputs
            I_exp = I_in.unsqueeze(1).expand(B, T, C_in, H, W)
            I_flat = I_exp.reshape(B*T, C_in, H, W)
            O_flat = O_pred.reshape(B*T, C_out, H, W)

            # Convert O_flat to 1 channel if needed
            if O_flat.size(1) > 1:
                O_flat = torch.argmax(O_flat, dim=1, keepdim=True).float()

            # Combine masks
            if mask_in is not None or mask_out is not None:
                if mask_in is None:
                    mask_in = torch.zeros(B, H, W, dtype=torch.bool, device=O_pred.device)
                if mask_out is None:
                    mask_out = torch.zeros(B, H, W, dtype=torch.bool, device=O_pred.device)
                mask = torch.logical_or(mask_in, mask_out)
                mask = mask.unsqueeze(1).expand(B, T, H, W).reshape(B*T, H, W)
            else:
                mask = None

            # Concatenate input + output
            x = torch.cat([I_flat, O_flat], dim=1)  # (B*T, 2, H, W)

            # Encode with ViT
            h_flat = self.vit.forward_grid(x, mask=mask)
            h = h_flat.view(B, T, -1)

            # Collect features
            feats = [h]

            if z is not None:
                if z.dim() == 2:
                    z = z.unsqueeze(1).expand(B, T, -1)
                feats.append(z)

            if C is not None:
                feats.append(C.unsqueeze(1).expand(B, T, -1))

            feat = torch.cat(feats, dim=-1)
            return self.mlp(feat).squeeze(-1)

        #################################
        #   SINGLE-PROPOSAL BRANCH      #
        #################################

        B, C_out, H, W = O_pred.shape

        # Convert O_pred to 1 channel if needed
        if O_pred.size(1) > 1:          # multi-channel class logits
            classes = torch.arange(O_pred.size(1), device=O_pred.device).view(1, -1, 1, 1)
            probs = O_pred.softmax(dim=1)
            O_pred = (probs * classes).sum(dim=1, keepdim=True)   # differentiable


        # Combine masks
        if mask_in is not None or mask_out is not None:
            if mask_in is None:
                mask_in = torch.zeros(B, H, W, dtype=torch.bool, device=O_pred.device)
            if mask_out is None:
                mask_out = torch.zeros(B, H, W, dtype=torch.bool, device=O_pred.device)
            mask = torch.logical_or(mask_in, mask_out)
        else:
            mask = None

        # Concatenate input + output
        x = torch.cat([I_in, O_pred], dim=1)  # (B, 2, H, W)

        # Encode with ViT
        h = self.vit.forward_grid(x, mask=mask)

        # Combine features for MLP
        feats = [h]
        if z is not None:
            feats.append(z)
        if C is not None:
            feats.append(C)
        feat = torch.cat(feats, dim=-1)

        return self.mlp(feat).squeeze(-1)


In [None]:
# src.architecture.context_encoding.example_pair_encoder

import torch
import torch.nn as nn


class ExamplePairEncoder(nn.Module):
    """
    Encods single example pair (I_i, O_i) into vector h_i
    """

    def __init__(
            self, 
            vit_pair: nn.Module
    ):
        super().__init__()
        self.vit = vit_pair # generic
        self.norm = nn.LayerNorm(self.vit.c_token.size(-1)) # normalize h_i

    def forward(
            self,
            I_i: torch.Tensor,
            O_i: torch.Tensor,
            mask_I: torch.Tensor,
            mask_O: torch.Tensor
    ) -> torch.Tensor:
        B, _, H, W = I_i.shape
        
        # Concatenate input-output as different channels
        x = torch.cat([I_i, O_i], dim=1)

        # Combine masks
        mask = torch.logical_or(mask_I, mask_O)
        key_padded_mask = ~mask

        # Pass through ViT for context embedding
        h_i = self.vit.forward_grid(x, mask=key_padded_mask)  # (B, embed_dim)

        # Normalize
        h_i = self.norm(h_i)

        return h_i


In [None]:
# src.architecture.context_encoding.example_pair_aggregator

import torch
import torch.nn as nn
import torch.nn.functional as F

class ExamplePairAggregator(nn.Module):
    """
    Aggregates the context vectors from k example pairs (I_i, O_i) into a single embedding
    """

    def __init__(
            self, 
            embed_dim: int,
            hidden_dim: int = 256
    ):
        super().__init__()

        # MLP to weigh context vectors
        self.score_mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.Tanh(), # [-1,1]
            nn.Linear(hidden_dim, 1)
        )

        # Normalize context vector
        self.norm = nn.LayerNorm(embed_dim)

    def forward(
            self,
            h: torch.Tensor,  # all embeddings 
            mask: torch.Tensor | None = None  # where the tokens are valid
    ) -> torch.Tensor:
        ###############################
        #   B = batch size            #
        #   K = num example pairs     #
        #   D = embedding dimension   #
        ###############################

        B, K, D = h.shape

        # Infer scores
        scores = self.score_mlp(h)

        # Mask = where the tokens are
        if mask is not None:

            # Ensure boolean
            mask = mask.to(dtype=torch.bool)

            # Match to score shape
            mask_expanded = mask.unsqueeze(-1) # (B, K, 1)

            # Set where the mask is not to very negative
            scores = scores.masked_fill(~mask_expanded, float("-inf"))

        # Attention weights over k example pairs
        attn = F.softmax(scores, dim=1)

        # Weighted sum
        C = torch.sum(attn * h, dim=1)

        # Final normalization
        C = self.norm(C)

        return C


In [None]:
# src.architecture.context_encoding.conditional_encoder.py

import torch
import torch.nn as nn

class ConditionalTestInputEncoder(nn.Module):

    def __init__(
            self, 
            vit_test: nn.Module
    ):
        ###############################
        #   B = batch size            #    
        #   D = token embedding dim   #
        #   S = num tokens            #
        #   H = height                #
        #   W = width                 #
        ###############################

        super().__init__()
        self.vit = vit_test
        self.embed_dim = self.vit.c_token.size(-1)

        # Project context vector to embedding dim
        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(
            self, 
            I_test,  # (B, 1, H, W)
            mask_test,  # (B, H, W) or None
            C  # (B, D)
    ):

        B, _, H, W = I_test.shape

        ######################
        #   Encode with ViT  #
        ######################

        tokens = self.vit.patch_embedding(I_test)  # (B, S, D)

        #######################
        #   Build Test Mask   #
        #######################

        if mask_test is not None:
            flat_mask = mask_test.reshape(B, -1)  # (B, S)
            key_padding_mask = ~flat_mask         # True = pad
        else:
            key_padding_mask = None
        
        key_padding_mask = key_padding_mask.to(torch.bool)

        ##########################
        #   Add Context Vector   #
        ##########################

        C_token = self.c_proj(C).unsqueeze(1)  # (B,1,D)
        tokens = torch.cat([C_token, tokens], dim=1)  # (B,1+S,D)

        if key_padding_mask is not None:
            # Add context to mask
            c_pad = torch.zeros(B, 1, dtype=torch.bool, device=key_padding_mask.device)
            key_padding_mask = torch.cat([c_pad, key_padding_mask], dim=1)  # (B,1+S)

        #####################################
        #   Positional Encoding + Dropout   #
        #####################################

        tokens = self.vit.pos_encoding(tokens)
        tokens = self.vit.dropout(tokens)

        return tokens, key_padding_mask

In [None]:
# src.architecture.executor.attention.py

import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

    def forward(self, x, key_padding_mask=None):

        out, _ = self.attn(
            x, x, x,
            key_padding_mask=key_padding_mask
        )

        return out

In [None]:
# src.architecture.executor.CNNBlock

import torch.nn as nn
from src.architecture.executor.FiLM import FiLM

class CNNBlock(nn.Module):
    def __init__(
            self,
            channels: int,
            z_dim: int | None = None
    ):
        super().__init__()
        self.conv = nn.Conv2d(
            channels,  # in
            channels,  # out
            kernel_size=3, 
            padding=1
        )
        self.norm = nn.GroupNorm(channels, channels)  # normalize each channel with respect to itself
        self.activation = nn.GELU()

        self.film = FiLM(channels, z_dim) if z_dim is not None else None

    def forward(
            self,
            x,
            z=None
    ):
        # Feature extraction
        x = self.norm(self.conv(x))

        # Proposed feature modulation
        if self.film is not None:
            x = self.film(x, z)

        return self.activation(x)

In [None]:
# src.architecture.executor.executor

import torch
import torch.nn as nn
from src.architecture.executor.CNNBlock import CNNBlock
from src.architecture.ViT.body import TransformerEncoderBlock


# Hybrid ViT and CNN
class Executor(nn.Module):
    """
    Applies a latent transformation z to an input grid
    """

    def __init__(
            self,
            embed_dim,
            num_heads,
            mlp_dim,
            depth,
            z_dim,
            hidden_channels=64,
            num_classes=10  # ARC colors
    ):
        super().__init__()

        ############################
        #   CNN Feature Enricher   #
        ############################

        self.enricher = nn.Sequential(
            nn.Conv2d(1, hidden_channels, 3, padding=1),
            nn.GELU()
        )

        ######################################
        #   CNN Proposal Feature Detection   #
        ######################################

        self.cnn_blocks = nn.ModuleList([
            CNNBlock(hidden_channels, z_dim=z_dim)
            for _ in range(2)
        ])

        ##################
        #   Tokenizers   #
        ##################

        self.to_embedding = nn.Linear(hidden_channels, embed_dim)

        # Interpret proposal 
        self.z_token = nn.Linear(z_dim, embed_dim)

        ##################
        #   ViT Layers   #
        ##################

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_dim)
            for _ in range(depth)
        ])

        #######################
        #   CNN Discretizer   #
        #######################

        self.discretizer = nn.Sequential(
            nn.Conv2d(  # detect features in token
                embed_dim, 
                hidden_channels, 
                kernel_size=3, 
                padding=1
            ),
            nn.GELU(),
            nn.Conv2d(  # convert token features to a classification
                hidden_channels, 
                num_classes, 
                kernel_size=1
            )
        )

    def forward(
            self, 
            grid, 
            z
    ):
        ###########################
        #   grid = (B, 1, H, W)   #
        #   z = (B, z_dim)        #
        ###############################
        #   B = batch size            #    
        #   D = embedding dimension   #
        #   H = height                #
        #   W = width                 #
        ###############################

        B, _, H, W = grid.shape

        #############################################
        #   Enricher + Proposed Feature Modulator   #
        #############################################

        x = self.enricher(grid)

        for block in self.cnn_blocks:
            x = block(x, z)

        ################
        #   Tokenize   #
        ################

        # (B, C, H, W) -> (B, H*W, C)
        x_flat = x.permute(0, 2, 3, 1).reshape(B, H*W, -1)

        tokens = self.to_embedding(x_flat)  # (B, S, D)

        # Add proposal z token
        z_token = self.z_token(z).unsqueeze(1)  # (B, 1, D) one for each batch
        tokens = torch.cat([z_token, tokens], dim=1)  # (B, 1+S, D)

        ################################
        #   ViT for Global Reasoning   #
        ################################

        for block in self.blocks:
            tokens = block(tokens, None)

        ###################
        #   Un-tokenize   #
        ###################

        # Remove z token
        x_tokens = tokens[:, 1:, :]  # (B, S, D)

        # Reshape to (B, D, H, W)
        x_feats = x_tokens.reshape(B, H, W, -1).permute(0, 3, 1, 2)

        ##################
        #   Discretize   #
        ##################

        # Compute on the embedding dimension
        logits = self.discretizer(x_feats)  # (B, num_classes, H, W)

        return logits

In [None]:
# src.architecture.executor.FiLM

import torch
import torch.nn as nn

# Feature-wise modulation
class FiLM(nn.Module): 
    def __init__(
            self,
            feature_dim,
            z_dim
    ):
        super().__init__()
        self.to_gamma = nn.Linear(z_dim, feature_dim)  # scale factor
        self.to_beta = nn.Linear(z_dim, feature_dim)  # shift factor

    def forward(
            self,
            x: torch.Tensor,
            z: torch.Tensor
    ):
        ########################
        #   x = (B, C, H, W)   #
        #   z = (B, z_dim)     #
        ########################
        #   B = batch size     #       
        #   C = channels       #
        #   H = height         #
        #   W = width          #
        ########################

        ############################
        #   Compute Coefficients   #
        ############################

        # Expand across input
        gamma = self.to_gamma(z).unsqueeze(-1).unsqueeze(-1)  # (B, C, 1, 1)
        beta = self.to_beta(z).unsqueeze(-1).unsqueeze(-1)
        
        ###############################
        #   Apply Feture Modulation   #
        ###############################

        return x * (1 + gamma) + beta