In [1]:
import sys
sys.path.append("../..")
import os.path

from hydra import initialize, compose
from hydra.utils import instantiate

from pytorch_lightning.callbacks import ModelSummary
from pytorch_lightning import Trainer

import torch

In [2]:
def load_model(
        model_checkpoint: str,
) -> torch.nn.Module:
    model_dir = os.path.dirname(model_checkpoint)
    with initialize(config_path=os.path.join(model_dir, 'hydra')):
        cfg = compose('config.yaml')
    if "model" in cfg.keys():
        cfg["model"]["_target_"] = cfg["model"]["_target_"].replace(
            ".model.", ".network."
        )
        cfg["model"]["backbone"]["_target_"] = cfg["model"]["backbone"]["_target_"].replace(
            ".model.", ".network."
        )
        cfg["model"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["model"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.model,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )
    else:
        cfg["network"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["network"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.network,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )    
    state_dict = torch.load(model_checkpoint, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict["state_dict"], strict=False)
    model = model.eval().cpu()
    model = model.requires_grad_(False)
    return model

In [3]:
network = load_model("../../data/models_jeanzay/unext_small/9/last.ckpt")

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path=os.path.join(model_dir, 'hydra')):


In [4]:
summary = ModelSummary(max_depth=5)
trainer = Trainer()
summary.on_fit_start(trainer, network)

  rank_zero_warn(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

   | Name                               | Type               | Params
---------------------------------------------------------------------------
0  | loss_func                          | LaplaceNLLLoss     | 9     
1  | fixed_laplace                      | LaplaceNLLLoss     | 9     
2  | fixed_gaussian                     | GaussianNLLLoss    | 9     
3  | best_metric                        | MinMetric          | 0     
4  | backbone                           | UNextBackbone      | 1.2 M 
5  | backbone.to_cartesian              | ToCartesianLayer   | 0     
6  | backbone.from_cartesian            | FromCartesianLayer | 0     
7  | backbone.init_layer                | ConvNextBlock      | 23.1 K
8  | backbone.init_layer.ds_conv        | Conv2d             | 1.0 K 
9  | backbone.init_layer.net            |

In [5]:
print(sum(p.numel() for p in network.backbone.parameters()))

1224848


In [6]:
print(sum(p.numel() for p in network.backbone.init_layer.parameters()))
print(sum(p.numel() for p in network.backbone.down_layers[0].pooling.parameters()))
print(sum(p.numel() for p in network.backbone.down_layers[0].out_layer[0].parameters()))
print(sum(p.numel() for p in network.backbone.down_layers[0].out_layer[1].parameters()))
print(sum(p.numel() for p in network.backbone.bottleneck_layer.parameters()))
print(sum(p.numel() for p in network.backbone.up_layers[0].upscaling.parameters()))
print(sum(p.numel() for p in network.backbone.up_layers[0].out_layer[0].parameters()))
print(sum(p.numel() for p in network.backbone.up_layers[0].out_layer[1].parameters()))
print(sum(p.numel() for p in network.backbone.out_layer.parameters()))

23056
295424
145152
145152
145152
295552
95744
39808
39808
