Post-hoc Laplace approx to LoRA parameters at model checkpoints theta_MAP obtained from standard fine-tuning.

In [None]:
import torch
from tqdm.auto import tqdm
from torch.optim import AdamW
from transformers import get_scheduler
from optree import tree_map_, tree_map
import pickle
import matplotlib.pyplot as plt

import uqlib

from load import load_dataloaders, load_model

In [None]:
# Load data
train_dataloader, eval_dataloader = load_dataloaders(small=True, batch_size=32)
num_data = len(train_dataloader.dataset)
print("Training data size: ", num_data)

In [None]:
# Load model (with standard Gaussian prior)
model, param_to_log_posterior, target_module_names = load_model(num_data=num_data, prior_sd=1e3, target_modules="last_layer")

# Turn off Dropout
model.eval()

# Load to GPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device);

In [None]:
# Extract only the parameters to be trained
sub_params, sub_param_to_log_posterior = uqlib.extract_requires_grad_and_func(dict(model.named_parameters()), param_to_log_posterior)

In [None]:
# Store initial values of sub_params to check against later
init_sub_params = tree_map(lambda x: x.detach().clone(), sub_params)

In [None]:
# Train (as usual, using native PyTorch) for MAP
optimizer = AdamW(sub_params.values(), lr=1e-5, maximize=True)

num_epochs = 30
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)


progress_bar = tqdm(range(num_training_steps))

log_posts = []

# model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}

        log_post, out = sub_param_to_log_posterior(sub_params, batch)

        log_post.backward()
        log_posts.append(log_post.item())
        
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        progress_bar.set_postfix(loss=log_posts[-1])

In [None]:
# Plot convergence
plt.plot(log_posts);

In [None]:
# Visualize trained sub_params vs their initial values
import regex as re
final_sub_params = tree_map(lambda p: p.detach().clone(), dict(model.named_parameters()))

base = ()
final = ()
for weights_matrix in target_module_names:
    W = [v for k, v in final_sub_params.items() if re.sub("^(base_model.model.model\\.)*|(\\.base_layer.weight)*$", "", k) == weights_matrix][0]
    A = [v for k, v in final_sub_params.items() if re.sub("^(base_model.model.model\\.)*|(\\.lora_A.default.weight)*$", "", k) == weights_matrix][0]
    B = [v for k, v in final_sub_params.items() if re.sub("^(base_model.model.model\\.)*|(\\.lora_B.default.weight)*$", "", k) == weights_matrix][0]
    
    W_del = B @ A 
    W_new = W + W_del

    base += (W, )
    final += (W_new,)

base = torch.cat(base).flatten()
final = torch.cat(final).flatten()

plt.hist(base.cpu().numpy(), bins=100, alpha=0.5, label='Init', density=True)
plt.hist(final.cpu().numpy(), bins=100, alpha=0.5, label='Final', density=True)
plt.legend();

In [None]:
# Jacobian requires more memory, so we'll use a smaller batch size for the Laplace approximation
laplace_train_dataloader, _ = load_dataloaders(small=True, batch_size=8)

In [None]:
# Use uqlib for diagonal Fisher information covariance matrix
laplace_approx_transform = uqlib.laplace.diag_fisher.build(sub_param_to_log_posterior)
laplace_state = laplace_approx_transform.init(sub_params)

for batch in tqdm(laplace_train_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    laplace_state = laplace_approx_transform.update(
        laplace_state, batch
    )

In [None]:
# Save state
laplace_state = tree_map_(lambda x: x.detach().cpu(), laplace_state)
pickle.dump(laplace_state, open("guanaco_laplace_state.pkl", "wb"))

# laplace_state = pickle.load(open("guanaco_laplace_state.pkl", "rb"))

In [None]:
# Visualize the standard deviations of the Laplace approximation
prec_diag = torch.cat([v.detach().cpu().flatten() for v in laplace_state.prec_diag.values()]).numpy()
sd_diag = prec_diag ** -0.5

plt.hist(sd_diag, bins=100, density=True);