In [16]:
import torch
import copy
import os

def get_model_paths(path_to_results, model_name_starts_with):
    all_files = os.listdir(path_to_results)
    model_files = [file for file in all_files if file.startswith(model_name_starts_with)]
    model_files = [os.path.join(path_to_results, filename) for filename in model_files]
    print(f"Loading {len(model_files)} files from model:")
    for file in model_files:
        print(file)
    print("")
    return model_files

def load_state_dicts(file_paths):
    state_dicts = []
    print("Loading model state dicts:")
    for file in file_paths:
        loaded_dict = torch.load(file, map_location="cpu")
        if "model_state_dict" in loaded_dict:
            print(f"Extracting 'model_state_dict' from checkpoint in {file}")
            state_dicts.append(loaded_dict["model_state_dict"])
    print("")
    return state_dicts

def average_model_weights(state_dicts):
    """
    Averages the weights from multiple state_dicts.

    Args:
        state_dicts: List of state_dict dictionaries

    Returns:
        Averaged state_dict.
    """
    avg_state_dict = copy.deepcopy(state_dicts[0])
    print("Averaging model...")
    for key in avg_state_dict.keys():
        for state_dict in state_dicts[1:]:
            avg_state_dict[key] += state_dict[key]
        avg_state_dict[key] /= len(state_dicts)
    return avg_state_dict

def save_avg_model_weights(save_path, save_name, avg_state_dict):
    final_save_path = os.path.join(save_path, save_name)
    torch.save({'model_state_dict': avg_state_dict}, final_save_path)
    print(f"Averaged model saved to {final_save_path}\n")

# Loading, averaging and saving the weights

In [17]:
training_results_path = "training-results/models/"
save_path = "test-results/models"

model_1_files = get_model_paths(training_results_path, "ex_config-1")
model_1_state_dicts = load_state_dicts(model_1_files)
model_1_avg_weights = average_model_weights(model_1_state_dicts)
save_avg_model_weights(save_path, "ex_config-1_avg_weights_model.pth", model_1_avg_weights)

Loading 19 files from model:
training-results/models/ex_config-1_model_epoch_5_step_7216.pth
training-results/models/ex_config-1_model_epoch_6.pth
training-results/models/ex_config-1_model_epoch_0_step_8228.pth
training-results/models/ex_config-1_model_epoch_2_step_7243.pth
training-results/models/ex_config-1_model_epoch_1_step_7807.pth
training-results/models/ex_config-1_model_epoch_8.pth
training-results/models/ex_config-1_model_epoch_3.pth
training-results/models/ex_config-1_model_epoch_9.pth
training-results/models/ex_config-1_model_epoch_8_step_7207.pth
training-results/models/ex_config-1_model_epoch_5.pth
training-results/models/ex_config-1_model_epoch_7_step_7241.pth
training-results/models/ex_config-1_model_epoch_9_step_6907.pth
training-results/models/ex_config-1_model_epoch_0.pth
training-results/models/ex_config-1_model_epoch_2.pth
training-results/models/ex_config-1_model_epoch_end.pth
training-results/models/ex_config-1_model_epoch_1.pth
training-results/models/ex_config-1

In [18]:
model_2_files = get_model_paths(training_results_path, "ex_config-2")
model_2_state_dicts = load_state_dicts(model_2_files)
model_2_avg_weights = average_model_weights(model_2_state_dicts)
save_avg_model_weights(save_path, "ex_config-2_avg_weights_model.pth", model_2_avg_weights)

Loading 17 files from model:
training-results/models/ex_config-2_model_epoch_7.pth
training-results/models/ex_config-2_model_epoch_3_step_7324.pth
training-results/models/ex_config-2_model_epoch_2_step_8687.pth
training-results/models/ex_config-2_model_epoch_0.pth
training-results/models/ex_config-2_model_epoch_9.pth
training-results/models/ex_config-2_model_epoch_end.pth
training-results/models/ex_config-2_model_epoch_7_step_8245.pth
training-results/models/ex_config-2_model_epoch_0_step_8089.pth
training-results/models/ex_config-2_model_epoch_5.pth
training-results/models/ex_config-2_model_epoch_2.pth
training-results/models/ex_config-2_model_epoch_8_step_8177.pth
training-results/models/ex_config-2_model_epoch_4.pth
training-results/models/ex_config-2_model_epoch_6.pth
training-results/models/ex_config-2_model_epoch_1.pth
training-results/models/ex_config-2_model_epoch_3.pth
training-results/models/ex_config-2_model_epoch_5_step_7041.pth
training-results/models/ex_config-2_model_epo

In [None]:
# model_3_files = get_model_paths(training_results_path, "ex_config-3")
# model_3_state_dicts = load_state_dicts(model_3_files)
# model_3_avg_weights = average_model_weights(model_3_state_dicts)
# save_avg_model_weights(save_path, "ex_config-3_avg_weights_model", model_3_avg_weights)

In [19]:
model_4_files = get_model_paths(training_results_path, "ex_config-4")
model_4_state_dicts = load_state_dicts(model_4_files)
model_4_avg_weights = average_model_weights(model_4_state_dicts)
save_avg_model_weights(save_path, "ex_config-4_avg_weights_model.pth", model_4_avg_weights)

Loading 14 files from model:
training-results/models/ex_config-4_model_epoch_8.pth
training-results/models/ex_config-4_model_epoch_0.pth
training-results/models/ex_config-4_model_epoch_9_step_8276.pth
training-results/models/ex_config-4_model_epoch_2.pth
training-results/models/ex_config-4_model_epoch_3.pth
training-results/models/ex_config-4_model_epoch_2_step_5966.pth
training-results/models/ex_config-4_model_epoch_9.pth
training-results/models/ex_config-4_model_epoch_7.pth
training-results/models/ex_config-4_model_epoch_6.pth
training-results/models/ex_config-4_model_epoch_4.pth
training-results/models/ex_config-4_model_epoch_end.pth
training-results/models/ex_config-4_model_epoch_8_step_5661.pth
training-results/models/ex_config-4_model_epoch_6_step_5725.pth
training-results/models/ex_config-4_model_epoch_1.pth

Loading model state dicts:
Extracting 'model_state_dict' from checkpoint in training-results/models/ex_config-4_model_epoch_8.pth
Extracting 'model_state_dict' from checkpo

In [20]:
model_5_files = get_model_paths(training_results_path, "ex_config-5")
model_5_state_dicts = load_state_dicts(model_5_files)
model_5_avg_weights = average_model_weights(model_5_state_dicts)
save_avg_model_weights(save_path, "ex_config-5_avg_weights_model.pth", model_5_avg_weights)

Loading 12 files from model:
training-results/models/ex_config-5_model_epoch_1_step_6631.pth
training-results/models/ex_config-5_model_epoch_0.pth
training-results/models/ex_config-5_model_epoch_8.pth
training-results/models/ex_config-5_model_epoch_3.pth
training-results/models/ex_config-5_model_epoch_1.pth
training-results/models/ex_config-5_model_epoch_4.pth
training-results/models/ex_config-5_model_epoch_5.pth
training-results/models/ex_config-5_model_epoch_9_step_5655.pth
training-results/models/ex_config-5_model_epoch_9.pth
training-results/models/ex_config-5_model_epoch_2_step_5444.pth
training-results/models/ex_config-5_model_epoch_end.pth
training-results/models/ex_config-5_model_epoch_2.pth

Loading model state dicts:
Extracting 'model_state_dict' from checkpoint in training-results/models/ex_config-5_model_epoch_1_step_6631.pth
Extracting 'model_state_dict' from checkpoint in training-results/models/ex_config-5_model_epoch_0.pth
Extracting 'model_state_dict' from checkpoint i

In [1]:
model_1_ext_files = get_model_paths(training_results_path, "ex_config-1-extension")
model_1_ext_state_dicts = load_state_dicts(model_1_ext_files)
model_1_ext_avg_weights = average_model_weights(model_1_ext_state_dicts)
save_avg_model_weights(save_path, "ex_config-1-extension_avg_weights_model", model_1_ext_avg_weights)

NameError: name 'get_model_paths' is not defined