# weights checking

In [None]:
def _RegNet_Y_400MF_torch_weights_mapping(old_key, new_key):
    W_KEY = ["conv1/weight", "conv2/weight", "conv3/weight", "downsample/0/weight"]
    new_mapping = new_key
    if builtins.any([kc in old_key for kc in W_KEY]):
        new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w c b"}
    return new_mapping

In [None]:
import ivy, torch


def _prune_keys(raw, ref, raw_keys_to_prune=[], ref_keys_to_prune=[]):
    pruned_ref = {}
    if raw_keys_to_prune:
        raw = raw.cont_prune_keys(raw_keys_to_prune)
    if ref_keys_to_prune:
        pruned_ref = ref.cont_at_keys(ref_keys_to_prune)
        ref = ref.cont_prune_keys(ref_keys_to_prune)
    return raw, ref, pruned_ref


def _map_weights(raw, ref, custom_mapping=None):
    mapping = {}
    for old_key, new_key in zip(
        raw.cont_sort_by_key().cont_to_iterator_keys(),
        ref.cont_sort_by_key().cont_to_iterator_keys(),
    ):
        new_mapping = new_key
        if custom_mapping is not None:
            new_mapping = custom_mapping(old_key, new_key)
        mapping[old_key] = new_mapping
    return mapping


def load_torch_weights(
    url,
    ref_model,
    raw_keys_to_prune=[],
    ref_keys_to_prune=[],
    custom_mapping=None,
    map_location=torch.device("cpu"),
):
    ivy_torch = ivy.with_backend("torch")
    weights = torch.hub.load_state_dict_from_url(url, map_location=map_location)

    from pprint import pprint as pp

    pp(weights)

    weights_raw = ivy.Container(
        ivy_torch.to_numpy(ivy_torch.Container(weights)).cont_to_dict()
    )
    weights_raw, weights_ref, pruned_ref = _prune_keys(
        weights_raw, ref_model.v, raw_keys_to_prune, ref_keys_to_prune
    )
    mapping = _map_weights(weights_raw, weights_ref, custom_mapping=custom_mapping)
    w_clean = weights_raw.cont_restructure(mapping, keep_orig=False)
    if ref_keys_to_prune:
        w_clean = ivy.Container.cont_combine(w_clean, pruned_ref)
    return ivy.asarray(w_clean)

In [None]:
url = "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth"
w_clean = load_torch_weights(
    url,
    model,
    raw_keys_to_prune=["num_batches_tracked"],
    custom_mapping=_RegNet_Y_400MF_torch_weights_mapping,
)

# testing model blocks saperatly

In [1]:
from typing import Optional, List, Callable, Any, Tuple, Union, Sequence
import ivy
import math
import warnings
import collections
from itertools import repeat

In [None]:
class AnyStage(ivy.Sequential):
    """AnyNet stage (sequence of blocks w/ the same output shape)."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        depth: int,
        block_constructor: Callable[..., ivy.Module],
        norm_layer: Callable[..., ivy.Module],
        activation_layer: Callable[..., ivy.Module],
        group_width: int,
        bottleneck_multiplier: float,
        se_ratio: Optional[float] = None,
        stage_index: int = 0,
    ) -> None:
        super().__init__()

        for i in range(depth):
            block = block_constructor(
                width_in if i == 0 else width_out,
                width_out,
                stride if i == 0 else 1,
                norm_layer,
                activation_layer,
                group_width,
                bottleneck_multiplier,
                se_ratio,
            )

            self.add_module(f"block{stage_index}-{i}", block)


def test_AnyStage():
    # N x 768 x 5 x 5
    random_test_tensor = ivy.random_normal(shape=(1, 5, 5, 768))
    display(f"random_test_tensor shape is: {random_test_tensor.shape}")

    block = AnyStage(
        32,
        width_out,
        stride,
        depth,
        block_type,
        norm_layer,
        activation,
        group_width,
        bottleneck_multiplier,
        block_params.se_ratio,
        stage_index=i + 1,
    )
    block(random_test_tensor)
    # N x 128 x 5 x 5
    display("Test Successfull!")


test_AnyStage()

# testing BlockPrams

In [13]:
class BlockParams:
    def __init__(
        self,
        depths: List[int],
        widths: List[int],
        group_widths: List[int],
        bottleneck_multipliers: List[float],
        strides: List[int],
        se_ratio: Optional[float] = None,
    ) -> None:
        self.depths = depths
        self.widths = widths
        self.group_widths = group_widths
        self.bottleneck_multipliers = bottleneck_multipliers
        self.strides = strides
        self.se_ratio = se_ratio

    @classmethod
    def from_init_params(
        cls,
        depth: int,
        w_0: int,
        w_a: float,
        w_m: float,
        group_width: int,
        bottleneck_multiplier: float = 1.0,
        se_ratio: Optional[float] = None,
        **kwargs: Any,
    ) -> "BlockParams":
        """
        Programmatically compute all the per-block settings,
        given the RegNet parameters.

        The first step is to compute the quantized linear block parameters,
        in log space. Key parameters are:
        - `w_a` is the width progression slope
        - `w_0` is the initial width
        - `w_m` is the width stepping in the log space

        In other terms
        `log(block_width) = log(w_0) + w_m * block_capacity`,
        with `bock_capacity` ramping up following the w_0 and w_a params.
        This block width is finally quantized to multiples of 8.

        The second step is to compute the parameters per stage,
        taking into account the skip connection and the final 1x1 convolutions.
        We use the fact that the output width is constant within a stage.
        """

        QUANT = 8
        STRIDE = 2

        if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0:
            raise ValueError("Invalid RegNet settings")
        # Compute the block widths. Each stage has one unique block width
        widths_cont = ivy.arange(depth) * w_a + w_0
        block_capacity = ivy.round(ivy.log(widths_cont / w_0) / math.log(w_m))
        block_widths = ivy.to_list(
            ivy.int32(
                ivy.round(ivy.divide(w_0 * ivy.pow(w_m, block_capacity), QUANT)) * QUANT
            )
        )
        num_stages = len(set(block_widths))

        # Convert to per stage parameters
        split_helper = zip(
            block_widths + [0],
            [0] + block_widths,
            block_widths + [0],
            [0] + block_widths,
        )
        splits = [w != wp or r != rp for w, wp, r, rp in split_helper]

        stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t]
        stage_depths = (
            ivy.diff(ivy.array([d for d, t in enumerate(splits) if t])).int().tolist()
        )

        strides = [STRIDE] * num_stages
        bottleneck_multipliers = [bottleneck_multiplier] * num_stages
        group_widths = [group_width] * num_stages

        # Adjust the compatibility of stage widths and group widths
        stage_widths, group_widths = cls._adjust_widths_groups_compatibilty(
            stage_widths, bottleneck_multipliers, group_widths
        )

        return cls(
            depths=stage_depths,
            widths=stage_widths,
            group_widths=group_widths,
            bottleneck_multipliers=bottleneck_multipliers,
            strides=strides,
            se_ratio=se_ratio,
        )

    def _get_expanded_params(self):
        return zip(
            self.widths,
            self.strides,
            self.depths,
            self.group_widths,
            self.bottleneck_multipliers,
        )

    def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
        """
        This function is taken from the original tf repo.
        It ensures that all layers have a channel number that is divisible by 8
        It can be seen here:
        https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
        """
        if min_value is None:
            min_value = divisor
        new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
        # Make sure that round down does not go down by more than 10%.
        if new_v < 0.9 * v:
            new_v += divisor
        return new_v

    @staticmethod
    def _adjust_widths_groups_compatibilty(
        stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int]
    ) -> Tuple[List[int], List[int]]:
        """
        Adjusts the compatibility of widths and groups,
        depending on the bottleneck ratio.
        """
        # Compute all widths for the current settings
        widths = [int(w * b) for w, b in zip(stage_widths, bottleneck_ratios)]
        group_widths_min = [min(g, w_bot) for g, w_bot in zip(group_widths, widths)]

        # Compute the adjusted widths so that stage and group widths fit
        ws_bot = [
            _make_divisible(w_bot, g) for w_bot, g in zip(widths, group_widths_min)
        ]
        stage_widths = [int(w_bot / b) for w_bot, b in zip(ws_bot, bottleneck_ratios)]
        return stage_widths, group_widths_min

In [16]:
BlockParams.from_init_params(
    depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25
)

TypeError: 'IntDtype' object is not callable