### Test the optimizer and scheduler setup

In [None]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import lightning as pl
from omegaconf import DictConfig
from lightning import Callback, LightningDataModule, LightningModule, Trainer
import torch
import hydra
import omegaconf
import pyrootutils
import matplotlib.pyplot as plt

In [None]:
cfg = omegaconf.OmegaConf.load(
    "/home/ubuntu/FGVC11/configs/model/plant_traits_model.yaml"
)
model = hydra.utils.instantiate(cfg)
d = model.configure_optimizers()
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)
print(total_params)

In [None]:
schd_names = ["head", "blend weights", "block7", "block8", "block9", "block10", "block11", "block12", "tokens"]
assert len(d) == len(schd_names)

In [None]:
import matplotlib.pyplot as plt

def plot_learning_rates_for_all(schedulers, total_epochs):
    # Setup the plot
    plt.figure(figsize=(12, 7))
    
    # Iterate over each scheduler in the dictionary
    for idx, sched_dict in enumerate(schedulers):
        # Extract the scheduler
        scheduler = sched_dict['lr_scheduler']
        learning_rates = []
        
        # Reset the scheduler to its initial state if necessary
        # This might require reinitializing or resetting parameters if your schedulers retain state across runs
        # (Comment out the next line if your scheduler does not support or require reinitialization)
        # scheduler.reset()  # This assumes you have a reset method to put it back to its initial state
        
        # Simulate the scheduler stepping through each epoch
        for epoch in range(total_epochs):
            scheduler.step()
            current_lr = scheduler.get_last_lr()
            learning_rates.append(current_lr[0])

        # Plot the learning rates for this scheduler
        plt.plot(learning_rates, label=f'Scheduler {schd_names[idx]}')

    # Add plot details
    plt.title('Learning Rate Schedules for Multiple Schedulers Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.legend()
    plt.grid(True)
    plt.show()

# Example usage assuming 'd' is a list of dictionaries containing schedulers
plot_learning_rates_for_all(d, 120)
