In [2]:
import torch
import copy
import sys
sys.path.append('../training')
from model_utils import ProteinMPNN

def average_models(model_list):
    """
    Averages the parameters of the models in model_list and returns a new model
    with the averaged parameters.

    Parameters:
    -----------
    model_list : list
        A list of PyTorch models (all having the same architecture).

    Returns:
    --------
    averaged_model : torch.nn.Module
        A new model whose parameters are the average of the parameters of the input models.
    """
    # Create a deep copy of the first model to use as a base for the averaged model.
    averaged_model = copy.deepcopy(model_list[0])
    # Get the state dictionaries from all models
    state_dicts = [model.state_dict() for model in model_list]
    
    # Initialize an empty dictionary to hold the averaged parameters
    averaged_state_dict = {}
    
    # Iterate over each parameter key
    for key in state_dicts[0].keys():
        # Sum up the parameters for this key from all models
        param_sum = sum(state_dict[key] for state_dict in state_dicts)
        # Compute the average
        averaged_state_dict[key] = param_sum / len(model_list)
    
    # Load the averaged parameters into the new model
    averaged_model.load_state_dict(averaged_state_dict)
    
    return averaged_model

# Example usage:
# Suppose you have 5 models: model1, model2, model3, model4, model5
# models = [model1, model2, model3, model4, model5]
# averaged_model = average_models(models)


In [6]:
models = []
for fold in range(5):
    model = ProteinMPNN(node_features=128, 
                    edge_features=128, 
                    hidden_dim=128,
                    num_encoder_layers=3, 
                    num_decoder_layers=3, 
                    k_neighbors=48, 
                    dropout=0.0, 
                    augment_eps=0.05)
    model.to('cuda')


    checkpoint = torch.load(f'../cache/megascale_finetuned/nolinear_fold_{fold+1}_epoch199.pt')
    if 'model_state_dict' in checkpoint.keys():
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)

    models.append(model)
averaged_model = average_models(models)
torch.save(averaged_model.state_dict(), 'nolinear_rocklin_avg.pt')

  checkpoint = torch.load(f'../cache/megascale_finetuned/nolinear_fold_{fold+1}_epoch199.pt')
