# imports

In [2]:
import os, importlib

import torch
import torch.nn as nn
import torch.optim as optim
import torch.onnx

from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split

from train import *
from grow_train import *
from model import *

from torchsummary import summary
from torchviz import make_dot

import pdfkit
import matplotlib.pyplot as plt

class MNISTDataset(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.dataset = datasets.MNIST(root=root, train=train, download=True, transform=transform)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        return image, label

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    transform = transforms.Compose([
        transforms.ToTensor(),  # 텐서로 변환
        transforms.Normalize((0.5,), (0.5,))
    ])

    full_dataset = MNISTDataset(root='./data', train=True, transform=transform)

    train_size = int(0.7 * len(full_dataset))
    val_size = int(0.1 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size
    print(train_size, val_size, test_size)
    train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])


    num_epochs = 10
    BATCH_SIZE = 16
    learning_rate = 0.0001

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    models = [
        "CNN_MLP",
        "CNN_ReLUKAN_grow",
        "CNN_ReLUKAN_nogrow",
        "CNN_MLP_grow",
        "CNN_MLP_nogrow",

        "CNN_ReLUKAN",
        "CNN_SiLUKAN",
        "DenseMLP",
        "DenseReLUKAN",
        "DenseSiLUKAN",
        "ReLUKAN_Conv_MLP",
        "ReLUKAN_Conv_ReLUKAN",
        "Spiking_DenseMLP",
        "Spiking_DenseReLUKAN",
        "ViT"
    ]

    save_dict = "./summary"
    for model_name in models:
        print(model_name)
        if "grow" in model_name:
            max_count = 3
            module = importlib.import_module(f"model.{model_name}")
            model_class = getattr(module, model_name)
            model = model_class(max_count).to(device)
        else:
            module = importlib.import_module(f"model.{model_name}")
            model_class = getattr(module, model_name)
            model = model_class().to(device)
            optimizer = optim.Adam(model.parameters(), lr=learning_rate)
            
        params = model.state_dict()

        dummy_data = torch.empty(16, 1, 28, 28, dtype = torch.float32).to(device)

        torch.onnx.export(model, dummy_data, f"{save_dict}/{model_name}.onnx")

42000 6000 12000
CNN_MLP
CNN_ReLUKAN_grow


  x = x.reshape((len(x), 1, self.g + self.k, self.input_size))
  x = x.reshape((len(x), self.output_size))


CNN_ReLUKAN_nogrow
CNN_MLP_grow
CNN_MLP_nogrow
CNN_ReLUKAN
CNN_SiLUKAN


  x = x.reshape((len(x), 1, self.g + self.k, self.input_size))
  x = x.reshape((len(x), self.output_size))


DenseMLP
DenseReLUKAN
DenseSiLUKAN
ReLUKAN_Conv_MLP
ReLUKAN_Conv_ReLUKAN




Spiking_DenseMLP




Spiking_DenseReLUKAN
ViT


In [None]:
import os, importlib

import torch
import torch.nn as nn
import torch.optim as optim
import torch.onnx
from torchinfo import summary

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    models = [
        "CNN_MLP",
        "CNN_ReLUKAN_grow",
        "CNN_ReLUKAN_nogrow",
        "CNN_MLP_grow",
        "CNN_MLP_nogrow",

        "CNN_ReLUKAN",
        "CNN_SiLUKAN",
        "DenseMLP",
        "DenseReLUKAN",
        "ReLUKAN_Conv_MLP",
        "ReLUKAN_Conv_ReLUKAN",
        "ViT"
    ]

    save_dict = "./summary"
    for model_name in models:
        print(model_name)
        if "grow" in model_name:
            max_count = 3
            module = importlib.import_module(f"model.{model_name}")
            model_class = getattr(module, model_name)
            model = model_class(max_count).to(device)
        else:
            module = importlib.import_module(f"model.{model_name}")
            model_class = getattr(module, model_name)
            model = model_class().to(device)
            
        params = model.state_dict()

        dummy_data_1 = torch.empty(16, 1, 28, 28, dtype = torch.float32).to(device)
        ss = str(summary(model, [(16, 1, 28, 28)]))
        
        text_file = open(f"./torchinfo/{model_name}.txt", "w")
        text_file.write(ss)
        text_file.close()


CNN_MLP
CNN_ReLUKAN_grow
CNN_ReLUKAN_nogrow
CNN_MLP_grow
CNN_MLP_nogrow
CNN_ReLUKAN
CNN_SiLUKAN
DenseMLP
DenseReLUKAN
DenseSiLUKAN
ReLUKAN_Conv_MLP
ReLUKAN_Conv_ReLUKAN
Spiking_DenseMLP
Spiking_DenseReLUKAN
ViT


In [17]:
output_spikes[0]

tensor([-0.0206,  0.0456, -0.0606, -0.0844, -0.0785, -0.0943,  0.0374,  0.0430,
        -0.0038, -0.0741], device='cuda:0')