In [None]:
! pip install git+https://github.com/unifyai/models.git

In [2]:
import torch
import ivy
import ivy_models

# layers

In [1]:
from typing import Optional
import ivy


def conv3x3(
    in_planes: int,
    out_planes: int,
    stride: int = 1,
    dilations: int = 1,
) -> ivy.Conv2D:
    """3x3 convolution with padding"""
    return ivy.Conv2D(
        in_planes,
        out_planes,
        [3, 3],
        stride,
        dilations,
        with_bias=False,
    )


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> ivy.Conv2D:
    """1x1 convolution"""
    return ivy.Conv2D(in_planes, out_planes, [1, 1], stride, 0, with_bias=False)


class BasicBlock(ivy.Module):
    """
    Basic block used in the ResNet architecture.

    Args::
        inplanes (int): Number of input channels.
        planes (int): Number of output channels.
        stride (int): Stride value for the block. Defaults to 1.
        downsample (Optional[ivy.Module]): Downsample module for the block.
        base_width (int): The base width of the block. Defaults to 64.
        dilation (int): Dilation rate of the block. Defaults to 1.

    """

    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[ivy.Module] = None,
        base_width: int = 64,
        dilation: int = 1,
    ) -> None:
        self.norm_layer = ivy.BatchNorm2D
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        if base_width != 64:
            raise ValueError("BasicBlock only supports base_width=64")

        self.inplanes = inplanes
        self.planes = planes
        self.stride = stride
        self.downsample = downsample
        super(BasicBlock, self).__init__()

    def _build(self, *args, **kwargs):
        self.conv1 = conv3x3(self.inplanes, self.planes, self.stride)
        self.bn1 = self.norm_layer(self.planes)
        self.relu = ivy.ReLU()
        self.conv2 = conv3x3(self.planes, self.planes)
        self.bn2 = self.norm_layer(self.planes)
        self.downsample = self.downsample
        self.stride = self.stride

    def _forward(self, x):
        """Forward pass method for the module."""
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(ivy.Module):
    """
    Bottleneck block used in the ResNet architecture.

    Args::
        inplanes (int): Number of input channels.
        planes (int): Number of output channels.
        stride (int): Stride value for the block. Defaults to 1.
        downsample (Optional[ivy.Module]): Downsample module for the block.
        base_width (int): The base width of the block. Defaults to 64.
        dilation (int): Dilation rate of the block. Defaults to 1.

    """

    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[ivy.Module] = None,
        base_width: int = 64,
        dilation: int = 1,
    ) -> None:
        self.norm_layer = ivy.BatchNorm2D
        self.width = int(planes * (base_width / 64.0))
        self.inplanes = inplanes
        self.planes = planes
        self.downsample = downsample
        self.dilation = dilation
        self.stride = stride
        super(Bottleneck, self).__init__()

    def _build(self, *args, **kwargs):
        self.conv1 = conv1x1(self.inplanes, self.width)
        self.bn1 = self.norm_layer(self.width)
        self.conv2 = conv3x3(self.width, self.width, self.stride, self.dilation)
        self.bn2 = self.norm_layer(self.width)
        self.conv3 = conv1x1(self.width, self.planes * self.expansion)
        self.bn3 = self.norm_layer(self.planes * self.expansion)
        self.relu = ivy.ReLU()
        self.downsample = self.downsample

    def _forward(self, x):
        """Forward pass method for the module."""
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


# resnet

In [3]:

def _prune_keys(raw, ref, raw_keys_to_prune=[], ref_keys_to_prune=[]):
    pruned_ref = {}
    if raw_keys_to_prune:
        raw = raw.cont_prune_keys(raw_keys_to_prune)
    if ref_keys_to_prune:
        pruned_ref = ref.cont_at_keys(ref_keys_to_prune)
        ref = ref.cont_prune_keys(ref_keys_to_prune)
    return raw, ref, pruned_ref


def _map_weights(raw, ref, custom_mapping=None):
    mapping = {}
    for old_key, new_key in zip(
        raw.cont_sort_by_key().cont_to_iterator_keys(),
        ref.cont_sort_by_key().cont_to_iterator_keys(),
    ):
        new_mapping = new_key
        if custom_mapping is not None:
            new_mapping = custom_mapping(old_key, new_key)
        mapping[old_key] = new_mapping
    return mapping


In [5]:
import torch

In [6]:

def load_torch_weights(
    url,
    ref_model,
    raw_keys_to_prune=[],
    ref_keys_to_prune=[],
    custom_mapping=None,
    map_location=torch.device("cpu"),
):
    ivy_torch = ivy.with_backend("torch")
    weights = torch.hub.load_state_dict_from_url(url, map_location=map_location)
    ###
    display(ivy.asarray(ivy.Container(weights)))


    weights_raw = ivy.Container(
        ivy_torch.to_numpy(ivy_torch.Container(weights)).cont_to_dict()
    )
    weights_raw, weights_ref, pruned_ref = _prune_keys(
        weights_raw, ref_model.v, raw_keys_to_prune, ref_keys_to_prune
    )
    mapping = _map_weights(weights_raw, weights_ref, custom_mapping=custom_mapping)
    w_clean = weights_raw.cont_restructure(mapping, keep_orig=False)
    if ref_keys_to_prune:
        w_clean = ivy.Container.cont_combine(w_clean, pruned_ref)
    return ivy.asarray(w_clean)


In [8]:
# global
from typing import List, Optional, Type, Union
import builtins

import ivy
import ivy_models
from ivy_models.resnet.layers import conv1x1, BasicBlock, Bottleneck
from ivy_models.base import BaseSpec, BaseModel


class ResNetSpec(BaseSpec):
    """
    ResNetSpec class.

    """

    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        base_width: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
    ) -> None:
        super(ResNetSpec, self).__init__(
            block=block,
            layers=layers,
            num_classes=num_classes,
            base_width=base_width,
            replace_stride_with_dilation=replace_stride_with_dilation,
        )


class ResNet(BaseModel):
    """
    Residual Neural Network (ResNet) architecture.

    Args::
        block (Type[Union[BasicBlock, Bottleneck]]):
            The block type used in the ResNet architecture.
        layers: List of integers specifying the number of blocks in each layer.
        num_classes (int): Number of output classes. Defaults to 1000.
        base_width (int): The base width of the ResNet. Defaults to 64.
        replace_stride_with_dilation (Optional[List[bool]]):
            List indicating whether to replace stride with dilation.
        v (ivy.Container): Unused parameter. Can be ignored.

    """

    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        base_width: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        spec=None,
        v: ivy.Container = None,
    ) -> None:
        self.spec = (
            spec
            if spec and isinstance(spec, ResNetSpec)
            else ResNetSpec(
                block, layers, num_classes, base_width, replace_stride_with_dilation
            )
        )

        super(ResNet, self).__init__(v=v)

    def _build(self, *args, **kwargs):
        self.inplanes = 64
        self.dilation = 1
        if self.spec.replace_stride_with_dilation is None:
            self.spec.replace_stride_with_dilation = [False, False, False]

        self.conv1 = ivy.Conv2D(3, self.inplanes, [7, 7], 2, 3, with_bias=False)
        self.bn1 = ivy.BatchNorm2D(self.inplanes)
        self.relu = ivy.ReLU()
        self.maxpool = ivy.MaxPool2D(3, 2, 1)
        self.layer1 = self._make_layer(self.spec.block, 64, self.spec.layers[0])
        self.layer2 = self._make_layer(
            self.spec.block,
            128,
            self.spec.layers[1],
            stride=2,
            dilate=self.spec.replace_stride_with_dilation[0],
        )
        self.layer3 = self._make_layer(
            self.spec.block,
            256,
            self.spec.layers[2],
            stride=2,
            dilate=self.spec.replace_stride_with_dilation[1],
        )
        self.layer4 = self._make_layer(
            self.spec.block,
            512,
            self.spec.layers[3],
            stride=2,
            dilate=self.spec.replace_stride_with_dilation[2],
        )
        self.avgpool = ivy.AdaptiveAvgPool2d((1, 1))
        self.fc = ivy.Linear(512 * self.spec.block.expansion, self.spec.num_classes)

    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        planes: int,
        blocks: int,
        stride: int = 1,
        dilate: bool = False,
    ) -> ivy.Sequential:
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = ivy.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                ivy.BatchNorm2D(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes,
                planes,
                stride,
                downsample,
                self.spec.base_width,
                previous_dilation,
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    base_width=self.spec.base_width,
                    dilation=self.dilation,
                )
            )

        return ivy.Sequential(*layers)

    @classmethod
    def get_spec_class(self):
        return ResNetSpec

    def _forward(self, x):
        dtype = x.dtype
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = ivy.asarray(x, dtype=dtype)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = ivy.permute_dims(x, (0, 3, 1, 2))
        x = self.avgpool(x)
        x = x.reshape((x.shape[0], -1))
        x = self.fc(x)
        return x


def _resnet_torch_weights_mapping(old_key, new_key):
    W_KEY = ["conv1/weight", "conv2/weight", "conv3/weight", "downsample/0/weight"]
    new_mapping = new_key
    if builtins.any([kc in old_key for kc in W_KEY]):
        new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w c b"}
    return new_mapping


def resnet_18(pretrained=True):
    """ResNet-18 model"""
    model = ResNet(BasicBlock, [2, 2, 2, 2])
    if pretrained:
        url = "https://download.pytorch.org/models/resnet18-f37072fd.pth"
        w_clean = ivy_models.helpers.load_torch_weights(
            url,
            model,
            raw_keys_to_prune=["num_batches_tracked"],
            custom_mapping=_resnet_torch_weights_mapping,
        )
        model.v = w_clean
    return model


def resnet_34(pretrained=True):
    """ResNet-34 model"""
    model = ResNet(BasicBlock, [3, 4, 6, 3])
    if pretrained:
        url = "https://download.pytorch.org/models/resnet34-333f7ec4.pth"
        w_clean = load_torch_weights(
            url,
            model,
            raw_keys_to_prune=["num_batches_tracked"],
            custom_mapping=_resnet_torch_weights_mapping,
        )
        model.v = w_clean
    return model


def resnet_50(pretrained=True):
    """ResNet-50 model"""
    model = ResNet(Bottleneck, [3, 4, 6, 3])
    if pretrained:
        url = "https://download.pytorch.org/models/resnet50-11ad3fa6.pth"
        w_clean = ivy_models.helpers.load_torch_weights(
            url,
            model,
            raw_keys_to_prune=["num_batches_tracked"],
            custom_mapping=_resnet_torch_weights_mapping,
        )
        model.v = w_clean
    return model


def resnet_101(pretrained=True):
    """ResNet-101 model"""
    model = ResNet(Bottleneck, [3, 4, 23, 3])
    if pretrained:
        url = "https://download.pytorch.org/models/resnet101-cd907fc2.pth"
        w_clean = ivy_models.helpers.load_torch_weights(
            url,
            model,
            raw_keys_to_prune=["num_batches_tracked"],
            custom_mapping=_resnet_torch_weights_mapping,
        )
        model.v = w_clean
    return model


def resnet_152(pretrained=True):
    """ResNet-152 model"""
    model = ResNet(Bottleneck, [3, 8, 36, 3])
    if pretrained:
        url = "https://download.pytorch.org/models/resnet152-f82ba261.pth"
        w_clean = ivy_models.helpers.load_torch_weights(
            url,
            model,
            raw_keys_to_prune=["num_batches_tracked"],
            custom_mapping=_resnet_torch_weights_mapping,
        )
        model.v = w_clean
    return model


# test

In [None]:
import os
import ivy
import random
import numpy as np
import jax

# Enable x64 support in JAX
jax.config.update("jax_enable_x64", True)
from ivy_models_tests import helpers
from ivy_models import (
    resnet_18,
    resnet_34,
    resnet_50,
    resnet_101,
    resnet_152,
)


VARIANTS = {
    "r18": resnet_18,
    "r34": resnet_34,
    "r50": resnet_50,
    "r101": resnet_101,
    "r152": resnet_152,
}

LOGITS = {
    "r18": np.array([0.7069, 0.2663, 0.0231]),
    "r34": np.array([0.8507, 0.1351, 0.0069]),
    "r50": np.array([0.3429, 0.0408, 0.0121]),
    "r101": np.array([0.7834, 0.0229, 0.0112]),
    "r152": np.array([0.8051, 0.0473, 0.0094]),
}


load_weights = random.choice([False, True])
model_var = random.choice(list(VARIANTS.keys()))
model = VARIANTS[model_var](pretrained=load_weights)
v = ivy.to_numpy(model.v)


def test_resnet_img_classification(device, fw):
    """Test ResNet-18 image classification."""
    num_classes = 1000
    batch_shape = [1]
    this_dir = os.path.dirname(os.path.realpath(__file__))

    # Load image
    img = ivy.asarray(
        helpers.load_and_preprocess_img(
            os.path.join(this_dir, "..", "..", "images", "cat.jpg"),
            256,
            224,
            data_format="NHWC",
            to_ivy=True,
        ),
    )

    model.v = ivy.asarray(v)
    output = model(img)

    # Cardinality test
    assert output.shape == tuple([ivy.to_scalar(batch_shape), num_classes])

    # Value test
    if load_weights:
        output = output[0]
        true_indices = ivy.array([282, 281, 285])
        calc_indices = ivy.argsort(output, descending=True)[:3]

        assert np.array_equal(true_indices, calc_indices)

        true_logits = LOGITS[model_var]
        calc_logits = np.take(
            helpers.np_softmax(ivy.to_numpy(output)), ivy.to_numpy(calc_indices)
        )

        assert np.allclose(true_logits, calc_logits, rtol=0.005)
