In [2]:
pip install torch torchvision tqdm scipy

Note: you may need to restart the kernel to use updated packages.


In [3]:
### BesselTorch_rbf

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselTorch_rbf(self, distances):
        return torch.special.bessel_j0(self.alpha * distances)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselTorch_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:04<00:00, 209.50it/s, accuracy=0.438, loss=1.67] 


Epoch 1, Train Loss: 1.8678033762395, Train Accuracy: 0.29907382729211085, Val Loss: 1.695045542565121, Val Accuracy: 0.34683519108280253


100%|██████████| 938/938 [00:04<00:00, 216.75it/s, accuracy=0.438, loss=1.6] 


Epoch 2, Train Loss: 1.5371385780987201, Train Accuracy: 0.45535714285714285, Val Loss: 1.3470757675778335, Val Accuracy: 0.5703622611464968


100%|██████████| 938/938 [00:04<00:00, 220.01it/s, accuracy=0.469, loss=1.16] 


Epoch 3, Train Loss: 1.2356506009091701, Train Accuracy: 0.5963152985074627, Val Loss: 1.2474449166826382, Val Accuracy: 0.5053742038216561


100%|██████████| 938/938 [00:04<00:00, 218.47it/s, accuracy=0.719, loss=1.08] 


Epoch 4, Train Loss: 1.0334866963215728, Train Accuracy: 0.6733575426439232, Val Loss: 0.9161583796428268, Val Accuracy: 0.7369625796178344


100%|██████████| 938/938 [00:04<00:00, 211.42it/s, accuracy=0.75, loss=0.805] 


Epoch 5, Train Loss: 0.9043250197032368, Train Accuracy: 0.7179670842217484, Val Loss: 0.8497062200193952, Val Accuracy: 0.7236265923566879


100%|██████████| 938/938 [00:04<00:00, 217.49it/s, accuracy=0.812, loss=0.766]


Epoch 6, Train Loss: 0.8248198475919044, Train Accuracy: 0.7448527452025586, Val Loss: 0.791245003035114, Val Accuracy: 0.7373606687898089


100%|██████████| 938/938 [00:04<00:00, 210.05it/s, accuracy=0.781, loss=0.701]


Epoch 7, Train Loss: 0.7744617322995972, Train Accuracy: 0.7619269722814499, Val Loss: 0.7305128982492314, Val Accuracy: 0.7737858280254777


100%|██████████| 938/938 [00:04<00:00, 213.41it/s, accuracy=0.844, loss=0.64] 


Epoch 8, Train Loss: 0.7358860552056766, Train Accuracy: 0.7717717217484008, Val Loss: 0.7045335460240674, Val Accuracy: 0.7890127388535032


100%|██████████| 938/938 [00:04<00:00, 217.15it/s, accuracy=0.75, loss=0.747] 


Epoch 9, Train Loss: 0.7052716451730809, Train Accuracy: 0.7818996535181236, Val Loss: 0.7653810668523144, Val Accuracy: 0.7411425159235668


100%|██████████| 938/938 [00:04<00:00, 219.91it/s, accuracy=0.938, loss=0.363]


Epoch 10, Train Loss: 0.6797795266802631, Train Accuracy: 0.7906783049040512, Val Loss: 0.6987739554636038, Val Accuracy: 0.7665207006369427


In [5]:
### BesselScipy_rbf

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0, n=0):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:06<00:00, 141.24it/s, accuracy=0.219, loss=1.91]


Epoch 1, Train Loss: 1.8671511615008942, Train Accuracy: 0.2974580223880597, Val Loss: 1.689373737687518, Val Accuracy: 0.3609673566878981


100%|██████████| 938/938 [00:06<00:00, 147.99it/s, accuracy=0.438, loss=1.51]


Epoch 2, Train Loss: 1.5423423262801506, Train Accuracy: 0.4489772121535181, Val Loss: 1.3726611972614458, Val Accuracy: 0.5444864649681529


100%|██████████| 938/938 [00:06<00:00, 146.61it/s, accuracy=0.594, loss=0.968]


Epoch 3, Train Loss: 1.247132884858768, Train Accuracy: 0.5827558635394456, Val Loss: 1.0949852815858878, Val Accuracy: 0.6426154458598726


100%|██████████| 938/938 [00:06<00:00, 148.05it/s, accuracy=0.719, loss=1.04] 


Epoch 4, Train Loss: 1.0348994812604462, Train Accuracy: 0.6676272654584222, Val Loss: 0.9134260336304926, Val Accuracy: 0.7232285031847133


100%|██████████| 938/938 [00:06<00:00, 138.84it/s, accuracy=0.719, loss=0.69] 


Epoch 5, Train Loss: 0.9039512032003545, Train Accuracy: 0.712769856076759, Val Loss: 0.8885481946027962, Val Accuracy: 0.6951632165605095


100%|██████████| 938/938 [00:06<00:00, 140.06it/s, accuracy=0.625, loss=0.883]


Epoch 6, Train Loss: 0.8302906726850375, Train Accuracy: 0.7341584488272921, Val Loss: 0.7634475048939893, Val Accuracy: 0.7678144904458599


100%|██████████| 938/938 [00:07<00:00, 133.59it/s, accuracy=0.719, loss=0.832]


Epoch 7, Train Loss: 0.7789393939189057, Train Accuracy: 0.7505663646055437, Val Loss: 0.7133746680560386, Val Accuracy: 0.7788614649681529


100%|██████████| 938/938 [00:06<00:00, 142.01it/s, accuracy=0.781, loss=0.705]


Epoch 8, Train Loss: 0.7398545205720198, Train Accuracy: 0.7631596481876333, Val Loss: 0.6618061999606478, Val Accuracy: 0.8026472929936306


100%|██████████| 938/938 [00:06<00:00, 140.97it/s, accuracy=0.594, loss=0.858]


Epoch 9, Train Loss: 0.717081808332187, Train Accuracy: 0.7707389392324094, Val Loss: 0.7141549063336318, Val Accuracy: 0.7627388535031847


100%|██████████| 938/938 [00:07<00:00, 133.83it/s, accuracy=0.812, loss=0.524]


Epoch 10, Train Loss: 0.6927612559246356, Train Accuracy: 0.7809335021321961, Val Loss: 0.641260493143349, Val Accuracy: 0.7990644904458599


In [7]:
### Yukawa function

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=0.5, beta=1.0):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.beta = beta

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def yukawa_rbf(self, distances):
        return (self.beta / distances) * torch.exp(-self.alpha * distances)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.yukawa_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output

class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)
        
    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:05<00:00, 186.55it/s, accuracy=0.0625, loss=2.3]


Epoch 1, Train Loss: 2.302585125732015, Train Accuracy: 0.09869736140724947, Val Loss: 2.3025851234508927, Val Accuracy: 0.09783041401273886


100%|██████████| 938/938 [00:04<00:00, 189.93it/s, accuracy=0.156, loss=2.3] 


Epoch 2, Train Loss: 2.302585053291402, Train Accuracy: 0.09874733475479744, Val Loss: 2.302584894143852, Val Accuracy: 0.09783041401273886


100%|██████████| 938/938 [00:04<00:00, 189.40it/s, accuracy=0.0625, loss=2.3]


Epoch 3, Train Loss: 2.3025848219897957, Train Accuracy: 0.09869736140724947, Val Loss: 2.302584672429759, Val Accuracy: 0.09783041401273886


100%|██████████| 938/938 [00:04<00:00, 192.30it/s, accuracy=0.188, loss=2.3] 


Epoch 4, Train Loss: 2.302583904408697, Train Accuracy: 0.09876399253731344, Val Loss: 2.302582303429865, Val Accuracy: 0.09783041401273886


100%|██████████| 938/938 [00:05<00:00, 178.49it/s, accuracy=0.125, loss=2.18] 


Epoch 5, Train Loss: 2.281276743549274, Train Accuracy: 0.09873067697228145, Val Loss: 2.1850998842032854, Val Accuracy: 0.09783041401273886


100%|██████████| 938/938 [00:05<00:00, 170.71it/s, accuracy=0.0625, loss=2.18]


Epoch 6, Train Loss: 2.1555443624697768, Train Accuracy: 0.09869736140724947, Val Loss: 2.1409531452093913, Val Accuracy: 0.09783041401273886


100%|██████████| 938/938 [00:05<00:00, 187.31it/s, accuracy=0.125, loss=2.04] 


Epoch 7, Train Loss: 2.131535403636981, Train Accuracy: 0.09873067697228145, Val Loss: 2.1263978944462574, Val Accuracy: 0.09783041401273886


100%|██████████| 938/938 [00:05<00:00, 185.69it/s, accuracy=0.0312, loss=2.26]


Epoch 8, Train Loss: 2.1195674377209595, Train Accuracy: 0.09868070362473348, Val Loss: 2.115885824914191, Val Accuracy: 0.09783041401273886


100%|██████████| 938/938 [00:05<00:00, 172.05it/s, accuracy=0.125, loss=2.03] 


Epoch 9, Train Loss: 2.1103934211009094, Train Accuracy: 0.09873067697228145, Val Loss: 2.1079943931786116, Val Accuracy: 0.09783041401273886


100%|██████████| 938/938 [00:05<00:00, 176.44it/s, accuracy=0.0625, loss=2.13]


Epoch 10, Train Loss: 2.1027972432596087, Train Accuracy: 0.09869736140724947, Val Loss: 2.1007082211743495, Val Accuracy: 0.09783041401273886


In [9]:
### yukawa_rbf when beta value is greater than alpha by a huge difference

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=0.2, beta=10):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.beta = beta

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def yukawa_rbf(self, distances):
        return (self.beta / distances) * torch.exp(-self.alpha * distances)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.yukawa_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output

class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)
        
    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:05<00:00, 177.49it/s, accuracy=0.719, loss=0.835]


Epoch 1, Train Loss: 1.6269024302964525, Train Accuracy: 0.45269189765458423, Val Loss: 0.7357678225465641, Val Accuracy: 0.7724920382165605


100%|██████████| 938/938 [00:05<00:00, 182.94it/s, accuracy=0.875, loss=0.626]


Epoch 2, Train Loss: 0.5784810103460162, Train Accuracy: 0.8349546908315565, Val Loss: 0.4543609149326944, Val Accuracy: 0.8678343949044586


100%|██████████| 938/938 [00:05<00:00, 168.29it/s, accuracy=0.938, loss=0.381]


Epoch 3, Train Loss: 0.41752364769228484, Train Accuracy: 0.8819129797441365, Val Loss: 0.3617260032778333, Val Accuracy: 0.8942078025477707


100%|██████████| 938/938 [00:05<00:00, 178.72it/s, accuracy=0.844, loss=0.369]


Epoch 4, Train Loss: 0.35362812948188804, Train Accuracy: 0.8997368070362474, Val Loss: 0.3166522380249326, Val Accuracy: 0.9065485668789809


100%|██████████| 938/938 [00:05<00:00, 174.06it/s, accuracy=0.938, loss=0.385]


Epoch 5, Train Loss: 0.3138428313343891, Train Accuracy: 0.9108642057569296, Val Loss: 0.28401078786819606, Val Accuracy: 0.9167993630573248


100%|██████████| 938/938 [00:05<00:00, 175.56it/s, accuracy=0.938, loss=0.253]


Epoch 6, Train Loss: 0.2841971624872959, Train Accuracy: 0.9193430170575693, Val Loss: 0.26134039949222354, Val Accuracy: 0.9226711783439491


100%|██████████| 938/938 [00:05<00:00, 170.21it/s, accuracy=0.875, loss=0.399] 


Epoch 7, Train Loss: 0.2603268955371527, Train Accuracy: 0.9263392857142857, Val Loss: 0.24516886982260047, Val Accuracy: 0.9263535031847133


100%|██████████| 938/938 [00:05<00:00, 177.50it/s, accuracy=0.906, loss=0.198] 


Epoch 8, Train Loss: 0.24135178266398943, Train Accuracy: 0.9317863805970149, Val Loss: 0.22629109293127517, Val Accuracy: 0.9328224522292994


100%|██████████| 938/938 [00:05<00:00, 174.70it/s, accuracy=0.969, loss=0.14]  


Epoch 9, Train Loss: 0.2257564180075868, Train Accuracy: 0.9359341684434968, Val Loss: 0.2134697505741552, Val Accuracy: 0.9370023885350318


100%|██████████| 938/938 [00:05<00:00, 172.56it/s, accuracy=1, loss=0.0863]    


Epoch 10, Train Loss: 0.2123655556305957, Train Accuracy: 0.9400819562899787, Val Loss: 0.20233065098713918, Val Accuracy: 0.9410828025477707


In [11]:
### BesselScipy_rbf when n=1

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0, n=1):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:06<00:00, 137.10it/s, accuracy=0.0312, loss=2.33]


Epoch 1, Train Loss: 2.2882166928065613, Train Accuracy: 0.11290644989339019, Val Loss: 2.268941546701322, Val Accuracy: 0.11544585987261147


100%|██████████| 938/938 [00:06<00:00, 143.46it/s, accuracy=0.156, loss=2.21] 


Epoch 2, Train Loss: 2.232599365939972, Train Accuracy: 0.1460720948827292, Val Loss: 2.1923634687047096, Val Accuracy: 0.1744625796178344


100%|██████████| 938/938 [00:06<00:00, 142.88it/s, accuracy=0.125, loss=2.24] 


Epoch 3, Train Loss: 2.1734009033072987, Train Accuracy: 0.18216950959488273, Val Loss: 2.1471669954858768, Val Accuracy: 0.18929140127388536


100%|██████████| 938/938 [00:06<00:00, 144.60it/s, accuracy=0.375, loss=1.92] 


Epoch 4, Train Loss: 2.1302306172944334, Train Accuracy: 0.20012659914712153, Val Loss: 2.0322551150230845, Val Accuracy: 0.2607484076433121


100%|██████████| 938/938 [00:06<00:00, 144.10it/s, accuracy=0.25, loss=1.99]  


Epoch 5, Train Loss: 2.0003949918472435, Train Accuracy: 0.25474746801705755, Val Loss: 1.976503000897207, Val Accuracy: 0.26910828025477707


100%|██████████| 938/938 [00:06<00:00, 143.94it/s, accuracy=0.156, loss=2]   


Epoch 6, Train Loss: 1.9628334483866499, Train Accuracy: 0.26687433368869934, Val Loss: 1.940046135786992, Val Accuracy: 0.2701035031847134


100%|██████████| 938/938 [00:06<00:00, 145.81it/s, accuracy=0.25, loss=1.97] 


Epoch 7, Train Loss: 1.944800921848842, Train Accuracy: 0.27238805970149255, Val Loss: 1.926380433854024, Val Accuracy: 0.28125


100%|██████████| 938/938 [00:06<00:00, 142.80it/s, accuracy=0.281, loss=2.02] 


Epoch 8, Train Loss: 1.9310542127725159, Train Accuracy: 0.27901785714285715, Val Loss: 1.9158973731812399, Val Accuracy: 0.2862261146496815


100%|██████████| 938/938 [00:06<00:00, 145.61it/s, accuracy=0.281, loss=1.92]


Epoch 9, Train Loss: 1.923161168469549, Train Accuracy: 0.27860141257995735, Val Loss: 1.9082460008609068, Val Accuracy: 0.28562898089171973


100%|██████████| 938/938 [00:06<00:00, 144.68it/s, accuracy=0.25, loss=1.98] 


Epoch 10, Train Loss: 1.915750351160574, Train Accuracy: 0.28081689765458423, Val Loss: 1.9041143002783416, Val Accuracy: 0.28662420382165604


In [13]:
### BesselScipy_rbf when n=2

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0, n=2):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:06<00:00, 145.12it/s, accuracy=0.312, loss=1.9] 


Epoch 1, Train Loss: 1.8340437501224118, Train Accuracy: 0.3158315565031983, Val Loss: 1.6089767407459818, Val Accuracy: 0.4096337579617834


100%|██████████| 938/938 [00:06<00:00, 146.39it/s, accuracy=0.406, loss=1.46]


Epoch 2, Train Loss: 1.455144373974058, Train Accuracy: 0.4896721748400853, Val Loss: 1.2450035973718971, Val Accuracy: 0.6142515923566879


100%|██████████| 938/938 [00:06<00:00, 149.05it/s, accuracy=0.656, loss=1.13] 


Epoch 3, Train Loss: 1.188550821372441, Train Accuracy: 0.6000299840085288, Val Loss: 1.0611771257819644, Val Accuracy: 0.6328622611464968


100%|██████████| 938/938 [00:06<00:00, 149.22it/s, accuracy=0.688, loss=0.764]


Epoch 4, Train Loss: 1.0208333355150243, Train Accuracy: 0.6665611673773987, Val Loss: 0.8980562834982659, Val Accuracy: 0.7169585987261147


100%|██████████| 938/938 [00:06<00:00, 153.26it/s, accuracy=0.719, loss=1.09] 


Epoch 5, Train Loss: 0.8853271438368856, Train Accuracy: 0.7227978411513859, Val Loss: 0.8072411214849752, Val Accuracy: 0.7541799363057324


100%|██████████| 938/938 [00:06<00:00, 144.43it/s, accuracy=0.656, loss=0.918]


Epoch 6, Train Loss: 0.7956605427491386, Train Accuracy: 0.7533482142857143, Val Loss: 0.7295777000439395, Val Accuracy: 0.7706011146496815


100%|██████████| 938/938 [00:06<00:00, 150.32it/s, accuracy=0.75, loss=0.685] 


Epoch 7, Train Loss: 0.742443427062238, Train Accuracy: 0.767573960554371, Val Loss: 0.6854701435110372, Val Accuracy: 0.7954816878980892


100%|██████████| 938/938 [00:06<00:00, 147.23it/s, accuracy=0.875, loss=0.409]


Epoch 8, Train Loss: 0.6967978584550337, Train Accuracy: 0.7844149786780383, Val Loss: 0.6604142320953357, Val Accuracy: 0.7977707006369427


100%|██████████| 938/938 [00:06<00:00, 143.31it/s, accuracy=0.781, loss=0.645]


Epoch 9, Train Loss: 0.6633231620798742, Train Accuracy: 0.7962086886993603, Val Loss: 0.6952171328530949, Val Accuracy: 0.7709992038216561


100%|██████████| 938/938 [00:06<00:00, 145.16it/s, accuracy=0.719, loss=0.808]


Epoch 10, Train Loss: 0.6453505696645424, Train Accuracy: 0.8015058635394456, Val Loss: 0.6191531448227585, Val Accuracy: 0.8072253184713376


In [15]:
### BesselScipy_rbf when n=3

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0, n=3):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:06<00:00, 145.33it/s, accuracy=0.188, loss=2.25] 


Epoch 1, Train Loss: 2.292624119502395, Train Accuracy: 0.11225679637526652, Val Loss: 2.269129139602564, Val Accuracy: 0.11504777070063694


100%|██████████| 938/938 [00:06<00:00, 145.90it/s, accuracy=0.188, loss=2.17] 


Epoch 2, Train Loss: 2.222855493712273, Train Accuracy: 0.14935367803837954, Val Loss: 2.172526423338872, Val Accuracy: 0.1819267515923567


100%|██████████| 938/938 [00:06<00:00, 138.84it/s, accuracy=0.125, loss=2.28] 


Epoch 3, Train Loss: 2.1533773181789213, Train Accuracy: 0.18410181236673773, Val Loss: 2.1279912442918034, Val Accuracy: 0.18272292993630573


100%|██████████| 938/938 [00:06<00:00, 141.19it/s, accuracy=0.0938, loss=2.28]


Epoch 4, Train Loss: 2.123523823742165, Train Accuracy: 0.1943463486140725, Val Loss: 2.1041943670078447, Val Accuracy: 0.1972531847133758


100%|██████████| 938/938 [00:06<00:00, 141.76it/s, accuracy=0.156, loss=2.19] 


Epoch 5, Train Loss: 2.109947318334315, Train Accuracy: 0.1958622068230277, Val Loss: 2.1004028487357362, Val Accuracy: 0.19974124203821655


100%|██████████| 938/938 [00:06<00:00, 142.82it/s, accuracy=0.25, loss=1.94]  


Epoch 6, Train Loss: 2.1017113976132897, Train Accuracy: 0.1964952025586354, Val Loss: 2.1086333602856677, Val Accuracy: 0.17277070063694266


 80%|███████▉  | 747/938 [00:05<00:01, 133.50it/s, accuracy=0.203, loss=2.07] 