### Utility Section

In [None]:
# clone main repo
%cd /content
!git clone https://github.com/har07/ngld-experiments.git
%cd ngld-experiments
!git checkout calibration_metrics
# %cd calibration
# !mkdir plots
!mkdir trained

# connect gdrive because training result will be saved to gdrive
# and codes for plotting will read the result from gdrive
from google.colab import drive
drive.mount('/content/drive')

# load tensorboard extension
%load_ext tensorboard

### Training

In [None]:
!mkdir trained

In [None]:
# !python train_batch.py -y config/batch_mnist.yaml -p /content/ngld-experiments/trained
!python train_batch.py -y config/batch_cifar10.yaml -p /content/ngld-experiments/trained

optimizer:  optim.SGD
optimizer params:  {'lr': 0.1}
current_lr:  0.1
Epoch: 1	Train Sec: 156.776	Loss: 1.597	Acc: 43.500	Val Acc: 42.320
current_lr:  0.1
Epoch: 2	Train Sec: 145.631	Loss: 0.938	Acc: 66.500	Val Acc: 64.300
current_lr:  0.1
Epoch: 3	Train Sec: 151.298	Loss: 0.718	Acc: 75.000	Val Acc: 70.140
current_lr:  0.1
Epoch: 4	Train Sec: 145.708	Loss: 0.835	Acc: 73.000	Val Acc: 66.630
current_lr:  0.1
Epoch: 5	Train Sec: 151.187	Loss: 0.708	Acc: 75.000	Val Acc: 75.100
current_lr:  0.1
Epoch: 6	Train Sec: 145.567	Loss: 0.769	Acc: 76.000	Val Acc: 79.090
current_lr:  0.1
Epoch: 7	Train Sec: 151.022	Loss: 0.500	Acc: 83.500	Val Acc: 79.530
current_lr:  0.1
Epoch: 8	Train Sec: 145.686	Loss: 0.440	Acc: 84.000	Val Acc: 78.860
current_lr:  0.1
Epoch: 9	Train Sec: 151.175	Loss: 0.363	Acc: 89.000	Val Acc: 80.370
current_lr:  0.1
Epoch: 10	Train Sec: 145.462	Loss: 0.407	Acc: 86.500	Val Acc: 82.250
epoch duration (mean +/- std): 148.95 +/- 3.70


In [None]:
!mkdir /content/drive/MyDrive/MyExperiments/cifar10_resnet18/runs
!mkdir /content/drive/MyDrive/MyExperiments/cifar10_resnet18/notes
!mv trained/*.pt /content/drive/MyDrive/MyExperiments/cifar10_resnet18
!mv trained/*.txt /content/drive/MyDrive/MyExperiments/cifar10_resnet18/notes
!mv runs/* /content/drive/MyDrive/MyExperiments/cifar10_resnet18/runs
!cp config/batch_cifar10.yaml /content/drive/MyDrive/MyExperiments/cifar10_resnet18/notes

### Evaluate Accuracy and Calibration

In [None]:
%cd calibration

In [None]:
!rm -rf plots
!mkdir plots

In [None]:
# python evaluate.py -n 10 -ds mnist -m lib.model.MnistModel -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
#   -d /content/drive/MyDrive/MyExperiments/mnist_lenet4

# python evaluate.py -n 10 -ds cifar10 -m resnet.LeNet -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
#   -d /content/drive/MyDrive/MyExperiments/cifar10_lenet5

python evaluate.py -n 10 -ds cifar10 -m resnet.ResNet18 -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
  -d /content/drive/MyDrive/MyExperiments/cifar10_resnet18

In [None]:
# !cp plots/* /content/drive/MyDrive/MyExperiments/mnist_lenet4/plots
# !cp plots/* /content/drive/MyDrive/MyExperiments/cifar10_lenet5/plots
!cp plots/* /content/drive/MyDrive/MyExperiments/cifar10_resnet18/plots

### Evaluate Performance on OOD

In [None]:
# %cd /content/ngld-experiments/calibration
!rm -rf ood
!rm -rf img_plots
!mkdir ood
!mkdir img_plots

#### Calculate Metrics

In [None]:
# prepare notMNIST raw data:
!tar -xzf /content/drive/MyDrive/Tesis/dataset/notMNIST_small.tar.gz \
  -C /content/ngld-experiments/calibration
!mv notMNIST_small notmnist_data

In [None]:
# %run ood.py -n 10 -ds notmnist -m lib.model.MnistModel -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
#   -d /content/drive/MyDrive/MyExperiments/mnist_lenet4

# %run ood.py -n 10 -ds svhn -m lib.model.LeNet -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
#   -d /content/drive/MyDrive/MyExperiments/cifar10_lenet5

%run ood.py -n 10 -ds svhn -m resnet.ResNet18 -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
  -d /content/drive/MyDrive/MyExperiments/cifar10_resnet18

In [None]:
!mv ood /content/drive/MyDrive/MyExperiments/cifar10_resnet18

#### Plotting

In [None]:
# prepare plot configs
import matplotlib.pyplot as plt
import torch
import numpy as np
import glob

# dataset_model = "mnist_lenet4"
# dataset_model = "cifar10_lenet5"
dataset_model = "cifar10_resnet18"
stats_path = f'/content/drive/MyDrive/MyExperiments/{dataset_model}/ood/stats_*.pt'
path = glob.glob(stats_path)[0]
config = {
    "legend_loc": "upper left",
    "ylabel": r"accuracy on examples $p(y|x) \geq \tau $",
    "xlabel": r"$ \tau $",
    "scalar_data": [{
        "plot_label": "EKSGLD",
        "marker": "o"
    }, {
        "plot_label": "KSGLD",
        "marker": "*"
    }, {
        "plot_label": "pSGLD",
        "marker": "D"
    }, {
        "plot_label": "ASGLD",
        "marker": ">"
    }
    , {
        "plot_label": "SGLD",
        "marker": "<"
    }
    , {
        "plot_label": "SGD",
        "marker": "s"
    }
    ]
}

chk = torch.load(path)
thresholds = chk['thresholds']
entropies = chk['entropies']
samples = chk['samples']

In [None]:
# plot conf vs number of samples
plt.figure(dpi=600)
for cfg in config["scalar_data"]:
    if cfg["plot_label"] != "EKSGLD":
        plt.plot(thresholds[:-1],samples[cfg["plot_label"]][:-1], label=cfg["plot_label"], linestyle=':', marker=cfg["marker"], markersize=4, alpha=0.5)
    else:
        plt.plot(thresholds[:-1],samples[cfg["plot_label"]][:-1], label=cfg["plot_label"], marker=cfg["marker"], markersize=4)

plt.xlabel(config["xlabel"])
plt.ylabel(r'number of samples $p(y|x) \geq \tau $')
# plt.ylim(bottom=minvalue, top=maxvalue)
plt.legend(loc='best')
plt.savefig(f"img_plots/samples_confthres_{dataset_model}.png", bbox_inches='tight')
plt.savefig(f"img_plots/samples_confthres_{dataset_model}.pdf", bbox_inches='tight')
plt.show()

In [None]:
# plot entropy

plt.figure(dpi=600)
def_cycler = plt.rcParams['axes.prop_cycle']
cycle_iter = iter(def_cycler)
for cfg in config["scalar_data"]:
  color = next(cycle_iter)['color']
  entropy_cumm = torch.cat(entropies[cfg['plot_label']])
  entropy_cumm_np = entropy_cumm.cpu().numpy()
  n,x,line = plt.hist(entropy_cumm_np, histtype='step', bins=20, alpha=0., color=color)
  bin_centers = 0.5*(x[1:]+x[:-1])
  bin_centers = np.append(0, bin_centers)
  bin_centers = np.append(bin_centers, x[-1])
  n = np.append(0, n)
  n = np.append(n, 0)

  if cfg["plot_label"] != "EKSGLD":
      plt.plot(bin_centers, n, label=cfg["plot_label"], linestyle=':', alpha=0.5)
  else:
      plt.plot(bin_centers, n, label=cfg["plot_label"])

plt.xlabel('entropy')
plt.ylabel(r'number of samples')
# plt.ylim(bottom=minvalue, top=maxvalue)
plt.legend(loc='best')
plt.savefig(f"img_plots/samples_entropy_{dataset_model}.png", bbox_inches='tight')
plt.savefig(f"img_plots/samples_entropy_{dataset_model}.pdf", bbox_inches='tight')
plt.show()

In [None]:
!cp img_plots/*.* /content/drive/MyDrive/MyExperiments/plots_img

### Evaluate Performance under Distribution Shift

In [None]:
# %cd /content/ngld-experiments/calibration
# %cd calibration
!rm -rf distrib_shift
!rm -rf img_plots
!mkdir distrib_shift
!mkdir img_plots

#### Calculate Metrics

In [None]:
# %run distrib_shift.py -n 10 -ds mnist -m lib.model.MnistModel -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
#   -d /content/drive/MyDrive/MyExperiments/mnist_lenet4

# %run distrib_shift.py -n 10 -ds cifar10 -m lib.model.LeNet -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
#   -d /content/drive/MyDrive/MyExperiments/cifar10_lenet5

%run distrib_shift.py -n 10 -ds cifar10 -m resnet.ResNet18 -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
  -d /content/drive/MyDrive/MyExperiments/cifar10_resnet18

In [None]:
# !mv distrib_shift /content/drive/MyDrive/MyExperiments/mnist_lenet4
# !mv distrib_shift /content/drive/MyDrive/MyExperiments/cifar10_lenet5
!mv distrib_shift /content/drive/MyDrive/MyExperiments/cifar10_resnet18

#### Plotting

In [None]:
# prepare plot configs
import metrics
import matplotlib.pyplot as plt
import torch
import numpy as np
import glob

# dataset_model = "mnist_lenet4"
# dataset_model = "cifar10_lenet5"
dataset_model = "cifar10_resnet18"
stats_path = f'/content/drive/MyDrive/MyExperiments/{dataset_model}/distrib_shift/stats_*.pt'
path = glob.glob(stats_path)[0]
config = {
    "legend_loc": "upper left",
    "ylabel": r"accuracy on examples $p(y|x) \geq \tau $",
    "xlabel": r"$ \tau $",
    "scalar_data": [{
        "plot_label": "EKSGLD",
        "marker": "o"
    }, {
        "plot_label": "KSGLD",
        "marker": "*"
    }, {
        "plot_label": "pSGLD",
        "marker": "D"
    }, {
        "plot_label": "ASGLD",
        "marker": ">"
    }
    , {
        "plot_label": "SGLD",
        "marker": "<"
    }
    , {
        "plot_label": "SGD",
        "marker": "s"
    }
    ]
}

chk = torch.load(path)
rotations = chk['rotations']
accuracies = chk['accuracies']
labels = chk['labels']
pred_probs = chk['pred_probs']
nlls = chk['nll']

In [None]:
# plot ECE vs shift intensity

plt.figure(dpi=600)
x = [str(r) for r in rotations]
for cfg in config["scalar_data"]:
  optimizer = cfg["plot_label"]
  ece_scores = []
  for i in range(len(rotations)):
    ece_criterion = metrics.ECELoss()
    ece_score = ece_criterion.loss(pred_probs[optimizer][i],labels[optimizer][i], 15, logits=False)
    ece_scores.append(ece_score)
  
  if cfg["plot_label"] != "EKSGLD":
      plt.plot(x[:],ece_scores[:], label=optimizer, linestyle=':', marker=cfg["marker"], markersize=4, alpha=0.5)
  else:
      plt.plot(x[:],ece_scores[:], label=optimizer, marker=cfg["marker"], markersize=4)

plt.xlabel('rotational degree')
plt.ylabel(r'ECE')
plt.legend(loc='best')
plt.savefig(f"img_plots/ece_shift_{dataset_model}.png", bbox_inches='tight')
plt.savefig(f"img_plots/ece_shift_{dataset_model}.pdf", bbox_inches='tight')
plt.show()

In [None]:
# plot Accuracy vs shift intensity

plt.figure(dpi=600)
x = [str(r) for r in rotations]
for cfg in config["scalar_data"]:
  optimizer = cfg["plot_label"]

  if cfg["plot_label"] != "EKSGLD":
      plt.plot(x[:],accuracies[optimizer][:], label=optimizer, linestyle=':', marker=cfg["marker"], markersize=4, alpha=0.5)
  else:
      plt.plot(x[:],accuracies[optimizer][:], label=optimizer, marker=cfg["marker"], markersize=4)


plt.xlabel('rotational degree')
plt.ylabel(r'Accuracy')
plt.legend(loc='best')
plt.savefig(f"img_plots/acc_shift_{dataset_model}.png", bbox_inches='tight')
plt.savefig(f"img_plots/acc_shift_{dataset_model}.pdf", bbox_inches='tight')
plt.show()

In [None]:
# plot Brier vs shift intensity
def brier_multi(targets, probs):
  return np.mean(np.sum((probs - targets)**2, axis=1))

plt.figure(dpi=600)
x = [str(r) for r in rotations]
for cfg in config["scalar_data"]:
  optimizer = cfg["plot_label"]
  brier_scores = []
  for i in range(len(rotations)):
    onehot_label = np.eye(10)[labels[optimizer][i]]
    brier_score = brier_multi(pred_probs[optimizer][i], onehot_label)
    brier_scores.append(brier_score)
  
  if cfg["plot_label"] != "EKSGLD":
      plt.plot(x[:],brier_scores[:], label=optimizer, linestyle=':', marker=cfg["marker"], markersize=4, alpha=0.5)
  else:
      plt.plot(x[:],brier_scores[:], label=optimizer, marker=cfg["marker"], markersize=4)

plt.xlabel('rotational degree')
plt.ylabel(r'Brier')
plt.legend(loc='best')
plt.savefig(f"img_plots/brier_shift_{dataset_model}.png", bbox_inches='tight')
plt.savefig(f"img_plots/brier_shift_{dataset_model}.pdf", bbox_inches='tight')
plt.show()

In [None]:
!cp img_plots/*.* /content/drive/MyDrive/MyExperiments/plots_img

### Plot Performance from TFEvents

#### Accuracy vs Epoch

In [None]:
!mkdir img_plots

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import numpy as np
import argparse
import yaml
import glob
import matplotlib.pyplot as plt

# Loading too much data is slow...
tf_size_guidance = {
    'compressedHistograms': 10,
    'images': 0,
    'scalars': 100,
    'histograms': 1
}

minvalue = 92
maxvalue = 99.6

# dataset_model = "mnist_lenet4"
# dataset_model = "cifar10_lenet5"
dataset_model = "cifar10_resnet18"
base_path = f"/content/drive/MyDrive/MyExperiments/{dataset_model}/runs"

config = {
    "legend_loc": "lower right",
    "ylabel": "accuracy",
    "xlabel": "epoch",
    "scalar_data": [{
        "plot_label": "EKSGLD",
        "marker": "o"
    }, {
        "plot_label": "KSGLD",
        "marker": "*"
    }, {
        "plot_label": "pSGLD",
        "marker": "D"
    }, {
        "plot_label": "ASGLD",
        "marker": ">"
    }, {
        "plot_label": "SGLD",
        "marker": "<"
    }, {
        "plot_label": "SGD",
        "marker": "s"
    }]
}

plt.figure(dpi=600)
for cfg in config["scalar_data"]:
    path = base_path + f'/*.{cfg["plot_label"]}_*/*.*'
    x = []
    y = []
    for data_path in glob.glob(path):
        event_acc = EventAccumulator(data_path, tf_size_guidance)
        event_acc.Reload()

        val_acc = event_acc.Scalars("Acc/train")

        for val_acc_i in val_acc:
            y.append(val_acc_i[2])
            x.append(val_acc_i.step)

        if cfg["plot_label"] != "EKSGLD":
            plt.plot(x[:],y[:], label=cfg["plot_label"], linestyle=':', marker=cfg["marker"], markersize=4, alpha=0.5)
        else:
            plt.plot(x[:],y[:], label=cfg["plot_label"], marker=cfg["marker"], markersize=4)

plt.xlabel("epoch")
plt.ylabel("accuracy")
# plt.ylim(bottom=minvalue, top=maxvalue)
plt.legend(loc="best")
plt.savefig(f"img_plots/acc_epoch_{dataset_model}.png", bbox_inches='tight')
plt.savefig(f"img_plots/acc_epoch_{dataset_model}.pdf", bbox_inches='tight')
plt.show()

#### Accuracy vs Wall-Clock Time

In [None]:
# %cd calibration
!rm -rf img_plots
!mkdir img_plots

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import numpy as np
import argparse
import yaml
import glob
import matplotlib.pyplot as plt

# Loading too much data is slow...
tf_size_guidance = {
    'compressedHistograms': 10,
    'images': 0,
    'scalars': 100,
    'histograms': 1
}

minvalue = 92
maxvalue = 99.6

# dataset_model = "mnist_lenet4"
# dataset_model = "cifar10_lenet5"
dataset_model = "cifar10_resnet18"
base_path = f"/content/drive/MyDrive/MyExperiments/{dataset_model}/runs"

config = {
    "legend_loc": "lower right",
    "ylabel": "accuracy",
    "xlabel": "epoch",
    "scalar_data": [{
        "plot_label": "EKSGLD",
        "marker": "o"
    }, {
        "plot_label": "KSGLD",
        "marker": "*"
    }, {
        "plot_label": "pSGLD",
        "marker": "D"
    }, {
        "plot_label": "ASGLD",
        "marker": ">"
    }, {
        "plot_label": "SGLD",
        "marker": "<"
    }, {
        "plot_label": "SGD",
        "marker": "s"
    }]
}

plt.figure(dpi=600)
for cfg in config["scalar_data"]:
    path = base_path + f'/*.{cfg["plot_label"]}_*/*.*'
    x = []
    y = []
    for data_path in glob.glob(path):
        event_acc = EventAccumulator(data_path, tf_size_guidance)
        event_acc.Reload()

        val_acc = event_acc.Scalars("Acc/train")
        
        start = val_acc[0].wall_time
        for val_acc_i in val_acc:
            duration = (val_acc_i.wall_time-start)
            y.append(val_acc_i[2])
            x.append(duration)

        if cfg["plot_label"] != "EKSGLD":
            plt.plot(x[:],y[:], label=cfg["plot_label"], linestyle=':', marker=cfg["marker"], markersize=4, alpha=0.5)
        else:
            plt.plot(x[:],y[:], label=cfg["plot_label"], marker=cfg["marker"], markersize=4)

plt.xlabel("wall clock time in seconds")
plt.ylabel("accuracy")
# plt.ylim(bottom=minvalue, top=maxvalue)
plt.legend(loc="best")
plt.savefig(f"img_plots/acc_clock_{dataset_model}.png", bbox_inches='tight')
plt.savefig(f"img_plots/acc_clock_{dataset_model}.pdf", bbox_inches='tight')
plt.show()

In [None]:
!cp img_plots/*.* /content/drive/MyDrive/MyExperiments/plots_img

### Plot Accuracy vs Confidence Threshold

In [None]:
# %cd calibration
!rm -rf conf_threshold
!mkdir conf_threshold
!rm -rf img_plots
!mkdir img_plots

In [None]:
# %run conf_threshold.py -n 10 -ds mnist -m lib.model.MnistModel -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
#   -d /content/drive/MyDrive/MyExperiments/mnist_lenet4

# %run conf_threshold.py -n 10 -ds cifar10 -m lib.model.LeNet -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
#   -d /content/drive/MyDrive/MyExperiments/cifar10_lenet5

%run conf_threshold.py -n 10 -ds cifar10 -m resnet.ResNet18 -o EKSGLD,KSGLD,pSGLD,ASGLD,SGLD,SGD \
  -d /content/drive/MyDrive/MyExperiments/cifar10_resnet18

#### Plotting

In [None]:
# prepare plot configs
import matplotlib.pyplot as plt
import torch
import numpy as np
import glob

# dataset_model = "mnist_lenet4"
# dataset_model = "cifar10_lenet5"
dataset_model = "cifar10_resnet18"
base_path = f"/content/drive/MyDrive/MyExperiments/{dataset_model}/conf_threshold/stats_*.pt"
path = glob.glob(base_path)[0]
config = {
    "legend_loc": "upper left",
    "ylabel": r"accuracy on examples $p(y|x) \geq \tau $",
    "xlabel": r"$ \tau $",
    "scalar_data": [{
        "plot_label": "EKSGLD",
        "marker": "o"
    }, {
        "plot_label": "KSGLD",
        "marker": "*"
    }, {
        "plot_label": "pSGLD",
        "marker": "D"
    }, {
        "plot_label": "ASGLD",
        "marker": ">"
    }
    , {
        "plot_label": "SGLD",
        "marker": "<"
    }
    , {
        "plot_label": "SGD",
        "marker": "s"
    }
    ]
}

chk = torch.load(path)
thresholds = chk['thresholds']
accuracies = chk['accuracies']
entropies = chk['entropies']
samples = chk['samples']

In [None]:
# plot conf vs accuracy
plt.figure(dpi=600)
for cfg in config["scalar_data"]:
  if cfg["plot_label"] != "EKSGLD":
      plt.plot(thresholds[:],accuracies[cfg["plot_label"]][:], label=cfg["plot_label"], linestyle=':', marker=cfg["marker"], markersize=4, alpha=0.5)
  else:
      plt.plot(thresholds[:],accuracies[cfg["plot_label"]][:], label=cfg["plot_label"], marker=cfg["marker"], markersize=4)

plt.xlabel(config["xlabel"])
plt.ylabel(config["ylabel"])
plt.legend(loc='best')
plt.savefig(f"img_plots/indistrib_acc_confthres_{dataset_model}.png", bbox_inches='tight')
plt.savefig(f"img_plots/indistrib_acc_confthres_{dataset_model}.pdf", bbox_inches='tight')
plt.show()

In [None]:
# plot conf vs number of samples
plt.figure(dpi=600)
for cfg in config["scalar_data"]:
  if cfg["plot_label"] != "EKSGLD":
      plt.plot(thresholds[:],samples[cfg["plot_label"]][:], label=cfg["plot_label"], linestyle=':', marker=cfg["marker"], markersize=4, alpha=0.5)
  else:
      plt.plot(thresholds[:],samples[cfg["plot_label"]][:], label=cfg["plot_label"], marker=cfg["marker"], markersize=4)

plt.xlabel(config["xlabel"])
plt.ylabel(r'number of samples $p(y|x) \geq \tau $')
plt.legend(loc='best')
plt.savefig(f"img_plots/indistrib_samples_confthres_{dataset_model}.png", bbox_inches='tight')
plt.savefig(f"img_plots/indistrib_samples_confthres_{dataset_model}.pdf", bbox_inches='tight')
plt.show()

In [None]:
# plot entropy
plt.figure(dpi=600)
def_cycler = plt.rcParams['axes.prop_cycle']
cycle_iter = iter(def_cycler)
for cfg in config["scalar_data"]:
  color = next(cycle_iter)['color']
  entropy_cumm = torch.cat(entropies[cfg['plot_label']])
  entropy_cumm_np = entropy_cumm.cpu().numpy()
  n,x,_ = plt.hist(entropy_cumm_np, histtype='step', bins=20, alpha=0., color=color)
  bin_centers = 0.5*(x[1:]+x[:-1])
  bin_centers = np.append(0, bin_centers)
  bin_centers = np.append(bin_centers, x[-1])
  n = np.append(0, n)
  n = np.append(n, 0)

  if cfg["plot_label"] != "EKSGLD":
      plt.plot(bin_centers, n, label=cfg["plot_label"], color=color, linestyle=':', alpha=0.5)
  else:
      plt.plot(bin_centers, n, label=cfg["plot_label"], color=color)
      
plt.xlabel('entropy')
plt.ylabel(r'number of samples')
plt.legend(loc='best')
plt.savefig(f"img_plots/indistrib_entropy_{dataset_model}.png", bbox_inches='tight')
plt.savefig(f"img_plots/indistrib_entropy_{dataset_model}.pdf", bbox_inches='tight')
plt.show()

In [None]:
!cp img_plots/*.* /content/drive/MyDrive/MyExperiments/plots_img