### Weight Distribution Comparison Metrics

In [21]:
import tensorflow as tf
import torch
import torchvision
import pytorch_lightning as pl
from torch import nn
import statistics
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from math import log2

In [22]:
config = {'learning_rate': .001, 'dropout': 0.2, 'batch_size': 64, 'epochs': 25}

In [23]:
def get_tf_weights(config):
    mnist = tf.keras.datasets.mnist

    (x_train, y_train),(x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(28, 28)),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dropout(config['dropout']),
      tf.keras.layers.Dense(10, activation='softmax')
    ])

    opt = tf.keras.optimizers.Adam(learning_rate=config['learning_rate'])

    model.compile(optimizer=opt,
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    res = model.fit(x_train, y_train, epochs=config['epochs'], batch_size=config['batch_size'])
    res_test = model.evaluate(x_test, y_test)
    just_tf_weights = list()
    # get weights
    for w in model.weights:
        just_tf_weights.extend(w.numpy().flatten())
    # scale the weights
    scaled_weights = MinMaxScaler().fit_transform(np.array(just_tf_weights).reshape(-1, 1))+1
    return scaled_weights

In [24]:
class NumberNet(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(), 
            nn.Linear(784, 128), 
            nn.ReLU(), 
            nn.Dropout(config['dropout']), 
            nn.Linear(128, 10)) ### no softmax because it's included in cross entropy loss
        self.criterion = nn.CrossEntropyLoss()
        self.config = config
        self.test_loss = None
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(torchvision.datasets.MNIST("~/resiliency/", train=True, 
                                                                      transform=torchvision.transforms.ToTensor(), target_transform=None, download=True), 
                                           batch_size=int(self.config['batch_size']))
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(torchvision.datasets.MNIST("~/resiliency/", train=True, 
                                                                      transform=torchvision.transforms.ToTensor(), target_transform=None, download=True), 
                                           batch_size=int(self.config['batch_size']))
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config['learning_rate'])
        return optimizer
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        logs = {'train_loss': loss}
        return {'loss': loss}
    
    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        logs = {'test_loss': loss}
        return {'test_loss': loss, 'logs': logs}
    
    def test_epoch_end(self, outputs):
        loss = []
        for x in outputs:
            loss.append(float(x['test_loss']))
        avg_loss = statistics.mean(loss)
        tensorboard_logs = {'test_loss': avg_loss}
        self.test_loss = avg_loss
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

In [25]:
def get_pt_weights(config):
    model = NumberNet(config)
    trainer = pl.Trainer(max_epochs=config['epochs'])
    trainer.fit(model)
    trainer.test(model)
    pt_model_weights = list(model.parameters())
    just_pt_weights = list()
    for w in pt_model_weights:
        just_pt_weights.extend(w.detach().numpy().flatten())
    pt_weights_scaled = MinMaxScaler().fit_transform(np.array(just_pt_weights).reshape(-1, 1))+1
    return pt_weights_scaled

In [26]:
all_tf_weights = list()
for i in range(10):
    all_tf_weights.append(get_tf_weights(config))

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoc

In [27]:
mean_tf_weights = np.mean(all_tf_weights, axis=0)

In [None]:
all_pt_weights = list()
for i in range(10):
    all_pt_weights.append(get_pt_weights(config))

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     


Epoch 0:   0%|          | 4/938 [00:00<00:31, 29.43it/s, loss=2.227, v_num=23]



Epoch 24: 100%|██████████| 938/938 [00:09<00:00, 93.92it/s, loss=0.023, v_num=23] 

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:09<00:00, 93.89it/s, loss=0.023, v_num=23]
Testing:   2%|▏         | 17/938 [00:00<00:05, 164.33it/s]



Testing:  99%|█████████▉| 933/938 [00:05<00:00, 170.08it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.006322422512482839, 'test_loss': 0.006322422512482839}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:05<00:00, 173.89it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   1%|▏         | 14/938 [00:00<00:10, 88.58it/s, loss=2.056, v_num=24]



Epoch 24: 100%|██████████| 938/938 [00:13<00:00, 67.67it/s, loss=0.024, v_num=24]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:13<00:00, 67.61it/s, loss=0.024, v_num=24]
Testing:   1%|▏         | 12/938 [00:00<00:08, 111.34it/s]



Testing:  99%|█████████▉| 929/938 [00:08<00:00, 110.99it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.007391922231718301, 'test_loss': 0.007391922231718301}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:08<00:00, 113.19it/s]


GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     


Epoch 0:   1%|          | 9/938 [00:00<00:13, 68.29it/s, loss=2.149, v_num=26]



Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 92.55it/s, loss=0.020, v_num=26] 

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 92.45it/s, loss=0.020, v_num=26]
Testing:   2%|▏         | 15/938 [00:00<00:06, 145.73it/s]



Testing:  98%|█████████▊| 923/938 [00:06<00:00, 169.16it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.007391686294188387, 'test_loss': 0.007391686294188387}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:06<00:00, 153.89it/s]


GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     


Epoch 0:   2%|▏         | 16/938 [00:00<00:09, 101.38it/s, loss=1.998, v_num=27]



Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 92.85it/s, loss=0.015, v_num=27] 

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 92.75it/s, loss=0.015, v_num=27]
Testing:   2%|▏         | 19/938 [00:00<00:05, 181.41it/s]



Testing:  99%|█████████▉| 933/938 [00:05<00:00, 139.06it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.0068636377817224295, 'test_loss': 0.0068636377817224295}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:06<00:00, 155.51it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   2%|▏         | 16/938 [00:00<00:08, 104.07it/s, loss=2.035, v_num=28]



Epoch 24: 100%|██████████| 938/938 [00:08<00:00, 108.98it/s, loss=0.022, v_num=28]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:08<00:00, 108.93it/s, loss=0.022, v_num=28]
Testing:   2%|▏         | 20/938 [00:00<00:04, 189.66it/s]



Testing: 100%|█████████▉| 936/938 [00:05<00:00, 165.44it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.008644895186758976, 'test_loss': 0.008644895186758976}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:05<00:00, 171.56it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   2%|▏         | 16/938 [00:00<00:10, 90.73it/s, loss=2.017, v_num=32]



Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 86.74it/s, loss=0.016, v_num=32] 

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:10<00:00, 86.72it/s, loss=0.016, v_num=32]
Testing:   2%|▏         | 18/938 [00:00<00:05, 172.34it/s]



Testing:  99%|█████████▉| 929/938 [00:06<00:00, 186.94it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.005707650057138574, 'test_loss': 0.005707650057138574}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:06<00:00, 153.24it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   2%|▏         | 17/938 [00:00<00:08, 106.91it/s, loss=1.938, v_num=35]



Epoch 24: 100%|██████████| 938/938 [00:11<00:00, 81.50it/s, loss=0.020, v_num=35] 

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:11<00:00, 81.48it/s, loss=0.020, v_num=35]
Testing:   2%|▏         | 16/938 [00:00<00:06, 153.02it/s]



Testing:  99%|█████████▉| 930/938 [00:05<00:00, 174.09it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.00758886054557349, 'test_loss': 0.00758886054557349}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:05<00:00, 171.03it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   2%|▏         | 17/938 [00:00<00:08, 106.86it/s, loss=1.980, v_num=37]



Epoch 24: 100%|██████████| 938/938 [00:09<00:00, 97.40it/s, loss=0.013, v_num=37]

Saving latest checkpoint..


Epoch 24: 100%|██████████| 938/938 [00:09<00:00, 97.37it/s, loss=0.013, v_num=37]
Testing:   2%|▏         | 19/938 [00:00<00:05, 183.30it/s]



Testing:  99%|█████████▉| 928/938 [00:05<00:00, 187.28it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': 0.006858440873819841, 'test_loss': 0.006858440873819841}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 938/938 [00:05<00:00, 184.29it/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 101 K 
1 | criterion | CrossEntropyLoss | 0     



Epoch 0:   2%|▏         | 16/938 [00:00<00:09, 100.63it/s, loss=1.986, v_num=40]



Epoch 15:  51%|█████▏    | 481/938 [00:05<00:05, 86.81it/s, loss=0.035, v_num=40] 

In [None]:
all_pt_weights

In [None]:
mean_pt_weights = np.mean(all_pt_weights, axis=0)

In [None]:
# calculate the kl divergence
def kl_divergence(p, q):
	return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

In [None]:
kl_divergence(mean_tf_weights, mean_pt_weights)