In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init

In [None]:
#METHODOLOGY 1
# An implementation that combines basenet, learnable epinet, prior epinet

# shape of x in [batch_size,x_dim], z is [z_dim]
# assuming input is always included and output is never included
# hidden layers and exposed layers same number of entries



class basenet_with_learnable_epinet_and_ensemble_prior(nn.Module):
    def __init__(self, input_size, basenet_hidden_sizes, n_classes, exposed_layers, z_dim, learnable_epinet_hiddens, hidden_sizes_prior, seed_base, seed_learnable_epinet, seed_prior_epinet, alpha):
        super(basenet_with_learnable_epinet_and_ensemble_prior, self).__init__()


        self.z_dim = z_dim
        self.n_classes = n_classes
        self.num_ensemble = z_dim
        self.alpha = alpha


        # Create a list of all sizes (input + hidden + output)
        basenet_all_sizes = [input_size] + basenet_hidden_sizes + [n_classes]

        self.basenet_all_sizes = basenet_all_sizes
        exposed_layers = [True]+exposed_layers+[False]     # assuming input is always included and output is never included

        self.exposed_layers = exposed_layers

        torch.manual_seed(seed_base)
        # Dynamically create layers
        self.basenet_layers = nn.ModuleList()
        for i in range(len(basenet_all_sizes) - 1):
            self.basenet_layers.append(nn.Linear(basenet_all_sizes[i], basenet_all_sizes[i + 1]))


        sum_input_base_epi = sum(basenet_all_size for basenet_all_size, exposed_layer in zip(basenet_all_sizes, exposed_layers) if exposed_layer)

        learnable_epinet_all_sizes = [sum_input_base_epi+z_dim]    + learnable_epinet_hiddens + [n_classes*z_dim]

        self.learnable_epinet_all_sizes = learnable_epinet_all_sizes

        torch.manual_seed(seed_learnable_epinet)
        self.learnable_epinet_layers = nn.ModuleList()
        for j in range(len(learnable_epinet_all_sizes) - 1):
            self.learnable_epinet_layers.append(nn.Linear(learnable_epinet_all_sizes[j], learnable_epinet_all_sizes[j + 1]))




        torch.manual_seed(seed_prior_epinet)
        self.ensemble = nn.ModuleList()
        for _ in range(self.num_ensemble):
            layers = []
            all_sizes_prior = [sum_input_base_epi] + hidden_sizes_prior + [n_classes]
            for i in range(len(all_sizes_prior) - 1):
                layer = nn.Linear(all_sizes_prior[i], all_sizes_prior[i + 1])


                # Initialize weights and biases here
                init.xavier_uniform_(layer.weight)
                init.zeros_(layer.bias)

                layers.append(layer)
                if i < len(all_sizes_prior) - 2:
                    layers.append(nn.ReLU())

            mlp = nn.Sequential(*layers)

            # Freeze the parameters of this MLP
            for param in mlp.parameters():
                param.requires_grad = False

            self.ensemble.append(mlp)







    def forward(self, x, z):
        hidden_outputs = []
        #concatenate_hidden = x   #assuming x is always input


        for i, (basenet_layer, flag) in enumerate(zip(self.basenet_layers, self.exposed_layers)):
            if flag:
                hidden_outputs.append(x)


            x = basenet_layer(x)

            if i < len(self.basenet_layers) - 1:  # Apply activation function except for the output layer
                x = torch.relu(x)


            #if i>0 and flag:
                #concatenate_hidden = torch.cat(x,concatenate_hidden, dim=1)

        concatenate_hidden = torch.cat(hidden_outputs, dim=1)

        detached_concatenate_hidden = concatenate_hidden.detach()                    ###-------NOT SURE IF BACKPROP WILL WORK PROPERLY THROUGH THIS

        detached_concatenate_hidden_to_prior = concatenate_hidden.detach()
        ###-------NOT SURE IF BACKPROP WILL WORK PROPERLY THROUGH THIS - should we clone and detach


        z_repeated = z.unsqueeze(0).repeat(detached_concatenate_hidden.size(0), 1)

        combined_output = torch.cat([detached_concatenate_hidden,z_repeated], dim=1)




        for j, learnable_epinet_layer in enumerate(self.learnable_epinet_layers):
            combined_output = learnable_epinet_layer(combined_output)

            if j < len(self.learnable_epinet_layers) - 1:  # Apply activation function except for the output layer
                combined_output = torch.relu(combined_output)

        #reshaped_output = combined_output_learnable.view(inputs.shape[0], self.num_classes, self.z_dim)
        reshaped_epinet_output = torch.reshape(combined_output, (combined_output.shape[0], self.n_classes, self.z_dim))

        epinet_output = torch.matmul(reshaped_epinet_output, z)


        outputs_prior = [mlp(detached_concatenate_hidden_to_prior) for mlp in self.ensemble]

        outputs_prior_tensor = torch.stack(outputs_prior, dim=0)

        prior_output = torch.einsum('nbo,n->bo', outputs_prior_tensor, z)

        final_output =  x + epinet_output + self.alpha* prior_output




        return final_output




#optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)

# optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.01)



In [None]:
#METHODOLOGY 1
# An implementation that combines basenet, learnable epinet, prior epinet

# shape of x in [batch_size,x_dim], z is [z_dim]
# assuming input is always included and output is never included
# hidden layers and exposed layers same number of entries



class basenet_with_learnable_epinet_and_ensemble_prior(nn.Module):
    def __init__(self, input_size, basenet_hidden_sizes, n_classes, exposed_layers, z_dim, learnable_epinet_hiddens, hidden_sizes_prior, seed_base, seed_learnable_epinet, seed_prior_epinet, alpha):
        super(basenet_with_learnable_epinet_and_ensemble_prior, self).__init__()


        self.z_dim = z_dim
        self.n_classes = n_classes
        self.num_ensemble = z_dim
        self.alpha = alpha


        # Create a list of all sizes (input + hidden + output)
        basenet_all_sizes = [input_size] + basenet_hidden_sizes + [n_classes]
        print("basenet_all_sizes:", basenet_all_sizes)
        self.basenet_all_sizes = basenet_all_sizes
        exposed_layers = [True]+exposed_layers+[False]     # assuming input is always included and output is never included
        print("exposed_layers:",exposed_layers)
        self.exposed_layers = exposed_layers

        torch.manual_seed(seed_base)
        # Dynamically create layers
        self.basenet_layers = nn.ModuleList()
        for i in range(len(basenet_all_sizes) - 1):
            self.basenet_layers.append(nn.Linear(basenet_all_sizes[i], basenet_all_sizes[i + 1]))
        print("basenet_layers:", self.basenet_layers)

        sum_input_base_epi = sum(basenet_all_size for basenet_all_size, exposed_layer in zip(basenet_all_sizes, exposed_layers) if exposed_layer)
        print("sum_input_base_epi:", sum_input_base_epi)
        learnable_epinet_all_sizes = [sum_input_base_epi+z_dim]    + learnable_epinet_hiddens + [n_classes*z_dim]
        print("learnable_epinet_all_sizes:", learnable_epinet_all_sizes)
        self.learnable_epinet_all_sizes = learnable_epinet_all_sizes

        torch.manual_seed(seed_learnable_epinet)
        self.learnable_epinet_layers = nn.ModuleList()
        for j in range(len(learnable_epinet_all_sizes) - 1):
            self.learnable_epinet_layers.append(nn.Linear(learnable_epinet_all_sizes[j], learnable_epinet_all_sizes[j + 1]))
        print("learnable_epinet_layers:", self.learnable_epinet_layers)



        torch.manual_seed(seed_prior_epinet)
        self.ensemble = nn.ModuleList()
        for _ in range(self.num_ensemble):
            layers = []
            all_sizes_prior = [sum_input_base_epi] + hidden_sizes_prior + [n_classes]
            for i in range(len(all_sizes_prior) - 1):
                layer = nn.Linear(all_sizes_prior[i], all_sizes_prior[i + 1])


                # Initialize weights and biases here
                init.xavier_uniform_(layer.weight)
                init.zeros_(layer.bias)

                layers.append(layer)
                if i < len(all_sizes_prior) - 2:
                    layers.append(nn.ReLU())

            mlp = nn.Sequential(*layers)

            # Freeze the parameters of this MLP
            for param in mlp.parameters():
                param.requires_grad = False

            self.ensemble.append(mlp)

        print("ensemble:", self.ensemble)





    def forward(self, x, z):
        hidden_outputs = []
        #concatenate_hidden = x   #assuming x is always input
        print("x:", x)

        for i, (basenet_layer, flag) in enumerate(zip(self.basenet_layers, self.exposed_layers)):
            if flag:
                hidden_outputs.append(x)
                print("hidden_outputs:", hidden_outputs)

            x = basenet_layer(x)
            print("x:", x)
            if i < len(self.basenet_layers) - 1:  # Apply activation function except for the output layer
                x = torch.relu(x)
                print("x:", x)

            #if i>0 and flag:
                #concatenate_hidden = torch.cat(x,concatenate_hidden, dim=1)

        concatenate_hidden = torch.cat(hidden_outputs, dim=1)
        print("concatenate_hidden:", concatenate_hidden)
        detached_concatenate_hidden = concatenate_hidden.detach()                    ###-------NOT SURE IF BACKPROP WILL WORK PROPERLY THROUGH THIS
        print("detached_concatenate_hidden:", detached_concatenate_hidden)
        detached_concatenate_hidden_to_prior = concatenate_hidden.detach()
        print("detached_concatenate_hidden_to_prior :", detached_concatenate_hidden_to_prior)    ###-------NOT SURE IF BACKPROP WILL WORK PROPERLY THROUGH THIS - should we clone and detach


        z_repeated = z.unsqueeze(0).repeat(detached_concatenate_hidden.size(0), 1)
        print("z_repeated:",z_repeated)
        combined_output = torch.cat([detached_concatenate_hidden,z_repeated], dim=1)
        print("combined_output:", combined_output)



        for j, learnable_epinet_layer in enumerate(self.learnable_epinet_layers):
            combined_output = learnable_epinet_layer(combined_output)
            print("combined_output:", combined_output)
            if j < len(self.learnable_epinet_layers) - 1:  # Apply activation function except for the output layer
                combined_output = torch.relu(combined_output)
                print("combined_output:", combined_output)

        print("intermediary_check, x:", x)
        print("intermediary_check, concatenate_hidden:", concatenate_hidden)
        print("intermediary_check, detached_concatenate_hidden:", detached_concatenate_hidden)
        print("intermediary_check, detached_concatenate_hidden_to_prior:", detached_concatenate_hidden_to_prior)


        print("detached_concatenate_hidden_to_prior:", detached_concatenate_hidden_to_prior)
        #reshaped_output = combined_output_learnable.view(inputs.shape[0], self.num_classes, self.z_dim)
        reshaped_epinet_output = torch.reshape(combined_output, (combined_output.shape[0], self.n_classes, self.z_dim))
        print("reshaped_epinet_output:",reshaped_epinet_output)
        epinet_output = torch.matmul(reshaped_epinet_output, z)
        print("epinet_output:", epinet_output)

        outputs_prior = [mlp(detached_concatenate_hidden_to_prior) for mlp in self.ensemble]
        print("outputs_prior:", outputs_prior)
        outputs_prior_tensor = torch.stack(outputs_prior, dim=0)
        print("outputs_prior_tensor:", outputs_prior_tensor)
        prior_output = torch.einsum('nbo,n->bo', outputs_prior_tensor, z)
        print("prior_output:", prior_output)
        final_output =  x + epinet_output + self.alpha* prior_output
        print("final_output:", final_output)



        return final_output




#optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)

# optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.01)



In [None]:
basenet_all_sizes= 3
print("basenet_all_sizes:", basenet_all_sizes)

basenet_all_sizes: 3


In [None]:
input_size = 3
basenet_hidden_sizes = [5,5]
n_classes = 2
exposed_layers = [False, True]
z_dim = 3
learnable_epinet_hiddens = [8,8]
hidden_sizes_prior = [2,2]
seed_base = 2
seed_learnable_epinet = 1
seed_prior_epinet = 0
alpha = 0.1

In [None]:
model = basenet_with_learnable_epinet_and_ensemble_prior(input_size, basenet_hidden_sizes, n_classes, exposed_layers, z_dim, learnable_epinet_hiddens, hidden_sizes_prior, seed_base, seed_learnable_epinet, seed_prior_epinet, alpha)

basenet_all_sizes: [3, 5, 5, 2]
exposed_layers: [True, False, True, False]
basenet_layers: ModuleList(
  (0): Linear(in_features=3, out_features=5, bias=True)
  (1): Linear(in_features=5, out_features=5, bias=True)
  (2): Linear(in_features=5, out_features=2, bias=True)
)
sum_input_base_epi: 8
learnable_epinet_all_sizes: [11, 8, 8, 6]
learnable_epinet_layers: ModuleList(
  (0): Linear(in_features=11, out_features=8, bias=True)
  (1): Linear(in_features=8, out_features=8, bias=True)
  (2): Linear(in_features=8, out_features=6, bias=True)
)
ensemble: ModuleList(
  (0-2): 3 x Sequential(
    (0): Linear(in_features=8, out_features=2, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2, out_features=2, bias=True)
    (3): ReLU()
    (4): Linear(in_features=2, out_features=2, bias=True)
  )
)


In [None]:
x = torch.randn(5,3)

In [None]:
x

tensor([[-2.1788,  0.5684, -1.0845],
        [-1.3986,  0.4033,  0.8380],
        [-0.7193, -0.4033, -0.5966],
        [ 0.1820, -0.8567,  1.1006],
        [-1.0712,  0.1227, -0.5663]])

In [None]:
z=torch.randn(3)

In [None]:
z

tensor([ 0.3731, -0.8920, -1.5091])

In [None]:
z.unsqueeze(0)

tensor([[ 0.3731, -0.8920, -1.5091]])

In [None]:
 z_repeated = z.unsqueeze(0).repeat(x.size(0), 1)

In [None]:
z_repeated

tensor([[ 0.3731, -0.8920, -1.5091],
        [ 0.3731, -0.8920, -1.5091],
        [ 0.3731, -0.8920, -1.5091],
        [ 0.3731, -0.8920, -1.5091],
        [ 0.3731, -0.8920, -1.5091]])

In [None]:
combined_output = torch.cat([x,z_repeated], dim=1)

In [None]:
combined_output

tensor([[-2.1788,  0.5684, -1.0845,  0.3731, -0.8920, -1.5091],
        [-1.3986,  0.4033,  0.8380,  0.3731, -0.8920, -1.5091],
        [-0.7193, -0.4033, -0.5966,  0.3731, -0.8920, -1.5091],
        [ 0.1820, -0.8567,  1.1006,  0.3731, -0.8920, -1.5091],
        [-1.0712,  0.1227, -0.5663,  0.3731, -0.8920, -1.5091]])

In [None]:
model_2(x,z)

x: tensor([[-2.1788,  0.5684, -1.0845],
        [-1.3986,  0.4033,  0.8380],
        [-0.7193, -0.4033, -0.5966],
        [ 0.1820, -0.8567,  1.1006],
        [-1.0712,  0.1227, -0.5663]])
hidden_outputs: [tensor([[-2.1788,  0.5684, -1.0845],
        [-1.3986,  0.4033,  0.8380],
        [-0.7193, -0.4033, -0.5966],
        [ 0.1820, -0.8567,  1.1006],
        [-1.0712,  0.1227, -0.5663]])]
x: tensor([[ 0.7222,  0.8597,  0.6871, -0.7004,  0.9911],
        [-0.2459,  0.8626,  0.5041, -0.7706,  0.1663],
        [ 0.1831,  0.5312,  0.2006, -0.3632,  0.4415],
        [-0.7677,  0.5118, -0.1042, -0.3390, -0.2862],
        [ 0.3331,  0.5684,  0.4439, -0.5122,  0.4273]],
       grad_fn=<AddmmBackward0>)
x: tensor([[0.7222, 0.8597, 0.6871, 0.0000, 0.9911],
        [0.0000, 0.8626, 0.5041, 0.0000, 0.1663],
        [0.1831, 0.5312, 0.2006, 0.0000, 0.4415],
        [0.0000, 0.5118, 0.0000, 0.0000, 0.0000],
        [0.3331, 0.5684, 0.4439, 0.0000, 0.4273]], grad_fn=<ReluBackward0>)
x: tensor([[ 0.0

tensor([[0.2567, 0.7528],
        [0.7295, 0.4753],
        [0.6009, 0.5719],
        [0.7760, 0.4688],
        [0.4677, 0.6300]], grad_fn=<AddBackward0>)

In [None]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)

basenet_layers.0.weight tensor([[ 0.1324, -0.1374,  0.1583],
        [-0.0295,  0.2466,  0.1375],
        [-0.0664, -0.4668,  0.1318],
        [-0.5112,  0.0759,  0.0384],
        [-0.1270,  0.4721,  0.0385]])
basenet_layers.0.bias tensor([ 0.2394,  0.2443, -0.3406, -0.2220,  0.5552])
basenet_layers.1.weight tensor([[-0.4380, -0.0304, -0.0354,  0.3172, -0.0425],
        [ 0.1178, -0.0215, -0.2504, -0.2535, -0.2173],
        [-0.4063, -0.2902,  0.1053,  0.2943,  0.0220],
        [-0.2050,  0.1965, -0.1717, -0.0991, -0.2452],
        [-0.1404, -0.4144,  0.1908,  0.1739,  0.0888]])
basenet_layers.1.bias tensor([0.2196, 0.1895, 0.0198, 0.0474, 0.0341])
basenet_layers.2.weight tensor([[ 0.2386,  0.3004,  0.3212,  0.2592, -0.1090],
        [-0.0199, -0.0909,  0.2602,  0.0497,  0.4139]])
basenet_layers.2.bias tensor([ 0.2268, -0.3822])
learnable_epinet_layers.0.weight tensor([[ 0.1554, -0.1331, -0.0585,  0.1415, -0.2839,  0.1808, -0.0620,  0.1534,
          0.0419, -0.0369,  0.0836],
        

In [None]:
for name, param in model.named_parameters():
        print(name, param.data)

basenet_layers.0.weight tensor([[-0.5181, -0.4985, -0.0903],
        [ 0.0075, -0.2623,  0.2175],
        [-0.5196, -0.0390,  0.5077],
        [-0.2355,  0.5213,  0.2091],
        [-0.5210,  0.3653, -0.0666]])
basenet_layers.0.bias tensor([-0.2577,  0.4617, -0.4666,  0.0620, -0.1209])
basenet_layers.1.weight tensor([[ 0.3194,  0.1248,  0.2149,  0.1579, -0.1075],
        [-0.0941, -0.3685,  0.2423,  0.3551,  0.3060],
        [-0.3155,  0.0199, -0.3153, -0.2462, -0.2606],
        [ 0.1528, -0.2665, -0.0098,  0.0188,  0.2883],
        [-0.3381, -0.3070, -0.2597,  0.3130, -0.1608]])
basenet_layers.1.bias tensor([ 0.3772,  0.1617,  0.0566, -0.0033, -0.0884])
basenet_layers.2.weight tensor([[ 0.0561, -0.1021, -0.0031,  0.0571, -0.3498],
        [-0.2344,  0.3611, -0.3629, -0.0321,  0.4424]])
basenet_layers.2.bias tensor([0.1616, 0.0127])
learnable_epinet_layers.0.weight tensor([[-0.2613,  0.1494, -0.2148, -0.0856, -0.1012, -0.0447,  0.0033,  0.2487,
          0.0376,  0.2701,  0.1844],
     

In [None]:
model_2 = basenet_with_learnable_epinet_and_ensemble_prior(input_size, basenet_hidden_sizes, n_classes, exposed_layers, z_dim, learnable_epinet_hiddens, hidden_sizes_prior,seed_prior, alpha)

basenet_all_sizes: [3, 5, 5, 2]
exposed_layers: [True, False, True, False]
basenet_layers: ModuleList(
  (0): Linear(in_features=3, out_features=5, bias=True)
  (1): Linear(in_features=5, out_features=5, bias=True)
  (2): Linear(in_features=5, out_features=2, bias=True)
)
sum_input_base_epi: 8
learnable_epinet_all_sizes: [11, 8, 8, 6]
learnable_epinet_layers: ModuleList(
  (0): Linear(in_features=11, out_features=8, bias=True)
  (1): Linear(in_features=8, out_features=8, bias=True)
  (2): Linear(in_features=8, out_features=6, bias=True)
)
ensemble: ModuleList(
  (0-2): 3 x Sequential(
    (0): Linear(in_features=8, out_features=2, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2, out_features=2, bias=True)
    (3): ReLU()
    (4): Linear(in_features=2, out_features=2, bias=True)
  )
)


In [None]:
for name, param in model_2.named_parameters():
        print(name, param.data)

basenet_layers.0.weight tensor([[-0.5181, -0.4985, -0.0903],
        [ 0.0075, -0.2623,  0.2175],
        [-0.5196, -0.0390,  0.5077],
        [-0.2355,  0.5213,  0.2091],
        [-0.5210,  0.3653, -0.0666]])
basenet_layers.0.bias tensor([-0.2577,  0.4617, -0.4666,  0.0620, -0.1209])
basenet_layers.1.weight tensor([[ 0.3194,  0.1248,  0.2149,  0.1579, -0.1075],
        [-0.0941, -0.3685,  0.2423,  0.3551,  0.3060],
        [-0.3155,  0.0199, -0.3153, -0.2462, -0.2606],
        [ 0.1528, -0.2665, -0.0098,  0.0188,  0.2883],
        [-0.3381, -0.3070, -0.2597,  0.3130, -0.1608]])
basenet_layers.1.bias tensor([ 0.3772,  0.1617,  0.0566, -0.0033, -0.0884])
basenet_layers.2.weight tensor([[ 0.0561, -0.1021, -0.0031,  0.0571, -0.3498],
        [-0.2344,  0.3611, -0.3629, -0.0321,  0.4424]])
basenet_layers.2.bias tensor([0.1616, 0.0127])
learnable_epinet_layers.0.weight tensor([[-0.2613,  0.1494, -0.2148, -0.0856, -0.1012, -0.0447,  0.0033,  0.2487,
          0.0376,  0.2701,  0.1844],
     

In [None]:
model = basenet_with_learnable_epinet_and_ensemble_prior(input_size, basenet_hidden_sizes, n_classes, exposed_layers, z_dim, learnable_epinet_hiddens, hidden_sizes_prior, 0, 1, 2, alpha)
for name, param in model.named_parameters():
        print(name, param.data)

basenet_all_sizes: [3, 5, 5, 2]
exposed_layers: [True, False, True, False]
basenet_layers: ModuleList(
  (0): Linear(in_features=3, out_features=5, bias=True)
  (1): Linear(in_features=5, out_features=5, bias=True)
  (2): Linear(in_features=5, out_features=2, bias=True)
)
sum_input_base_epi: 8
learnable_epinet_all_sizes: [11, 8, 8, 6]
learnable_epinet_layers: ModuleList(
  (0): Linear(in_features=11, out_features=8, bias=True)
  (1): Linear(in_features=8, out_features=8, bias=True)
  (2): Linear(in_features=8, out_features=6, bias=True)
)
ensemble: ModuleList(
  (0-2): 3 x Sequential(
    (0): Linear(in_features=8, out_features=2, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2, out_features=2, bias=True)
    (3): ReLU()
    (4): Linear(in_features=2, out_features=2, bias=True)
  )
)
basenet_layers.0.weight tensor([[-0.0043,  0.3097, -0.4752],
        [-0.4249, -0.2224,  0.1548],
        [-0.0114,  0.4578, -0.0512],
        [ 0.1528, -0.1745, -0.1135],
        [-0.5516, -0.382

In [None]:
model_2 = basenet_with_learnable_epinet_and_ensemble_prior(input_size, basenet_hidden_sizes, n_classes, exposed_layers, z_dim, learnable_epinet_hiddens, hidden_sizes_prior, 0, 4, 2, alpha)
for name, param in model_2.named_parameters():
        print(name, param.data)

basenet_all_sizes: [3, 5, 5, 2]
exposed_layers: [True, False, True, False]
basenet_layers: ModuleList(
  (0): Linear(in_features=3, out_features=5, bias=True)
  (1): Linear(in_features=5, out_features=5, bias=True)
  (2): Linear(in_features=5, out_features=2, bias=True)
)
sum_input_base_epi: 8
learnable_epinet_all_sizes: [11, 8, 8, 6]
learnable_epinet_layers: ModuleList(
  (0): Linear(in_features=11, out_features=8, bias=True)
  (1): Linear(in_features=8, out_features=8, bias=True)
  (2): Linear(in_features=8, out_features=6, bias=True)
)
ensemble: ModuleList(
  (0-2): 3 x Sequential(
    (0): Linear(in_features=8, out_features=2, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2, out_features=2, bias=True)
    (3): ReLU()
    (4): Linear(in_features=2, out_features=2, bias=True)
  )
)
basenet_layers.0.weight tensor([[-0.0043,  0.3097, -0.4752],
        [-0.4249, -0.2224,  0.1548],
        [-0.0114,  0.4578, -0.0512],
        [ 0.1528, -0.1745, -0.1135],
        [-0.5516, -0.382

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyBranchingModel(nn.Module):
    def __init__(self):
        super(MyBranchingModel, self).__init__()
        # Define initial layers
        self.initial_layer = nn.Linear(in_features=10, out_features=20)

        # Define layers for branch 1
        self.branch1_layer1 = nn.Linear(in_features=20, out_features=15)

        # Define layers for branch 2
        self.branch2_layer1 = nn.Linear(in_features=20, out_features=15)

        # Define layers after recombining
        self.post_combine_layer = nn.Linear(in_features=15, out_features=10)

    def forward(self, x):
        # Initial layers
        x = self.initial_layer(x)
        y=x
        z=x

        # Branching
        branch1_output = F.relu(self.branch1_layer1(y))
        branch2_output = F.sigmoid(self.branch2_layer1(z))

        combined = branch1_output + branch2_output

        # Further processing
        output = self.post_combine_layer(combined)
        return output

# Example usage
model = MyBranchingModel()
input_tensor = torch.randn(1, 10)
output = model(input_tensor)
