# Post-Training Weight Handling

After completing the training, you need to process the generated weights before using them for evaluation or deployment. For detailed operation steps, refer to the `pretrained_model_transfer.ipynb` notebook.

Two key points require special attention during the weight processing:

1. **For Multi-Agent Environments**: Only retain the weights trained for the primary agent (with index = 0). Discard or ignore weights corresponding to other agents, as the evaluation pipeline typically focuses on the primary agent's performance.
2. **For Auxiliary Tasks**: If auxiliary tasks (e.g., pedestrian future trajectory prediction) were used during training, the auxiliary network components must be removed from the weight file. This is because auxiliary information (such as pedestrian future trajectory data) is not available in the evaluation phase, and retaining the auxiliary network will cause runtime errors or incorrect inference results.

In [None]:
import torch

checkpoint_path_v0 = "/app/Falcon/pretrained_model/falcon_pretrained_25.pth" 
pretrained_state = torch.load(checkpoint_path_v0)
model_state_dict = pretrained_state[0]['state_dict']

filtered_pretrained_state_dict = {
    k: v for k, v in pretrained_state[0]['state_dict'].items() 
    if not k.startswith('aux_loss_modules')
}

pretrained_state[0]['state_dict'] = filtered_pretrained_state_dict
torch.save(pretrained_state, '/app/Falcon/pretrained_model/falcon_noaux_25_v2.pth')

checkpoint_path_v1 = "/app/Falcon/pretrained_model/falcon_noaux_25.pth" 
checkpoint_path_v2 = "/app/Falcon/pretrained_model/falcon_noaux_25_v2.pth" 
processed_state_v1 = torch.load(checkpoint_path_v1)
processed_state_v2 = torch.load(checkpoint_path_v2)

len_v1=len(processed_state_v1[0]['state_dict'])
len_v2=len(processed_state_v2[0]['state_dict'])
if len_v1 == len_v2:
    print("Get the same weights between processed_state_v1 and processed_state_v2")
else:
    print("Weights are different between processed_state_v1 and processed_state_v2")