### Weight Distribution Comparison Metrics

In [None]:
import tensorflow as tf
import torch
import torchvision
import pytorch_lightning as pl
from torch import nn
import statistics
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from math import log2

In [4]:
config = config = {'learning_rate': .001, 'dropout': 0.2, 'batch_size': 64, 'epochs': 25}

In [1]:
def get_tf_weights(config):
    mnist = tf.keras.datasets.mnist

    (x_train, y_train),(x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(28, 28)),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dropout(config['dropout']),
      tf.keras.layers.Dense(10, activation='softmax')
    ])

    opt = tf.keras.optimizers.Adam(learning_rate=config['learning_rate'])

    model.compile(optimizer=opt,
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    res = model.fit(x_train, y_train, epochs=config['epochs'], batch_size=config['batch_size'])
    res_test = model.evaluate(x_test, y_test)
    just_tf_weights = list()
    # get weights
    for w in model.weights:
        just_tf_weights.extend(w.numpy().flatten())
    # scale the weights
    scaled_weights = MinMaxScaler().fit_transform(np.array(just_tf_weights).reshape(-1, 1))+1
    return scaled_weights

In [5]:
class NumberNet(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(), 
            nn.Linear(784, 128), 
            nn.ReLU(), 
            nn.Dropout(config['dropout']), 
            nn.Linear(128, 10)) ### no softmax because it's included in cross entropy loss
        self.criterion = nn.CrossEntropyLoss()
        self.config = config
        self.test_loss = None
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(torchvision.datasets.MNIST("~/resiliency/", train=True, 
                                                                      transform=torchvision.transforms.ToTensor(), target_transform=None, download=True), 
                                           batch_size=int(self.config['batch_size']))
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(torchvision.datasets.MNIST("~/resiliency/", train=True, 
                                                                      transform=torchvision.transforms.ToTensor(), target_transform=None, download=True), 
                                           batch_size=int(self.config['batch_size']))
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config['learning_rate'])
        return optimizer
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        logs = {'train_loss': loss}
        return {'loss': loss}
    
    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        logs = {'test_loss': loss}
        return {'test_loss': loss, 'logs': logs}
    
    def test_epoch_end(self, outputs):
        loss = []
        for x in outputs:
            loss.append(float(x['test_loss']))
        avg_loss = statistics.mean(loss)
        tensorboard_logs = {'test_loss': avg_loss}
        self.test_loss = avg_loss
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

In [7]:
def get_pt_weights(config):
    model = NumberNet(config)
    trainer = pl.Trainer(max_epochs=config['epochs'])
    trainer.fit(model)
    trainer.test(model)
    pt_model_weights = list(model.parameters())
    just_pt_weights = list()
    for w in pt_model_weights:
        just_pt_weights.extend(w.detach().numpy().flatten())
    pt_weights_scaled = MinMaxScaler().fit_transform(np.array(just_pt_weights).reshape(-1, 1))+1

In [8]:
all_tf_weights = list()
for i in range(10):
    all_tf_weights.append(get_tf_weights(config))

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoc

In [15]:
mean_tf_weights = np.mean(all_tf_weights, axis=0)

In [16]:
all_pt_weights = list()
for i in range(10):
    all_pt_weights.append(get_pt_weights(config))

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     


Epoch 0:   0%|          | 3/938 [00:00<00:28, 32.63it/s, loss=2.252, v_num=3]



Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 85.54it/s, loss=0.015, v_num=3]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 85.52it/s, loss=0.015, v_num=3]
Testing:   2%|▏         | 17/938 [00:00<00:05, 169.43it/s]



Testing: 100%|█████████▉| 937/938 [00:06<00:00, 123.93it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.008191703880836199, 'test_loss': 0.008191703880836199}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:06<00:00, 148.89it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   1%|          | 9/938 [00:00<00:13, 66.75it/s, loss=2.134, v_num=4]



Epoch 24: 100%|██████████| 938/938 [00:12<00:00, 74.55it/s, loss=0.016, v_num=4]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:12<00:00, 74.54it/s, loss=0.016, v_num=4]
Testing: 0it [00:00, ?it/s]



Testing:  99%|█████████▉| 932/938 [00:05<00:00, 172.85it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.007606163811752636, 'test_loss': 0.007606163811752636}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:05<00:00, 165.05it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores






  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     


Epoch 0:   1%|▏         | 12/938 [00:00<00:11, 79.73it/s, loss=2.127, v_num=5]



Epoch 24: 100%|██████████| 938/938 [00:13<00:00, 71.89it/s, loss=0.015, v_num=5]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:13<00:00, 71.88it/s, loss=0.015, v_num=5]
Testing:   1%|▏         | 12/938 [00:00<00:07, 119.28it/s]



Testing:  99%|█████████▉| 931/938 [00:06<00:00, 141.98it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.007847462001889358, 'test_loss': 0.007847462001889358}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:06<00:00, 138.11it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   1%|          | 7/938 [00:00<00:16, 56.42it/s, loss=2.170, v_num=6]



Epoch 24: 100%|██████████| 938/938 [00:12<00:00, 75.61it/s, loss=0.011, v_num=6]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:12<00:00, 75.59it/s, loss=0.011, v_num=6]
Testing:   1%|▏         | 13/938 [00:00<00:07, 128.47it/s]



Testing:  99%|█████████▉| 930/938 [00:05<00:00, 155.24it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.006437031994640481, 'test_loss': 0.006437031994640481}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:05<00:00, 159.79it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores






  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     


Epoch 0:   1%|          | 11/938 [00:00<00:13, 70.74it/s, loss=2.111, v_num=7]



Epoch 24: 100%|██████████| 938/938 [00:12<00:00, 73.03it/s, loss=0.023, v_num=7]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:12<00:00, 73.01it/s, loss=0.023, v_num=7]
Testing:   1%|▏         | 12/938 [00:00<00:07, 118.54it/s]



Testing:  99%|█████████▊| 924/938 [00:05<00:00, 176.22it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.007282857068880953, 'test_loss': 0.007282857068880953}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:05<00:00, 160.81it/s]


GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     


Epoch 0:   1%|          | 11/938 [00:00<00:12, 73.70it/s, loss=2.102, v_num=8]



Epoch 24: 100%|██████████| 938/938 [00:12<00:00, 73.06it/s, loss=0.012, v_num=8]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:12<00:00, 73.04it/s, loss=0.012, v_num=8]
Testing:   1%|▏         | 12/938 [00:00<00:08, 115.50it/s]



Testing:  99%|█████████▉| 930/938 [00:06<00:00, 121.01it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.008231987057416054, 'test_loss': 0.008231987057416054}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:06<00:00, 134.17it/s]


GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     


Epoch 0:   1%|          | 9/938 [00:00<00:14, 64.80it/s, loss=2.149, v_num=9]



Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 88.98it/s, loss=0.020, v_num=9]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 88.96it/s, loss=0.020, v_num=9]
Testing:   2%|▏         | 17/938 [00:00<00:05, 168.69it/s]



Testing: 100%|█████████▉| 936/938 [00:04<00:00, 189.46it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.007183375410660567, 'test_loss': 0.007183375410660567}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:04<00:00, 188.11it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   1%|▏         | 13/938 [00:00<00:10, 85.73it/s, loss=2.082, v_num=10]



Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 88.01it/s, loss=0.016, v_num=10]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 87.98it/s, loss=0.016, v_num=10]
Testing:   2%|▏         | 17/938 [00:00<00:05, 163.49it/s]



Testing:  99%|█████████▉| 928/938 [00:04<00:00, 187.14it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.006369947246236282, 'test_loss': 0.006369947246236282}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:04<00:00, 190.06it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   1%|▏         | 13/938 [00:00<00:10, 86.72it/s, loss=2.056, v_num=11]



Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 88.24it/s, loss=0.011, v_num=11]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 88.21it/s, loss=0.011, v_num=11]
Testing:   2%|▏         | 15/938 [00:00<00:06, 149.77it/s]



Testing: 100%|█████████▉| 937/938 [00:04<00:00, 188.83it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.00686777872997404, 'test_loss': 0.00686777872997404}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:04<00:00, 187.94it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   1%|▏         | 14/938 [00:00<00:10, 90.06it/s, loss=2.045, v_num=12]



Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 86.69it/s, loss=0.018, v_num=12]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 86.67it/s, loss=0.018, v_num=12]
Testing:   2%|▏         | 17/938 [00:00<00:05, 164.02it/s]



Testing:  99%|█████████▉| 930/938 [00:05<00:00, 185.87it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.0066498999786107115, 'test_loss': 0.0066498999786107115}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:05<00:00, 184.75it/s]


In [17]:
mean_pt_weights = np.mean(all_pt_weights, axis=0)

TypeError: unsupported operand type(s) for +: 'NoneType' and 'NoneType'

In [None]:
# calculate the kl divergence
def kl_divergence(p, q):
	return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

In [None]:
kl_divergence(mean_tf_weights, mean_pt_weights)