In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Fusion Encoder Architectures

In [36]:
class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

In [37]:
testlayer = nn.TransformerEncoderLayer(d_model=512, nhead=8, activation="gelu")
test = nn.TransformerEncoder(testlayer, num_layers=6).to("cuda")

In [38]:
print(test.layers[0].activation)
print(test)

<built-in function gelu>
TransformerEncoder(
  (layers): ModuleList(
    (0-5): 6 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (linear1): Linear(in_features=512, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=512, bias=True)
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
)


In [39]:
#Individual bottleneck layer
class BottleneckLayer(nn.Module):
    def __init__(self, num_latents, schem_enc, graph_enc, tab_enc): #add encoders if modified to include attn and mlp and such
        super(BottleneckLayer, self).__init__()

        # SCHEMATIC
        self.schem_encoder = schem_enc
        
        # GRAPH
        self.graph_encoder = graph_enc

        # TABULAR
        self.tab_encoder = tab_enc
        
        # Latents
        self.num_latents = num_latents
        self.latents = nn.Parameter(torch.empty(1,num_latents,512).normal_(std=0.02)) #512 to match dimensionality
        self.scale_s = nn.Parameter(torch.zeros(1))
        self.scale_g = nn.Parameter(torch.zeros(1))
        self.scale_t = nn.Parameter(torch.zeros(1))


    def attention(self,q,k,v): # requires q,k,v to have same dim. In future I want multi head self attention (MSA)
        B, N, C = q.shape
        attn = (q @ k.transpose(-2, -1)) * (C ** -0.5) # scaling
        attn = attn.softmax(dim=-1)
        x = (attn @ v).reshape(B, N, C)
        return x
    
    # Latent Fusion
    def fusion(self, schem_tokens, graph_tokens, tab_tokens):
        # shapes
        BS = schem_tokens.shape[0]
        # concat all the tokens
        concat_ = torch.cat((schem_tokens, graph_tokens, tab_tokens),dim=1)
        # cross attention (modalities -->> latents)
        fused_latents = self.attention(q=self.latents.expand(BS,-1,-1), k=concat_, v=concat_)
        # cross attention (latents -->> modalities)
        schem_tokens = schem_tokens + self.scale_s * self.attention(q=schem_tokens, k=fused_latents, v=fused_latents)
        graph_tokens = graph_tokens + self.scale_g * self.attention(q=graph_tokens, k=fused_latents, v=fused_latents)
        tab_tokens = tab_tokens + self.scale_t * self.attention(q=tab_tokens, k=fused_latents, v=fused_latents)
        return schem_tokens, graph_tokens, tab_tokens
    
    def forward(self, x, y, z):

        # Bottleneck Fusion
        x,y,z = self.fusion(x,y,z)

        x = schem_encoder(x)
        y = graph_encoder(y)
        z = tab_encoder(z)
        
        return x,y,z, self.latents.expand(BS,-1,-1)
        
#####################################################################################################################################################
#####################################################################################################################################################
#####################################################################################################################################################

### Unimodal Encoder Architectures

In [40]:
#Schematic (CNN)
class SchemEncoder(nn.Module):
    def __init__(self, fmax):
        super(SchemEncoder, self).__init__()
        
        # Define the CNN part for input 1 (224x224x3)
        self.cnn = nn.Sequential(
            nn.Conv2d(3, n_channel, kernel_size=3, stride=1, padding=1),  # (224, 224, 32)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # (112, 112, 32)
            nn.Conv2d(n_channel, n_channel*2, kernel_size=3, stride=1, padding=1),  # (112, 112, 128)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # (56, 56, 128)
            nn.Conv2d(n_channel*2, n_channel*4, kernel_size=3, stride=1, padding=1),  # (56, 56, 128)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # (28, 28, 128)
            nn.Conv2d(n_channel*4, n_channel*4, kernel_size=3, stride=1, padding=1),  # (28, 28, 128)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # (14, 14, 128)
            nn.Conv2d(n_channel*4, n_channel*4, kernel_size=3, stride=1, padding=1),  # (14, 14, 128)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # (7, 7, 128)
            nn.Flatten(),  # Flatten to (7 * 7 * 128)
            nn.Linear(7 * 7 * n_channel*4, 512),  # Fully connected layer
            nn.ReLU(),
        )
        
        # Define the output layers
        self.fc = nn.Sequential(
            nn.Linear(512, fmax * 12),  # Concatenate with input 2 (12 features)
            nn.Tanh(),
        )
        self.reshape = nn.Unflatten(1, (fmax, 12))  # Reshape to (fmax, 12)

    def forward(self, x1):
        # Process input 1 through CNN
        x = self.cnn(x1)
        # Concatenate with input 2
        #x = torch.cat((x1, x2), dim=1)
        # Fully connected layer and reshape
        x = self.fc(x)
        x = self.reshape(x)
        return x


In [41]:
#Graph (Graph Transformer)

def get_activation(activation_str):
    """Map string to activation function."""
    if activation_str.lower() == "relu":
        return F.relu
    elif activation_str.lower() == "gelu":
        return F.gelu
    # Add more activations if needed
    else:
        raise ValueError(f"Unsupported activation: {activation_str}")

class TransformerEncoderBlock(nn.Module):
    """
    One Transformer encoder block, matching the structure of:
      1) Multi-head self-attention (with dropout)
      2) Feed-forward network (two linear layers + activation + dropout)
      3) LayerNorm + residual connections
      4) Optional "pre-norm" (norm_first=True) vs "post-norm" (norm_first=False)
    """
    def __init__(
        self,
        hidden_size,
        num_attention_heads,
        intermediate_size,
        activation="relu", #check with houbo which activation is the best for GT block performance
        dropout_rate=0.0,
        attention_dropout_rate=0.0,
        use_bias=False,
        norm_first=True,
        norm_epsilon=1e-6,
        intermediate_dropout=0.0
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.dropout_rate = dropout_rate
        self.attention_dropout_rate = attention_dropout_rate
        self.use_bias = use_bias
        self.norm_first = norm_first
        self.norm_epsilon = norm_epsilon
        self.intermediate_dropout = intermediate_dropout

        # ---- Self-Attention ----
        # nn.MultiheadAttention expects shape: [seq_len, batch_size, embed_dim]
        # bias = `use_bias` is not directly exposed in nn.MultiheadAttention;
        # PyTorch always learns a bias in the projection layers. If you want
        # to remove bias, you must create a custom multi-head attention layer.
        self.self_attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_attention_heads,
            dropout=attention_dropout_rate,
            batch_first=False,  # We'll reshape manually
            # PyTorch multi-head attention includes biases by default.
            # For a strictly "no-bias" version, you'd need a custom approach.
        )

        self.attention_dropout = nn.Dropout(dropout_rate)
        self.attention_layer_norm = nn.LayerNorm(hidden_size, eps=norm_epsilon)

        # ---- Feed-Forward Network (FFN) ----
        self.intermediate_dense = nn.Linear(hidden_size, intermediate_size, bias=use_bias)
        self.intermediate_act_fn = get_activation(activation)
        self.intermediate_dropout_layer = nn.Dropout(intermediate_dropout)

        self.output_dense = nn.Linear(intermediate_size, hidden_size, bias=use_bias)
        self.output_dropout = nn.Dropout(dropout_rate)
        self.output_layer_norm = nn.LayerNorm(hidden_size, eps=norm_epsilon)

    def forward(self, hidden_states, attention_mask=None):
        """
        Args:
          hidden_states: Tensor of shape [batch_size, seq_len, hidden_size].
          attention_mask: Optional tensor for attention, expected shape
              [batch_size, seq_len, seq_len] with 0 for valid positions and
              -inf (or large negative) for masked positions, or a boolean mask.
              This may need to be adapted depending on how you've constructed
              your mask. 
        Returns:
          hidden_states: Tensor of shape [batch_size, seq_len, hidden_size].
        """

        # --- Self-Attention block ---
        # If norm_first, we layer-norm before attention; otherwise after
        residual = hidden_states
        if self.norm_first:
            hidden_states = self.attention_layer_norm(hidden_states)

        # Reshape hidden_states from [batch, seq, dim] to [seq, batch, dim]
        hidden_states_t = hidden_states.transpose(0, 1)

        # Convert mask if needed: PyTorch expects shape [seq_len, seq_len] or 
        # [batch_size * num_heads, seq_len, seq_len]. 
        # A simple approach is to expand so shape [batch, 1, seq, seq].
        # Then internally PyTorch may broadcast it properly, or you
        # can pass `attn_mask=some_mask` that is [seq, seq]. 
        # Here is an example that transforms the user’s [batch, seq, seq] 
        # into a float mask with -inf in invalid positions:
        if attention_mask is not None:
            # Suppose attention_mask=1 for valid, 0 for invalid, or the other way around.
            # You may need to invert it, depending on how your mask is built.
            # Here we assume "1 = keep, 0 = mask out".
            attn_mask_pytorch = (1.0 - attention_mask) * -1e9
            #print(attn_mask_pytorch.size())
            # Expand dims if needed to [batch, 1, seq, seq], then flatten
            # heads.  Alternatively, you can let PyTorch broadcast the shape.
            # We’ll do a direct approach below:
            #attn_mask_pytorch = attn_mask_pytorch.unsqueeze(1)  # [batch, 1, seq, seq]
            #print(attn_mask_pytorch.size())
        else:
            attn_mask_pytorch = None

        # Apply multi-head attention:
        attn_output, _ = self.self_attention(
            hidden_states_t,   # query
            hidden_states_t,   # key
            hidden_states_t,   # value
            attn_mask=attn_mask_pytorch,
        )

        # Transpose back to [batch, seq, dim]
        attn_output = attn_output.transpose(0, 1)

        attn_output = self.attention_dropout(attn_output)
        # Residual connection
        hidden_states = residual + attn_output

        if not self.norm_first:
            hidden_states = self.attention_layer_norm(hidden_states)

        # --- Feed Forward block ---
        residual = hidden_states
        if self.norm_first:
            hidden_states = self.output_layer_norm(hidden_states)

        # Intermediate (expand) + activation
        hidden_states = self.intermediate_dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        hidden_states = self.intermediate_dropout_layer(hidden_states)

        # Project back to hidden_size
        hidden_states = self.output_dense(hidden_states)
        hidden_states = self.output_dropout(hidden_states)

        # Residual connection
        hidden_states = residual + hidden_states

        if not self.norm_first:
            hidden_states = self.output_layer_norm(hidden_states)

        return hidden_states


class TransformerEncoder(nn.Module):
    """
    Stacks N TransformerEncoderBlock layers and applies a final layer norm
    (to match the original Keras code which has 'output_normalization').
    """
    def __init__(
        self,
        num_layers=6,
        num_attention_heads=8,
        intermediate_size=2048,
        activation="relu",
        dropout_rate=0.0,
        attention_dropout_rate=0.0,
        use_bias=False,
        norm_first=True,
        norm_epsilon=1e-6,
        intermediate_dropout=0.0,
        hidden_size=None,
    ):
        """
        Args:
          num_layers: Number of encoder layers.
          num_attention_heads: Number of attention heads.
          intermediate_size: Dim of the FFN's hidden layer.
          activation: Activation for the intermediate (FFN) layer.
          dropout_rate: Dropout probability for the output of each sub-layer.
          attention_dropout_rate: Dropout probability for the attention scores.
          use_bias: Whether linear layers use bias.
          norm_first: If True, apply layer norm before each sub-block.
          norm_epsilon: Epsilon for layer norm.
          intermediate_dropout: Dropout within the feed-forward 'intermediate' layers.
          hidden_size: The input/output hidden size. If None, derive from input.
        """
        super().__init__()
        self.num_layers = num_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.activation = activation
        self.dropout_rate = dropout_rate
        self.attention_dropout_rate = attention_dropout_rate
        self.use_bias = use_bias
        self.norm_first = norm_first
        self.norm_epsilon = norm_epsilon
        self.intermediate_dropout = intermediate_dropout

        # You can either require hidden_size to be passed explicitly,
        # or you can infer it at runtime (by passing the first batch through).
        if hidden_size is None:
            raise ValueError(
                "You must specify 'hidden_size' (the input feature dimension)."
            )

        self.encoder_layers = nn.ModuleList([
            TransformerEncoderBlock(
                hidden_size=hidden_size,
                num_attention_heads=self.num_attention_heads,
                intermediate_size=self.intermediate_size,
                activation=self.activation,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                use_bias=self.use_bias,
                norm_first=self.norm_first,
                norm_epsilon=self.norm_epsilon,
                intermediate_dropout=self.intermediate_dropout,
            ) for _ in range(self.num_layers)
        ])

        self.output_normalization = nn.LayerNorm(hidden_size, eps=self.norm_epsilon)

    def forward(self, encoder_inputs, attention_mask=None):
        """
        Args:
          encoder_inputs: shape [batch_size, seq_len, hidden_size].
          attention_mask: shape [batch_size, seq_len, seq_len] or None.
        Returns:
          output shape [batch_size, seq_len, hidden_size].
        """
        hidden_states = encoder_inputs

        # Pass through each TransformerEncoderBlock
        for i, layer in enumerate(self.encoder_layers):
            #print(attention_mask.size())
            hidden_states = layer(hidden_states, attention_mask=attention_mask)

        # Final layer normalization (as in Keras code)
        output_tensor = self.output_normalization(hidden_states)
        return output_tensor


#################################
# Main Model
#################################
class GraphEncoder(nn.Module):
    def __init__(self, transformer_encoder, fmax):
        super(GraphEncoder, self).__init__()
        self.transformer_encoder = transformer_encoder  # Use the pre-defined transformer model
        self.fcl1 = nn.Linear(7, fdim)
        self.fcl2 = nn.Linear(7, fdim)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(max_len*fdim*2, 512)  # Adjust input shape after concatenation 448*512
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512, 512)
        self.fc5 = nn.Linear(512, 512)
        self.out = nn.Linear(512, fmax * 12)
        self.reshape = lambda x: x.view(-1, fmax, 12)  # Equivalent to `Reshape((fmax,12))` in Keras
    def forward(self, inp1, inp2):
        l1 = self.fcl1(inp1)
        l2 = self.fcl2(inp2)
        #print(l1.size())
        #print(create_padding_mask(l1).size())
        l1 = self.transformer_encoder(l1,create_padding_mask(inp1))
        l2 = self.transformer_encoder(l2,create_padding_mask(inp2))

        out = torch.cat((l1, l2), dim=1)  # Equivalent to `Concatenate()([l1, l2])`
        out = self.flatten(out) #shape [8 2048]
     #   print("1:", out.shape)  
        out = torch.relu(self.fc1(out))
        out = torch.relu(self.fc2(out))
        out = torch.relu(self.fc3(out))
        out = torch.relu(self.fc4(out))
        out = torch.relu(self.fc5(out))
      #  print("2:", out.shape)  
        out = torch.tanh(self.out(out))  # Equivalent to `Dense(fmax*12, activation='tanh')`
        out = self.reshape(out)
       # print("3:", out.shape)  
        return out


In [42]:
#Tabular (MLP)
# PyTorch model equivalent to Keras Sequential model
class TabEncoder(nn.Module):
    def __init__(self, fband, input_size=16):
        super(TabEncoder, self).__init__()
        self.fband = fband
        self.fc1 = nn.Linear(input_size, 512)  # Equivalent to Dense(512) in Keras
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512, 512)
        self.out = nn.Linear(512, fband * 12)  # Output layer
        self.tanh = nn.Tanh()  # Equivalent to 'tanh' activation

    def forward(self, x):
        x = F.relu(self.fc1(x))  # First Dense layer with ReLU activation
        x = F.relu(self.fc2(x))  # Second Dense layer
        x = F.relu(self.fc3(x))  # Third Dense layer
        x = F.relu(self.fc4(x))  # Fourth Dense layer
        x = self.out(x)          # Output layer
        x = self.tanh(x)         # Tanh activation for output
        x = x.view(-1, self.fband, 12)  # Reshape to (fband, 12)
        return x




### Fusion Model

In [57]:
import torch.optim as optim
# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

fmax = 100
fstart = 0

# Initialize PyTorch model and move to device
nband = int(1)
overlap = int(1)
fband = int(fmax *overlap/ nband)  # Replace with your value for fmax
bandslice = int(fband/overlap)
fdim = 32
max_len = 7 #idk what this is yet
n_channel = 16

models_dir = '/home/ch106/Desktop/ASP_DAC2026/MxN/models'
pretrained_paths = { #dictionary storing pretrained weights for unimodal encoders
    'schem_weights': models_dir + '/' +'cnn_1band/0.pth',
    'graph_weights': models_dir + '/' +'GT_1band/0.pth',
    'tab_weights': models_dir + '/' +'mlp_1band/0.pth'

}

Using device: cuda


In [44]:
#Helper function to freeze parameters for unimodal heads (non-joint training scheme)

def freeze(model):
    for param in model.parameters():
        param.requires_grad = False

In [59]:
class MultimodalModel(nn.Module):
    def __init__(self, pretrained_paths=None, num_latents=4, dim=512):
        super(MultimodalModel, self).__init__()

        #unimodal heads
        self.v1 = SchemEncoder(fband)# for schematic
        self.v2 = GraphEncoder(TransformerEncoder(intermediate_size=512,hidden_size=fdim), fmax=fband) # for graph
        self.v3 = TabEncoder(fband)# for tabular


        if pretrained_paths: #load pretrained weights
            if 'schem_weights' in pretrained_paths:
                self.v1.load_state_dict(torch.load(pretrained_paths['schem_weights'], strict =False))
            if 'graph_weights' in pretrained_paths:
                self.v2.load_state_dict(torch.load(pretrained_paths['graph_weights'], strict =False))
            if 'tab_weights' in pretrained_paths:
                self.v3.load_state_dict(torch.load(pretrained_paths['tab_weights'], strict =False))

        """
        discard unnecessary layers and save parameters
        """
        self.v1.fc = nn.Identity()
        self.v1.reshape = nn.Identity()

        self.v2.out = nn.Identity()
        self.v2.reshape = nn.Identity()
        
        self.v3.out = nn.Identity()
        self.v3.tanh = nn.Identity()
        
        """
        Freeze parameters (comment out for joint training scheme)
        """
        freeze(self.v1)
        freeze(self.v2)
        freeze(self.v3)

        """
        Initialize auxillary unimodal transformers for fusion layers
        """
        encoder_base = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        fusion_layers = 4
        
        self.schem_aux = nn.TransformerEncoder(encoder_base, num_layers=fusion_layers)
        self.graph_aux = nn.TransformerEncoder(encoder_base, num_layers=fusion_layers)
        self.tab_aux = nn.TransformerEncoder(encoder_base, num_layers=fusion_layers)
        
        """
        Initialize Fusion Encoder and spectral head
        """
        encoder_layers = []
        for i in range(fusion_layers):

            # Vanilla Transformer Encoder (use for full fine tuning)
            
            encoder_layers.append(BottleneckLayer(num_latents=num_latents, schem_enc=self.schem_aux.layers[i],graph_enc=self.graph_aux.layers[i],tab_enc=self.tab_aux.layers[i]))

            # Frozen Transformer Encoder with AdaptFormer 
            #encoder_layers.append(AdaptFormer(num_latents=num_latents, dim=dim, schem_enc=self.schem_aux.blocks[i], graph_enc=self.graph_aux.blocks[i], tab_enc=self.tab_aux.blocks[i]))
             
        self.fusion_blocks = nn.Sequential(*encoder_layers)

        #add normalization of bottlenecks maybe?
        
        # spectral head
        self.fc1 = nn.Linear(512,512)
        self.fc2 = nn.Linear(512,512)
        self.fc3 = nn.Linear(512,512)
        self.fc4 = nn.Linear(512,512)
        self.out = nn.Linear(512, fmax*12)
        self.reshape = lambda x: x.view(-1, fmax, 12)

    def forward_encoder(self,x,y,z):     
        # encoder forward pass
        for blk in self.fusion_blocks:
            x,y,z, bottlenecks = blk(x,y,z)
        return x,y,z, bottlenecks
        
    def forward(self, x, y, z):
        #unimodal heads
        x = self.v1(x)
        y = self.v2(y)
        z = self.v3(z)

        #fusion encoders
        x,y,z, bottlenecks = self.forward_encoder(x,y,z) #try unified transformer after this in future w/ concatenated tokens

        #spectral head
        out = torch.relu(self.fc1(bottlenecks))
        out = torch.relu(self.fc2(out))
        out = torch.relu(self.fc3(out))
        out = torch.relu(self.fc4(out))
        out = torch.tanh(self.out(out))
        out = self.reshape(out)
        return out

In [61]:
test = MultimodalModel()
print(test)

MultimodalModel(
  (v1): SchemEncoder(
    (cnn): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU()
      (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (9): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (10): ReLU()
      (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU()
      (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (15): Flatten(start_dim