In [85]:
from components import Bilinear
import torch
import numpy as np
import torch
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [86]:

# define the 1-hidden layer MLP
class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.bl1 = Bilinear(input_size, hidden_size, bias=False)
        self.bl2 = Bilinear(hidden_size, output_size, bias=False)
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.u1 = torch.zeros(hidden_size).to(device)
        self.u2 = torch.zeros(hidden_size).to(device)
        self.u3 = torch.zeros(hidden_size).to(device)
        self.w1 = torch.zeros(hidden_size).to(device)
        self.w2 = torch.zeros(hidden_size).to(device)
        self.w3 = torch.zeros(hidden_size).to(device)
        self.s1 = torch.zeros(hidden_size).to(device)
        self.s2 = torch.zeros(hidden_size).to(device)
        self.s3 = torch.zeros(hidden_size).to(device)
        self.o1 = torch.zeros(hidden_size).to(device)
        self.o2 = torch.zeros(hidden_size).to(device)
        self.o3 = torch.zeros(hidden_size).to(device)


    def forward(self, x, val=False):
        pre_relu = self.bl1(x)
        post_relu = torch.nn.functional.relu(pre_relu)
        x = self.bl2(post_relu)
        return x, pre_relu, post_relu


    def cosine(self, x, u, w, s, o):
      return u*torch.cos(x*w*2*np.pi + s) + o


    def forward_with_replaced_weights(self, x, p_l1=0, plot=True, altered_i=None):
        # Cloning the weights before modification
        bl1_weight_original = self.bl1.weight.data.clone()
        bl2_weight_original = self.bl2.weight.data.clone()

        a = torch.argmax(x[:, :113], dim=1).unsqueeze(-1)
        b = torch.argmax(x[:, 113:], dim=1).unsqueeze(-1)

        pre_relu = self.bl1(x)
        if plot:
            plt.matshow(pre_relu[:, 0].cpu().detach().numpy().reshape(113, 113))
            plt.show()

        if not altered_i:
          num_altered_neurons = int(p_l1 * pre_relu.shape[-1])
          altered_i = torch.randperm(pre_relu.shape[-1])[:num_altered_neurons].to(device)


        for i in altered_i:
            # Replace incoming weights of the altered neurons
            self.bl1.weight.data[i, :113] = self.cosine(torch.arange(113).to(device), self.u1[i], self.w1[i], self.s1[i], self.o1[i])
            self.bl1.weight.data[i, 113:] = self.cosine(torch.arange(113).to(device), self.u2[i], self.w2[i], self.s2[i], self.o2[i])

            # Replace outgoing weights of the altered neurons
            self.bl2.weight.data[:, i] = self.cosine(torch.arange(113).to(device), self.u3[i], self.w3[i], self.s3[i], self.o3[i])

        pre_relu = self.bl1(x)  # Re-compute pre_relu with updated weights

        if plot:
            plt.matshow(pre_relu[:, 0].cpu().detach().numpy().reshape(113, 113))
            plt.show()

        post_relu = torch.nn.functional.relu(pre_relu)
        output = self.bl2(post_relu)

        # Restoring the original weights after forward pass
        self.bl1.weight.data = bl1_weight_original
        self.bl2.weight.data = bl2_weight_original

        return output


    def calculate_neuron_properties(self):
        for neuron in range(self.hidden_size):
          weights1 = self.bl1.weight.detach().cpu().numpy()[neuron, :113]
          weights2 = self.bl1.weight.detach().cpu().numpy()[neuron, 113:]
          weights3 = self.bl2.weight.detach().cpu().numpy()[:, neuron]
          x_data = np.arange(self.output_size)
          w1, s1, u1, o1 = self.find_cosine(x_data, weights1)
          w2, s2, u2, o2 = self.find_cosine(x_data, weights2)
          w3, s3, u3, o3 = self.find_cosine(x_data, weights3)
          self.u1[neuron] = torch.tensor(u1, dtype=torch.float32, device=device)
          self.u2[neuron] = torch.tensor(u2, dtype=torch.float32, device=device)
          self.u3[neuron] = torch.tensor(u3, dtype=torch.float32, device=device)
          self.w1[neuron] = torch.tensor(w1, dtype=torch.float32, device=device)
          self.w2[neuron] = torch.tensor(w2, dtype=torch.float32, device=device)
          self.w3[neuron] = torch.tensor(w3, dtype=torch.float32, device=device)
          self.s1[neuron] = torch.tensor(s1, dtype=torch.float32, device=device)
          self.s2[neuron] = torch.tensor(s2, dtype=torch.float32, device=device)
          self.s3[neuron] = torch.tensor(s3, dtype=torch.float32, device=device)
          self.o1[neuron] = torch.tensor(o1)
          self.o2[neuron] = torch.tensor(o2)
          self.o3[neuron] = torch.tensor(o3)


    def find_cosine(self, x_data, y_data):
        # Calculate DFT
        yf = np.fft.fft(y_data)
        xf = np.fft.fftfreq(x_data.size, d=(x_data[1]-x_data[0]))  # assuming x_data is evenly spaced

        # Find the peak frequency
        idx = np.argmax(np.abs(yf[1:yf.size//2]))  # ignore the zero frequency "peak", and only consider the first half of points
        freq = np.abs(xf[idx+1])  # shift index by 1 because we ignored the first point

        # Calculate phase shift
        phase_shift = -np.angle(yf[idx+1])

        # Calculate scale
        scale = 2 * np.abs(yf[idx+1]) / x_data.size

        # Estimate offset
        offset = np.mean(y_data)
        return freq, phase_shift, scale, offset


In [87]:
p = 113
model = MLP(2*p, 50, p).to(device)

model.load_state_dict(torch.load("model_weights.pth"))
model.eval()  

  model.load_state_dict(torch.load("model_weights.pth"))


MLP(
  (bl1): Bilinear(
    in_features=226, out_features=100, bias=False
    (gate): Identity()
  )
  (bl2): Bilinear(
    in_features=50, out_features=226, bias=False
    (gate): Identity()
  )
)

In [88]:
w_l1, w_r1 = model.bl1.w_l, model.bl1.w_r
w_l2, w_r2 = model.bl2.w_l, model.bl2.w_r

In [92]:
layer1 = model.bl1
layer2 = model.bl2

# wl1, wr1 = layer1.weight.detach().cpu().chunk(2, dim=0)
# wl2, wr2 = layer2.weight.detach().cpu().chunk(2, dim=0) # or dim=-1 idk

In [90]:
print(w_l1.shape, w_r1.shape)
print(w_l2.shape, w_r2.shape)

torch.Size([50, 226]) torch.Size([50, 226])
torch.Size([113, 50]) torch.Size([113, 50])


In [91]:
B = torch.einsum('ai,aj->aij', w_l1, w_r1)
B.shape # d_hidden, d_input, d_input (third-order bilinear tensor)

torch.Size([50, 226, 226])

In [96]:
model.bl1.weight.shape

torch.Size([100, 226])