In [17]:
import torch  
import torch.nn as nn  
import torch.optim as optim  
import numpy as np  
from torch.utils.data import DataLoader, TensorDataset  

import os
from scipy import io
import matplotlib.pyplot as plt


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.get_device_name(0))


cuda:0
NVIDIA A100 80GB PCIe


In [50]:
class Gaussian_af(nn.Module):  
    def __init__(self, weight_std, neuron_modulation_type, phaseR, amplitudeR):  
        # neuron_modulation_type: none; phase; amplitude; both
        super(Gaussian_af, self).__init__() 
        self.orignal_phase = torch.from_numpy(np.radians(np.arange(15, 360, step=30)))

        neuron_sigma = torch.zeros(12, 2, 2)
        neuron_sigma[:, 0, 0] = 0.202
        neuron_sigma[:, 1, 1] = 0.202
        self.sigma = neuron_sigma
        self.pi = np.pi
        self.eps = torch.tensor(1e-6)
        self.phaseR = torch.tensor(phaseR)
        self.amplitudeR = torch.tensor(amplitudeR)

        if neuron_modulation_type == 'none':
            self.phase_parameter = torch.zeros(12, 2).to(device)
            self.amplitude_parameter = torch.zeros(12, 2).to(device)
            self.amplitudeR = torch.tensor(0)

        elif neuron_modulation_type == 'phase':
            self.phase_parameter = nn.Parameter(torch.normal(mean=0, std=weight_std, size=(12, 2), requires_grad=True))
            self.amplitude_parameter = torch.zeros(12, 2).to(device)
            self.amplitudeR = torch.tensor(0)
            
        elif neuron_modulation_type == 'amplitude':
            self.phase_parameter = torch.zeros(12, 2).to(device)
            self.amplitude_parameter = nn.Parameter(torch.normal(mean=0, std=weight_std, size=(12, 2), requires_grad=True))
        
        elif neuron_modulation_type == 'both':
            self.phase_parameter = nn.Parameter(torch.normal(mean=0, std=weight_std, size=(12, 2), requires_grad=True))
            self.amplitude_parameter = nn.Parameter(torch.normal(mean=0, std=weight_std, size=(12, 2), requires_grad=True))

    def forward(self, input, ctx_input):
        orignal_phase = self.orignal_phase.to(device=device, dtype=torch.get_default_dtype())
        phaseR = (1/self.phaseR).to(device=device, dtype=torch.get_default_dtype())
        amplitudeR = self.amplitudeR.to(device=device, dtype=torch.get_default_dtype())
        mvn = torch.empty(input.shape[0], 12).to(device=device, dtype=torch.get_default_dtype())
        sigma = self.sigma.to(device=device, dtype=torch.get_default_dtype())

        for ctx_i in range(2):
            ctx_index = ctx_input[:, 0] == ctx_i
            ctx_resp = input[ctx_index, :]
            para_ctx = torch.tensor([[0+ctx_i], [1-ctx_i]]).to(device=device, dtype=torch.get_default_dtype())
            # phase modulation
            delta_phase = phaseR*2*(torch.sigmoid(self.phase_parameter@para_ctx)-0.5)*self.pi # phaseR * Â± pi
            phase = orignal_phase.unsqueeze(1) + delta_phase
            a = torch.cos(phase)
            b = torch.sin(phase)
            mu = torch.stack((a,b),axis=1).squeeze()
            
            # amplitude modulation
            resp_amplitude = self.amplitude_parameter@para_ctx
            aa = 0.202 + torch.abs(a*resp_amplitude)
            bb = 0.202 + torch.abs(b*resp_amplitude)
            sigma[:, 0, 0] = aa.squeeze()
            sigma[:, 1, 1] = bb.squeeze()

            batch_size = ctx_resp.shape[0]
            # num_components = mu.shape[0]  # Assuming 12 components
            # Ensure all tensors are on the same device and dtype
            # input = input.to(device=device, dtype=torch.get_default_dtype())
            # mu = mu.to(device=device, dtype=torch.get_default_dtype())
            # sigma = sigma.to(device=device, dtype=torch.get_default_dtype())

            # Expand input and mu for broadcasting
            # input: (batch_size, 1, features)
            # mu: (1, num_components, features)
            input_expanded = ctx_resp.unsqueeze(1)  # (batch_size, 1, features)
            mu_expanded = mu.unsqueeze(0)        # (1, num_components, features)
           
            # Compute differences for all batch samples and components at once
            diff = input_expanded - mu_expanded  # (batch_size, num_components, features)
            d = diff.shape[-1]
            # Compute Cholesky decomposition for all components
            L = torch.cholesky(sigma, upper=False)  # (num_components, features, features)
            # Expand L for batch operations
            L_expanded = L.unsqueeze(0).expand(batch_size, -1, -1, -1)  # (batch_size, num_components, features, features)
            # Prepare diff for cholesky_solve
            diff_unsqueezed = diff.unsqueeze(-1)  # (batch_size, num_components, features, 1)
            # Solve the linear system for all samples and components
            sol = torch.cholesky_solve(diff_unsqueezed, L_expanded, upper=False)  # (batch_size, num_components, features, 1)
            # Compute Mahalanobis distance
            maha = torch.sum(diff_unsqueezed.squeeze(-1) * sol.squeeze(-1), dim=-1)  # (batch_size, num_components)
            # Compute log determinant
            log_det = 2.0 * torch.sum(torch.log(torch.diagonal(L, dim1=-2, dim2=-1)), dim=-1)  # (num_components,)
            log_det_expanded = log_det.unsqueeze(0).expand(batch_size, -1)  # (batch_size, num_components)
            # Compute normalization constant
            log_2pi = torch.log(torch.tensor(2.0 * torch.pi, dtype=diff.dtype, device=device))
            log_norm = -0.5 * (d * log_2pi + log_det_expanded)  # (batch_size, num_components)
            # Compute log probability density
            log_pdf = log_norm - 0.5 * maha  # (batch_size, num_components)
            # Convert to probability and apply scaling factor
            ctx_mvn = torch.exp(log_pdf) / 0.7879  # (batch_size, num_components)
            mvn[ctx_index, :] = ctx_mvn

        amplitude_modulation = torch.sigmoid(self.amplitude_parameter)
        phase_modulation = phaseR*2*(torch.sigmoid(self.phase_parameter)-0.5)*self.pi

        return mvn, amplitude_modulation, phase_modulation

In [20]:
class GetLoader(torch.utils.data.Dataset):
    def __init__(self, data_root, data_label):
        self.data = data_root
        self.label = data_label

    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return data, labels

    def __len__(self):
        return len(self.data)

In [21]:
# Define the Neural Network  
class SimpleNN(nn.Module):  
    def __init__(self, hiddenN, weight_std, neuron_modulation_type, phaseR, amplitudeR):  
        super(SimpleNN, self).__init__()  
        self.Gaussian_neuron = Gaussian_af(weight_std, neuron_modulation_type, phaseR, amplitudeR)
        self.fc1 = nn.Linear(14, hiddenN)
        self.bn = nn.BatchNorm1d(14)
        self.relu = nn.ReLU() 
        self.fc2 = nn.Linear(hiddenN, 1)  
        # self.sigmoid = nn.Sigmoid() 
        self.softmax = nn.Softmax(dim=-1)
        # Initialize parameters with Gaussian distribution  
        self._initialize_weights(weight_std)  

    def _initialize_weights(self, weight_std):  
        for m in self.modules():  
            if isinstance(m, nn.Linear):  
                nn.init.normal_(m.weight, mean=0.0, std=weight_std)  # Gaussian initialization  
                if m.bias is not None:  
                    nn.init.constant_(m.bias, 0)  # Initialize bias to zero  

    def forward(self, input):
        vtc_input = input[:, :2]
        ctx_input = input[:, 2:4]
        vtc_input, amplitude, phase =  self.Gaussian_neuron(vtc_input, ctx_input)
        x = torch.cat([vtc_input, ctx_input], dim=1)

        x = self.fc1(x) 
        hidden_layer = x
        x = self.relu(x)  
        hidden_layer_relu = x
        x = self.fc2(x) 
        output_layer = x
        # x = self.sigmoid(x)  
        # x = self.softmax(x)
        return x, vtc_input, hidden_layer, hidden_layer_relu, output_layer, amplitude, phase

In [13]:
# Training function for single-sample SGD
def train_model_single_sample(model, criterion, optimizer, scheduler, train_datas, X_test, model_save_path, model_name):  
    sub_train_loss = [] 
    amplitude_training = []
    phase_training = []

    prediction_mat = []
    hidden_layer_mat=[]
    hidden_layer_relu_mat=[]
    output_layer_mat = []
    amplitude_mat = []
    phase_mat = []
    vtc_input_mat = []
    
    model.to(device)
    model.train() 
    for i, data in enumerate(train_datas): 
        optimizer.zero_grad()  
        output, _, _, _, _, amplitude, phase = model(data[0].to(device))  # Use a single sample  
        
        loss = criterion(output, data[1].to(device))  # Corresponding target  
        loss.backward()  
        optimizer.step() 
        sub_train_loss.append(loss.item())
        amplitude_training.append(amplitude.cpu().detach().numpy())
        phase_training.append(phase.cpu().detach().numpy())
      # scheduler.step()

        if (i+1)%100 == 0:
            # Model test    
            model.eval()
            with torch.no_grad():
                prediction, vtc_input, hidden_layer, hidden_layer_relu, output_layer, amplitude, phase= model(X_test)  # Use a single sample  
                prediction_mat.append(prediction.cpu().numpy())
                hidden_layer_mat.append(hidden_layer.cpu().numpy())
                hidden_layer_relu_mat.append(hidden_layer_relu.cpu().numpy())
                output_layer_mat.append(output_layer.cpu().numpy())
                amplitude_mat.append(amplitude.cpu().numpy())
                phase_mat.append(phase.cpu().numpy())
                vtc_input_mat.append(vtc_input.cpu().numpy())

                # mlp_weight = []
                # state_dict = model.state_dict()
                # for key, value in state_dict.items():
                #     if key[:2]=='fc':
                #         mlp_weight.append(value.cpu().numpy())
                # mlp_params = os.path.join(model_save_path, model_name + '_' + str((i+1)/100)+'_params.mat')
                # io.savemat(mlp_params, {'fc1_weight': mlp_weight[0], 'fc1_bias': mlp_weight[1],'fc2_weight': mlp_weight[2], 'fc2_bias': mlp_weight[3]})
                
            model.train() 

    print(f'Trials [{i+1}], Loss: {loss.item():.4f}')  
    model.to('cpu')

    io.savemat(os.path.join(model_save_path, model_name +'_result.mat'),
        {'sub_train_loss':sub_train_loss, 'prediction':prediction_mat, 'hidden_layer':hidden_layer_mat, 'hidden_layer_relu':hidden_layer_relu_mat, 'output_layer':output_layer_mat, 
            'amplitude': amplitude_mat, 'phase': phase_mat, 'amplitude_training':amplitude_training, 'phase_training': phase_training, 'vtc_input':vtc_input_mat })  
    torch.save(model, os.path.join(model_save_path, model_name + '_model.pth'))

    return model, sub_train_loss

In [None]:
# data loading for A100
data_path = r'/Re_analysis_data_path/d010_Gaussian_MLPtraing_set.mat'
sub_data = io.loadmat(data_path)['sub_data']

X_test = io.loadmat(data_path)['test_trial']
X_test = torch.from_numpy(X_test.astype(np.float32)).to(device)

In [None]:
# Partial parameter testing
model_save_path = r'/your_save_path/Gaussian_sigma_test'

# Hyperparameters  
cc_angle = np.radians(np.arange(15, 360, step=30))
a = np.cos(cc_angle)
a = a[:, np.newaxis]
b = np.sin(cc_angle)
b = b[:, np.newaxis]
cc_coord = torch.tensor(np.concatenate((a,b),axis=1))

batch_n = 24
learning_rate = 0.001 
weight_std_list = [0.01, 0.1, 0.5, 1, 2, 3]
phaseR_list = [1, 2, 3, 6, 9, 18, 36, 72, 180] 
amplitudeR_list = [1, 0.8, 0.6, 0.4, 0.2, 0]  
modulation_list = ['none', 'phase', 'amplitude', 'both']

# amplitudeR = 0.8
# phaseR = 3
# wi = 0.01

for amplitudeR in amplitudeR_list:
    for phaseR in phaseR_list:
        for wi in weight_std_list:
            for neuron_modulation in modulation_list:
                for sub_i in range(36):
                    x_train = sub_data[sub_i,0].astype(np.float32)
                    y_train = sub_data[sub_i,2].astype(np.float32)

                    torch_data = GetLoader(x_train, y_train)
                    train_datas = DataLoader(torch_data, batch_size=batch_n, shuffle=True, drop_last=False, num_workers=0)

                    # Initialize model, criterion, and optimizer  
                    model = SimpleNN(20, wi, neuron_modulation, phaseR, amplitudeR) 
                    criterion = nn.MSELoss()
                    # optimizer = optim.Adam(model.parameters(), lr=learning_rate) 
                    optimizer = optim.SGD(model.parameters(), lr=learning_rate) 
                    #  
                    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
                    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.8)
                    
                    # Train the model using single-sample SGD  
                    model_name = 'mGaussianMLP_WI' + str(wi) + '_' + neuron_modulation + '_'+ str(phaseR) + '_' + str(amplitudeR)+ '_Sub'+("{:02d}".format(sub_i+1))
                    sub_behavior_model, sub_train_loss = train_model_single_sample(model, criterion, optimizer, scheduler, train_datas, X_test, model_save_path, model_name)  
                

Trials [4000], Loss: 2.5142
Trials [4000], Loss: 33.3029
Trials [4000], Loss: 24.1694
Trials [4000], Loss: 3.0948
Trials [4000], Loss: 43.0265
Trials [4000], Loss: 1.2950
Trials [4000], Loss: 29.9846
Trials [4000], Loss: 19.7853
Trials [4000], Loss: 1.5991
Trials [4000], Loss: 19.8555
Trials [4000], Loss: 19.0216
Trials [4000], Loss: 1.7291
Trials [4000], Loss: 17.9052
Trials [4000], Loss: 5.0041
Trials [4000], Loss: 26.7315
Trials [4000], Loss: 3.6840
Trials [4000], Loss: 27.9342
Trials [4000], Loss: 3.7098
Trials [4000], Loss: 2.9940
Trials [4000], Loss: 24.2873
Trials [4000], Loss: 27.5901
Trials [4000], Loss: 2.1368
Trials [4000], Loss: 2.2156
Trials [4000], Loss: 21.0818
Trials [4000], Loss: 4.4068
Trials [4000], Loss: 2.0076
Trials [4000], Loss: 14.4292
Trials [4000], Loss: 14.5304
Trials [4000], Loss: 1.4757
Trials [4000], Loss: 2.3089
Trials [4000], Loss: 7.6529
Trials [4000], Loss: 13.1612
Trials [4000], Loss: 1.0458
Trials [4000], Loss: 2.7516
Trials [4000], Loss: 18.3765
Tri