In [1]:
import torch
import os
import argparse
import json
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader, random_split
import math
import json
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import pickle
import time
from scipy.optimize import minimize
from tqdm import *
plt.rcParams.update({'font.size': 32})

In [2]:
class BilinearModified(nn.Module):
    __constants__ = ['in1_features', 'in2_features', 'out_features']
    in1_features: int
    in2_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in1_features = in1_features
        self.in2_features = in2_features
        self.out_features = out_features
        self.weight = Parameter(torch.empty((out_features, in1_features, in2_features), **factory_kwargs))

        if bias:
            # Use register_buffer to make bias a non-trainable fixed value (-70)
            self.register_buffer('bias', torch.tensor([-70.0], **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        # if bias:
        #     self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        # else:
        #     self.register_parameter('bias', None)
        # self.reset_parameters()

    def reset_parameters(self) -> None:
        bound = 1 / math.sqrt(self.weight.size(1))
        nn.init.uniform_(self.weight, -bound, bound)
        # if self.bias is not None:
        #     nn.init.uniform_(self.bias, -bound, bound)
        
        # Zero out the diagonal elements of the weight matrix
        with torch.no_grad():
            for i in range(min(self.in1_features, self.in2_features)):
                self.weight[:, i, i] = 0

    def forward(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor:
        # Ensure diagonal elements are zero during the forward pass as well
        with torch.no_grad():
            for i in range(min(self.in1_features, self.in2_features)):
                self.weight[:, i, i] = 0

        return F.bilinear(input1, input2, self.weight, self.bias)

    def extra_repr(self) -> str:
        return (f'in1_features={self.in1_features}, in2_features={self.in2_features}, '
                f'out_features={self.out_features}, bias={self.bias is not None}')


class TimeSeriesDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
    

# input size: (batch_size, N_synapse, time_dur)
# kernels size: (N_synapse, 1, time_dur)
# output after convolution: (batch_size, N_synapse, time_dur)
# transpose to (batch_size, time_dur, N_synapse)
# bilinear matrix size: (N_synapse, N_synapse, 1)
# output size: (batch_size, time_dur)

class DBNN(nn.Module):
    def __init__(self, num_dimensions, time_dur, device):
        super(DBNN, self).__init__()
        self.num_dimensions = num_dimensions
        self.time_dur = time_dur
        self.device = device
        # intinial values
        self.tau_rise = nn.Parameter(torch.ones(num_dimensions).to(self.device) * 50)
        self.tau_decay = nn.Parameter(torch.ones(num_dimensions).to(self.device) * 200)
        self.omega = nn.Parameter(torch.ones(num_dimensions).to(self.device) * 2)

        self.bilinear = BilinearModified(num_dimensions, num_dimensions, 1).to(self.device)

    def create_kernels(self):
        T = torch.arange(self.time_dur).to(self.device)
        N = self.num_dimensions
        net_tau_rise = self.tau_rise.unsqueeze(1)  # (N, 1)
        net_tau_decay = self.tau_decay.unsqueeze(1)  # (N, 1)
        net_omega = self.omega.unsqueeze(1)  # (N, 1)

        kernels = net_omega * (1 - torch.exp(-T / net_tau_rise)) * torch.exp(-T / net_tau_decay)
        return kernels.unsqueeze(1)
        
    def forward(self, x):

        kernels = self.create_kernels()
        kernel_flipped = torch.flip(kernels, dims=[2])

        # Convolve using the kernel (perform manual convolution)
        y = torch.nn.functional.conv1d(x, kernel_flipped, groups=self.num_dimensions, padding=self.time_dur - 1)[:, :, :self.time_dur][:,:,:self.time_dur]
        y_permuted = y.permute(0, 2, 1)
        bilinear_term = self.bilinear(y_permuted, y_permuted)
        linear_term = torch.sum(y_permuted, dim=2).unsqueeze(-1)
        output = bilinear_term + linear_term
        return output.squeeze(-1)

def variance_explained(y, y_hat, axis=None):
    """
    计算时间序列的 variance explained

    参数：
        y      : ndarray, shape = (batch, time, ...) 或 (time,)
        y_hat  : ndarray, 和 y 形状相同
        axis   : int 或 tuple, 在哪个维度上计算方差解释度
                 - None   -> 把整个数据展平成一维来算
                 - 1      -> 针对 time 维分别计算每个样本的方差解释度

    返回：
        ve : float 或 ndarray
             如果 axis=None，返回单个数值
             如果指定 axis，返回对应维度的结果
    """
    y = np.array(y)
    y_hat = np.array(y_hat)
    ss_res = np.sum((y - y_hat) ** 2, axis=axis)
    ss_tot = np.sum((y - np.mean(y, axis=axis, keepdims=True)) ** 2, axis=axis)
    ve = 1 - ss_res / ss_tot
    return ve

def add_poisson_noise(ip, noise_rate=5, dt=1e-3):
    """
    在每个输入通道上叠加 Poisson 噪声脉冲。
    参数：
        ip: (N, D, T) 的输入张量，0/1 值
        noise_rate: 噪声频率（Hz）
        dt: 采样时间步长（秒）
    返回：
        noisy_ip: 含噪信号 (0/1)
    """
    # 每个时间步发生噪声脉冲的概率
    p = noise_rate * dt
    # 生成独立的Poisson噪声矩阵
    noise = (torch.rand_like(ip) < p).float()
    # 叠加到原信号上（并截断为0/1）
    noisy_ip = torch.clamp(ip + noise, 0, 1)
    return noisy_ip

In [3]:
data_std = np.load('/home/mjy/project/bilinear_network/NC_code/data/data_5s_active.npz', allow_pickle=True)
ip_std = torch.tensor(data_std["ip"], dtype=torch.float32)
op_std = torch.tensor(data_std["op"], dtype=torch.float32)
num_dims = ip_std.size(1)
time_dur = ip_std.size(2)
model = DBNN(num_dimensions=num_dims, time_dur=time_dur, device='cpu')
checkpoint = '/home/mjy/project/bilinear_network/NC_code/parameters/DBNN_spike_bs16_lr0.004_epochs3000_seed6660451868631776210_LAST.pth'
state_dict = torch.load(checkpoint, map_location='cpu')
model.load_state_dict(state_dict)
predict_std = model(ip_std)
ve_std = variance_explained(op_std.detach().cpu(), predict_std.detach().cpu(), axis=1)
print(ip_std.shape)
np.mean(ve_std)

  state_dict = torch.load(checkpoint, map_location='cpu')


torch.Size([120, 9, 5001])


0.87729967

In [None]:
data_100s = np.load('/home/mjy/project/bilinear_network/data/data_100s_active.npz', allow_pickle=True)
ip_100s = torch.tensor(data_100s["ip"], dtype=torch.float32)
op_100s = torch.tensor(data_100s["op"], dtype=torch.float32)
num_dims = ip_100s.size(1)
time_dur = ip_100s.size(2)
model = DBNN(num_dimensions=num_dims, time_dur=time_dur, device='cpu')
checkpoint = '/home/mjy/project/bilinear_network/NC_code/parameters/DBNN