<a href="https://colab.research.google.com/github/edmundlth/local_learning_coefficient_estimation/blob/main/Bert_RLCT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torch
!pip install transformers datasets --quiet
import transformers
from transformers import AutoModel, BertForSequenceClassification, AutoTokenizer
from datasets import load_dataset
import pandas as pd

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/7.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/7.2 MB[0m [31m89.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m7.2/7.2 MB[0m [31m132.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m86.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m54.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m116.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m89.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━

In [4]:
!pip install engineering_notation
from engineering_notation import EngNumber

Collecting engineering_notation
  Downloading engineering_notation-0.8.0-py3-none-any.whl (6.6 kB)
Installing collected packages: engineering_notation
Successfully installed engineering_notation-0.8.0


In [5]:
import decimal
import torch
from copy import deepcopy
import torch
import numpy as np
import time

class MNISTExperiment(object):
    def __init__(
        self,
        net,
        trainloader,
        testloader,
        optimizer,
        device,
        sgld_num_chains=4,
        sgld_num_iter=100,
        sgld_gamma=None,
        sgld_noise_std=1e-5,
    ):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader
        self.optimizer = optimizer
        self.device = device

        self.sgld_num_chains = sgld_num_chains
        self.sgld_num_iter = sgld_num_iter
        self.sgld_gamma = sgld_gamma
        self.sgld_noise_std = sgld_noise_std

        self.batch_size = trainloader.batch_size
        self.total_train = len(self.trainloader.dataset)

        self.trainloader_iter = iter(self.trainloader)

        self.records = {
            "lfe": [],
            "energy": [],
            "hatlambda": [],
            "test_error": [],
            "train_error": []
        }

    def eval(self, dataloader):
        correct = 0
        total = 0
        with torch.no_grad():
            for data in dataloader:
                inputs, labels = data[0].to(self.device), data[1].to(self.device)
                outputs = self.net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return correct / total

    def _generate_next_training_batch(self):
        try:
            data = next(self.trainloader_iter)
        except StopIteration:
            self.trainloader_iter = iter(self.trainloader)
            data = next(self.trainloader_iter)
        inputs, labels = data[0].to(self.device), data[1].to(self.device)
        return inputs, labels

    def closure(self):
        inputs, labels = self._generate_next_training_batch()
        self.optimizer.zero_grad()
        outputs = self.net(inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        return loss, inputs, labels

    def compute_energy(self):
        # this is nL_n,k, sum of the losses at w^* found so far
        energies = []
        with torch.no_grad():
            for data in self.trainloader:
                inputs, labels = data[0].to(self.device), data[1].to(self.device)
                outputs = self.net(inputs, labels=labels)
                loss = outputs.loss
                energies.append(loss.item() * self.batch_size)
        return sum(energies)

    def compute_local_free_energy(
        self, num_iter=100, num_chains=1, gamma=None, epsilon=1e-5, verbose=True
    ):
        model_copy = deepcopy(self.net)
        gamma_dict = {}
        if gamma is None:
            with torch.no_grad():
                for name, param in model_copy.named_parameters():
                    gamma_val = 100.0 / torch.linalg.norm(param)
                    gamma_dict[name] = gamma_val


        chain_Lms = []
        for chain in range(num_chains):
            model_copy = deepcopy(self.net)
            og_params = deepcopy(dict(model_copy.named_parameters()))
            Lms = []
            for _ in range(num_iter):
                with torch.enable_grad():
                    # call a minibatch loss backward
                    # so that we have gradient of average minibatch loss with respect to w'
                    inputs, labels = self._generate_next_training_batch()
                    outputs = model_copy(inputs, labels=labels)
                    loss = outputs.loss
                    loss.backward()
                for name, w in model_copy.named_parameters():
                    w_og = og_params[name]
                    dw = -w.grad.data / np.log(self.total_train) * self.total_train
                    if gamma is None:
                        prior_weight = gamma_dict[name]
                    else:
                        prior_weight = gamma
                    dw.add_(w.data - w_og.data, alpha=-prior_weight)
                    w.data.add_(dw, alpha=epsilon / 2)
                    gaussian_noise = torch.empty_like(w)
                    gaussian_noise.normal_()
                    w.data.add_(gaussian_noise, alpha=np.sqrt(epsilon))
                    w.grad.zero_()
                Lms.append(loss.item())
            chain_Lms.append(Lms)
            if verbose:
                print(f"Chain {chain + 1}: L_m = {np.mean(Lms)}")

        chain_Lms = np.array(chain_Lms)
        local_free_energy = self.total_train * np.mean(chain_Lms)
        if verbose:
            chain_std = np.std(self.total_train * np.mean(chain_Lms, axis=1))
            print(
                f"LFE: {EngNumber(local_free_energy)} (std: {EngNumber(chain_std)}, n_chain={num_chains})"
            )
        return local_free_energy, chain_std

    def _record_epoch(self):
        local_free_energy, energy, hatlambda = self.compute_fenergy_energy_rlct()
        self.records["lfe"].append(local_free_energy)
        self.records["energy"].append(energy)
        self.records["hatlambda"].append(hatlambda)
        test_err = 1 - self.eval(self.testloader)
        train_err = 1 - self.eval(self.trainloader)

        self.records["test_error"].append(test_err)
        self.records["train_error"].append(train_err)
        epoch = len(self.records["test_error"])
        print(
            f"Epoch: {epoch} "
            f"energy: {energy:.4f} "
            f"hatlambda: {hatlambda:.4f} "
            f"test error: {test_err:.4f} "
            f"train error: {train_err:.4f} "
        )
        return

    def compute_fenergy_energy_rlct(self):
        energy = self.compute_energy()
        local_free_energy, local_free_energy_std = self.compute_local_free_energy(
            self.sgld_num_iter,
            self.sgld_num_chains,
            self.sgld_gamma,
            self.sgld_noise_std,
        )
        lfe_standard_error = local_free_energy_std/(self.sgld_num_chains)**0.5

        local_free_energy_lower_bound = local_free_energy - lfe_standard_error*2
        local_free_energy_upper_bound = local_free_energy + lfe_standard_error*2

        hatlambda = (local_free_energy - energy) / np.log(self.total_train)
        hatlambda_lower = (local_free_energy_lower_bound - energy) / np.log(self.total_train)
        hatlambda_upper = (local_free_energy_upper_bound - energy) / np.log(self.total_train)
        return local_free_energy, energy, hatlambda, hatlambda_lower, hatlambda_upper

    def run_entropy_sgd(self, esgd_L, num_epoch):
        print("Running Entropy-SGD optimizer")
        # errors, lfes, energies, lmbdas = [], [], [], []

        for epoch in range(num_epoch):  # loop over the dataset multiple times
            start_time = time.time()
            for _ in range(len(self.trainloader) // esgd_L):
                # len(self.trainloader) is the number of minibatches,
                # division by L is to make the same number of passes as plain SGD below
                self.optimizer.step(self.closure)
            self._record_epoch()
            print(f"Finished epoch {epoch + 1} / {num_epoch}, time taken: {time.time() - start_time:.3f}")
        return self.records

    def run_sgd(self, num_epoch):
        print("Running SGD optimizer")
        # SGD should be run L times longer to be fair comparison with entropy-SGD
        # loop over the dataset multiple times
        for epoch in range(num_epoch):
            start_time = time.time()
            for data in self.trainloader:
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data[0].to(self.device), data[1].to(self.device)

                # zero the parameter gradients
                self.optimizer.zero_grad()

                # forward + backward + optimize
                outputs = self.net(inputs, labels=labels)
                loss = outputs.loss
                loss.backward()
                self.optimizer.step()
            self._record_epoch()
            print(f"Finished epoch {epoch + 1} / {num_epoch}, time taken: {time.time() - start_time:.3f}")
        return self.records


In [12]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# Assuming you have your train data and labels as PyTorch tensors.
# train_data and train_labels should be of type torch.Tensor
# If they are numpy arrays, you can convert them to torch tensors using torch.from_numpy() function

# loading training data
train_data = torch.load("drive/MyDrive/bert_slt_rlct/mnli_training_data_subset.pth")
train_labels = torch.load("drive/MyDrive/bert_slt_rlct/mnli_training_labels_subset.pth")

train_dataset = TensorDataset(train_data, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# loading test data
test_data = torch.load("drive/MyDrive/bert_slt_rlct/mnli_testing_data_subset.pth")
test_labels = torch.load("drive/MyDrive/bert_slt_rlct/mnli_testing_labels_subset.pth")

test_dataset = TensorDataset(train_data, train_labels)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Models
models = list()
for i in range(1, 6):
  models.append(torch.load(f"drive/MyDrive/bert_slt_rlct/non_overfit_mnli_small_bert_{i}.pth"))

In [26]:
model_rlct_low = list()
model_rlct_high = list()
model_rlct_mid = list()
for model in models:
  optimizer = transformers.optimization.AdamW(model.parameters(), lr = 2e-3, weight_decay = 0)
  device = "cuda:0"
  net = MNISTExperiment(model,
                        train_dataloader,
                        test_dataloader,
                        optimizer,
                        device,
                        sgld_num_chains = 200,
                        sgld_num_iter = 5,
                        sgld_noise_std=1e-5,
                        )

  fenergy, energy, rlct, lower, upper = net.compute_fenergy_energy_rlct()
  print()
  print("---")
  print(f"Data for model {i}:")
  print(f"95% lambdahat confidence interval: {EngNumber(lower)}-{EngNumber(upper)}")
  print(f"Mean lambdahat: {EngNumber(rlct)}")
  print(f"Free energy: {EngNumber(fenergy)}")
  print(f"Energy: {EngNumber(energy)}")
  print("---")
  print()
  model_rlct_low.append(lower)
  model_rlct_high.append(upper)
  model_rlct_mid.append(rlct)



Chain 1: L_m = 1.009008002281189
Chain 2: L_m = 1.269812786579132
Chain 3: L_m = 0.3481311917304993
Chain 4: L_m = 1.4019922733306884
Chain 5: L_m = 1.0784075856208801
Chain 6: L_m = 0.6516258955001831
Chain 7: L_m = 1.0560375094413756
Chain 8: L_m = 0.2801062077283859
Chain 9: L_m = 0.8994292616844177
Chain 10: L_m = 0.2767772823572159
Chain 11: L_m = 0.32988073527812956
Chain 12: L_m = 0.5943172633647918
Chain 13: L_m = 0.3725042551755905
Chain 14: L_m = 0.3763133525848389
Chain 15: L_m = 0.6130687952041626
Chain 16: L_m = 0.3781674563884735
Chain 17: L_m = 1.8827911168336868
Chain 18: L_m = 0.3250880569219589
Chain 19: L_m = 0.8363217681646347
Chain 20: L_m = 0.5841808140277862
Chain 21: L_m = 0.46314270198345187
Chain 22: L_m = 0.2825740724802017
Chain 23: L_m = 0.9193332076072693
Chain 24: L_m = 0.3925476402044296
Chain 25: L_m = 0.5183261096477508
Chain 26: L_m = 0.3125071346759796
Chain 27: L_m = 0.2931938409805298
Chain 28: L_m = 0.25702306926250457
Chain 29: L_m = 0.5200364738



Chain 1: L_m = 0.8862992465496063
Chain 2: L_m = 0.4826162546873093
Chain 3: L_m = 0.5302879333496093
Chain 4: L_m = 0.9021453320980072
Chain 5: L_m = 1.0589762806892395
Chain 6: L_m = 0.4235114872455597
Chain 7: L_m = 4.8776222437620165
Chain 8: L_m = 1.6952634632587433
Chain 9: L_m = 2.981574684381485
Chain 10: L_m = 1.449630504846573
Chain 11: L_m = 1.8420335054397583
Chain 12: L_m = 2.05920824110508
Chain 13: L_m = 4.852935808897018
Chain 14: L_m = 2.967543566226959
Chain 15: L_m = 4.030357921123505
Chain 16: L_m = 2.2906035602092745
Chain 17: L_m = 0.47008504569530485
Chain 18: L_m = 1.2645684957504273
Chain 19: L_m = 0.7836362779140472
Chain 20: L_m = 3.8778687179088593
Chain 21: L_m = 0.6816185861825943
Chain 22: L_m = 1.6130731284618378
Chain 23: L_m = 3.7531589567661285
Chain 24: L_m = 2.3230449497699737
Chain 25: L_m = 3.7553028374910356
Chain 26: L_m = 3.398603343963623
Chain 27: L_m = 4.950607270002365
Chain 28: L_m = 3.4939595639705656
Chain 29: L_m = 1.9512363374233246
Ch



Chain 1: L_m = 1.0050874888896941
Chain 2: L_m = 1.6022550642490387
Chain 3: L_m = 2.009458029270172
Chain 4: L_m = 0.568588238954544
Chain 5: L_m = 0.6960787355899811
Chain 6: L_m = 0.6445960402488708
Chain 7: L_m = 0.9098561942577362
Chain 8: L_m = 0.691326642036438
Chain 9: L_m = 1.236388510465622
Chain 10: L_m = 1.5756377935409547
Chain 11: L_m = 1.4509209543466568
Chain 12: L_m = 0.6778219699859619
Chain 13: L_m = 1.5216062307357787
Chain 14: L_m = 1.2966041147708893
Chain 15: L_m = 0.7088881492614746
Chain 16: L_m = 0.7702719211578369
Chain 17: L_m = 0.7817010581493378
Chain 18: L_m = 0.7715225100517273
Chain 19: L_m = 2.249972552061081
Chain 20: L_m = 0.3956453502178192
Chain 21: L_m = 0.7386801064014434
Chain 22: L_m = 1.9504074811935426
Chain 23: L_m = 0.2651714369654655
Chain 24: L_m = 0.5072966635227203
Chain 25: L_m = 0.6755021095275879
Chain 26: L_m = 0.7268991708755493
Chain 27: L_m = 0.33813768029212954
Chain 28: L_m = 0.375960373878479
Chain 29: L_m = 0.9734169542789459



Chain 1: L_m = 4.512161856889724
Chain 2: L_m = 3.2913641899824144
Chain 3: L_m = 3.9946631610393526
Chain 4: L_m = 1.3746885716915132
Chain 5: L_m = 1.413570660352707
Chain 6: L_m = 3.412365901470184
Chain 7: L_m = 2.0339708983898164
Chain 8: L_m = 0.691975599527359
Chain 9: L_m = 0.4212522566318512
Chain 10: L_m = 0.9782685846090317
Chain 11: L_m = 5.244519692659378
Chain 12: L_m = 3.306220281124115
Chain 13: L_m = 1.7976462692022324
Chain 14: L_m = 2.0055090337991714
Chain 15: L_m = 0.9875315964221955
Chain 16: L_m = 2.908234643936157
Chain 17: L_m = 2.026265561580658
Chain 18: L_m = 1.5316135466098786
Chain 19: L_m = 0.5079347610473632
Chain 20: L_m = 2.1143519312143324
Chain 21: L_m = 2.898058557510376
Chain 22: L_m = 0.7913927376270294
Chain 23: L_m = 2.5729620933532713
Chain 24: L_m = 1.6496045470237732
Chain 25: L_m = 5.118790417909622
Chain 26: L_m = 0.7938552618026733
Chain 27: L_m = 1.6792375564575195
Chain 28: L_m = 0.4000787615776062
Chain 29: L_m = 4.949735134840012
Chain



Chain 1: L_m = 1.7143892109394074
Chain 2: L_m = 1.6269792318344116
Chain 3: L_m = 0.21123700886964797
Chain 4: L_m = 0.44388504326343536
Chain 5: L_m = 0.3481792688369751
Chain 6: L_m = 1.2627573311328888
Chain 7: L_m = 0.3580059617757797
Chain 8: L_m = 1.4979170203208922
Chain 9: L_m = 0.9597122550010682
Chain 10: L_m = 1.094135332107544
Chain 11: L_m = 1.0743216514587401
Chain 12: L_m = 1.7775719463825226
Chain 13: L_m = 1.929034447669983
Chain 14: L_m = 0.9286267161369324
Chain 15: L_m = 1.4787593007087707
Chain 16: L_m = 0.23279042541980743
Chain 17: L_m = 1.043607661128044
Chain 18: L_m = 0.38135826587677
Chain 19: L_m = 0.6841808587312699
Chain 20: L_m = 0.8642332136631012
Chain 21: L_m = 1.4391672253608703
Chain 22: L_m = 1.899801766872406
Chain 23: L_m = 1.5323917806148528
Chain 24: L_m = 1.4584864318370818
Chain 25: L_m = 1.5233963966369628
Chain 26: L_m = 0.47550772726535795
Chain 27: L_m = 1.687919980287552
Chain 28: L_m = 0.31120317578315737
Chain 29: L_m = 0.9499550700187

In [28]:
for i in range(len(model_rlct_mid)):
  print(f"Model {i+1}:")
  print(f"95% lambda hat CI: {EngNumber(model_rlct_low[i])} - {EngNumber(model_rlct_high[i])}")
  print(f"Mean lambda hat: {EngNumber(model_rlct_mid[i])}")
  print()

Model 1:
95% lambda hat CI: 2.76k - 3.59k
Mean lambda hat: 3.17k

Model 2:
95% lambda hat CI: 14.57k - 18.31k
Mean lambda hat: 16.44k

Model 3:
95% lambda hat CI: 4.55k - 5.59k
Mean lambda hat: 5.07k

Model 4:
95% lambda hat CI: 19.23k - 23.39k
Mean lambda hat: 21.31k

Model 5:
95% lambda hat CI: 5.82k - 6.85k
Mean lambda hat: 6.33k

