## Slim saved RoFormer weights, e.g., the word embedding

In [58]:
import os
from pathlib import Path

import torch

from swissrivernetwork.benchmark.model import TransformerEmbeddingModel


def prune_roformer_word_embeddings(input_path: Path, weights_to_prune: list[str], output_path: Path | None = None):
    """Prune RoFormer model weights.
    """
    if output_path is None:
        output_path = input_path
    os.makedirs(output_path.parent, exist_ok=True)

    state_dict = torch.load(input_path)
    # for k in state_dict.keys():
    #     print(k, state_dict[k].shape)

    # Infer proper size from the state dict:
    num_embeddings = state_dict['embedding.weight'].shape[0]
    embedding_size = state_dict['embedding.weight'].shape[1]
    input_size = state_dict['input_proj.weight'].shape[1] - embedding_size
    # Assuming state dict is ordered:
    num_layers = int(
        [name for name in state_dict.keys() if name.startswith('transformer.encoder.layer')][-1].split('.')[3]
    ) + 1
    dim_feedforward = state_dict['transformer.encoder.layer.0.intermediate.dense.weight'].shape[0]
    d_model = state_dict['transformer.embeddings.word_embeddings.weight'].shape[1]
    num_heads = d_model // state_dict['transformer.encoder.embed_positions.weight'].shape[1]
    max_len = state_dict['transformer.encoder.embed_positions.weight'].shape[0]
    missing_value_method = 'mask_embedding' if 'mask_embedding' in state_dict else None
    # print(state_dict['transformer.encoder.embed_positions.weight'].shape[1])
    # print(
    #     f'num_embeddings: {num_embeddings}, embedding_size: {embedding_size}, num_heads: {num_heads}, '
    #     f'num_layers: {num_layers}, dim_feedforward: {dim_feedforward}, d_model: {d_model}, max_len: {max_len}, '
    #     f'missing_value_method: {missing_value_method}'
    # )

    model = TransformerEmbeddingModel(
        input_size, num_embeddings=num_embeddings,
        embedding_size=embedding_size,
        num_heads=num_heads,
        num_layers=num_layers,
        dim_feedforward=dim_feedforward,
        dropout=0.1,  # This does not matter for loading weights
        d_model=d_model,
        max_len=max_len,
        missing_value_method=missing_value_method,
        use_current_x=True,  # This does not matter for loading weights
        positional_encoding='rope',
    )

    vocab_size = state_dict['transformer.embeddings.word_embeddings.weight'].shape[0]
    word_embeddings = model.transformer.embeddings.word_embeddings
    if word_embeddings.weight.shape[0] != vocab_size:
        # print(
        #     f'Pruning word embeddings from {word_embeddings.weight.shape[0]} to {vocab_size}...'
        # )
        state_dict['transformer.embeddings.word_embeddings.weight'] = word_embeddings.weight
        # state_dict.pop('transformer.embeddings.word_embeddings.weight')
    # for k in state_dict.keys():
    #     print(k, state_dict[k].shape)

    # This is not actually necessary, but just to check the consistency:
    model.load_state_dict(state_dict)
    # print(model)

    torch.save(state_dict, output_path)

In [64]:
CUR_ABS_DIR = Path.cwd().resolve()
PROJ_DIR = (CUR_ABS_DIR / '../../../').resolve()
root_dir = (PROJ_DIR / 'swissrivernetwork/benchmark/outputs/ray_results/').resolve()

ray_exp_dirs = sorted(
    [item for item in root_dir.iterdir() if
     item.name.startswith('transformer_embedding') and 'rope' in item.name]
)
weights_to_prune = ['transformer.embeddings.word_embeddings']

for input_dir in ray_exp_dirs[1:]:
    print('\n====================')
    print(input_dir)
    count = 0
    trial_dirs = sorted([item for item in input_dir.iterdir() if item.is_dir()])
    for trial_dir in trial_dirs:
        ckp_dirs = sorted(
            [item for item in trial_dir.iterdir() if item.is_dir() and item.name.startswith('checkpoint_')]
        )
        for ckp_dir in ckp_dirs:
            ckp_paths = sorted([item for item in ckp_dir.iterdir() if item.is_file() and item.name.endswith('.pth')])
            for ckp_path in ckp_paths:
                # print(ckp_path)
                try:
                    prune_roformer_word_embeddings(
                        ckp_path, weights_to_prune, output_path=ckp_path
                    )
                    print(f'Pruned weights for the {count}-th checkpoint.')
                    count += 1
                except Exception as e:
                    print(f'Error processing {ckp_path}: {e}')

print('\nAll done!')

/mnt/832acd65-7396-480e-aa76-dca6765861b0/research-repo/projects/2025.09_hydrology_Switzerland/codes/swiss-river-network-benchmark/swissrivernetwork/benchmark/outputs/ray_results/transformer_embedding-swiss-2010-rope-2025-10-01_16-04-25
Pruned weights for the 0-th checkpoint.
Pruned weights for the 1-th checkpoint.
Pruned weights for the 2-th checkpoint.
Pruned weights for the 3-th checkpoint.
Pruned weights for the 4-th checkpoint.
Pruned weights for the 5-th checkpoint.
Pruned weights for the 6-th checkpoint.
Pruned weights for the 7-th checkpoint.
Pruned weights for the 8-th checkpoint.
Pruned weights for the 9-th checkpoint.
Pruned weights for the 10-th checkpoint.
Pruned weights for the 11-th checkpoint.
Pruned weights for the 12-th checkpoint.
Pruned weights for the 13-th checkpoint.
Pruned weights for the 14-th checkpoint.
Pruned weights for the 15-th checkpoint.
Pruned weights for the 16-th checkpoint.
Pruned weights for the 17-th checkpoint.
Pruned weights for the 18-th checkp

In [6]:
# Test dropout:
import torch.nn as nn

m = nn.Dropout(p=0.3)
print(m.state_dict())

OrderedDict()
