PyTorch

In [None]:
# Install CUDA 11.6
%env DEBIAN_FRONTEND=noninteractive

!wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin
!sudo mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600
!wget https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda-repo-ubuntu2004-11-6-local_11.6.2-510.47.03-1_amd64.deb
!sudo dpkg -i cuda-repo-ubuntu2004-11-6-local_11.6.2-510.47.03-1_amd64.deb
!sudo apt-key add /var/cuda-repo-ubuntu2004-11-6-local/7fa2af80.pub

!sudo add-apt-repository -y ppa:cloudhan/liburcu6
!sudo apt-get update
!sudo apt-get -y install liburcu6 cuda-11-6

import os

%env PATH=/usr/local/cuda-11.6/bin:{os.getenv('PATH')}
%env BNB_CUDA_VERSION=116
%env CUDA_VERSION=116
%env LD_LIBRARY_PATH=/usr/local/cuda-11.6/lib64:{os.environ['LD_LIBRARY_PATH']}
!nvcc --version

In [None]:
!echo $PATH

In [None]:
!python3 -m pip uninstall -y torch torchvision torchaudio torchtext
!python3 -m pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 timm --extra-index-url https://download.pytorch.org/whl/cu116



> 引用を追加



In [None]:
# Install mmaction2 dependencies
!pip install -U openmim
!mim install mmengine
!mim install mmcv==2.1.0
!mim install mmdet
!mim install mmpose

In [None]:
# cleanup directories.
!rm -rf mmaction2
!rm -rf SqueezeTime

!git clone https://github.com/open-mmlab/mmaction2.git
!git clone https://github.com/xinghaochen/SqueezeTime.git
!cp SqueezeTime/mmaction/models/backbones/SqueezeTime.py mmaction2/mmaction/models/backbones/
!cp SqueezeTime/mmaction/models/backbones/SqueezeTime_ava.py mmaction2/mmaction/models/backbones/
!cp SqueezeTime/mmaction/models/heads/i2d_head.py mmaction2/mmaction/models/heads/

!echo "from .SqueezeTime import SqueezeTime" >> mmaction2/mmaction/models/backbones/__init__.py
!echo "from .SqueezeTime_ava import SqueezeTime_ava" >> mmaction2/mmaction/models/backbones/__init__.py
!echo "from .i2d_head import I2DHead" >> mmaction2/mmaction/models/heads/__init__.py
!echo "" > mmaction2/mmaction/models/localizers/drn/__init__.py
!pip install ./mmaction2


Flax implementation.

In [None]:
!pip install -U flax chex einops

In [None]:
# Download PyTorch checkpoints
!curl -LO https://github.com/xinghaochen/SqueezeTime/releases/download/ckpts/SqueezeTime_in1k_pretrain.pth
!curl -LO https://github.com/xinghaochen/SqueezeTime/releases/download/ckpts/SqueezeTime_K400_71.64.pth
!curl -LO https://github.com/xinghaochen/SqueezeTime/releases/download/ckpts/SqueezeTime_K600_76.06.pth
!curl -LO https://github.com/xinghaochen/SqueezeTime/releases/download/ckpts/SqueezeTime_HMDB51_65.56.pth
# !curl -LO https://github.com/xinghaochen/SqueezeTime/releases/download/ckpts/SqueezeTime-AVA2.1.pth

%mkdir ckpts
!mv *.pth ckpts/

In [None]:
from typing import Any, Sequence
from functools import partial

import jax
import jax.numpy as jnp
from flax import linen
import chex
import einops

ModuleDef = Any


def to_2tuple(x: Any | Sequence[Any]) -> tuple[Any, Any]:
    if isinstance(x, Sequence):
        assert len(x) == 2
        return x
    return (x, x)


def to_padding(padding: int | tuple[int, int] | Sequence[tuple[int, int]]) -> list[tuple[int, int]]:
    if isinstance(padding, int):
        return [(padding, padding), (padding, padding)]
    elif isinstance(padding, tuple) and isinstance(padding[0], int):
        return [padding, padding]
    else:
        assert isinstance(padding, Sequence) and len(padding) == 2
        for x in padding:
            assert isinstance(x, Sequence) and len(x) == 2
        return [tuple[p] for p in padding]


def resize_with_aligned_corners(
    image: chex.Array,
    shape: Sequence[int],
    method: str | jax.image.ResizeMethod,
    antialias: bool,
):
    """Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's
    interpolation functions.

    Copy from https://github.com/google/jax/issues/11206#issuecomment-1423140760
    """
    spatial_dims = tuple(i for i in range(len(shape)) if not jax.core.symbolic_equal_dim(image.shape[i], shape[i]))
    scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])
    translation = -(scale / 2.0 - 0.5)
    return jax.image.scale_and_translate(
        image,
        shape,
        method=method,
        scale=scale,
        spatial_dims=spatial_dims,
        translation=translation,
        antialias=antialias,
    )


class GlobalConv(linen.Module):
    """Top branch in IOI module."""

    features: int
    num_frames: int = 16
    pos_dim: int = 16
    conv: ModuleDef = linen.Conv
    norm: ModuleDef = linen.BatchNorm

    @linen.compact
    def __call__(self, x: chex.Array, param: chex.Array) -> chex.Array:
        # Temporal focus convolution.
        x = x * param
        x = self.conv(self.num_frames, kernel_size=(3, 3), padding=to_padding(1), name="conv1")(x)
        x = self.norm(name="norm1")(x)
        x = linen.relu(x)

        # Time encoding
        *_, h, w, _ = jnp.shape(x)
        x += resize_with_aligned_corners(
            self.param("pos_embed", linen.initializers.kaiming_normal(), (self.pos_dim, self.pos_dim, self.num_frames)),
            shape=(h, w, self.num_frames),
            method="bilinear",
            antialias=False,
        )

        x = self.conv(self.num_frames, kernel_size=(7, 7), padding=to_padding(3), name="conv2")(x)
        x = self.norm(name="norm2")(x)
        x = linen.relu(x)

        x = self.conv(self.features, kernel_size=(3, 3), padding=to_padding(1), name="conv3")(x)
        x = linen.sigmoid(x)
        return x


class IOI(linen.Module):
    """
    Inter-temporal object interaction module.
    """

    features: int
    num_frames: int = 16
    pos_dim: int = 16
    conv: ModuleDef = linen.Conv
    norm: ModuleDef = linen.BatchNorm

    @linen.compact
    def __call__(self, x: chex.Array, param: chex.Array) -> chex.Array:
        # Top branch of IOI module.
        x_glo = GlobalConv(
            self.features,
            num_frames=self.num_frames,
            pos_dim=self.pos_dim,
            conv=self.conv,
            norm=self.norm,
            name="glo_conv",
        )(x, param)

        # Bottom branch of IOI module.
        x_short = self.conv(
            self.features,
            kernel_size=(3, 3),
            padding=to_padding(1),
            name="short_conv",
        )(x)

        return x_short * x_glo


class ParamConv(linen.Module):
    """A module that calculates the temporal-adaptive weights."""

    conv: ModuleDef = linen.Conv
    norm: ModuleDef = linen.BatchNorm

    @linen.compact
    def __call__(self, x: chex.Array) -> chex.Array:
        in_features = jnp.size(x, axis=-1)
        param = einops.reduce(x, "... h w c -> ... 1 1 c", "mean")
        param = self.conv(in_features, kernel_size=(1, 1), use_bias=False, name="conv1")(param)
        param = self.norm(name="norm1")(param)
        param = linen.relu(param)
        param = self.conv(in_features, kernel_size=(1, 1), use_bias=False, name="conv2")(param)
        param = linen.sigmoid(param)
        return param


class CTL(linen.Module):
    """Channel-Time Learning block."""

    features: int
    num_frames: int = 16
    pos_dim: int = 7
    kernel_size: int | tuple[int, int] = 1
    padding: int = 0
    feature_group_count: int = 1
    use_bias: bool = True
    conv: ModuleDef = linen.Conv
    norm: ModuleDef = linen.BatchNorm

    @linen.compact
    def __call__(self, x: chex.Array) -> chex.Array:
        # Calculate temporal-adaptive weights
        param = ParamConv(conv=self.conv, norm=self.norm, name="param_conv")(x)

        # Temporal focus convolution
        x_temporal = self.conv(
            self.features,
            kernel_size=to_2tuple(self.kernel_size),
            padding=to_padding(self.padding),
            feature_group_count=self.feature_group_count,
            use_bias=self.use_bias,
            name="temporal_conv",
        )(x * param)

        x_spatial = IOI(
            self.features,
            num_frames=self.num_frames,
            pos_dim=self.pos_dim,
            conv=self.conv,
            norm=self.norm,
            name="spatial_conv",
        )(x, param)

        return x_temporal + x_spatial


class BasicBlock(linen.Module):
    features: int
    num_frames: int = 16
    stride: int = 1
    pos_dim: int = 7
    conv: ModuleDef = linen.Conv
    norm: ModuleDef = linen.BatchNorm

    @linen.compact
    def __call__(self, x: chex.Array) -> chex.Array:
        if self.stride != 1:
            assert self.stride == 2
            in_features = jnp.size(x, axis=-1)
            x = self.conv(
                in_features,
                kernel_size=(2, 2),
                strides=self.stride,
                feature_group_count=in_features,
                padding="VALID",
                name="downsample.0",
            )(x)
            x = self.norm(name="downsample.1")(x)

        # FIXME: pos_dim is not set here in official impl.
        # is it a bug?
        h = CTL(
            self.features,
            num_frames=self.num_frames,
            # pos_dim=self.pos_dim,
            kernel_size=1,
            padding=0,
            use_bias=False,
            conv=self.conv,
            norm=self.norm,
            name="conv1",
        )(x)
        h = self.norm(name="bn1")(h)
        h = linen.relu(h)

        h = CTL(
            self.features,
            num_frames=self.num_frames,
            pos_dim=self.pos_dim,
            kernel_size=1,
            padding=0,
            use_bias=False,
            conv=self.conv,
            norm=self.norm,
            name="conv2",
        )(h)
        h = self.norm(name="bn2")(h)

        if jnp.shape(x) != jnp.shape(h):
            x = self.conv(self.features, kernel_size=(1, 1), use_bias=False, name="shortcut_conv.0")(x)
            x = self.norm(name="shortcut_conv.1")(x)

        y = linen.relu(x + h)
        return y


class Bottleneck(linen.Module):
    features: int
    num_frames: int = 16
    stride: int = 1
    pos_dim: int = 7
    expansion: int = 4
    conv: ModuleDef = linen.Conv
    norm: ModuleDef = linen.BatchNorm

    @linen.compact
    def __call__(self, x: chex.Array) -> chex.Array:
        if self.stride != 1:
            assert self.stride == 2
            in_features = jnp.size(x, axis=-1)
            x = self.conv(
                in_features,
                kernel_size=(2, 2),
                strides=self.stride,
                feature_group_count=in_features,
                padding="VALID",
                name="downsample.0",
            )(x)
            x = self.norm(name="downsample.1")(x)

        h = self.conv(self.features, kernel_size=(1, 1), use_bias=False, name="conv1")(x)
        h = self.norm(name="bn1")(h)
        h = linen.relu(h)

        h = CTL(
            self.features,
            num_frames=self.num_frames,
            pos_dim=self.pos_dim,
            kernel_size=1,
            padding=0,
            use_bias=False,
            conv=self.conv,
            norm=self.norm,
            name="conv2",
        )(h)
        h = self.norm(name="bn2")(h)
        h = linen.relu(h)

        h = self.conv(self.features * self.expansion, kernel_size=(1, 1), use_bias=False, name="conv3")(h)
        h = self.norm(name="bn3")(h)

        if jnp.shape(x) != jnp.shape(h):
            x = self.conv(self.features * self.expansion, kernel_size=(1, 1), use_bias=False, name="shortcut_conv.0")(x)
            x = self.norm(name="shortcut_conv.1")(x)

        y = linen.relu(x + h)
        return y


class ResNet(linen.Module):
    """SqueezeTime ResNet model.

    Attributes:
        stage_sizes: Number of blocks in each stage.
        block_cls: Residual block class.
        num_classes: Number of classes.
            If 0, the final dense layer is not added.
        num_frames: Number of frames.
        drop_rate: Dropout rate.
        widen_factor: Width factor.
        pos_dims: Positional embedding dimensions.
        dtype: Data type for computation.
        norm_dtype: Data type for normalization.
        param_dtype: Data type for parameters.
    """

    stage_sizes: list[int]
    block_cls: ModuleDef
    num_classes: int = 400
    num_frames: int = 16
    drop_rate: float = 0.5
    widen_factor: float = 1.0
    pos_dims: list[int] = (56, 28, 14, 7)
    dtype: chex.ArrayDType = jnp.float32
    norm_dtype: chex.ArrayDType = jnp.float32
    param_dtype: chex.ArrayDType = jnp.float32

    @linen.compact
    def __call__(self, x: chex.Array, is_training: bool = False) -> chex.Array:
        base_size = int(64 * self.widen_factor)
        conv = partial(linen.Conv, dtype=self.dtype, param_dtype=self.param_dtype)
        norm = partial(
            linen.BatchNorm, use_running_average=not is_training, dtype=self.norm_dtype, param_dtype=self.param_dtype
        )

        x = einops.rearrange(x, "... T H W C -> ... H W (C T)")
        x = conv(base_size, kernel_size=(5, 5), strides=2, padding=to_padding(2), use_bias=False, name="conv1")(x)
        x = norm(name="bn1")(x)
        x = linen.relu(x)
        x = linen.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding=to_padding(1))

        for i, block_size in enumerate(self.stage_sizes):
            for j in range(block_size):
                x = self.block_cls(
                    features=base_size * 2**i,
                    num_frames=self.num_frames,
                    stride=2 if i > 0 and j == 0 else 1,
                    pos_dim=self.pos_dims[i],
                    conv=conv,
                    norm=norm,
                    name=f"layer{i+1}.{j}",
                )(x)

        x = einops.reduce(x, "... H W C -> ... C", "mean")
        x = linen.Dropout(rate=self.drop_rate, deterministic=not is_training)(x)
        if self.num_classes > 0:
            x = linen.Dense(self.num_classes, dtype=self.dtype, param_dtype=self.param_dtype, name="fc")(x)

        return x


def resnet18(**kwargs) -> ResNet:
    return ResNet(stage_sizes=[2, 2, 2, 2], block_cls=BasicBlock, **kwargs)


def resnet34(**kwargs) -> ResNet:
    return ResNet(stage_sizes=[3, 4, 6, 3], block_cls=BasicBlock, **kwargs)


def resnet50(**kwargs) -> ResNet:
    return ResNet(stage_sizes=[3, 4, 6, 3], block_cls=Bottleneck, **kwargs)


def resnet101(**kwargs) -> ResNet:
    return ResNet(stage_sizes=[3, 4, 23, 3], block_cls=Bottleneck, **kwargs)


def resnet152(**kwargs) -> ResNet:
    return ResNet(stage_sizes=[3, 8, 36, 3], block_cls=Bottleneck, **kwargs)


def resnet200(**kwargs) -> ResNet:
    return ResNet(stage_sizes=[3, 24, 36, 3], block_cls=Bottleneck, **kwargs)



In [None]:
# Assign variables/
from torch import nn
from einops import rearrange
import jax.numpy as jnp
from jax import tree_util
from flax import traverse_util

from mmaction.models.backbones.SqueezeTime import Conv2d as SqueezeTimeConv2d
from mmaction.models.backbones.SqueezeTime_ava import Conv2d as SqueezeTimeAvaConv2d


def tensor_to_array(tensor):
    return jnp.array(tensor.detach().cpu().numpy())


def convert_dense(m):
    state = tree_util.tree_map(tensor_to_array, m.state_dict())
    params = {"kernel": rearrange(state["weight"], "outC inC -> inC outC")}
    if "bias" in state:
        params["bias"] = state["bias"]
    return params, {}

def convert_conv(m):
    state = tree_util.tree_map(tensor_to_array, m.state_dict())
    params = {"kernel": rearrange(state["weight"], "outC inC ... -> ... inC outC")}
    if "bias" in state:
        params["bias"] = state["bias"]
    return params, {}


def convert_bn(m):
    state = tree_util.tree_map(tensor_to_array, m.state_dict())
    params = {}
    batch_stats = {}
    if "weight" in state:
        params["scale"] = state["weight"]
    if "bias" in state:
        params["bias"] = state["bias"]
    if "running_mean" in state:
        batch_stats["mean"] = state["running_mean"]
    if "running_var" in state:
        batch_stats["var"] = state["running_var"]
    return params, batch_stats


def convert_ctl(m):
    params = {}
    batch_stats = {}

    # param conv.
    norm_params, norm_batch_stats = convert_bn(m.param_conv[2])
    params["param_conv"] = {}
    params["param_conv"]["conv1"], _ = convert_conv(m.param_conv[1])
    params["param_conv"]["norm1"] = norm_params
    params["param_conv"]["conv2"], _ = convert_conv(m.param_conv[4])
    if norm_batch_stats:
        batch_stats["param_conv"] = {"norm1": norm_batch_stats}

    # temporal conv.
    params["temporal_conv"], _ = convert_conv(m.temporal_conv)

    # spatial conv.
    params["spatial_conv"] = {}
    batch_stats["spatial_conv"] = {}

    # short conv in spatial conv.
    params["spatial_conv"]["short_conv"], _ = convert_conv(m.spatial_conv.short_conv)

    # global conv in spatial conv.
    glo_conv = m.spatial_conv.glo_conv
    glo_params, glo_batch_stats = {}, {}
    glo_params["conv1"], _ = convert_conv(glo_conv[0])
    glo_params["norm1"], glo_batch_stats["norm1"] = convert_bn(glo_conv[1])
    glo_params["conv2"], _ = convert_conv(glo_conv[3])
    glo_params["norm2"], glo_batch_stats["norm2"] = convert_bn(glo_conv[4])
    glo_params["conv3"], _ = convert_conv(glo_conv[6])

    # positional encoding
    glo_params["pos_embed"] = rearrange(
        tensor_to_array(m.spatial_conv.pos_embed),
        "1 t h w -> h w t",
    )

    params["spatial_conv"]["glo_conv"] = glo_params
    batch_stats["spatial_conv"]["glo_conv"] = glo_batch_stats

    return params, batch_stats


def get_variables_from_torch_model(torch_model: nn.Module):
    def f(m):
        new_params = {}
        new_batch_stats = {}
        for name, module in m.named_children():
            if isinstance(module, nn.Linear):
                params, batch_stats = convert_dense(module)
            elif isinstance(module, nn.Conv2d):
                params, batch_stats = convert_conv(module)
            elif isinstance(module, nn.BatchNorm2d):
                params, batch_stats = convert_bn(module)
            elif isinstance(module, (SqueezeTimeConv2d, SqueezeTimeAvaConv2d)):
                params, batch_stats = convert_ctl(module)
            else:
                params, batch_stats = f(module)

            new_params[name] = params
            new_batch_stats[name] = batch_stats
        return new_params, new_batch_stats

    def remove_empty_dict(d):
        y = {}
        for k, v in d.items():
            if isinstance(v, dict):
                if v:
                    y[k] = remove_empty_dict(v)
            else:
                y[k] = v
        return y

    params, batch_stats = f(torch_model)
    params = remove_empty_dict(params)
    batch_stats = remove_empty_dict(batch_stats)
    return {"params": params, "batch_stats": batch_stats}


def assign_variables_from_torch_model(variables, torch_model):
    to_assign = get_variables_from_torch_model(torch_model)
    assigned = {}
    for key, col in variables.items():
        flatten = traverse_util.flatten_dict(to_assign[key], sep=".")

        new_col = {}
        for name_tuple, array in traverse_util.flatten_dict(col).items():
            name = ".".join(name_tuple)
            if name in flatten:
                assert (
                    array.shape == flatten[name].shape
                ), f"Shape mismatch: {name} ({array.shape} vs {flatten[name].shape})."
                new_col["/".join(name_tuple)] = flatten[name]
            else:
                raise RuntimeError(f"{name} is not found from PyTorch model.")

        assigned[key] = traverse_util.unflatten_dict(new_col, sep="/")
    return assigned


In [None]:
from mmaction.apis import init_recognizer

def create_torch_model(config_file, checkpoint_file):
    torch_model = init_recognizer(config_file, checkpoint_file, device="cpu")

    m = torch_model.backbone.net
    m.add_module("avg_pool", torch_model.cls_head.avg_pool)
    m.add_module("fc", torch_model.cls_head.fc_cls)
    return m

In [None]:
from pathlib import Path
import pickle
import numpy
import torch
import jax
import jax.random as jr
from einops import rearrange


def convert(out_file, config_file, checkpoint_file, num_classes):
    flax_model = resnet50(num_classes=num_classes, num_frames=16, pos_dims=(56, 28, 14, 7))
    torch_model = create_torch_model(
        config_file=config_file,
        checkpoint_file=checkpoint_file,
    )

    input_array = jr.uniform(jr.PRNGKey(0), (1, 16, 224, 224, 3))
    input_tensor = torch.from_numpy(numpy.array(input_array))
    input_tensor = rearrange(input_tensor, "b t h w c -> b c t h w")
    input_tensor = input_tensor.contiguous()

    # initialize.
    variables = flax_model.init(jr.PRNGKey(0), input_array)
    variables = assign_variables_from_torch_model(variables, torch_model)

    # compute outputs.
    output_array = flax_model.apply(variables, input_array)

    output_tensor = torch_model(input_tensor)
    output_tensor = torch_model.avg_pool(output_tensor)
    output_tensor = output_tensor.view(output_tensor.shape[0], -1)
    output_tensor = torch_model.fc(output_tensor)
    output_tensor = output_tensor.detach().cpu().numpy()
    # output_tensor = rearrange(output_tensor, "b c h w -> b h w c")

    # compare outputs.
    print(checkpoint_file)
    print("=> Max diff.", jnp.abs(output_array - output_tensor).max())
    print("=> Avg diff.", jnp.abs(output_array - output_tensor).mean())


    Path(out_file).parent.mkdir(parents=True, exist_ok=True)
    variables = jax.tree_util.tree_map(numpy.array, variables)
    Path(out_file).write_bytes(pickle.dumps(variables))


In [None]:
convert(
    "converted/k400.pkl",
    config_file="SqueezeTime/configs/recognition/SqueezeTime/SqueezeTime_K400.py",
    checkpoint_file="ckpts/SqueezeTime_K400_71.64.pth",
    num_classes=400,
    )

convert(
    "converted/k600.pkl",
    config_file="SqueezeTime/configs/recognition/SqueezeTime/SqueezeTime_K600.py",
    checkpoint_file="ckpts/SqueezeTime_K600_76.06.pth",
    num_classes=600,
)

convert(
    "converted/hmdb51.pkl",
    config_file="SqueezeTime/configs/recognition/SqueezeTime/SqueezeTime_HMDB51.py",
    checkpoint_file="ckpts/SqueezeTime_HMDB51_65.56.pth",
    num_classes=51,
)