In [52]:
import torch
import torch.nn as nn
import torch.optim as optim
import snntorch as snn
from snntorch import surrogate

from train_test import get_all_data_adj
from test_data_encoding_ZHU import compute_spiking_node_representation

In [53]:
data_folder = "BRCA"
view_list = [1, 2, 3]
num_class = 5

(
    data_tr_list,
    data_trte_list,
    trte_idx,
    labels_trte,
    labels_tr_tensor,
    onehot_labels_tr_tensor,
    adj_tr_list,
    adj_te_list,
    dim_list,
) = get_all_data_adj(
    data_folder,
    view_list,
    num_class,
)

In [54]:
class SimpleSNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleSNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.lif1 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.bn2 = nn.BatchNorm1d(output_size)
        self.lif2 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())
        
        # Weight initialization
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        spk_rec = []
        mem_rec = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        for step in range(x.size(0)):
            cur1 = self.fc1(x[step])
            cur1 = self.bn1(cur1)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            cur2 = self.bn2(cur2)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk_rec.append(spk2)
            mem_rec.append(mem2)

        return torch.stack(spk_rec), torch.stack(mem_rec)


In [55]:
K = 3
num_steps = 100

H_encoded_tr_view1 = compute_spiking_node_representation(
    data_tr_list[0], adj_tr_list[0], K, num_steps
)
print(H_encoded_tr_view1.shape, type(H_encoded_tr_view1))

H_encoded_te_view1 = compute_spiking_node_representation(
    data_trte_list[0], adj_te_list[0], K, num_steps
)
print(H_encoded_te_view1.shape)

torch.Size([100, 612, 1000]) <class 'torch.Tensor'>
torch.Size([100, 875, 1000])


In [56]:
print(labels_tr_tensor.shape)
labels_trte_tensor = torch.tensor(labels_trte)
print(labels_trte_tensor.shape)

torch.Size([612])
torch.Size([875])


In [57]:
inputs=H_encoded_tr_view1
targets=labels_tr_tensor

input_size = inputs.shape[2]  # input features
hidden_size = 50  # neurons in the hidden layer
output_size = 5  # output classes
num_steps = inputs.shape[0]  # time steps
num_samples = inputs.shape[1]  # samples
num_epochs = 100  
batch_size = 32  

snn = SimpleSNN(input_size, hidden_size, output_size)
optimizer = optim.Adam(snn.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

In [58]:
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs, _ = snn(inputs)
    outputs = outputs.mean(dim=0)  # average over time steps

    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}')

Epoch 10/100, Loss: 1.1341798305511475
Epoch 20/100, Loss: 1.069413423538208
Epoch 30/100, Loss: 1.0412991046905518
Epoch 40/100, Loss: 1.0281991958618164
Epoch 50/100, Loss: 1.0180375576019287
Epoch 60/100, Loss: 1.0075780153274536
Epoch 70/100, Loss: 0.9971830248832703
Epoch 80/100, Loss: 0.9871007204055786
Epoch 90/100, Loss: 0.9723661541938782
Epoch 100/100, Loss: 0.9645165205001831


In [59]:
for name, param in snn.named_parameters():
    if param.grad is not None:
        print(f"{name}: {param.grad.abs().mean().item()}")

fc1.weight: 1.8895843822974712e-06
fc1.bias: 6.611375401137376e-13
bn1.weight: 1.933348539751023e-05
bn1.bias: 2.1655901946360245e-05
fc2.weight: 7.120329246390611e-05
fc2.bias: 1.8616042735120075e-11
bn2.weight: 0.0008424957050010562
bn2.bias: 0.0008054388454183936


In [60]:
# eval on trte data
test_inputs=H_encoded_te_view1
test_targets=labels_trte_tensor

outputs, _ = snn(test_inputs)
outputs = outputs.mean(dim=0)  # average over time steps
_, predictions = torch.max(outputs, 1)

test_accuracy = (predictions == test_targets).float().mean().item()
print("Testing Accuracy:", test_accuracy)

# sample
print(predictions[:10])
print(test_targets[:10])
# print(outputs[:10])

print(outputs.shape)

from sklearn.metrics import f1_score
#f1 weighted
f1 = f1_score(test_targets, predictions, average='weighted')
print(f1)
#f1 macro
f1 = f1_score(test_targets, predictions, average='macro')
print(f1)



Testing Accuracy: 0.8777142763137817
tensor([2, 3, 0, 3, 1, 3, 3, 1, 1, 3])
tensor([2, 4, 0, 3, 1, 3, 3, 1, 1, 3])
torch.Size([875, 5])
0.8856865086943118
0.8670578647077098
