<a href="https://colab.research.google.com/github/foxtrotmike/CS909/blob/master/multilinear.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import torch
import torch.nn as nn
import math
import torch
import torch.nn as nn

class AdditiveMultiplicativeLayer(nn.Module):
    def __init__(self, input_sizes, output_size):
        super(AdditiveMultiplicativeLayer, self).__init__()

        # Weight tensors for multiplicative interaction
        self.weight_mul = nn.Parameter(torch.Tensor(*input_sizes, output_size))

        # Weight tensors and biases for additive interaction
        self.weight_add = nn.Parameter(torch.Tensor(sum(input_sizes), output_size))
        self.bias = nn.Parameter(torch.Tensor(output_size))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight_mul, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight_add, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_add)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, *inputs):
        # Concatenate inputs for additive interaction
        concatenated_inputs = torch.cat(inputs, dim=-1)

        # Compute multiplicative interaction
        multiplicative_output = torch.einsum('...i,...j,...k,ijkl->...l', *inputs, self.weight_mul)

        # Compute additive interaction
        additive_output = torch.matmul(concatenated_inputs, self.weight_add) + self.bias

        # Combine both interactions
        output = multiplicative_output + additive_output
        return output

# Example usage
x1 = torch.randn(10, 5)  # Batch size of 10, input size 5
x2 = torch.randn(10, 4)  # Batch size of 10, input size 4
x3 = torch.randn(10, 6)  # Batch size of 10, input size 6

layer = AdditiveMultiplicativeLayer((5, 4, 6), 7)  # Output size 7
output = layer(x1, x2, x3)


In [12]:
output.shape

torch.Size([10, 7])

In [16]:
layer.weight_mul.shape

torch.Size([5, 4, 6, 7])