In [None]:
!pip install av

In [None]:
!pip install einops

In [None]:
f"""
Code inspired by:
https://github.com/Atze00/MoViNet-pytorch
https://pytorch.org/vision/stable/_modules/torchvision/models/mobilenetv2.html
https://pytorch.org/vision/stable/_modules/torchvision/models/mobilenetv3.html
"""
from collections import OrderedDict
import torch
from torch.nn.modules.utils import _triple, _pair
import torch.nn.functional as Ff
from typing import Any, Callable, Optional, Tuple, Union
from einops import rearrange
from torch import nn, Tensor


class Hardsigmoid(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        x = (0.2 * x + 0.5).clamp(min=0.0, max=1.0)
        return x


class Swish(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        return x * torch.sigmoid(x)


class CausalModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.activation = None

    def reset_activation(self) -> None:
        self.activation = None


class TemporalCGAvgPool3D(CausalModule):
    def __init__(self,) -> None:
        super().__init__()
        self.n_cumulated_values = 0
        self.register_forward_hook(self._detach_activation)

    def forward(self, x: Tensor) -> Tensor:
        input_shape = x.shape
        device = x.device
        cumulative_sum = torch.cumsum(x, dim=2)
        if self.activation is None:
            self.activation = cumulative_sum[:, :, -1:].clone()
        else:
            cumulative_sum += self.activation
            self.activation = cumulative_sum[:, :, -1:].clone()
        divisor = (torch.arange(1, input_shape[2]+1,
                   device=device)[None, None, :, None, None]
                   .expand(x.shape))
        x = cumulative_sum / (self.n_cumulated_values + divisor)
        self.n_cumulated_values += input_shape[2]
        return x

    @staticmethod
    def _detach_activation(module: CausalModule,
                           input: Tensor,
                           output: Tensor) -> None:
        module.activation.detach_()

    def reset_activation(self) -> None:
        super().reset_activation()
        self.n_cumulated_values = 0


class Conv2dBNActivation(nn.Sequential):
    def __init__(
                 self,
                 in_planes: int,
                 out_planes: int,
                 *,
                 kernel_size: Union[int, Tuple[int, int]],
                 padding: Union[int, Tuple[int, int]],
                 stride: Union[int, Tuple[int, int]] = 1,
                 groups: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None,
                 activation_layer: Optional[Callable[..., nn.Module]] = None,
                 **kwargs: Any,
                 ) -> None:
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        if norm_layer is None:
            norm_layer = nn.Identity
        if activation_layer is None:
            activation_layer = nn.Identity
        self.kernel_size = kernel_size
        self.stride = stride
        dict_layers = OrderedDict({
                            "conv2d": nn.Conv2d(in_planes, out_planes,
                                                kernel_size=kernel_size,
                                                stride=stride,
                                                padding=padding,
                                                groups=groups,
                                                **kwargs),
                            "norm": norm_layer(out_planes, eps=0.001),
                            "act": activation_layer()
                            })

        self.out_channels = out_planes
        super(Conv2dBNActivation, self).__init__(dict_layers)


class Conv3DBNActivation(nn.Sequential):
    def __init__(
                 self,
                 in_planes: int,
                 out_planes: int,
                 *,
                 kernel_size: Union[int, Tuple[int, int, int]],
                 padding: Union[int, Tuple[int, int, int]],
                 stride: Union[int, Tuple[int, int, int]] = 1,
                 groups: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None,
                 activation_layer: Optional[Callable[..., nn.Module]] = None,
                 **kwargs: Any,
                 ) -> None:
        kernel_size = _triple(kernel_size)
        stride = _triple(stride)
        padding = _triple(padding)
        if norm_layer is None:
            norm_layer = nn.Identity
        if activation_layer is None:
            activation_layer = nn.Identity
        self.kernel_size = kernel_size
        self.stride = stride

        dict_layers = OrderedDict({
                                "conv3d": nn.Conv3d(in_planes, out_planes,
                                                    kernel_size=kernel_size,
                                                    stride=stride,
                                                    padding=padding,
                                                    groups=groups,
                                                    **kwargs),
                                "norm": norm_layer(out_planes, eps=0.001),
                                "act": activation_layer()
                                })

        self.out_channels = out_planes
        super(Conv3DBNActivation, self).__init__(dict_layers)


class ConvBlock3D(CausalModule):
    def __init__(
            self,
            in_planes: int,
            out_planes: int,
            *,
            kernel_size: Union[int, Tuple[int, int, int]],
            tf_like: bool,
            causal: bool,
            conv_type: str,
            padding: Union[int, Tuple[int, int, int]] = 0,
            stride: Union[int, Tuple[int, int, int]] = 1,
            norm_layer: Optional[Callable[..., nn.Module]] = None,
            activation_layer: Optional[Callable[..., nn.Module]] = None,
            bias: bool = False,
            **kwargs: Any,
            ) -> None:
        super().__init__()
        kernel_size = _triple(kernel_size)
        stride = _triple(stride)
        padding = _triple(padding)
        self.conv_2 = None
        if tf_like:
            if kernel_size[0] % 2 == 0:
                raise ValueError('tf_like supports only odd'
                                 + ' kernels for temporal dimension')
            padding = ((kernel_size[0]-1)//2, 0, 0)
            if stride[0] != 1:
                raise ValueError('illegal stride value, tf like supports'
                                 + ' only stride == 1 for temporal dimension')
            if stride[1] > kernel_size[1] or stride[2] > kernel_size[2]:
                raise ValueError('tf_like supports only'
                                 + '  stride <= of the kernel size')

        if causal is True:
            padding = (0, padding[1], padding[2])
        if conv_type != "2plus1d" and conv_type != "3d":
            raise ValueError("only 2plus2d or 3d are "
                             + "allowed as 3d convolutions")

        if conv_type == "2plus1d":
            self.conv_1 = Conv2dBNActivation(in_planes,
                                             out_planes,
                                             kernel_size=(kernel_size[1],
                                                          kernel_size[2]),
                                             padding=(padding[1],
                                                      padding[2]),
                                             stride=(stride[1], stride[2]),
                                             activation_layer=activation_layer,
                                             norm_layer=norm_layer,
                                             bias=bias,
                                             **kwargs)
            if kernel_size[0] > 1:
                self.conv_2 = Conv2dBNActivation(in_planes,
                                                 out_planes,
                                                 kernel_size=(kernel_size[0],
                                                              1),
                                                 padding=(padding[0], 0),
                                                 stride=(stride[0], 1),
                                                 activation_layer=activation_layer,
                                                 norm_layer=norm_layer,
                                                 bias=bias,
                                                 **kwargs)
        elif conv_type == "3d":
            self.conv_1 = Conv3DBNActivation(in_planes,
                                             out_planes,
                                             kernel_size=kernel_size,
                                             padding=padding,
                                             activation_layer=activation_layer,
                                             norm_layer=norm_layer,
                                             stride=stride,
                                             bias=bias,
                                             **kwargs)
        self.padding = padding
        self.kernel_size = kernel_size
        self.dim_pad = self.kernel_size[0]-1
        self.stride = stride
        self.causal = causal
        self.conv_type = conv_type
        self.tf_like = tf_like

    def _forward(self, x: Tensor) -> Tensor:
        device = x.device
        if self.dim_pad > 0 and self.conv_2 is None and self.causal is True:
            x = self._cat_stream_buffer(x, device)
        shape_with_buffer = x.shape
        if self.conv_type == "2plus1d":
            x = rearrange(x, "b c t h w -> (b t) c h w")
        x = self.conv_1(x)
        if self.conv_type == "2plus1d":
            x = rearrange(x,
                          "(b t) c h w -> b c t h w",
                          t=shape_with_buffer[2])

            if self.conv_2 is not None:
                if self.dim_pad > 0 and self.causal is True:
                    x = self._cat_stream_buffer(x, device)
                w = x.shape[-1]
                x = rearrange(x, "b c t h w -> b c t (h w)")
                x = self.conv_2(x)
                x = rearrange(x, "b c t (h w) -> b c t h w", w=w)
        return x

    def forward(self, x: Tensor) -> Tensor:
        if self.tf_like:
            x = same_padding(x, x.shape[-2], x.shape[-1],
                             self.stride[-2], self.stride[-1],
                             self.kernel_size[-2], self.kernel_size[-1])
        x = self._forward(x)
        return x

    def _cat_stream_buffer(self, x: Tensor, device: torch.device) -> Tensor:
        if self.activation is None:
            self._setup_activation(x.shape)
        x = torch.cat((self.activation.to(device), x), 2)
        self._save_in_activation(x)
        return x

    def _save_in_activation(self, x: Tensor) -> None:
        assert self.dim_pad > 0
        self.activation = x[:, :, -self.dim_pad:, ...].clone().detach()

    def _setup_activation(self, input_shape: Tuple[float, ...]) -> None:
        assert self.dim_pad > 0
        self.activation = torch.zeros(*input_shape[:2],  
                                      self.dim_pad,
                                      *input_shape[3:])


class SqueezeExcitation(nn.Module):

    def __init__(self, input_channels: int,  
                 activation_2: nn.Module,
                 activation_1: nn.Module,
                 conv_type: str,
                 causal: bool,
                 squeeze_factor: int = 4,
                 bias: bool = True) -> None:
        super().__init__()
        self.causal = causal
        se_multiplier = 2 if causal else 1
        squeeze_channels = _make_divisible(input_channels
                                           // squeeze_factor
                                           * se_multiplier, 8)
        self.temporal_cumualtive_GAvg3D = TemporalCGAvgPool3D()
        self.fc1 = ConvBlock3D(input_channels*se_multiplier,
                               squeeze_channels,
                               kernel_size=(1, 1, 1),
                               padding=0,
                               tf_like=False,
                               causal=causal,
                               conv_type=conv_type,
                               bias=bias)
        self.activation_1 = activation_1()
        self.activation_2 = activation_2()
        self.fc2 = ConvBlock3D(squeeze_channels,
                               input_channels,
                               kernel_size=(1, 1, 1),
                               padding=0,
                               tf_like=False,
                               causal=causal,
                               conv_type=conv_type,
                               bias=bias)

    def _scale(self, input: Tensor) -> Tensor:
        if self.causal:
            x_space = torch.mean(input, dim=[3, 4], keepdim=True)
            scale = self.temporal_cumualtive_GAvg3D(x_space)
            scale = torch.cat((scale, x_space), dim=1)
        else:
            scale = F.adaptive_avg_pool3d(input, 1)
        scale = self.fc1(scale)
        scale = self.activation_1(scale)
        scale = self.fc2(scale)
        return self.activation_2(scale)

    def forward(self, input: Tensor) -> Tensor:
        scale = self._scale(input)
        return scale * input


def _make_divisible(v: float,
                    divisor: int,
                    min_value: Optional[int] = None
                    ) -> int:
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


def same_padding(x: Tensor,
                 in_height: int, in_width: int,
                 stride_h: int, stride_w: int,
                 filter_height: int, filter_width: int) -> Tensor:
    if (in_height % stride_h == 0):
        pad_along_height = max(filter_height - stride_h, 0)
    else:
        pad_along_height = max(filter_height - (in_height % stride_h), 0)
    if (in_width % stride_w == 0):
        pad_along_width = max(filter_width - stride_w, 0)
    else:
        pad_along_width = max(filter_width - (in_width % stride_w), 0)
    pad_top = pad_along_height // 2
    pad_bottom = pad_along_height - pad_top
    pad_left = pad_along_width // 2
    pad_right = pad_along_width - pad_left
    padding_pad = (pad_left, pad_right, pad_top, pad_bottom)
    return torch.nn.functional.pad(x, padding_pad)


class tfAvgPool3D(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.avgf = nn.AvgPool3d((1, 3, 3), stride=(1, 2, 2))

    def forward(self, x: Tensor) -> Tensor:
        if x.shape[-1] != x.shape[-2]:
            raise RuntimeError('only same shape for h and w ' +
                               'are supported by avg with tf_like')
        if x.shape[-1] != x.shape[-2]:
            raise RuntimeError('only same shape for h and w ' +
                               'are supported by avg with tf_like')
        f1 = x.shape[-1] % 2 != 0
        if f1:
            padding_pad = (0, 0, 0, 0)
        else:
            padding_pad = (0, 1, 0, 1)
        x = torch.nn.functional.pad(x, padding_pad)
        if f1:
            x = torch.nn.functional.avg_pool3d(x,
                                               (1, 3, 3),
                                               stride=(1, 2, 2),
                                               count_include_pad=False,
                                               padding=(0, 1, 1))
        else:
            x = self.avgf(x)
            x[..., -1] = x[..., -1] * 9/6
            x[..., -1, :] = x[..., -1, :] * 9/6
        return x


class BasicBneck(nn.Module):
    def __init__(self,
                 cfg: "CfgNode",
                 causal: bool,
                 tf_like: bool,
                 conv_type: str,
                 norm_layer: Optional[Callable[..., nn.Module]] = None,
                 activation_layer: Optional[Callable[..., nn.Module]] = None,
                 ) -> None:
        super().__init__()
        assert type(cfg.stride) is tuple
        if (not cfg.stride[0] == 1
                or not (1 <= cfg.stride[1] <= 2)
                or not (1 <= cfg.stride[2] <= 2)):
            raise ValueError('illegal stride value')
        self.res = None

        layers = []
        if cfg.expanded_channels != cfg.out_channels:
            # expand
            self.expand = ConvBlock3D(
                in_planes=cfg.input_channels,
                out_planes=cfg.expanded_channels,
                kernel_size=(1, 1, 1),
                padding=(0, 0, 0),
                causal=causal,
                conv_type=conv_type,
                tf_like=tf_like,
                norm_layer=norm_layer,
                activation_layer=activation_layer
                )

        self.deep = ConvBlock3D(
            in_planes=cfg.expanded_channels,
            out_planes=cfg.expanded_channels,
            kernel_size=cfg.kernel_size,
            padding=cfg.padding,
            stride=cfg.stride,
            groups=cfg.expanded_channels,
            causal=causal,
            conv_type=conv_type,
            tf_like=tf_like,
            norm_layer=norm_layer,
            activation_layer=activation_layer
            )

        self.se = SqueezeExcitation(cfg.expanded_channels,
                                    causal=causal,
                                    activation_1=activation_layer,
                                    activation_2=(nn.Sigmoid
                                                  if conv_type == "3d"
                                                  else Hardsigmoid),
                                    conv_type=conv_type
                                    )

        self.project = ConvBlock3D(
            cfg.expanded_channels,
            cfg.out_channels,
            kernel_size=(1, 1, 1),
            padding=(0, 0, 0),
            causal=causal,
            conv_type=conv_type,
            tf_like=tf_like,
            norm_layer=norm_layer,
            activation_layer=nn.Identity
            )

        if not (cfg.stride == (1, 1, 1)
                and cfg.input_channels == cfg.out_channels):
            if cfg.stride != (1, 1, 1):
                if tf_like:
                    layers.append(tfAvgPool3D())
                else:
                    layers.append(nn.AvgPool3d((1, 3, 3),
                                  stride=cfg.stride,
                                  padding=cfg.padding_avg))
            layers.append(ConvBlock3D(
                    in_planes=cfg.input_channels,
                    out_planes=cfg.out_channels,
                    kernel_size=(1, 1, 1),
                    padding=(0, 0, 0),
                    norm_layer=norm_layer,
                    activation_layer=nn.Identity,
                    causal=causal,
                    conv_type=conv_type,
                    tf_like=tf_like
                    ))
            self.res = nn.Sequential(*layers)

        self.alpha = nn.Parameter(torch.tensor(0.0), requires_grad=True)

    def forward(self, input: Tensor) -> Tensor:
        if self.res is not None:
            residual = self.res(input)
        else:
            residual = input
        if self.expand is not None:
            x = self.expand(input)
        else:
            x = input
        x = self.deep(x)
        x = self.se(x)
        x = self.project(x)
        result = residual + self.alpha * x
        return result


class MoViNet(nn.Module):
    def __init__(self,
                 cfg: "CfgNode",
                 causal: bool = True,
                 pretrained: bool = False,
                 num_classes: int = 16,
                 conv_type: str = "3d",
                 tf_like: bool = False
                 ) -> None:
        super().__init__()
        """
        causal: causal mode
        pretrained: pretrained models
        If pretrained is True:
            num_classes is set to 600,
            conv_type is set to "3d" if causal is False,
                "2plus1d" if causal is True
            tf_like is set to True
        num_classes: number of classes for classifcation
        conv_type: type of convolution either 3d or 2plus1d
        tf_like: tf_like behaviour, basically same padding for convolutions
        """
        if pretrained:
            tf_like = True
            num_classes = 600
            conv_type = "2plus1d" if causal else "3d"
        blocks_dic = OrderedDict()

        norm_layer = nn.BatchNorm3d if conv_type == "3d" else nn.BatchNorm2d
        activation_layer = Swish if conv_type == "3d" else nn.Hardswish

        self.conv1 = ConvBlock3D(
            in_planes=cfg.conv1.input_channels,
            out_planes=cfg.conv1.out_channels,
            kernel_size=cfg.conv1.kernel_size,
            stride=cfg.conv1.stride,
            padding=cfg.conv1.padding,
            causal=causal,
            conv_type=conv_type,
            tf_like=tf_like,
            norm_layer=norm_layer,
            activation_layer=activation_layer
            )

        for i, block in enumerate(cfg.blocks):
            for j, basicblock in enumerate(block):
                blocks_dic[f"b{i}_l{j}"] = BasicBneck(basicblock,
                                                      causal=causal,
                                                      conv_type=conv_type,
                                                      tf_like=tf_like,
                                                      norm_layer=norm_layer,
                                                      activation_layer=activation_layer
                                                      )
        self.blocks = nn.Sequential(blocks_dic)
      
        self.conv7 = ConvBlock3D(
            in_planes=cfg.conv7.input_channels,
            out_planes=cfg.conv7.out_channels,
            kernel_size=cfg.conv7.kernel_size,
            stride=cfg.conv7.stride,
            padding=cfg.conv7.padding,
            causal=causal,
            conv_type=conv_type,
            tf_like=tf_like,
            norm_layer=norm_layer,
            activation_layer=activation_layer
            )
       
        self.classifier = nn.Sequential(
        
            ConvBlock3D(cfg.conv7.out_channels,
                        cfg.dense9.hidden_dim,
                        kernel_size=(1, 1, 1),
                        tf_like=tf_like,
                        causal=causal,
                        conv_type=conv_type,
                        bias=True),
            Swish(),
            nn.Dropout(p=0.5, inplace=True),
      
            ConvBlock3D(cfg.dense9.hidden_dim,
                        num_classes,
                        kernel_size=(1, 1, 1),
                        tf_like=tf_like,
                        causal=causal,
                        conv_type=conv_type,
                        bias=True),
        )
        if causal:
            self.cgap = TemporalCGAvgPool3D()
        if pretrained:
            if causal:
                if cfg.name not in ["A0", "A1", "A2"]:
                    raise ValueError("Only A0,A1,A2 streaming" +
                                     "networks are available pretrained")
                state_dict = (torch.hub
                              .load_state_dict_from_url(cfg.stream_weights))
            else:
                state_dict = torch.hub.load_state_dict_from_url(cfg.weights)
            self.load_state_dict(state_dict)
        else:
            self.apply(self._weight_init)
        self.causal = causal

    def avg(self, x: Tensor) -> Tensor:
        if self.causal:
            avg = F.adaptive_avg_pool3d(x, (x.shape[2], 1, 1))
            avg = self.cgap(avg)[:, :, -1:]
        else:
            avg = F.adaptive_avg_pool3d(x, 1)
        return avg

    @staticmethod
    def _weight_init(m): 
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.blocks(x)
        x = self.conv7(x)
        x = self.avg(x)
        x = self.classifier(x)
        x = x.flatten(1)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)

    @staticmethod
    def _clean_activation_buffers(m):
        if issubclass(type(m), CausalModule):
            m.reset_activation()

    def clean_activation_buffers(self) -> None:
        self.apply(self._clean_activation_buffers)

In [None]:
!pip install fvcore

In [None]:
"""
Inspired by
https://github.com/PeizeSun/SparseR-CNN/blob/dff4c43a9526a6d0d2480abc833e78a7c29ddb1a/detectron2/config/defaults.py
"""
from fvcore.common.config import CfgNode as CN

def fill_SE_config(conf, input_channels,
                    out_channels,
                    expanded_channels,
                    kernel_size,
                    stride,
                    padding,
                    padding_avg,
):
    conf.expanded_channels =expanded_channels
    conf.padding_avg= padding_avg
    fill_conv(conf,input_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
)

def fill_conv(conf, input_channels,
                out_channels,
                kernel_size,
                stride,
                padding,):
    conf.input_channels = input_channels
    conf.out_channels = out_channels
    conf.kernel_size = kernel_size
    conf.stride = stride
    conf.padding = padding




_C = CN()

_C.MODEL = CN()



###################
#### MoViNetA2 ####
###################

_C.MODEL.MoViNetA2 = CN()
_C.MODEL.MoViNetA2.name = "A2"
_C.MODEL.MoViNetA2.weights = "https://github.com/Atze00/MoViNet-pytorch/blob/main/weights/modelA2_statedict_v3?raw=true"
_C.MODEL.MoViNetA2.stream_weights = "https://github.com/Atze00/MoViNet-pytorch/blob/main/weights/modelA2_stream_statedict_v3?raw=true"

_C.MODEL.MoViNetA2.conv1 = CN()
fill_conv(_C.MODEL.MoViNetA2.conv1, 3,16,(1,3,3),(1,2,2),(0,1,1))


_C.MODEL.MoViNetA2.blocks = [ [CN() for _ in range(3)],
                              [CN() for _ in range(5)],
                              [CN() for _ in range(5)],
                              [CN() for _ in range(6)],
                              [CN() for _ in range(7)]]


fill_SE_config(_C.MODEL.MoViNetA2.blocks[0][0], 16, 16, 40, (1,5,5), (1,2,2), (0,2,2), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[0][1], 16, 16, 40, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[0][2], 16, 16, 64, (3,3,3), (1,1,1), (1,1,1), (0,1,1))


fill_SE_config(_C.MODEL.MoViNetA2.blocks[1][0], 16, 40, 96, (3,3,3), (1,2,2), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[1][1], 40, 40, 120, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[1][2], 40, 40, 96, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[1][3], 40, 40, 96, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[1][4], 40, 40, 120, (3,3,3), (1,1,1), (1,1,1), (0,1,1))


fill_SE_config(_C.MODEL.MoViNetA2.blocks[2][0], 40, 72, 240, (5,3,3), (1,2,2), (2,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[2][1], 72, 72, 160, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[2][2], 72, 72, 240, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[2][3], 72, 72, 192, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[2][4], 72, 72, 240, (3,3,3), (1,1,1), (1,1,1), (0,1,1))


fill_SE_config(_C.MODEL.MoViNetA2.blocks[3][0], 72, 72, 240, (5,3,3), (1,1,1), (2,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[3][1], 72, 72, 240, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[3][2], 72, 72, 240, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[3][3], 72, 72, 240, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[3][4], 72, 72, 144, (1,5,5), (1,1,1), (0,2,2), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[3][5], 72, 72, 240, (3,3,3), (1,1,1), (1,1,1), (0,1,1))


fill_SE_config(_C.MODEL.MoViNetA2.blocks[4][0], 72 , 144, 480, (5,3,3), (1,2,2), (2,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[4][1], 144, 144, 384, (1,5,5), (1,1,1), (0,2,2), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[4][2], 144, 144, 384, (1,5,5), (1,1,1), (0,2,2), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[4][3], 144, 144, 480, (1,5,5), (1,1,1), (0,2,2), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[4][4], 144, 144, 480, (1,5,5), (1,1,1), (0,2,2), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[4][5], 144, 144, 480, (3,3,3), (1,1,1), (1,1,1), (0,1,1))
fill_SE_config(_C.MODEL.MoViNetA2.blocks[4][6], 144, 144, 576, (1,3,3), (1,1,1), (0,1,1), (0,1,1))

_C.MODEL.MoViNetA2.conv7= CN()
fill_conv(_C.MODEL.MoViNetA2.conv7, 144,640,(1,1,1),(1,1,1),(0,0,0))

_C.MODEL.MoViNetA2.dense9= CN()
_C.MODEL.MoViNetA2.dense9.hidden_dim = 2048



In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_hub as hub
tensorflow_device = "/device:GPU:0"
def get_file_paths(directory):
    file_paths = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_paths.append(os.path.join(root, file))
    return file_paths
with tf.device(tensorflow_device):
    #path to style images dataset
    dataset_directory = '/kaggle/input/gan10-gan10/gan_10'
    file_paths = get_file_paths(dataset_directory)
    #NST model used for style transfer
    hub_module = hub.load('https://kaggle.com/models/google/arbitrary-image-stylization-v1/frameworks/TensorFlow1/variations/256/versions/2')

In [None]:
import numpy as np
import tensorflow as tf
import cv2
import random

def nstAugment(nst_model, style_list, prob, gan_dim, output_dim, frames_front, frames_side, seed):
    """NST based augmentation.
    
    Args:
    nst_model: NST model to be used for augmentation.
    style_list: A list containing the path of all the style images.
    prob: Probability of augmentation.
    gan_dim: Input dimension of NST model.
    output_dim: Required dimension for the frame.
    frames_front: List of frames to be augmented (Front View).
    frames_side: List of frames to be augmented (Side View).
    seed: Random seed to ensure uniformity of augmentations
    
    Returns:
    Augmented clips.
    """
    torch.manual_seed(seed)
    random.seed(seed)
    rand = np.random.random(1)
    if rand < prob:
        # Select a random style image
        torch.manual_seed(seed)
        random.seed(seed)
        rand_index = np.random.randint(0, len(style_list))
        style_image = plt.imread(style_list[rand_index])
        
        # Preprocess the style image
        style_image = style_image.astype(np.float32)[np.newaxis, ...] / 255.
        style_image = tf.image.resize(style_image, (gan_dim, gan_dim))
        
        # Preprocess the frame and apply style transfer
        frames_front = [
            cv2.resize(
                np.squeeze(
                    hub_module(
                        tf.constant(
                            tf.image.resize(
                                tf.cast(frame, tf.float32)[tf.newaxis, ...] / 255,
                                (gan_dim, gan_dim)
                            )
                        ),
                        tf.constant(style_image)
                    )
                ),
                (output_dim, output_dim)
            ) for frame in frames_front
        ]
        frames_side = [
            cv2.resize(
                np.squeeze(
                    hub_module(
                        tf.constant(
                            tf.image.resize(
                                tf.cast(frame, tf.float32)[tf.newaxis, ...] / 255,
                                (gan_dim, gan_dim)
                            )
                        ),
                        tf.constant(style_image)
                    )
                ),
                (output_dim, output_dim)
            )
            for frame in frames_side
        ]
        # Convert frame to numpy array
        frames_front = np.array(frames_front)
        frames_side = np.array(frames_side)
        
        frames_front = np.squeeze(frames_front)
        frames_side = np.squeeze(frames_side)
        return frames_front, frames_side
    else:
        # If augmentation is not applied, return the original frame
        frames_front  = np.array(frames_front)
        frames_side = np.array(frames_side)
        return frames_front, frames_side

In [None]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import pandas as pd
import random


class FrontSideVideoDataset(Dataset):
    def __init__(self, side_data_dir, front_data_dir, num_frames, frame_rate=1, step_between_clips=0, fold=1, train=True, transform=None,nst = None):
        self.front_data_dir = front_data_dir
        self.side_data_dir = side_data_dir
        self.num_frames = num_frames
        self.frame_rate = frame_rate
        self.step_between_clips = step_between_clips
        self.fold = fold
        self.train = train
        self.transform = transform
        self.nst = nst

        # Get a list of video file paths in both front and side directories
        self.front_video_paths = self._get_video_paths(self.front_data_dir)
        self.side_video_paths = self._get_video_paths(self.side_data_dir)

        # Determine the total number of video clips
        self.total_clips = min(len(self.front_video_paths), len(self.side_video_paths))

    def _get_video_paths(self, data_dir):
        video_paths = []
        for class_name in os.listdir(data_dir):
            class_dir = os.path.join(data_dir, class_name)
            for filename in os.listdir(class_dir):

                if filename.endswith(".avi"):  
                    video_paths.append((os.path.join(class_dir, filename),int(class_name)))
        video_paths.sort()
        
        return video_paths

    def __len__(self):
        return self.total_clips

    def __getitem__(self, idx):
            # Get video paths for the front and side views
            front_video_path , label1 = self.front_video_paths[idx]
            side_video_path , label2 = self.side_video_paths[idx]
           
            if label1!=label2:
                print("label mismatch")
                print(label1,label2)
            # Read video frames for front view
            front_frames = self._read_video_frames(front_video_path)

            # Read video frames for side view
            side_frames = self._read_video_frames(side_video_path)

            # Apply transformations to each frame (if specified) 
            seed = np.random.randint(2147483647) 

            #Apply nst
            if self.nst:
                
                front_frames, side_frames = self.nst(frames_front=front_frames, frames_side=side_frames, seed=seed)

            #Apply other basic image augmentation methods
            if self.transform:
                torch.manual_seed(seed)
                random.seed(seed)
                front_frames = [self.transform(frame) for frame in front_frames]
                torch.manual_seed(seed)
                random.seed(seed)
                side_frames = [self.transform(frame) for frame in side_frames]
                    
                    

            # Stack frames into tensors
            front_video_tensor = torch.stack(front_frames)
            side_video_tensor = torch.stack(side_frames)

            front_video_tensor = front_video_tensor.permute(1,0,2,3)
            side_video_tensor = side_video_tensor.permute(1,0,2,3)

            # Assuming you want to return both front and side views
            return side_video_tensor, front_video_tensor , label1


    def _read_video_frames(self, video_path):
        cap = cv2.VideoCapture(video_path)
        frames = []

        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        start_frame_idx = 0

        for frame_idx in range(start_frame_idx, total_frames):
            if frame_idx >= total_frames:
                break

            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()

            if not ret:
                break

            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)

            if len(frames) >= self.num_frames:
                break

        cap.release()
        return frames

In [None]:
import time
import torchvision
import torch.nn.functional as F
from torchvision.transforms import v2
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import torch
from functools import partial

from torchvision.transforms import functional as F
from PIL import Image
import numpy as np

torch.manual_seed(97)
num_frames = 30 
clip_steps = 0
Bs_Train = 2
Bs_Test = 2

gan_model = hub_module
style_list = file_paths
prob = 0.7
input_gan = 256
frame_dim = 224


transform = v2.Compose([

                                 v2.ToTensor(),
                                 v2.RandomErasing(p=0.5,scale=(0.01,0.08),ratio=(0.3,3,3),value=0,inplace=False),
                                 v2.Normalize(mean=[0, 0, 0], std=[255, 255, 255])
                      ])
transform_test = v2.Compose([
                                 v2.ToTensor(),
                                 v2.Normalize(mean=[0, 0, 0], std=[255, 255, 255])
                          ])

#Creating the nst instance
nst_instance = partial(nstAugment, gan_model, style_list, prob, input_gan, frame_dim) 
train_data_dir_side = '/kaggle/input/3mdad-day-new/3mdad side 30/3mdad side 30/Train 30'
train_data_dir_front = '/kaggle/input/3mdad-day-new/3mdad front 30/3mdad front 30/Train 30'
train_dataset = FrontSideVideoDataset(train_data_dir_side,train_data_dir_front,
                                   num_frames,
                                   frame_rate=30,
                                   step_between_clips = clip_steps,
                                   fold=1,
                                   train=True,
                                   transform=transform,
                                   nst = nst_instance
                                   )
val_data_dir_side = '/kaggle/input/3mdad-day-new/3mdad side 30/3mdad side 30/Validation 30'
val_data_dir_front = '/kaggle/input/3mdad-day-new/3mdad front 30/3mdad front 30/Validation 30'
val_dataset = FrontSideVideoDataset(val_data_dir_side,val_data_dir_front,
                                   num_frames,
                                   frame_rate=30,
                                   step_between_clips = clip_steps,
                                   fold=1,
                                   train=False,
                                   transform=transform_test,
                                   nst = None
                                   )
train_loader = DataLoader(train_dataset, batch_size=Bs_Train, shuffle=True, num_workers=0)
val_loader  = DataLoader(val_dataset, batch_size=Bs_Test, shuffle=False, num_workers=0)


In [None]:
!pip install adabelief-pytorch


In [None]:
#calculating classwise weights
path='/kaggle/input/3mdad-day-new/3mdad front 30/3mdad front 30/Train 30'
clas=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
for i in os.listdir(path):
    for j in os.listdir(os.path.join(path,i)):
        clas[int(i)]+=1
sum_=0
for i in clas:
    sum_+=i
print(sum_)

class_wghts=[]
for i in range(16):
    class_wghts.append(sum_/(clas[i]*16))
    
class_wghts = torch.tensor(class_wghts)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MetaLearner(nn.Module):
    pytorch_device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    def __init__(self, expert_side, expert_front):
        super(MetaLearner, self).__init__()
        self.expert_side = expert_side
        self.expert_front = expert_front
        self.fc1 = nn.Linear(1200, 64)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(64, 16)

    def forward(self, side, front):
        pytorch_device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
        self.expert_side.clean_activation_buffers()
        self.expert_front.clean_activation_buffers()
        
        predictions_model1 = self.expert_side(side.to(pytorch_device))
        predictions_model2 = self.expert_front(front.to(pytorch_device))
        
        predictions_from_pretrained_models = torch.cat((predictions_model1, predictions_model2), dim=1)
        
        hidden = F.relu(self.fc1(predictions_from_pretrained_models))
        if self.training:
            hidden = self.dropout(hidden)
        ensemble_preds = self.fc2(hidden)
        out = F.log_softmax(ensemble_preds, dim=1)
        
        self.expert_side.clean_activation_buffers()
        self.expert_front.clean_activation_buffers()
        
        return out



    


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SymmetricCrossEntropyLoss(nn.Module):
    def __init__(self, alpha=0.1, beta=1.0,class_weights=None):
        super(SymmetricCrossEntropyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.class_weights = class_weights 

    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.beta * ce_loss
        return focal_loss

def train_iter(model, data_load, optimizer, loss_val, class_weights):
    pytorch_device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    samples = len(data_load.dataset)
    model.train()
    model.to(pytorch_device)
    optimizer.zero_grad()
    correct = 0
    total = 0

    for i, (side, front, target) in enumerate(data_load):
        # Forward pass
        out = model(side.to(pytorch_device), front.to(pytorch_device))
        # Calculate loss (using symmetric cross entropy)
        loss = SymmetricCrossEntropyLoss(class_weights=class_weights)(out, target.to(pytorch_device))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        _, predicted = torch.max(out, 1)
        total += target.size(0)
        correct += (predicted == target.to(pytorch_device)).sum().item()

        if i % 50 == 0:
            print('[' + '{:5}'.format(i * len(side)) + '/' + '{:5}'.format(samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_load)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
            loss_val.append(loss.item())

    # Calculate and print training accuracy within the loop
    train_acc = 100 * correct / total
    print(f"Training Accuracy: {train_acc:.2f}%")

def evaluate(model, data_load, loss_val, class_weights):
    model.eval()
    pytorch_device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    samples = len(data_load.dataset)
    csamp = 0
    tloss = 0
    with torch.no_grad():
        for side, front, target in data_load:
            out = model(side.to(pytorch_device), front.to(pytorch_device))
            # Calculate loss (using symmetric cross entropy)
            loss = SymmetricCrossEntropyLoss(class_weights=class_weights)(out, target.to(pytorch_device))
            _, pred = torch.max(out, dim=1)
            tloss += loss.item()
            csamp += pred.eq(target.to(pytorch_device)).sum()
    aloss = tloss / samples
    loss_val.append(aloss)
    print('\nAverage test loss: ' + '{:.4f}'.format(aloss) +
          '  Accuracy:' + '{:5}'.format(csamp) + '/' +
          '{:5}'.format(samples) + '(' +
          '{:4.2f}'.format(100.0 * csamp / samples) + '%)\n')


In [None]:
from torch import optim
from adabelief_pytorch import AdaBelief
from torch.optim.lr_scheduler import ReduceLROnPlateau
expert_side = MoViNet(_C.MODEL.MoViNetA2, causal=False, pretrained=True)  # Instantiate and load pre-trained model 1

expert_front = MoViNet(_C.MODEL.MoViNetA2, causal=False, pretrained=True)  # Instantiate and load pre-trained model 2

pytorch_device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    
with torch.cuda.device(pytorch_device):
    best_val_loss = float('inf')
    best_model_path = "/kaggle/working/best_model.pth"

    num_epochs = 14
    traccuracy_val = []

    # Initialize lists to track training and val losses
    trloss, tvloss = [], [] 

    meta_learner = MetaLearner(expert_side, expert_front)
    # Move the meta-learner to the GPU 
    meta_learner.to(pytorch_device)  # Initialize the meta-learner
    meta_learner.train()  # Set the meta-learner to training mode
    meta_learner.load_state_dict(torch.load(f='/kaggle/input/0-7probfrontside10ganrun1/best_model (8).pth'))
    start_time = time.time()

    # Use AdaBelief optimizer
    optimizer = AdaBelief(meta_learner.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, weight_decouple=True, rectify=True)
    scheduler=ReduceLROnPlateau(optimizer,mode='min',factor=0.5, patience=3, verbose=True)
    patience = 5  # Number of epochs with no improvement to wait before stopping
    early_stopping_counter = 0
    # Training loop
    for epoch in range(1, num_epochs + 1):
        print('Epoch:', epoch)

        optimizer.zero_grad()

        train_iter(meta_learner, train_loader, optimizer, trloss,class_wghts.to(pytorch_device))
        evaluate(meta_learner, val_loader, tvloss,class_wghts.to(pytorch_device))
        current_val_loss = tvloss[-1]

        scheduler.step(current_val_loss)

        if current_val_loss < best_val_loss:
            best_val_loss = current_val_loss
            # Save the model with the best validation loss
            torch.save(meta_learner.state_dict(), best_model_path)

            # Reset the early stopping counter
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1

        if early_stopping_counter >= patience:
            print(f'Early stopping triggered after {patience} epochs without improvement.')
            break

    print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')
