In [1]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.autograd import Function
import torch.optim as optim

In [111]:
import torch
from torch.utils.data import Dataset

class MulticlassDataset(Dataset):
    def __init__(self, num_samples, num_classes, num_features, means, std=1):
        super().__init__()
        self.num_samples = num_samples
        self.num_classes = num_classes
        self.num_features = num_features
        self.means = means
        self.std = std
        self.features = torch.zeros((num_samples, num_features))
        self.labels = torch.zeros(num_samples, dtype=torch.long)

        for i in range(num_samples):
            # Randomly choose a class label
            class_label = torch.randint(low=0, high=num_classes, size=(1,)).item()
            self.labels[i] = class_label

            # Generate features based on the mean and standard deviation of the chosen class
            self.features[i] = torch.normal(mean=self.means[class_label], std=self.std)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        # Normalize the feature vector to have unit length
        features = self.features[index]
        features = features / torch.norm(features)

        return features, self.labels[index]

In [112]:
# Single RBF Neuron
class RBFNeuron(nn.Module):

    # mu: RBF mu vector
    # sig: RBF sigma
    def __init__(self, mu, sig):
        super(RBFNeuron, self).__init__()
        self.mu = mu
        self.sig = sig

    def __call__(self, x):
        top = torch.linalg.norm(x-self.mu, dim=1)
        return torch.exp((-0.5)*(top.pow(2) / self.sig))

# Layer of RBF Neurons
class RBFLayer(nn.Module):

    # nin: input dim
    # nout: output dim
    # mus: list of mean vectors for RBF neurons
    # sigs: list of sigmas for RBF neurons
    def __init__(self, nin, nout, mus, sigs):
        super(RBFLayer, self).__init__()
        self.neurons = nn.ModuleList([RBFNeuron(mus[i],sigs[i]) for i in range(nout)])

    def __call__(self, x):
        return torch.tensor([f(x).detach().numpy() for f in self.neurons], dtype=torch.float32, requires_grad=False).transpose(0,1)

# Full RBF Network
class RBFNet(nn.Module):

    # mus: list of means to use in basis functions
    # sigs: list of sigmas to use in basis functions
    # n_classes: num of classes to pred
    def __init__(self, mus, sigs, n_classes=10):
        super(RBFNet, self).__init__()
        self.K = len(mus) # number of RBFs
        self.mus = nn.Parameter(mus)#, requires_grad=False)
        self.sigs = nn.Parameter(sigs)#, requires_grad=False)
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential(
            RBFLayer(2, self.K, self.mus, self.sigs),
            nn.Linear(self.K, n_classes)
        )

    def forward(self, x):
        #x = (x-x.min())/(x.max()-x.min())
        #x = self.flatten(x)
        return self.layers(x)

In [113]:
from tqdm.notebook import tqdm # status bar

def train(model, data, loss_fn, optimizer, epochs=5):

    for epoch in range(epochs):

        epoch_loss = []

        for batch, (samples, labels) in enumerate(tqdm(data)):

            # we need to convert these into tensors
            #samples = samples.type('torch.FloatTensor')
            #labels = labels.type('torch.LongTensor')

            # forward pass
            prediction = model(samples)
            loss = loss_fn(prediction, labels)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # record loss
            epoch_loss.append(loss.item())

        # keep track of loss over our batches
        #epoch_loss = statistics.mean(epoch_loss)
        print(loss)

In [114]:
def test(model, data, loss_fn):

    for batch, (samples, labels) in enumerate(tqdm(data)):

        # we need to convert these into tensors
        #samples = samples.type('torch.FloatTensor')
        #labels = labels.type('torch.LongTensor')

        # forward pass
        prediction = model(samples)
        loss = loss_fn(prediction, labels)

    # test loss
    print(loss)

In [115]:
loss_fn = nn.CrossEntropyLoss()
# hyperparams
learning_rate = 1e-3
momentum = 0.3
epochs = 3

In [177]:
num_samples = 10000
num_classes = 10
num_features = 784
means = torch.randn(num_classes, num_features)
std = 0.5

dataset = MulticlassDataset(num_samples, num_classes, num_features, means, std)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [178]:
# mus = means
mus = torch.div(means, torch.linalg.vector_norm(means, dim=1).view(-1,1)) # unit norm means
sigs = torch.ones(num_classes)*std

rbf_model = RBFNet(mus=mus, sigs=sigs, n_classes=num_classes)

optimizer = optim.Adam(rbf_model.parameters(), lr=learning_rate)

In [181]:
train(rbf_model, dataloader, loss_fn, optimizer, epochs=epochs)

  0%|          | 0/313 [00:00<?, ?it/s]

tensor(0.8154, grad_fn=<NllLossBackward0>)


  0%|          | 0/313 [00:00<?, ?it/s]

tensor(0.6717, grad_fn=<NllLossBackward0>)


  0%|          | 0/313 [00:00<?, ?it/s]

tensor(0.5555, grad_fn=<NllLossBackward0>)
