In [1]:
pip install torch torchvision tqdm scipy

Note: you may need to restart the kernel to use updated packages.


In [2]:
### BesselTorch_rbf

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselTorch_rbf(self, distances):
        return torch.special.bessel_j0(self.alpha * distances)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselTorch_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|█████████████| 938/938 [00:04<00:00, 195.91it/s, accuracy=0.188, loss=1.85]


Epoch 1, Train Loss: 1.8491423546886647, Train Accuracy: 0.3093350213219616, Val Loss: 1.637801615295896, Val Accuracy: 0.4254578025477707


100%|█████████████| 938/938 [00:04<00:00, 202.65it/s, accuracy=0.469, loss=1.54]


Epoch 2, Train Loss: 1.4611397613086172, Train Accuracy: 0.5031816364605544, Val Loss: 1.2612032381592282, Val Accuracy: 0.5824044585987261


100%|█████████████| 938/938 [00:04<00:00, 203.57it/s, accuracy=0.75, loss=0.859]


Epoch 3, Train Loss: 1.170762886497766, Train Accuracy: 0.617204157782516, Val Loss: 1.053619878686917, Val Accuracy: 0.6705812101910829


100%|█████████████| 938/938 [00:04<00:00, 189.15it/s, accuracy=0.375, loss=1.43]


Epoch 4, Train Loss: 1.0146674231044266, Train Accuracy: 0.6745235874200426, Val Loss: 1.0150204756457335, Val Accuracy: 0.6713773885350318


100%|█████████████| 938/938 [00:04<00:00, 207.33it/s, accuracy=0.625, loss=1.18]


Epoch 5, Train Loss: 0.9117738707487517, Train Accuracy: 0.7115371801705757, Val Loss: 0.8570217661037567, Val Accuracy: 0.7227308917197452


100%|████████████| 938/938 [00:04<00:00, 207.54it/s, accuracy=0.656, loss=0.935]


Epoch 6, Train Loss: 0.8345351500678927, Train Accuracy: 0.7385894189765458, Val Loss: 0.8207709975303359, Val Accuracy: 0.7471138535031847


100%|████████████| 938/938 [00:04<00:00, 190.12it/s, accuracy=0.688, loss=0.947]


Epoch 7, Train Loss: 0.7817760809207521, Train Accuracy: 0.7569629530916845, Val Loss: 0.7140475562803305, Val Accuracy: 0.7884156050955414


100%|████████████| 938/938 [00:04<00:00, 190.17it/s, accuracy=0.812, loss=0.627]


Epoch 8, Train Loss: 0.7362334903941226, Train Accuracy: 0.7741537846481876, Val Loss: 0.730167004523004, Val Accuracy: 0.7898089171974523


100%|████████████| 938/938 [00:05<00:00, 172.89it/s, accuracy=0.812, loss=0.636]


Epoch 9, Train Loss: 0.7033162705107793, Train Accuracy: 0.7850646321961621, Val Loss: 0.6546746264597413, Val Accuracy: 0.8074243630573248


100%|████████████| 938/938 [00:05<00:00, 176.11it/s, accuracy=0.656, loss=0.775]


Epoch 10, Train Loss: 0.6832666634115329, Train Accuracy: 0.7906616471215352, Val Loss: 0.6668376351238057, Val Accuracy: 0.7894108280254777


In [4]:
### BesselScipy_rbf

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0, n=0):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|█████████████| 938/938 [00:07<00:00, 127.90it/s, accuracy=0.375, loss=1.59]


Epoch 1, Train Loss: 1.8383450393737761, Train Accuracy: 0.3138825959488273, Val Loss: 1.657223371943091, Val Accuracy: 0.4131170382165605


100%|█████████████| 938/938 [00:07<00:00, 133.38it/s, accuracy=0.656, loss=1.11]


Epoch 2, Train Loss: 1.4630956729846214, Train Accuracy: 0.4810434434968017, Val Loss: 1.2891415251288445, Val Accuracy: 0.5661823248407644


100%|█████████████| 938/938 [00:07<00:00, 123.02it/s, accuracy=0.594, loss=1.19]


Epoch 3, Train Loss: 1.1875712949075679, Train Accuracy: 0.5989805437100213, Val Loss: 1.1215891925392636, Val Accuracy: 0.6159434713375797


100%|████████████| 938/938 [00:07<00:00, 130.95it/s, accuracy=0.781, loss=0.651]


Epoch 4, Train Loss: 1.0120267848978672, Train Accuracy: 0.6637460021321961, Val Loss: 0.8993738574586856, Val Accuracy: 0.7136743630573248


100%|████████████| 938/938 [00:06<00:00, 135.84it/s, accuracy=0.719, loss=0.854]


Epoch 5, Train Loss: 0.9073988695515752, Train Accuracy: 0.7046075426439232, Val Loss: 0.9126652016001902, Val Accuracy: 0.6909832802547771


100%|█████████████| 938/938 [00:07<00:00, 126.25it/s, accuracy=0.719, loss=1.02]


Epoch 6, Train Loss: 0.8357455570306351, Train Accuracy: 0.730460421108742, Val Loss: 0.7690950873171448, Val Accuracy: 0.7654259554140127


100%|████████████| 938/938 [00:07<00:00, 129.17it/s, accuracy=0.688, loss=0.724]


Epoch 7, Train Loss: 0.7869509586901553, Train Accuracy: 0.7483841950959488, Val Loss: 0.7274813437537783, Val Accuracy: 0.7718949044585988


100%|████████████| 938/938 [00:07<00:00, 130.84it/s, accuracy=0.781, loss=0.713]


Epoch 8, Train Loss: 0.746294835260682, Train Accuracy: 0.7643756663113006, Val Loss: 0.73796828062671, Val Accuracy: 0.7534832802547771


100%|████████████| 938/938 [00:07<00:00, 133.00it/s, accuracy=0.906, loss=0.522]


Epoch 9, Train Loss: 0.7211696076621887, Train Accuracy: 0.7707389392324094, Val Loss: 0.6917945691354715, Val Accuracy: 0.7708996815286624


100%|████████████| 938/938 [00:07<00:00, 132.99it/s, accuracy=0.594, loss=0.892]


Epoch 10, Train Loss: 0.6938371279282864, Train Accuracy: 0.7817497334754797, Val Loss: 0.6694020243587008, Val Accuracy: 0.7853304140127388


In [6]:
### Yukawa function

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=0.5, beta=1.0):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.beta = beta

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def yukawa_rbf(self, distances):
        return (self.beta / distances) * torch.exp(-self.alpha * distances)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.yukawa_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output

class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)
        
    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|█████████████| 938/938 [00:06<00:00, 151.48it/s, accuracy=0.0938, loss=2.3]


Epoch 1, Train Loss: 2.302585129544679, Train Accuracy: 0.09886393923240938, Val Loss: 2.302585128006662, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:05<00:00, 158.05it/s, accuracy=0.0938, loss=2.3]


Epoch 2, Train Loss: 2.302585066508637, Train Accuracy: 0.09871401918976545, Val Loss: 2.3025849427387213, Val Accuracy: 0.09783041401273886


100%|██████████████████| 938/938 [00:06<00:00, 150.56it/s, accuracy=0, loss=2.3]


Epoch 3, Train Loss: 2.3025848156353557, Train Accuracy: 0.09866404584221748, Val Loss: 2.3025845995374548, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:06<00:00, 144.58it/s, accuracy=0.0938, loss=2.3]


Epoch 4, Train Loss: 2.3025837564773397, Train Accuracy: 0.09871401918976545, Val Loss: 2.302581847852962, Val Accuracy: 0.09783041401273886


100%|██████████████| 938/938 [00:06<00:00, 154.32it/s, accuracy=0.25, loss=2.04]


Epoch 5, Train Loss: 2.278119439763555, Train Accuracy: 0.09879730810234541, Val Loss: 2.186510521894807, Val Accuracy: 0.09783041401273886


 69%|████████▎   | 649/938 [00:04<00:01, 168.11it/s, accuracy=0.0938, loss=2.15]