In [None]:
import torch
import torch.nn as nn
from diffusion.diffusion_model import DiffusionAttnUnet1DCond

# Loading the state_dict of the previosuly-trained model
state_dict_old = torch.load("trained_models/diffusion_model_original.pth", map_location="cpu")

# Cut the weights not needed for unconditional diffusion
for prefix in ["diffusion.", "diffusion_ema."]:

    weight_main = state_dict_old[f"{prefix}net.0.main.0.weight"]  # [128, 22, 5]
    bias_main   = state_dict_old[f"{prefix}net.0.main.0.bias"]
    weight_skip = state_dict_old[f"{prefix}net.0.skip.weight"]    # [128, 22, 1]

    # New layers of 19 channels
    new_main_weight = weight_main[:, :19, :].clone()
    new_skip_weight = weight_skip[:, :19, :].clone()

    state_dict_old[f"{prefix}net.0.main.0.weight"] = new_main_weight
    state_dict_old[f"{prefix}net.0.main.0.bias"]   = bias_main.clone()
    state_dict_old[f"{prefix}net.0.skip.weight"]   = new_skip_weight

    w_main6 = state_dict_old[f"{prefix}net.6.main.3.weight"]  # [6, 128, 5]
    b_main6 = state_dict_old[f"{prefix}net.6.main.3.bias"]
    w_skip6 = state_dict_old[f"{prefix}net.6.skip.weight"]    # [6, 128, 1]

    state_dict_old[f"{prefix}net.6.main.3.weight"] = w_main6[:3].clone()
    state_dict_old[f"{prefix}net.6.main.3.bias"]   = b_main6[:3].clone()
    state_dict_old[f"{prefix}net.6.skip.weight"]   = w_skip6[:3].clone()

    # lastconv
    w_last = state_dict_old[f"{prefix}lastconv.weight"]  # [3, 6, 3]
    b_last = state_dict_old[f"{prefix}lastconv.bias"]

    state_dict_old[f"{prefix}lastconv.weight"] = w_last[:, :3, :].clone()
    state_dict_old[f"{prefix}lastconv.bias"]   = b_last.clone()

# Updated model
torch.save(state_dict_old, "trained_models/Model_trimmed_NOcond.pth")

print(f"Total number of keys: {len(state_dict_old)}")