In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class HyperNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(HyperNetwork, self).__init__()
        # Define layers, output_size should match main network's weight size
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class MainNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(MainNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, 128, bias=False)
        self.fc2 = nn.Linear(128, output_size, bias=False)
        self._initialize_weights()

    def _initialize_weights(self):
        with torch.no_grad():
            self.fc1.weight.fill_(1.0)
            self.fc2.weight.fill_(1.0)
            # Make weights non-trainable
            self.fc1.weight.requires_grad = False
            self.fc2.weight.requires_grad = False

    def forward(self, x, weights):
        w1_elements = self.fc1.in_features * self.fc1.out_features
        w2_elements = self.fc2.in_features * self.fc2.out_features

        # Reshape weights according to the dimensions of the layer's weight matrix
        w1 = weights[:w1_elements].reshape(self.fc1.out_features, self.fc1.in_features)
        w2 = weights[w1_elements:w1_elements + w2_elements].reshape(self.fc2.out_features, self.fc2.in_features)

        x = F.linear(x, weight=w1)
        x = F.relu(x)
        x = F.linear(x, weight=w2)
        return x

# Example usage
input_size = 100
output_size = 10

# Initialize MainNetwork
main_network = MainNetwork(input_size, output_size)

# Calculate the correct output size for HyperNetwork
hypernetwork_output_size = (input_size * 128) + (128 * output_size)

# Initialize HyperNetwork with the corrected output size
hypernetwork = HyperNetwork(input_size, hypernetwork_output_size)

# Example input
x = torch.randn(1, input_size)

# Forward pass
hypernetwork_weights = hypernetwork(x)
hypernetwork_weights = hypernetwork_weights.squeeze()  # Remove the extra batch dimension
output = main_network(x, hypernetwork_weights)

# Check if the output is correct
print(output.shape)


torch.Size([1, 10])
