In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from alpaca.dataloader.builder import build_dataset

import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt

import pdb
from tqdm import tqdm

from torchsummary import summary

%matplotlib inline

In [2]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torchType = torch.float32

### Data

```python 
DATASETS = {
    'boston_housing': BostonHousingData,
    'concrete': ConcreteData,
    'energy_efficiency': EnergyEfficiencyData,
    'kin8nm': Kin8nmData,
    'naval_propulsion': NavalPropulsionData,
    'ccpp': CCPPData,
    'protein_structure': ProteinStructureData,
    'red_wine': RedWineData,
    'yacht_hydrodynamics': YachtHydrodynamicsData,
    'year_prediction_msd': YearPredictionMSDData,
    'mnist': MnistData,
    'fashion_mnist': FashionMnistData,
    'cifar_10': Cifar10,
    'svhn': SVHN
}
```

In [3]:
num_epoches = 20001
print_info = 1000

train_batch_size = 250
val_dataset = 10
val_batch_size = 10
test_batch_size = 10

In [4]:
class Dataset():
    def __init__(self, dataset_name='mnist'):
        self.dataset_name = dataset_name

        try:
            dataset = build_dataset(self.dataset_name, val_size=val_dataset)
        except TypeError:
            dataset = build_dataset(self.dataset_name, val_split=val_dataset)
        x_train, y_train = dataset.dataset('train')
        print(f'Train data shape {x_train.shape[0]}')
        self.train_ans = y_train
        self.in_features = x_train.shape[1:]
        x_val, y_val = dataset.dataset('val')
        
        if self.dataset_name in ['mnist', 'fashion_mnist']:
            x_train /= x_train.max()
            x_val /= x_val.max()
            x_shape = (-1, 1, 28, 28)
        else:
            x_shape = (-1, *x_train.shape[1:])
            
        self.train = TensorDataset(torch.tensor(x_train.reshape(x_shape), dtype=torch.float32, device=device), torch.tensor(y_train, dtype=torch.float32, device=device))
        self.validation = TensorDataset(torch.tensor(x_val.reshape(x_shape), dtype=torch.float32, device=device), torch.tensor(y_val, dtype=torch.float32, device=device))
        self.train_dataloader = DataLoader(self.train, batch_size=train_batch_size)
        self.val_dataloader = DataLoader(self.validation, batch_size=val_batch_size)
        
        if self.dataset_name=='mnist':
            test = datasets.MNIST(root=f'./data/{dataset_name}', download=True, train=False)
            data_test = test.test_data.type(torchType).to(device)
            labels_test = test.test_labels.type(torchType).to(device)

            self.test = data_test.data
            self.test_labels = labels_test.data

            test_data = []
            for i in range(self.test.shape[0]):
                test_data.append([self.test[i], self.test_labels[i]])
            self.test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False)
    
    def next_train_batch(self):
        for train_batch in self.train_dataloader:
            batch = train_batch[0]
            labels = train_batch[1]
            if self.dataset_name in ['mnist', 'fashion_mnist']:
                batch = torch.distributions.Binomial(probs=batch).sample()
            yield batch, labels

    def next_val_batch(self):
        for val_batch in self.val_dataloader:
            batch = val_batch[0]
            labels = val_batch[1]
            yield batch, labels

    def next_test_batch(self):
        for test_batch in self.test_dataloader:
            batch = test_batch[0]
            labels = test_batch[1]
            if self.dataset_name in ['mnist', 'fashion_mnist']:
                batch = torch.distributions.Binomial(probs=batch).sample()
                batch = batch.view([-1, 1, 28, 28])
            yield batch, labels

In [5]:
dataset = Dataset(dataset_name='mnist')

Train data shape 69990




### NN definition

In [6]:
num_classes = len(np.unique(dataset.train_ans))
problem = 'classification' if num_classes < 20 else 'regression'
if problem == 'regression':
    num_classes = 1
problem

'regression'

In [7]:
in_features = dataset.in_features[0]
in_features

13

In [8]:
last_features = 10

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        if problem == 'classification':
            self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=2, padding=2)
            self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=2)
            self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=2, padding=2)
            self.linear1 = nn.Linear(in_features=1024, out_features=256)
            self.linear2 = nn.Linear(in_features=256, out_features=last_features)
        else:
            self.linear1 = nn.Linear(in_features=in_features, out_features=10*in_features)
            self.linear2 = nn.Linear(in_features=10*in_features, out_features=last_features)

    def forward(self, x):
        if problem == 'classification':
            h1 = torch.relu(self.conv1(x))
            h2 = torch.relu(self.conv2(h1))
            h3 = torch.relu(self.conv3(h2))
            h3_flat = h3.view(h3.shape[0], -1)
        else:
            h3_flat = x
        h4 = torch.relu(self.linear1(h3_flat))
        h5 = torch.relu(self.linear2(h4))
        return h5

In [10]:
model = Net().to(device)

In [11]:
if problem == 'classification':
    summary(model, input_size=(1, 28, 28))
else:
    summary(model, input_size=(1, in_features))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1               [-1, 1, 130]           1,820
            Linear-2                [-1, 1, 10]           1,310
Total params: 3,130
Trainable params: 3,130
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.01
Estimated Total Size (MB): 0.01
----------------------------------------------------------------


### Bayesian last layer definition

In [12]:
last_weight_mu = nn.Parameter(torch.randn((last_features, num_classes), device=device, dtype=torchType))
last_weight_logvar = nn.Parameter(torch.randn((last_features, num_classes), device=device, dtype=torchType))

last_bias_mu = nn.Parameter(torch.randn((1, num_classes), device=device, dtype=torchType))
last_bias_logvar = nn.Parameter(torch.randn((1, num_classes), device=device, dtype=torchType))

### Define optimizer

In [13]:
params = list(model.parameters()) + [last_weight_mu, last_weight_logvar] + [last_bias_mu, last_bias_logvar]
optimizer = torch.optim.Adam(params)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, np.linspace(start=10, stop=num_epoches, num=50), gamma=0.9)

### Training

In [14]:
std_normal = torch.distributions.Normal(loc=torch.tensor(0., device=device, dtype=torchType),
                                       scale=torch.tensor(1., device=device, dtype=torchType),)

In [15]:
for ep in tqdm(range(num_epoches)):
    for x_train, y_train_labels in dataset.next_train_batch():
        emb = model(x_train)
        last_weight = last_weight_mu + std_normal.sample(last_weight_mu.shape) * torch.exp(0.5 * last_weight_logvar)
        last_bias = last_bias_mu + std_normal.sample(last_bias_mu.shape) * torch.exp(0.5 * last_bias_logvar)
        preds = emb @ last_weight + last_bias
        
        if problem == 'classification':
            log_likelihood = torch.distributions.Categorical(logits=preds).log_prob(y_train_labels).sum()
        else:
            log_likelihood = torch.distributions.Normal(loc=preds, scale=torch.tensor(1., device=device,
                                                                                      dtype=torchType)).log_prob(y_train_labels).sum()
    
        KL = (0.5 * (-last_weight_logvar + torch.exp(last_weight_logvar) + last_weight_mu ** 2 - 1.)).mean() \
                        + (0.5 * (-last_bias_logvar + torch.exp(last_bias_logvar) + last_bias_mu ** 2 - 1.)).mean()
        
        elbo = log_likelihood - KL
        (-elbo).backward()
        
        optimizer.step()
        optimizer.zero_grad()
    scheduler.step()
        
    if ep % print_info == 0:
        print(f'ELBO value is {elbo.cpu().detach().numpy()} on epoch number {ep}')
        score_total = []
        with torch.no_grad():
            for x_val, y_val_labels in dataset.next_val_batch():
                emb = model(x_val)
                last_weight = last_weight_mu
                last_bias = last_bias_mu
                logits = emb @ last_weight + last_bias
                if problem == 'classification':
                    probs = torch.softmax(logits, dim=-1)
                    y_pred = torch.argmax(probs, dim=-1)
                    score = (y_pred==y_val_labels).to(torchType).cpu().mean().numpy()
                    score_total.append(score)
                else:
                    score = ((logits - y_val_labels)**2).mean().cpu().numpy()
                    score_total.append(score)
        
        if problem == 'classification':
            print(f"Mean validation accuracy at epoch number {ep} is {np.array(score_total).mean()}")
        else:
            print(f"Mean validation MSE at epoch number {ep} is {np.array(score_total).mean()}")
        print(f'Current KL is {KL.cpu().detach().numpy()}')

  0%|          | 15/20001 [00:00<02:13, 149.35it/s]

ELBO value is -1969046.875 on epoch number 0
Mean validation MSE at epoch number 0 is 7665.44775390625
Current KL is 0.8030207753181458


  5%|▌         | 1023/20001 [00:06<02:10, 144.96it/s]

ELBO value is -6680.2119140625 on epoch number 1000
Mean validation MSE at epoch number 1000 is 57.852699279785156
Current KL is 0.6091501116752625


 10%|█         | 2015/20001 [00:13<01:56, 154.87it/s]

ELBO value is -4507.89892578125 on epoch number 2000
Mean validation MSE at epoch number 2000 is 59.2878303527832
Current KL is 0.8268250823020935


 15%|█▌        | 3023/20001 [00:19<01:50, 154.23it/s]

ELBO value is -2641.317626953125 on epoch number 3000
Mean validation MSE at epoch number 3000 is 20.325830459594727
Current KL is 1.0989570617675781


 20%|██        | 4031/20001 [00:26<01:43, 154.36it/s]

ELBO value is -2929.76025390625 on epoch number 4000
Mean validation MSE at epoch number 4000 is 12.463539123535156
Current KL is 1.3552286624908447


 25%|██▌       | 5023/20001 [00:32<01:37, 154.36it/s]

ELBO value is -2743.01904296875 on epoch number 5000
Mean validation MSE at epoch number 5000 is 11.994623184204102
Current KL is 1.6025416851043701


 30%|███       | 6031/20001 [00:39<01:29, 156.75it/s]

ELBO value is -1863.131103515625 on epoch number 6000
Mean validation MSE at epoch number 6000 is 11.634230613708496
Current KL is 1.838252305984497


 35%|███▌      | 7023/20001 [00:45<01:22, 156.43it/s]

ELBO value is -1965.470947265625 on epoch number 7000
Mean validation MSE at epoch number 7000 is 13.208304405212402
Current KL is 2.0364298820495605


 40%|████      | 8031/20001 [00:52<01:16, 156.56it/s]

ELBO value is -1206.333740234375 on epoch number 8000
Mean validation MSE at epoch number 8000 is 12.864912986755371
Current KL is 2.238433361053467


 45%|████▌     | 9023/20001 [00:58<01:10, 156.53it/s]

ELBO value is -1112.089599609375 on epoch number 9000
Mean validation MSE at epoch number 9000 is 11.413199424743652
Current KL is 2.423302412033081


 50%|█████     | 10031/20001 [01:05<01:03, 156.37it/s]

ELBO value is -939.5399780273438 on epoch number 10000
Mean validation MSE at epoch number 10000 is 11.291762351989746
Current KL is 2.5674638748168945


 55%|█████▌    | 11023/20001 [01:11<00:57, 156.41it/s]

ELBO value is -905.5546264648438 on epoch number 11000
Mean validation MSE at epoch number 11000 is 12.479893684387207
Current KL is 2.6978085041046143


 60%|██████    | 12031/20001 [01:18<00:50, 156.38it/s]

ELBO value is -718.1116333007812 on epoch number 12000
Mean validation MSE at epoch number 12000 is 13.121752738952637
Current KL is 2.7885448932647705


 65%|██████▌   | 13023/20001 [01:24<00:44, 155.73it/s]

ELBO value is -662.1700439453125 on epoch number 13000
Mean validation MSE at epoch number 13000 is 13.6211576461792
Current KL is 2.89694881439209


 70%|███████   | 14031/20001 [01:31<00:38, 155.91it/s]

ELBO value is -670.2958374023438 on epoch number 14000
Mean validation MSE at epoch number 14000 is 14.030771255493164
Current KL is 3.037389039993286


 75%|███████▌  | 15023/20001 [01:37<00:32, 155.10it/s]

ELBO value is -581.7822265625 on epoch number 15000
Mean validation MSE at epoch number 15000 is 15.036531448364258
Current KL is 3.159209966659546


 80%|████████  | 16031/20001 [01:44<00:25, 154.30it/s]

ELBO value is -531.6737670898438 on epoch number 16000
Mean validation MSE at epoch number 16000 is 15.670501708984375
Current KL is 3.244945526123047


 85%|████████▌ | 17023/20001 [01:50<00:19, 153.56it/s]

ELBO value is -500.3828430175781 on epoch number 17000
Mean validation MSE at epoch number 17000 is 15.912866592407227
Current KL is 3.3410842418670654


 90%|█████████ | 18031/20001 [01:57<00:12, 152.08it/s]

ELBO value is -494.6930847167969 on epoch number 18000
Mean validation MSE at epoch number 18000 is 17.519559860229492
Current KL is 3.412440299987793


 95%|█████████▌| 19023/20001 [02:03<00:06, 150.28it/s]

ELBO value is -482.97991943359375 on epoch number 19000
Mean validation MSE at epoch number 19000 is 17.103857040405273
Current KL is 3.4698357582092285


100%|██████████| 20001/20001 [02:09<00:00, 154.04it/s]

ELBO value is -465.5469055175781 on epoch number 20000
Mean validation MSE at epoch number 20000 is 17.824260711669922
Current KL is 3.5273096561431885





# Inference for classification

### Inference

In [16]:
val_image_id = 2


for batch in dataset.next_val_batch():
    test_image = batch[0][val_image_id].squeeze()
    test_label = batch[1][val_image_id]

plt.title(f"{test_label.cpu().numpy()}")
plt.imshow(test_image.cpu().numpy());

In [17]:
n_samples = 100

results = []
with torch.no_grad():
    for _ in range(n_samples):
        emb = model(test_image[None, None, ...])
        last_weight = last_weight_mu + std_normal.sample(last_weight_mu.shape) * torch.exp(0.5 * last_weight_logvar)
        last_bias = last_bias_mu + std_normal.sample(last_bias_mu.shape) * torch.exp(0.5 * last_bias_logvar)
        
        logits = emb @ last_weight + last_bias
        probs = torch.softmax(logits, dim=-1)
        y_pred = torch.argmax(probs, dim=-1)
        results.append(y_pred.cpu().item())
        

labels, counts = np.unique(results, return_counts=True)
plt.bar(labels, counts, align='center')
plt.xticks(ticks=np.arange(10))
plt.xlim((-1, 10));

### Find digits with non-trivial distribution

In [18]:
n_samples = 100

for val_batch in dataset.next_val_batch():
    val_images = val_batch[0]
    val_labels = val_batch[1]
    for i in range(val_images.shape[0]):
        test_image = val_images[i].squeeze()
        test_label = val_labels[i].squeeze()
        plt.close()
        results = []
        with torch.no_grad():
            for _ in range(n_samples):
                emb = model(test_image[None, None, ...])
                last_weight = last_weight_mu + std_normal.sample(last_weight_mu.shape) * torch.exp(0.5 * last_weight_logvar)
                last_bias = last_bias_mu + std_normal.sample(last_bias_mu.shape) * torch.exp(0.5 * last_bias_logvar)

                logits = emb @ last_weight + last_bias
                probs = torch.softmax(logits, dim=-1)
                y_pred = torch.argmax(probs, dim=-1)
                results.append(y_pred.cpu().item())
        if np.unique(results).shape[0] > 1: # or np.unique(results)[0] != test_label:
            print('-' * 100)
            plt.title(f"{test_label.cpu().numpy()}")
            plt.imshow(test_image.cpu().numpy());
            plt.show()

            labels, counts = np.unique(results, return_counts=True)
            plt.bar(labels, counts, align='center')
            plt.xticks(ticks=np.arange(10))
            plt.xlim((-1, 10));
            plt.show();

# Inference for Regression

In [23]:
val_image_id = 0


for batch in dataset.next_val_batch():
    test_image = batch[0][val_image_id].squeeze()
    test_label = batch[1][val_image_id]

print(test_label)

tensor([16.7000], device='cuda:0')


In [24]:
n_samples = 100

results = []
with torch.no_grad():
    for _ in range(n_samples):
        if problem == 'classification':
            emb = model(test_image[None, None, ...])
        else:
            emb = model(test_image[None,  ...])
        last_weight = last_weight_mu + std_normal.sample(last_weight_mu.shape) * torch.exp(0.5 * last_weight_logvar)
        last_bias = last_bias_mu + std_normal.sample(last_bias_mu.shape) * torch.exp(0.5 * last_bias_logvar)
        
        logits = emb @ last_weight + last_bias
        results.append(logits.cpu().item())
        
print('Marginalized answer ', np.mean(results))
print('True answer', test_label)
print('-' * 100)
print(f'Parameters are:')
print('Weight:')
print(f'Mean is \n {last_weight_mu.cpu().detach().numpy()}')
print(f'Variance is \n {np.exp(last_weight_logvar.cpu().detach().numpy())}')
print('-' * 100)
print(f'Parameters are:')
print('Bias:')
print(f'Mean is \n {last_bias_mu.cpu().detach().numpy()}')
print(f'Variance is \n {np.exp(last_bias_logvar.cpu().detach().numpy())}')

Marginalized answer  16.372964668273926
True answer tensor([16.7000], device='cuda:0')
----------------------------------------------------------------------------------------------------
Parameters are:
Weight:
Mean is 
 [[-5.7733497e-43]
 [-6.0255834e-44]
 [-1.2277958e-13]
 [-5.0446745e-44]
 [ 2.0074294e+00]
 [ 7.1607366e-19]
 [-2.4595265e-13]
 [-8.4077908e-45]
 [-8.4077908e-45]
 [ 1.1350518e-43]]
Variance is 
 [[1.0000000e+00]
 [1.0000001e+00]
 [1.0000001e+00]
 [1.0000001e+00]
 [2.7631626e-05]
 [9.9999994e-01]
 [1.0000001e+00]
 [9.9999994e-01]
 [1.0000001e+00]
 [1.0000001e+00]]
----------------------------------------------------------------------------------------------------
Parameters are:
Bias:
Mean is 
 [[1.2293063]]
Variance is 
 [[0.00560028]]


In [21]:
from sklearn.linear_model import LinearRegression, SGDRegressor
from sklearn.metrics import mean_squared_error

In [None]:
X = dataset.train.tensors[0].cpu().numpy()
y = dataset.train.tensors[1].cpu().numpy()

X_val = dataset.validation.tensors[0].cpu().numpy()
y_val = dataset.validation.tensors[1].cpu().numpy()

In [35]:
lrg = LinearRegression()
lrg.fit(X, y)

print(f'MSE of classical Linear regression is {mean_squared_error(y_val, lrg.predict(X_val))}')

MSE of classical Linear regression is 18.19565773010254


In [45]:
sgdlrg = SGDRegressor(penalty='none', n_iter=20000, early_stopping=True, learning_rate='constant', eta0=0.000001)
sgdlrg.fit(X, np.ravel(y))

print(f'MSE of SGD Linear regression is {mean_squared_error(y_val, sgdlrg.predict(X_val))}')



MSE of SGD Linear regression is 20.986923980669314
