In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR,MultiStepLR
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision
from pylab import *
from dataclasses import dataclass
import math
import matplotlib as mpl
from datetime import date
import dill

today = date.today()
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  #  Specify which GPUs should be visible and available to PyTorch

In [21]:
# whether to generate new dataset or load an existing one
LOAD_DATASET = True

In [3]:
# into the spiking MLP after each time step
criterion = nn.CrossEntropyLoss();  # because it is a classification

# Quantization F

In [4]:
from collections import namedtuple
import torch
import torch.nn as nn
import math

QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])

def calcScaleZeroPoint(min_val, max_val,num_bits=8):
  # Calc Scale and zero point of next
  qmin = 0.
  qmax = 2.**num_bits - 1.

  scale = (max_val - min_val) / (qmax - qmin)

  initial_zero_point = qmin - min_val / scale

  zero_point = 0
  if initial_zero_point < qmin:
      zero_point = qmin
  elif initial_zero_point > qmax:
      zero_point = qmax
  else:
      zero_point = initial_zero_point

  zero_point = int(zero_point)

  return scale, zero_point

def calcScaleZeroPointSym(min_val, max_val,num_bits=8):

  # Calc Scale
  max_val = max(abs(min_val), abs(max_val))
  qmin = 0.
  qmax = 2.**(num_bits-1) - 1.

  scale = max_val / qmax

  return scale, 0

def quantize_tensor(x, num_bits=8, min_val=None, max_val=None):

    if not min_val and not max_val:
      min_val, max_val = x.min(), x.max()

    qmin = 0.
    qmax = 2.**num_bits - 1.

    scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits)
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    q_x = q_x.round().byte()

    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)

def dequantize_tensor(q_x):
    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)

# num_bits=4 means 3 bits for positive and negative side respectively
def quantize_tensor_sym(x, num_bits=4, min_val=None, max_val=None):

    if not min_val and not max_val or x.max() <= max_val * 0.9:
      min_val, max_val = x.min(), x.max()

    max_val = max(abs(min_val), abs(max_val))
    qmin = 0.
    qmax = 2.**(num_bits-1) - 1.

    scale = max_val / qmax

    q_x = x/scale

    q_x.clamp_(-qmax, qmax).round_()
    q_x = q_x.round()
    return QTensor(tensor=q_x, scale=scale, zero_point=0)

def dequantize_tensor_sym(q_x):
    return q_x.scale * (q_x.tensor.float())

def exp_scaler(n_bits, a, b, c, max_val):
    x_max = (math.log(max_val / a) - c) / b
    unit_x = x_max / (2. ** n_bits - 1)
    x_ticks = torch.tensor([unit_x * item for item in range(1, int(2. ** n_bits))])
    y_ticks = a * torch.exp(b * x_ticks + c)

    y_range = torch.zeros(len(y_ticks)) 

    # Ensure y_range has the correct size
    y_range[1:] = (y_ticks[:-1] + y_ticks[1:]) / 2  # Midpoints
    y_range[0] = y_ticks[0] / 2  # Lower bound
    return y_range, y_ticks

def exp_quantizer(x, n_bits, max_val, a, b, c):
    device = x.device
    y_range, y_ticks = exp_scaler(n_bits, a, b, c, max_val)

    y_range = y_range.to(device)
    y_ticks = y_ticks.to(device)
    
    abs_x = torch.abs(x)
    sign = torch.sign(x)

    # Find bin index where abs_x falls in y_range
    idx = torch.searchsorted(y_range, abs_x) - 1  # Find the index in y_range
    idx = torch.clamp(idx, 0, len(y_ticks) - 1)  # Ensure valid index

    # Assign quantized values
    q_x = torch.where(abs_x < y_range[0], torch.tensor(0.0, device=x.device), y_ticks[idx])
    return q_x * sign  # Restore the sign to the output

def exp_quantizer_probabilistic(x, n_bits, max_val, a, b, c, base_temp=0.02, scale_factor=0.5):
    device = x.device
    y_range, y_ticks = exp_scaler(n_bits, a, b, c, max_val)

    y_range = y_range.to(device)
    y_ticks = y_ticks.to(device)

    abs_x = torch.abs(x)
    sign = torch.sign(x)

    # **Prepend 0 to y_ticks**
    y_ticks = torch.cat([torch.tensor([0.0], device=device), y_ticks])

    # **Normalize y_ticks by its max value**
    max_y_ticks = y_ticks.max()
    normalized_y_ticks = y_ticks / max_y_ticks  # Now in range [0, 1]

    # **Adaptive temperature based on normalized y_ticks**
    adaptive_temperature = base_temp * (scale_factor + normalized_y_ticks)

    # Compute distances between abs_x and each y_tick
    distances = torch.abs(abs_x.unsqueeze(-1) - y_ticks)  # Shape: [batch_size, num_ticks]

    # Convert distances to probabilities using the adaptive temperature
    probs = torch.exp(-distances / adaptive_temperature)  # Exponential decay
    probs = probs / probs.sum(dim=-1, keepdim=True)  # Normalize

    # Sample from the probability distribution
    idx = torch.multinomial(probs, 1).squeeze(-1)  # Get sampled indices

    # Assign quantized values
    q_x = y_ticks[idx]

    return q_x * sign  # Restore sign

class FakeQuantOp_exp_probabilistic(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, base_temp, scale_factor, num_bits=3, min_val=None, max_val=None):
        x = exp_quantizer_probabilistic(x, num_bits, max_val, a=0.01, b=0.5, c=0, base_temp=base_temp, scale_factor=scale_factor)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None

class FakeQuantOp_exp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, num_bits=3, min_val=None, max_val=None):
        x = exp_quantizer(x, num_bits, max_val, a=0.01, b=0.5, c=0)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None
    
class FakeQuantOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, num_bits=3, min_val=None, max_val=None):
        x = quantize_tensor_sym(x, num_bits=num_bits, min_val=min_val, max_val=max_val)
        x = dequantize_tensor_sym(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None

# Neuron MODEL

In [29]:
def gaussian(x, mu=0., sigma=.5):
    return torch.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / torch.sqrt(2 * torch.tensor(math.pi)) / sigma

class ActFun(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):  # input = membrane potential- threshold
        ctx.save_for_backward(input)
        return input.gt(0).float()  # is firing ???

    @staticmethod
    def backward(ctx, grad_output):  # approximate the gradients
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        # temp = abs(input) < lens
        scale = 15.0
        hight = .15
        #temp = torch.exp(-(input**2)/(2*lens**2))/torch.sqrt(2*torch.tensor(math.pi))/lens
        temp = gaussian(input, mu=0., sigma=lens) * (1. + hight) \
                - gaussian(input, mu=lens, sigma=scale * lens) * hight \
                - gaussian(input, mu=-lens, sigma=scale * lens) * hight
        #print("gamma in ActFun:backward", gamma)
        return grad_input * temp.float() * gamma

act_fun = ActFun.apply

### NEURON MODEL
def mem_update_adp(x, mem, thr):
    # raise Exception
    mem = decay_neu * mem + x
    inputs_ = mem - thr
    spike = act_fun(inputs_)  # (inputs_ > 0).float()
    mem = mem * (1 - spike)
    negative_ = (mem < 0).float()  # Check all the negative elements of mem and convert into floating number (1.0 or 0.0)
    mem = (mem * (1 - negative_)) - (0 * negative_)
    return mem, spike

# mlpsnn Def

In [6]:
class mlpsnn(nn.Module):
    def __init__(self, n_inputs, hidden_dim, n_outputs):
        super(mlpsnn, self).__init__()

        self.n_inputs = n_inputs
        self.hidden_size = hidden_dim
        self.output_size = n_outputs

        # nn is from torch library
        self.i2h = nn.Linear(self.n_inputs, self.hidden_size, bias=False)
        self.h2o = nn.Linear(self.hidden_size, self.output_size, bias=False)

        self.thr_i = nn.Parameter(torch.Tensor(self.n_inputs))  # , requires_grad=True) #learn threshold
        self.thr_h = nn.Parameter(torch.Tensor(self.hidden_size))  # , requires_grad=True) #learn threshold
        self.thr_o = nn.Parameter(torch.Tensor(self.output_size))  # , requires_grad=True) #learn threshold

        nn.init.xavier_uniform_(self.i2h.weight)
        nn.init.xavier_uniform_(self.h2o.weight)

        nn.init.constant_(self.thr_h, 1.0)
        nn.init.constant_(self.thr_o, 1.0)
        nn.init.constant_(self.thr_i, 1.0)

    def forward(self, input):
        # init
        input_mem = torch.zeros(self.n_inputs)
        hidden_mem =  torch.zeros(self.hidden_size)
        output_mem =  torch.zeros(self.output_size)
        hidden_spike = torch.zeros(self.hidden_size)
        output_spike = torch.zeros(self.output_size)

        # Keep track on the spike train of each layer
        input_spike_train = []
        hidden_spike_train = []
        output_spike_train = []

        # Feed in the whole sequence
        batch_size, seq_num, input_dimx, input_dimy = input.shape

        loss_h = Variable(torch.Tensor([0]), requires_grad=True)

        output_ = []
        I_h = []
        predictions = []

        output_spike_sum = torch.zeros(batch_size, self.output_size).to(device)

        for this_t in range(input_dimy):

            #input organization
            input_x = torch.zeros([batch_size, input_dimx])
            input_x = input[:, 0, :, ((this_t)% input_dimy)]

            #################   update states  #########################
            input_mem, i_input = mem_update_adp(input_x.to(device), input_mem.to(device), self.thr_i)                                      # The first layer neuron only does activation.....
            input_spike_train.append(i_input.detach().cpu().view(1, -1))
            h_input = self.i2h(i_input.float())
            hidden_mem, hidden_spike = mem_update_adp(h_input.to(device),hidden_mem.to(device), self.thr_h)
            hidden_spike_train.append(hidden_spike.detach().cpu().view(1, -1))

            output_inputs = self.h2o(hidden_spike.to(device))
            output_mem, output_spike = mem_update_adp(output_inputs.to(device), output_mem.to(device), self.thr_o)
            output_spike_train.append(output_spike.detach().cpu().view(1, -1))

            output_spike_sum[:, :] = output_spike_sum[:, :].to(device) + output_spike.to(device)

        input_spike_train = torch.cat(input_spike_train, dim=0)
        hidden_spike_train = torch.cat(hidden_spike_train, dim=0)
        output_spike_train = torch.cat(output_spike_train, dim=0)

        #################   return spikes sum  #########################
        return output_spike_sum     #, input_spike_train, hidden_spike_train, output_spike_train

# lr scaling (for exp quantization)

In [7]:
import torch
import torch.optim as optim
import torch.nn as nn

def adjusted_sigmoid(x, x_target=0.6):
    # Solve for s to ensure sigmoid(x_target) = 0.95
    center = x_target / 2
    s = -torch.log(torch.tensor(1/0.95 - 1)) / (x_target - center)
    return 1 / (1 + torch.exp(-s * (torch.abs(x) - center)))

def cal_lr(x, max_val, min_lr=0.01, max_lr=0.5):
    return min_lr + (max_lr - min_lr) * adjusted_sigmoid(x, x_target=max_val)

def pre_update(model, max_val, min, max):
    for i, param in enumerate(model.parameters()):
        if i >= 2:
            lr_matrice = cal_lr(param, max_val, min_lr=min, max_lr=max)
            param.grad.mul_(lr_matrice) 
            #print(f'param is : {param}')
            #print(f'Learning rate is : {lr_matrice}')
            #print(f'new_matrice gradient is : {param.grad}')

# Normalizing initial weight matrices

In [8]:
def normalize_tensor_safe(tensor, min_val, max_val, epsilon=1e-6):
    X_min, X_max = tensor.min(), tensor.max()
    return min_val + (tensor - X_min) * (max_val - min_val) / (X_max - X_min + epsilon)

# Train/Testing F

In [9]:
def test1(config, model, test_loader, odorant_idx):
    test_acc = []
    for i, (traces, labels) in enumerate(test_loader):
        batch_size = traces.shape[0]
        traces = traces * -1 * config['amp']
        traces = traces.unsqueeze(1)
        #examples = enumerate(test_loader)
        #batch_idx, (x_emp, y_emp) = next(examples)
        traces = torch.tensor(traces).to(device)#view((-1, x_emp.shape[0], input_dim)).requires_grad_().to(device)
#         traces = traces.clone().detach().to(device)
#         traces = traces.clone().detach().requires_grad_(True).to(device)
        labels = labels[:, 1]
        labels = [odorant_idx.index(element) for element in labels]
        labels = torch.tensor(labels)
        labels = labels.view((-1,batch_size)).long().to(device)

        #print("Testing model")
        output_spike_sum = model.forward(traces)

        #################   classification  #########################
        output_sumspike = F.log_softmax(output_spike_sum, dim=1)
        pred_ = output_sumspike.argmax(axis=1)

        test_accuracy = (pred_.to(device) == labels.to(device)).sum().data.cpu().numpy() / float(len(pred_))

        #print('Test Acc: {:.4f}'.format(test_accuracy))
        test_acc.append(test_accuracy)
    #print('Test Accuracy final: {:.4f}'.format(np.mean(test_acc)))
    return np.mean(test_acc)

In [32]:
# into the spiking MLP after each time step
criterion =nn.CrossEntropyLoss()

def train_test_val(config, model, train_loader, test_loader, val_loader, odorant_idx, optimizer, min_val, max_val, num_epochs, criterion = criterion, n_bits = 4, gamma=0.8, min_lr=0.01, max_lr=0.5, base_temp=0.02, scale_factor=0.5):
    w_tracker = []
    w_tracker_q = []
    acc_train = [0.] * num_epochs

    if config['use_scheduler']:
        scheduler = StepLR(optimizer, step_size=2, gamma=gamma)
        
    test_acc = []
    val_acc = []
    for epoch in range(num_epochs):
        train_acc = 0.0
        train_loss_sum = 0.0
        predictions = []
        data_size = 0
        temp_test = 0
        if epoch == 0:
            temp_val = test1(config, model, val_loader, odorant_idx)
            print(f'Validation acc before training: {temp_val}')
            val_acc.append(temp_val)
            temp_test = test1(config, model, test_loader, odorant_idx)
            print(f'Test acc before training: {temp_test}')
            test_acc.append(temp_test)
        
        for i, (traces, labels) in enumerate(train_loader):
            batch_size = traces.shape[0]
            data_size = data_size + batch_size
            traces = traces * -1 * config['amp']
            traces = traces.unsqueeze(1)
            traces = traces.to(device)
            labels = labels[:, 1]
            labels = [odorant_idx.index(element) for element in labels]
            labels = torch.tensor(labels)
            labels = labels.view((-1,batch_size)).long().to(device)

            # Clear gradients w.r.t. parameters
            optimizer.zero_grad()
            w_tracker.append(torch.cat([model.i2h.weight.data.flatten(), model.h2o.weight.data.flatten()]).tolist())

            if config['use_quantization']:
                i2h_weight = model.i2h.weight.data
                h2o_weight = model.h2o.weight.data

                # Stack both weight tensors along a new dimension
                stacked_weights = torch.cat([model.i2h.weight.data.flatten(), model.h2o.weight.data.flatten()])

                # Quantize them together
                if config['quantization_type'] == "exponential":
                    quantized_weights = FakeQuantOp_exp_probabilistic.apply(stacked_weights, base_temp, scale_factor, n_bits, min_val, max_val)
                if config['quantization_type'] == "linear":
                    quantized_weights = FakeQuantOp.apply(stacked_weights, n_bits, min_val, max_val)

                # Split them back
                model.i2h.weight.data = quantized_weights[:i2h_weight.numel()].reshape(i2h_weight.shape)
                model.h2o.weight.data = quantized_weights[i2h_weight.numel():].reshape(h2o_weight.shape)

            w_tracker_q.append(torch.cat([model.i2h.weight.data.flatten(), model.h2o.weight.data.flatten()]).tolist())

            output_spike_sum = model.forward(traces) # binary input spikes        

            #################   classification  #########################
            output_sumspike = output_spike_sum #F.log_softmax(output_spike_sum, dim=1)
            pred_ = output_sumspike.argmax(axis=1)
            predictions.append(pred_.data.cpu().numpy())
            loss_h = criterion(output_sumspike.to(device), labels[0].to(device))

            # Getting gradients w.r.t. parameters
            loss_h.sum().backward()         #retain_graph=True
            train_loss_sum += loss_h.detach().cpu().numpy()

            # Manually update parameters
            # pre_update(model, max_val, min=min_lr, max=max_lr)           
            optimizer.step()
            predicted = pred_.t()
            train_acc += (predicted == labels).sum()
        
        if config['use_quantization']:
            # Stack both weight tensors along a new dimension
            stacked_weights = torch.cat([model.i2h.weight.data.flatten(), model.h2o.weight.data.flatten()])
            # Quantize them together
            if config['quantization_type'] == "exponential":
                quantized_weights = FakeQuantOp_exp_probabilistic.apply(stacked_weights, base_temp, scale_factor, n_bits, min_val, max_val)
            if config['quantization_type'] == "linear":
                quantized_weights = FakeQuantOp.apply(stacked_weights, n_bits, min_val, max_val)
            # Split them back
            model.i2h.weight.data = quantized_weights[:i2h_weight.numel()].reshape(i2h_weight.shape)
            model.h2o.weight.data = quantized_weights[i2h_weight.numel():].reshape(h2o_weight.shape)

        temp_val = test1(config, model, val_loader, odorant_idx)
        val_acc.append(temp_val)
        temp_test = test1(config, model, test_loader, odorant_idx)
        test_acc.append(temp_test)

        train_acc_np = train_acc.data.cpu().numpy()
        acc_train[epoch] = train_acc_np/data_size
        if config['use_scheduler']:
            scheduler.step()
        print('epoch: {:3d}, Train Loss: {:.4f}, Train Acc: {:.4f}, Current val_accuracy: {:.4f}, Current test_accuracy: {:.4f}'.format(epoch,
                                                                           train_loss_sum.item()/data_size,
                                                                           train_acc_np/data_size,
                                                                           temp_val,
                                                                           temp_test))
    return test_acc[val_acc.index(max(val_acc[4:]), 4)], w_tracker, w_tracker_q

# Generate Dataset

In [22]:
import os
import shutil
import subprocess
import sys
import dill
import json

# Define number of datasets and the folders for each group
num_datasets = 12
groups = {
    "group1": range(0, 4),
    "group2": range(4, 8),
    "group3": range(8, 12)
}

# Directory where the datasets are saved
saved_datasets_dir = "saved_datasets/"

# Path to the new directories for groups
group_dirs = {group: os.path.join(saved_datasets_dir, group) for group in groups}

# Create group directories if they don't exist
for group_dir in group_dirs.values():
    if not os.path.exists(group_dir):
        os.makedirs(group_dir)

# Generate datasets and move them into appropriate group folders
for i in range(num_datasets):
    dataset_name = f"dataset_template_{i}.json"
    
    if LOAD_DATASET:
        print(f"Loading dataset {dataset_name}")
    else:
        print(f"Generating new dataset. This may take a while")
        # Run the subprocess commands (assumed not to change)
        subprocess.run(["python", "odor_space_analysis.py", "--params", f"saved_datasets/{dataset_name}"])
        subprocess.run([sys.executable, "generate_dataset.py", "--params", f"saved_datasets/{dataset_name}"])
    
    # Load the dataset_dict after running subprocess
    with open(f"saved_datasets/{dataset_name}") as f:
        dataset_dict = json.load(f)
    
    # Generate the filename for the generated voltage data
    voltage_file_name = f"dataset_size{dataset_dict['dataset_size']}_Nodor{dataset_dict['N_odorants']}_NOR{dataset_dict['N_OR']}_voltage.pkl"
    
    # Determine the group based on the index
    for group, idx_range in groups.items():
        if i in idx_range:
            group_dir = group_dirs[group]
            voltage_file_path = os.path.join(saved_datasets_dir, voltage_file_name)
            
            # Move the generated file to the appropriate group folder
            if os.path.exists(voltage_file_path):
                destination_file = os.path.join(group_dir, voltage_file_name)
                shutil.move(voltage_file_path, destination_file)
                print(f"Moved {voltage_file_name} to {group_dir}")



Generating new dataset. This may take a while
Loading saved_datasets/dataset_template_0.json
Loaded DoOR 2.0.1 dataset
Performing analysis for odorants with idx [693 131 164  10]:
none
methanol
propyl acetate
gamma-butyrolactone

Best receptors for N_OR = 3 :
Selected receptor idx: [16 17 18]
Selected receptors: ['Or35a' 'Or42a' 'Or42b']
Saving results to saved_datasets/dataset_template_0.json


  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_0.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 164, 10]
odorant names: ['none', 'methanol', 'propyl acetate', 'gamma-butyrolactone']
OR_idx: [16, 17, 18]
OR names: ['Or35a' 'Or42a' 'Or42b']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for 

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_1.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 164, 10]
odorant names: ['none', 'methanol', 'propyl acetate', 'gamma-butyrolactone']
OR_idx: [3, 7, 16, 17, 18, 20, 24, 40, 49]
OR names: ['Or9a' 'Or22a' 'Or35a' 'Or42a' 'Or42b' 'Or43b' 'Or47a' 'Or85b' 'Or98a']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace for odor [164]
Calculating trace f

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_2.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 164, 10]
odorant names: ['none', 'methanol', 'propyl acetate', 'gamma-butyrolactone']
OR_idx: [3, 7, 14, 16, 17, 18, 20, 21, 24, 25, 32, 40, 41, 49, 51]
OR names: ['Or9a' 'Or22a' 'Or33b' 'Or35a' 'Or42a' 'Or42b' 'Or43b' 'Or45a' 'Or47a'
 'Or47b' 'Or67a' 'Or85b' 'Or85c' 'Or98a' 'Or83c']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odo

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_3.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 164, 10]
odorant names: ['none', 'methanol', 'propyl acetate', 'gamma-butyrolactone']
OR_idx: [2, 3, 4, 5, 7, 14, 16, 17, 18, 19, 20, 21, 24, 25, 32, 37, 40, 41, 44, 49, 51]
OR names: ['Or7a' 'Or9a' 'Or10a' 'Or13a' 'Or22a' 'Or33b' 'Or35a' 'Or42a' 'Or42b'
 'Or43a' 'Or43b' 'Or45a' 'Or47a' 'Or47b' 'Or67a' 'Or74a' 'Or85b' 'Or85c'
 'Or85f' 'Or98a' 'Or83c']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace for odor [10]
Calculating trace fo

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_4.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 97, 149]
odorant names: ['none', 'methanol', 'benzyl alcohol', '2,3-butanediol']
OR_idx: [3, 33, 46]
OR names: ['Or9a' 'Or67b' 'Or92a']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating tr

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_5.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 97, 149]
odorant names: ['none', 'methanol', 'benzyl alcohol', '2,3-butanediol']
OR_idx: [3, 5, 16, 19, 20, 27, 33, 46, 50]
OR names: ['Or9a' 'Or13a' 'Or35a' 'Or43a' 'Or43b' 'Or49b' 'Or67b' 'Or92a' 'Or69a']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for od

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_6.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 97, 149]
odorant names: ['none', 'methanol', 'benzyl alcohol', '2,3-butanediol']
OR_idx: [2, 3, 5, 7, 14, 16, 19, 20, 25, 27, 32, 33, 46, 50, 51]
OR names: ['Or7a' 'Or9a' 'Or13a' 'Or22a' 'Or33b' 'Or35a' 'Or43a' 'Or43b' 'Or47b'
 'Or49b' 'Or67a' 'Or67b' 'Or92a' 'Or69a' 'Or83c']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
Calculating trace for odor [97]
C

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_7.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 97, 149]
odorant names: ['none', 'methanol', 'benzyl alcohol', '2,3-butanediol']
OR_idx: [1, 2, 3, 4, 5, 6, 7, 14, 16, 18, 19, 20, 25, 27, 32, 33, 39, 44, 46, 50, 51]
OR names: ['Or2a' 'Or7a' 'Or9a' 'Or10a' 'Or13a' 'Or19a' 'Or22a' 'Or33b' 'Or35a'
 'Or42b' 'Or43a' 'Or43b' 'Or47b' 'Or49b' 'Or67a' 'Or67b' 'Or85a' 'Or85f'
 'Or92a' 'Or69a' 'Or83c']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace for odor [149]
Calculating trace

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_8.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 132, 133]
odorant names: ['none', 'methanol', 'ethanol', '1-propanol']
OR_idx: [2, 7, 16]
OR names: ['Or7a' 'Or22a' 'Or35a']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calcula

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_9.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 132, 133]
odorant names: ['none', 'methanol', 'ethanol', '1-propanol']
OR_idx: [1, 2, 3, 7, 14, 16, 20, 25, 40]
OR names: ['Or2a' 'Or7a' 'Or9a' 'Or22a' 'Or33b' 'Or35a' 'Or43b' 'Or47b' 'Or85b']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calculating trace for odor [131]
Calcu

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_10.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 132, 133]
odorant names: ['none', 'methanol', 'ethanol', '1-propanol']
OR_idx: [1, 2, 3, 4, 6, 7, 14, 16, 19, 20, 25, 34, 39, 40, 44]
OR names: ['Or2a' 'Or7a' 'Or9a' 'Or10a' 'Or19a' 'Or22a' 'Or33b' 'Or35a' 'Or43a'
 'Or43b' 'Or47b' 'Or67c' 'Or85a' 'Or85b' 'Or85f']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]
Calculating trace for odor [132]

  from scipy.signal import butter,filtfilt


Loading saved_datasets/dataset_template_11.json

Generating dataset using DoOR 2.0.1 database

dataset_size: 50
output_dt: 0.005
downsampling factor: 100
N_ORCOs: 10
odorant_idx: [693, 131, 132, 133]
odorant names: ['none', 'methanol', 'ethanol', '1-propanol']
OR_idx: [1, 2, 3, 4, 6, 7, 14, 16, 19, 20, 24, 25, 27, 29, 31, 34, 38, 39, 40, 44, 51]
OR names: ['Or2a' 'Or7a' 'Or9a' 'Or10a' 'Or19a' 'Or22a' 'Or33b' 'Or35a' 'Or43a'
 'Or43b' 'Or47a' 'Or47b' 'Or49b' 'Or59b' 'Or65a' 'Or67c' 'Or82a' 'Or85a'
 'Or85b' 'Or85f' 'Or83c']
output: voltage
total simulation length: 0.5s
ligand binding simulation length: 0.4s
Calculating trace for odor [133]
Calculating trace for odor [133]
Calculating trace for odor [133]
Calculating trace for odor [133]
Calculating trace for odor [133]
Calculating trace for odor [133]
Calculating trace for odor [133]
Calculating trace for odor [133]
Calculating trace for odor [133]
Calculating trace for odor [133]
Calculating trace for odor [133]
Calculating trace for odo

# Run_batch F

In [26]:
############################################
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch.autograd.set_detect_anomaly(True)

# PARAMETERS FOR BACKPROP TRAINIG
gamma = .8  # gradient scale
lens = .5
decay_neu = 1.0  # no decay
#################################################

def run_batch(config): 
    num_dataset = 12
    ratio1 = 0.7                                 # for splitting the dataset into train set and test set
    ratio2 = 0.15
    hidden_dim = 50                            # number of hidden layer neurons

    # training configuration
    num_epochs = 15
    learning_rate = config['lr']
    num_iteration = 10

    # quantization parameters
    min_v = 0.0
    max_v = 0.6
    n_bits = 4 
    min_lr = 0.01
    max_lr = 0.2

    base_temp = 0.02
    scale_factor = 0.5

    # regularization parameters
    w_decay=1e-4

    result_set = []
    saved_datasets_dir = "saved_datasets/"  # The original location of dataset_template_0~11.json
    generated_datasets_dir = "saved_datasets/"  # The folder where generated files are moved (group1, group2, group3)

    # Loop through the dataset indices and load the corresponding files
    for i in range(num_dataset):
        # Determine the group and the folder where the generated dataset is stored
        for group, idx_range in groups.items():
            if i in idx_range:
                group_folder = os.path.join(generated_datasets_dir, group)
                break
        
        # Build the dataset name for the JSON template file
        dataset_name = f"dataset_template_{i}.json"
        dataset_path = os.path.join(saved_datasets_dir, dataset_name)
        
        # Load the corresponding dict from the dataset template
        with open(dataset_path) as f:
            dataset_dict = json.load(f)

        # Build the file name for the generated voltage data (assuming they are in group folders)
        voltage_file_name = f"dataset_size{dataset_dict['dataset_size']}_Nodor{dataset_dict['N_odorants']}_NOR{dataset_dict['N_OR']}_voltage.pkl"
        voltage_file_path = os.path.join(group_folder, voltage_file_name)

        print(f"Loading dataset_template_{i}.json -> {voltage_file_path} completed")

        # Load the generated dataset
        with open(voltage_file_path, 'rb') as f:
            dataset = dill.load(f)

        dataset_size_total = dataset_dict['dataset_size'] * dataset_dict['N_odorants']

        # Splitting the dataset into train, validation, and test sizes
        train_size = int(dataset_size_total * ratio1)
        val_size = int(ratio2 * dataset_size_total)
        print(f"Total number of train samples is : {train_size}")
        print(f"Total number of validation samples is : {val_size}")
        print(f"Total number of test samples is : {dataset_size_total - train_size - val_size}")

        # Extract odorant names and indices
        odorant_name = dataset_dict['odorant_names']
        odorant_idx = dataset_dict['odorant_idx']

        dt = dataset_dict['dt']
        output_dt = dataset_dict['output_dt']
        down_sampling_rate = output_dt / dt

        total_timesteps = dataset_dict['total_steps'] / down_sampling_rate
        timestamps = np.arange(0, total_timesteps * output_dt, output_dt)

        dimx = dataset_dict['N_OR'] # n_ORs
        dimy = int(total_timesteps) # number of timesteps
        n_classes = dataset_dict['N_odorants']
        print(f'dimx = {dimx}, dimy = {dimy}, n_classes = {n_classes}')

        # Calculate the batch size (1/10 of training samples)
        batch_size = int(train_size / 10)

        # Find the largest power of 2 less than or equal to the batch size
        batch_size_lower_power_of_2 = 2 ** (math.floor(math.log2(batch_size)))

        # Check if the 1/10 batch size is more than 70% larger than the lower power of 2
        if batch_size >= 1.7 * batch_size_lower_power_of_2:
            # Find the smallest power of 2 greater than or equal to the batch size
            batch_size_power_of_2 = 2 ** (math.ceil(math.log2(batch_size)))
        else:
            # Use the lower power of 2
            batch_size_power_of_2 = batch_size_lower_power_of_2

        print(f"Using batch size: {batch_size_power_of_2}")

        test_acc = []
        for j in range(num_iteration):
            w_tracker = []
            w_tracker_q = []
            temp = 0

            train_size = int(ratio1 * dataset_size_total)
            val_size = int(ratio2 * dataset_size_total)
            test_size = dataset_size_total - train_size - val_size  # Ensure all samples are used
            dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

            train_loader = DataLoader(dataset_train, batch_size = batch_size_power_of_2, shuffle=True)
            test_loader = DataLoader(dataset_test, batch_size = batch_size_power_of_2, shuffle=True)
            val_loader = DataLoader(dataset_val, batch_size = batch_size_power_of_2, shuffle=True)
            
            model = mlpsnn(n_inputs=dimx, hidden_dim=hidden_dim,  n_outputs=n_classes)
            model.to(device)
            
            base_params = [model.i2h.weight, model.h2o.weight]
            optimizer = torch.optim.Adam([
                {'params': base_params},
                ],
                lr=learning_rate,
                # weight_decay = w_decay
                )

            # Proceed with training
            i2h_weight = model.i2h.weight.data
            h2o_weight = model.h2o.weight.data

            stacked_weights_norm = torch.cat([model.i2h.weight.data.flatten(), model.h2o.weight.data.flatten()])
            stacked_weights_norm = normalize_tensor_safe(stacked_weights_norm, -1 * max_v, max_v)
            # Split them back
            model.i2h.weight.data = stacked_weights_norm[:i2h_weight.numel()].reshape(i2h_weight.shape)
            model.h2o.weight.data = stacked_weights_norm[i2h_weight.numel():].reshape(h2o_weight.shape)

            temp, w_tracker, w_tracker_q = train_test_val(config, model, train_loader, test_loader, val_loader, odorant_idx, optimizer, min_v, max_v, num_epochs, criterion=criterion, n_bits=n_bits, gamma=config['gamma'], min_lr=min_lr, max_lr=max_lr, base_temp=base_temp, scale_factor=scale_factor)            
            test_acc.append(temp)

            print(f"The {j+1} iteration completed.\n")
        print(f"Test accuracy for {num_iteration} runs are : {test_acc}")
        result_set.append([float(element) for element in test_acc])

    print(f'Final result : {result_set}')
    return result_set

# Param Sweep for optimization

In [27]:
import numpy as np

def remove_outliers_and_calculate_mean(data):
    # Assuming data is of shape (num_dataset, num_iteration)
    
    # Initialize a list to store means of each dataset after removing outliers
    dataset_means = []
    
    # Loop over each dataset
    for dataset in data:
        # Flatten the dataset across iterations
        dataset = np.concatenate([np.array(item).flatten() for item in dataset])
        
        # Calculate percentiles for the dataset
        Q1 = np.percentile(dataset, 25)  # First quartile (25th percentile)
        Q3 = np.percentile(dataset, 75)  # Third quartile (75th percentile)
        IQR = Q3 - Q1  # Interquartile range

        # Define the bounds for filtering
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR

        # Filter out the outliers
        filtered_data = dataset[(dataset >= lower_bound) & (dataset <= upper_bound)]

        # Calculate the mean of the filtered data for this dataset
        dataset_means.append(np.mean(filtered_data))
    
    # After all datasets, calculate the overall mean of all dataset means (no outlier removal here)
    all_data = np.concatenate([np.array(dataset).flatten() for dataset in data])
    
    overall_mean = np.mean(all_data)
    
    return overall_mean

In [None]:
import itertools

# Objective function for calculating the score
def objective(config):
    result = run_batch(config)
    score = remove_outliers_and_calculate_mean(result)
    return score

# Main sweep function
def run_sweep(parameter_grid):
    # Generate all combinations of parameters using itertools.product
    all_combinations = list(itertools.product(*parameter_grid.values()))
    
    # Initialize an empty list to store the results (scores)
    results = []

    # Iterate through all parameter combinations
    for combination in all_combinations:
        # Convert the combination to a dictionary
        config = dict(zip(parameter_grid.keys(), combination))

        # Print the current parameter combination being tested
        print(f"Testing configuration: {config}")

        # Calculate the score for this combination
        score = objective(config)

        # Store the config and corresponding score
        results.append({'config': config, 'score': score})

    # Sort results based on score
    results.sort(key=lambda x: x['score'], reverse=True)

    # Print or return the results (e.g., best score)
    return results

# Parameter grid for sweeping
parameter_grid = {
    "lr": [0.05, 0.06],                  #[0.005, 0.01, 0.05],
    "use_scheduler": [True],
    "use_regularizer": [False],
    "use_quantization": [True],
    "quantization_type": ["linear"],
    "use_lr_scaling": [False],
    "gamma": [0.9],
    "amp": [5000],             #[0.001e6, 0.005e6, 0.01e6, 0.08e6, 0.12e6],
}

# Run the sweep and get the results
results = run_sweep(parameter_grid)

# Output the best results
best_result = results[0]
print(f"Best Configuration: {best_result['config']}")
print(f"Best Score: {best_result['score']}")


Testing configuration: {'lr': 0.05, 'use_scheduler': True, 'use_regularizer': False, 'use_quantization': True, 'quantization_type': 'linear', 'use_lr_scaling': False, 'gamma': 0.9, 'amp': 5000}
Loading dataset_template_0.json -> saved_datasets/group1/dataset_size50_Nodor4_NOR3_voltage.pkl completed
Total number of train samples is : 140
Total number of validation samples is : 30
Total number of test samples is : 30
dimx = 3, dimy = 100, n_classes = 4
Using batch size: 16


  traces = torch.tensor(traces).to(device)#view((-1, x_emp.shape[0], input_dim)).requires_grad_().to(device)


Validation acc before training: 0.3392857142857143
Test acc before training: 0.26339285714285715
epoch:   0, Train Loss: 0.0762, Train Acc: 0.4571, Current val_accuracy: 0.7054, Current test_accuracy: 0.4955
epoch:   1, Train Loss: 0.0706, Train Acc: 0.4571, Current val_accuracy: 0.7054, Current test_accuracy: 0.5045
epoch:   2, Train Loss: 0.0709, Train Acc: 0.4571, Current val_accuracy: 0.7009, Current test_accuracy: 0.5000
epoch:   3, Train Loss: 0.0711, Train Acc: 0.4571, Current val_accuracy: 0.7098, Current test_accuracy: 0.5089
epoch:   4, Train Loss: 0.0710, Train Acc: 0.4571, Current val_accuracy: 0.6920, Current test_accuracy: 0.4955
epoch:   5, Train Loss: 0.0705, Train Acc: 0.4571, Current val_accuracy: 0.6964, Current test_accuracy: 0.4955
epoch:   6, Train Loss: 0.0711, Train Acc: 0.4571, Current val_accuracy: 0.7098, Current test_accuracy: 0.5089
epoch:   7, Train Loss: 0.0709, Train Acc: 0.4571, Current val_accuracy: 0.7098, Current test_accuracy: 0.5045
epoch:   8, Tra