In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path
import skimage
import skimage.segmentation
import sklearn.preprocessing
import sklearn.model_selection
import math
import shutil
import pathlib
import glob
import shutil
import uuid
import random
import platform
import torch
import torchvision
import numpy as np
import scipy as sp
import scipy.io
import scipy.signal
import pandas as pd
import networkx
import wfdb
import json
import tqdm
import dill
import pickle
import matplotlib.pyplot as plt

import scipy.stats

import src.data
import sak
import sak.signal.wavelet
import sak.data
import sak.data.augmentation
import sak.visualization
import sak.visualization.signal
import sak.torch
import sak.torch.nn
import sak.torch.nn as nn
import sak.torch.train
import sak.torch.data
import sak.data.preprocessing
import sak.torch.models
import sak.torch.models.lego
import sak.torch.models.variational
import sak.torch.models.classification

from sak.signal import StandardHeader

def smooth(x: np.ndarray, window_size: int, conv_mode: str = 'same'):
    x = np.pad(np.copy(x),(window_size,window_size),'edge')
    window = np.hamming(window_size)/(window_size//2)
    x = np.convolve(x, window, mode=conv_mode)
    x = x[window_size:-window_size]
    return x

# Train

In [3]:
import json
with open('./configurations/WNeXt6Levels.json', 'r') as f:
    execution = json.load(f)
# Define model
model = sak.from_dict(execution["model"]).float()
# plt.figure(figsize=(50,50));model.draw_networkx()

In [4]:
input_files = './pickle/'

##### 2. Load synthetic dataset #####
# 2.1. Load individual segments
P = sak.pickleload(os.path.join(input_files,"Psignal_new.pkl"))
PQ = sak.pickleload(os.path.join(input_files,"PQsignal_new.pkl"))
QRS = sak.pickleload(os.path.join(input_files,"QRSsignal_new.pkl"))
ST = sak.pickleload(os.path.join(input_files,"STsignal_new.pkl"))
T = sak.pickleload(os.path.join(input_files,"Tsignal_new.pkl"))
TP = sak.pickleload(os.path.join(input_files,"TPsignal_new.pkl"))

Pamplitudes = sak.pickleload(os.path.join(input_files,"Pamplitudes_new.pkl"))
PQamplitudes = sak.pickleload(os.path.join(input_files,"PQamplitudes_new.pkl"))
QRSamplitudes = sak.pickleload(os.path.join(input_files,"QRSamplitudes_new.pkl"))
STamplitudes = sak.pickleload(os.path.join(input_files,"STamplitudes_new.pkl"))
Tamplitudes = sak.pickleload(os.path.join(input_files,"Tamplitudes_new.pkl"))
TPamplitudes = sak.pickleload(os.path.join(input_files,"TPamplitudes_new.pkl"))

# 2.2. Get amplitude distribution
Pdistribution   = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(Pamplitudes.values()))))
PQdistribution  = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(PQamplitudes.values()))))
QRSdistribution = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.hstack((np.array(list(QRSamplitudes.values())), 2-np.array(list(QRSamplitudes.values()))))))
STdistribution  = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(STamplitudes.values()))))
Tdistribution   = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(Tamplitudes.values()))))
TPdistribution  = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(TPamplitudes.values()))))

# 2.3. Smooth all
window = 5
P   = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth(  P[k],window)),metric=sak.signal.abs_max) for k in   P}
PQ  = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth( PQ[k],window)),metric=sak.signal.abs_max) for k in  PQ}
QRS = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth(QRS[k],window)),metric=sak.signal.abs_max) for k in QRS}
ST  = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth( ST[k],window)),metric=sak.signal.abs_max) for k in  ST}
T   = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth(  T[k],window)),metric=sak.signal.abs_max) for k in   T}
TP  = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth( TP[k],window)),metric=sak.signal.abs_max) for k in  TP}


##### 3. Load QTDB #####
dataset             = pd.read_csv(os.path.join(input_files,'QTDB','Dataset.csv'), index_col=0)
dataset             = dataset.sort_index(axis=1)
labels              = np.asarray(list(dataset)) # In case no data augmentation is applied
description         = dataset.describe()
group               = {k: '_'.join(k.split('_')[:-1]) for k in dataset}
unique_ids          = list(set([k.split('_')[0] for k in dataset]))

# Load validity
validity            = sak.load_data(os.path.join(input_files,'QTDB','validity.csv'))

# Load fiducials
Pon_QTDB            = sak.load_data(os.path.join(input_files,'QTDB','PonNew.csv'))
Poff_QTDB           = sak.load_data(os.path.join(input_files,'QTDB','PoffNew.csv'))
QRSon_QTDB          = sak.load_data(os.path.join(input_files,'QTDB','QRSonNew.csv'))
QRSoff_QTDB         = sak.load_data(os.path.join(input_files,'QTDB','QRSoffNew.csv'))
Ton_QTDB            = sak.load_data(os.path.join(input_files,'QTDB','TonNew.csv'))
Toff_QTDB           = sak.load_data(os.path.join(input_files,'QTDB','ToffNew.csv'))

# Generate masks & signals
signal_QTDB = {}
segmentation_QTDB = {}
for k in tqdm.tqdm(QRSon_QTDB):
    # Check file exists and all that
    if k not in validity:
        print("Issue with file {}, continuing...".format(k))
        continue

    # Store signal
    signal = dataset[k][validity[k][0]:validity[k][1]].values
    signal = sak.signal.on_off_correction(signal)
    amplitude = np.median(sak.signal.moving_lambda(signal,200,sak.signal.abs_max))
    signal = signal/amplitude
    signal_QTDB[k] = signal[None,]

    # Generate boolean mask
    segmentation = np.zeros((3,dataset.shape[0]),dtype=bool)
    if k in Pon_QTDB:
        for on,off in zip(Pon_QTDB[k],Poff_QTDB[k]):
            segmentation[0,on:off] = True
    if k in QRSon_QTDB:
        for on,off in zip(QRSon_QTDB[k],QRSoff_QTDB[k]):
            segmentation[1,on:off] = True
    if k in Ton_QTDB:
        for on,off in zip(Ton_QTDB[k],Toff_QTDB[k]):
            segmentation[2,on:off] = True

    segmentation_QTDB[k] = segmentation[:,validity[k][0]:validity[k][1]]


##### 4. Generate random splits #####
# 4.1. Split into train and test
all_keys_synthetic = {}
for k in list(P) + list(PQ) + list(QRS) + list(ST) + list(T) + list(TP):
    uid = k.split("###")[0].split("_")[0].split("-")[0]
    if uid not in all_keys_synthetic:
        all_keys_synthetic[uid] = [k]
    else:
        all_keys_synthetic[uid].append(k)

all_keys_real = {}
for k in list(signal_QTDB) + list(segmentation_QTDB):
    uid = k.split("###")[0].split("_")[0].split("-")[0]
    if uid not in all_keys_real:
        all_keys_real[uid] = [k]
    else:
        all_keys_real[uid].append(k)

# 4.2. Get database and file
filenames = []
database = []
for k in all_keys_synthetic:
    filenames.append(k)
    if k.startswith("SOO"):
        database.append(0)
    elif k.startswith("sel"):
        database.append(1)
    else:
        database.append(2)
filenames = np.array(filenames)
database = np.array(database)

# Set random seed for the execution and perform train/test splitting
random.seed(execution["seed"])
np.random.seed(execution["seed"])
torch.random.manual_seed(execution["seed"])
splitter = sklearn.model_selection.StratifiedKFold(5).split(filenames,database)
splits = list(splitter)
indices_train = [s[0] for s in splits]
indices_valid = [s[1] for s in splits]

##### 5. Train folds #####
# 5.1. Save model-generating files
target_path = execution["save_directory"] # Store original output path for future usage
original_length = execution["dataset"]["length"]

# 5.2. Save folds of valid files
all_folds_test = {"fold_{}".format(i+1): np.array(filenames)[ix_valid] for i,ix_valid in enumerate(indices_valid)}

# 5.3. Iterate over folds
for i,(ix_train,ix_valid) in enumerate(zip(indices_train,indices_valid)):
    print("################# FOLD {} #################".format(i+1))
    # Synthetic keys
    train_keys_synthetic, valid_keys_synthetic = ([],[])
    for k in np.array(filenames)[ix_train]: 
        train_keys_synthetic += all_keys_synthetic[k]
    for k in np.array(filenames)[ix_valid]: 
        valid_keys_synthetic += all_keys_synthetic[k]

    # Real keys
    train_keys_real, valid_keys_real = ([],[])
    for k in np.array(filenames)[ix_train]: 
        if k in all_keys_real: train_keys_real += all_keys_real[k]
    for k in np.array(filenames)[ix_valid]: 
        if k in all_keys_real: valid_keys_real += all_keys_real[k]

    # Avoid repetitions
    train_keys_synthetic = list(set(train_keys_synthetic))
    valid_keys_synthetic = list(set(valid_keys_synthetic))
    train_keys_real = list(set(train_keys_real))
    valid_keys_real = list(set(valid_keys_real))

    # ~~~~~~~~~~~~~~~~~~~~ Refine synthetic set ~~~~~~~~~~~~~~~~~~~~
    # Divide train/valid segments
    Ptrain   = {k:   P[k] for k in   P if k in train_keys_synthetic}
    PQtrain  = {k:  PQ[k] for k in  PQ if k in train_keys_synthetic}
    QRStrain = {k: QRS[k] for k in QRS if k in train_keys_synthetic}
    STtrain  = {k:  ST[k] for k in  ST if k in train_keys_synthetic}
    Ttrain   = {k:   T[k] for k in   T if k in train_keys_synthetic}
    TPtrain  = {k:  TP[k] for k in  TP if k in train_keys_synthetic}

    Pvalid   = {k:   P[k] for k in   P if k in valid_keys_synthetic}
    PQvalid  = {k:  PQ[k] for k in  PQ if k in valid_keys_synthetic}
    QRSvalid = {k: QRS[k] for k in QRS if k in valid_keys_synthetic}
    STvalid  = {k:  ST[k] for k in  ST if k in valid_keys_synthetic}
    Tvalid   = {k:   T[k] for k in   T if k in valid_keys_synthetic}
    TPvalid  = {k:  TP[k] for k in  TP if k in valid_keys_synthetic}

    # ~~~~~~~~~~~~~~~~~~~~~~ Refine real set ~~~~~~~~~~~~~~~~~~~~~~~
    signal_QTDB_train       = {k:       signal_QTDB[k] for k in       signal_QTDB if k in train_keys_real}
    signal_QTDB_valid       = {k:       signal_QTDB[k] for k in       signal_QTDB if k in valid_keys_real}
    segmentation_QTDB_train = {k: segmentation_QTDB[k] for k in segmentation_QTDB if k in train_keys_real}
    segmentation_QTDB_valid = {k: segmentation_QTDB[k] for k in segmentation_QTDB if k in valid_keys_real}

    # Define synthetic datasets
    dataset_train_synthetic = src.data.Dataset(Ptrain, QRStrain, Ttrain, PQtrain, STtrain, TPtrain, 
                                               Pdistribution, QRSdistribution, Tdistribution, PQdistribution, 
                                               STdistribution, TPdistribution, **execution["dataset"])
    execution["dataset"]["length"] = execution["dataset"]["length"]//4 # On synthetic data, not so useful to do intensive validation
    dataset_valid_synthetic = src.data.Dataset(Pvalid, QRSvalid, Tvalid, PQvalid, STvalid, TPvalid, 
                                               Pdistribution, QRSdistribution, Tdistribution, PQdistribution, 
                                               STdistribution, TPdistribution, **execution["dataset"])
    execution["dataset"]["length"] = original_length # On synthetic data, not so useful to do intensive validation

    # Define real datasets
    dataset_train_real = src.data.DatasetQTDB(signal_QTDB_train,segmentation_QTDB_train,execution["dataset"]["N"],128)
    dataset_valid_real = src.data.DatasetQTDB(signal_QTDB_valid,segmentation_QTDB_valid,execution["dataset"]["N"],128)

    # Define merging dataset
    dataset_train = sak.torch.data.UniformMultiDataset((dataset_train_synthetic,dataset_train_real),[10,1],[1,10],return_weights=True)
    sampler_train = sak.torch.data.UniformMultiSampler(dataset_train)
    dataset_valid = sak.torch.data.UniformMultiDataset((dataset_valid_synthetic,dataset_valid_real),[10,1],[1,10],return_weights=True)
    sampler_valid = sak.torch.data.UniformMultiSampler(dataset_valid)

    # Create dataloaders
    loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train, **execution["loader"])
    loader_valid = torch.utils.data.DataLoader(dataset_valid, sampler=sampler_valid, **execution["loader"])

    break

100%|██████████| 206/206 [00:00<00:00, 2520.50it/s]


Issue with file sel35_0, continuing...
Issue with file sel35_1, continuing...
################# FOLD 1 #################


In [5]:
for tmp in loader_train:
    break

In [15]:
import timm
import timm.models
import timm.models.layers
import timm.models.layers

In [21]:
import torch

In [None]:
torch.nn.LayerNorm

In [3]:
from torch import Tensor
from sak.__ops import required
from sak.__ops import check_required
from sak import class_selector

class ConvNeXtBlockNd(torch.nn.Module):
    def __init__(self, in_channels: int = required, layer_scale_init_value: float = 1e-6, 
                 drop_path: float = 0., dim: int = required, **kwargs: dict):
        super(ConvNeXtBlockNd, self).__init__()
        # Check required inputs
        check_required(self, {"in_channels":in_channels, "dim":dim})

        # Establish default inputs
        kwargs["groups"] = in_channels
        kwargs["kernel_size"] = kwargs.get("kernel_size",7)
        kwargs["padding"] = kwargs.get("padding", (kwargs["kernel_size"]-1)//2)
        activation = kwargs.pop("activation","torch.nn.GELU")
        initializer = kwargs.pop("initializer","timm.models.layers.trunc_normal_")

        # Declare operations
        if   dim == 1: 
            self.depthwise_conv = torch.nn.Conv1d(in_channels, in_channels, **kwargs)
            self.permute_in,self.permute_out = [0,2,1],     [0,2,1]
        elif dim == 2: 
            self.depthwise_conv = torch.nn.Conv2d(in_channels, in_channels, **kwargs)
            self.permute_in,self.permute_out = [0,2,3,1],   [0,3,1,2]
        elif dim == 3: 
            self.depthwise_conv = torch.nn.Conv3d(in_channels, in_channels, **kwargs)
            self.permute_in,self.permute_out = [0,2,3,4,1], [0,4,1,2,3]
        else: raise ValueError("Invalid number of dimensions: {}".format(dim))
        self.normalization    = torch.nn.LayerNorm(in_channels,eps=1e-6)
        self.pointwise_conv_1 = torch.nn.Linear(in_channels, 4*in_channels)
        self.activation       = class_selector(activation)()
        self.pointwise_conv_2 = torch.nn.Linear(4*in_channels, in_channels)
        self.gamma            = nn.Parameter(layer_scale_init_value * torch.ones((in_channels,)), 
                                             requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path        = timm.models.layers.DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
        # Initialize weights values
        initializer = class_selector(initializer)
        initializer(self.depthwise_conv.weight)
        initializer(self.pointwise_conv_1.weight)
        initializer(self.pointwise_conv_2.weight)

    def forward(self, x: Tensor) -> Tensor:
        x_prev = x
        x = self.depthwise_conv(x)
        x = x.permute(*self.permute_in) # (N, C, ...) -> (N, ..., C)
        x = self.normalization(x)
        x = self.pointwise_conv_1(x)
        x = self.activation(x)
        x = self.pointwise_conv_2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(*self.permute_out) # (N, ..., C) -> (N, C, ...)
        return x_prev + self.drop_path(x) # Residual

class ConvNeXtBlock1d(ConvNeXtBlockNd):
    def __init__(self, in_channels: int = required, layer_scale_init_value: float = 1e-6, drop_path: float = 0., **kwargs):
        super(ConvNeXtBlock1d, self).__init__(in_channels, layer_scale_init_value, drop_path, dim=1, **kwargs)

class ConvNeXtBlock2d(ConvNeXtBlockNd):
    def __init__(self, in_channels: int = required, layer_scale_init_value: float = 1e-6, drop_path: float = 0., **kwargs):
        super(ConvNeXtBlock2d, self).__init__(in_channels, layer_scale_init_value, drop_path, dim=2, **kwargs)

class ConvNeXtBlock3d(ConvNeXtBlockNd):
    def __init__(self, in_channels: int = required, layer_scale_init_value: float = 1e-6, drop_path: float = 0., **kwargs):
        super(ConvNeXtBlock3d, self).__init__(in_channels, layer_scale_init_value, drop_path, dim=3, **kwargs)



In [4]:
ConvNeXtBlock1d(32,)

ConvNeXtBlock1d(
  (depthwise_conv): Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(3,), groups=32)
  (normalization): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
  (pointwise_conv_1): Linear(in_features=32, out_features=128, bias=True)
  (activation): GELU()
  (pointwise_conv_2): Linear(in_features=128, out_features=32, bias=True)
  (drop_path): Identity()
)

In [14]:
from typing import Union, List

import torch
from torch import Tensor
from torch import exp
from torch import ones
from torch import ones_like
from torch import log
from torch.nn import Module
from torch.nn import Parameter
from torch.nn import Identity
from torch.nn import BatchNorm1d
from torch.nn import BatchNorm2d
from torch.nn import BatchNorm3d
from torch.nn import Dropout2d
from torch.nn import Dropout3d
from torch.nn import Conv1d
from torch.nn import Conv2d
from torch.nn import Conv3d
from torch.nn import ConvTranspose1d
from torch.nn import ConvTranspose2d
from torch.nn import ConvTranspose3d
from torch.nn import AdaptiveAvgPool1d
from torch.nn import AdaptiveAvgPool2d
from torch.nn import AdaptiveAvgPool3d
from torch.nn import Linear
from torch.nn import Sigmoid
from torch.nn import ReLU
from torch.nn import LayerNorm

from torch.nn import init

from timm.models.layers import DropPath

from torch.nn.functional import interpolate
from sak.torch.nn.modules.utils import Concatenate
from sak import class_selector
from sak import from_dict
from sak.__ops import required
from sak.__ops import check_required
from sak.torch.nn import DepthwiseConv1d,DepthwiseConv2d,DepthwiseConv3d,PointwiseConv1d,PointwiseConv2d,PointwiseConv3d

class ConvNeXtBlockNd(Module):
    """Partially based on: 
    * https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py
    * https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py
    
    """
    
    def __init__(self, in_channels: int = required, hidden_channels: int = None, gamma_init: float = 1e-6, 
                 drop_proba: float = 0., dim: int = required, **kwargs: dict):
        super(ConvNeXtBlockNd, self).__init__()
        # Check required inputs
        check_required(self, {"in_channels":in_channels, "dim":dim})
        hidden_channels = hidden_channels or 4*in_channels

        # Establish default inputs
        kwargs["groups"] = in_channels
        kwargs["kernel_size"] = kwargs.get("kernel_size",7)
        kwargs["padding"] = kwargs.get("padding", (kwargs["kernel_size"]-1)//2)
        kwargs["initializer"] = kwargs.get("initializer", "timm.models.layers.trunc_normal_")
        activation = kwargs.pop("activation",{"class": "torch.nn.GELU", "arguments": {}})

        # Declare operations
        if   dim == 1: depth_op,norm_op,point_op = DepthwiseConv1d,LayerNorm1d,PointwiseConv1d
        elif dim == 2: depth_op,norm_op,point_op = DepthwiseConv2d,LayerNorm2d,PointwiseConv2d
        elif dim == 3: depth_op,norm_op,point_op = DepthwiseConv3d,LayerNorm3d,PointwiseConv3d
        else: raise ValueError("Invalid number of dimensions: {}".format(dim))
        self.depthwise_conv   = depth_op(in_channels, **kwargs)
        self.normalization    = norm_op(in_channels)
        self.pointwise_conv_1 = point_op(in_channels, hidden_channels, **kwargs)
        self.activation       = from_dict(activation)
        self.pointwise_conv_2 = point_op(hidden_channels, in_channels, **kwargs)
        self.gamma            = Parameter(gamma_init * ones((in_channels,)), requires_grad=True).reshape(1, -1, *[1]*dim) if gamma_init > 0 else None
        self.drop_path        = DropPath(drop_proba) if drop_proba > 0. else lambda x: x

    def forward(self, x: Tensor) -> Tensor:
        residual = x
        x = self.depthwise_conv(x)
        x = self.normalization(x)
        x = self.pointwise_conv_1(x)
        x = self.activation(x)
        x = self.pointwise_conv_2(x)
        if self.gamma is not None:
            x = self.gamma * x
        return self.drop_path(x) + residual

class ConvNeXtBlock1d(ConvNeXtBlockNd):
    def __init__(self, in_channels: int = required, hidden_channels: int = None, gamma_init: float = 1e-6, drop_path: float = 0., **kwargs):
        super(ConvNeXtBlock1d, self).__init__(in_channels, hidden_channels, gamma_init, drop_path, dim=1, **kwargs)

class ConvNeXtBlock2d(ConvNeXtBlockNd):
    def __init__(self, in_channels: int = required, hidden_channels: int = None, gamma_init: float = 1e-6, drop_path: float = 0., **kwargs):
        super(ConvNeXtBlock2d, self).__init__(in_channels, hidden_channels, gamma_init, drop_path, dim=2, **kwargs)

class ConvNeXtBlock3d(ConvNeXtBlockNd):
    def __init__(self, in_channels: int = required, hidden_channels: int = None, gamma_init: float = 1e-6, drop_path: float = 0., **kwargs):
        super(ConvNeXtBlock3d, self).__init__(in_channels, hidden_channels, gamma_init, drop_path, dim=3, **kwargs)

        
class LayerNormNd(torch.nn.LayerNorm):
    r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
    """

    def __init__(self, normalized_shape: Union[int, List[int], torch.Size], dim: int = required, **kwargs):
        kwargs["eps"] = kwargs.get("eps",1e-6) # Fix epsilon
        super().__init__(normalized_shape, **kwargs)
        if   dim == 1: self.permute_in,self.permute_out = [0,2,1],     [0,2,1]
        elif dim == 2: self.permute_in,self.permute_out = [0,2,3,1],   [0,3,1,2]
        elif dim == 3: self.permute_in,self.permute_out = [0,2,3,4,1], [0,4,1,2,3]
        else: raise ValueError("Invalid number of dimensions: {}".format(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.is_contiguous():
            return F.layer_norm(
                x.permute(*self.permute_in), self.normalized_shape, self.weight, self.bias, self.eps).permute(*self.permute_in)
        else:
            s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
            x = (x - u) * torch.rsqrt(s + self.eps)
            x = x * self.weight[:, None, None] + self.bias[:, None, None]
            return x
        

class LayerNorm1d(LayerNormNd):
    def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs):
        super().__init__(normalized_shape, dim=1, **kwargs)

class LayerNorm2d(LayerNormNd):
    def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs):
        super().__init__(normalized_shape, dim=2, **kwargs)

class LayerNorm3d(LayerNormNd):
    def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs):
        super().__init__(normalized_shape, dim=3, **kwargs)


In [15]:
hidden_channels = None
in_channels = 32

ConvNeXtBlock1d(32,initializer="timm.models.layers.trunc_normal_")

ConvNeXtBlock1d(
  (depthwise_conv): DepthwiseConv1d(
    (depthwise_conv): Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(3,), groups=32)
  )
  (normalization): LayerNorm1d((32,), eps=1e-06, elementwise_affine=True)
  (pointwise_conv_1): PointwiseConv1d(
    (pointwise_conv): Conv1d(32, 128, kernel_size=(1,), stride=(1,))
  )
  (activation): GELU()
  (pointwise_conv_2): PointwiseConv1d(
    (pointwise_conv): Conv1d(128, 32, kernel_size=(1,), stride=(1,))
  )
)

In [42]:
from typing import Union, List

class LayerNormNd(torch.nn.LayerNorm):
    r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
    """

    def __init__(self, normalized_shape: Union[int, List[int], torch.Size], dim: int = required, **kwargs):
        kwargs["eps"] = kwargs.get("eps",1e-6) # Fix epsilon
        super().__init__(normalized_shape, **kwargs)
        if   dim == 1: self.permute_in,self.permute_out = [0,2,1],     [0,2,1]
        elif dim == 2: self.permute_in,self.permute_out = [0,2,3,1],   [0,3,1,2]
        elif dim == 3: self.permute_in,self.permute_out = [0,2,3,4,1], [0,4,1,2,3]
        else: raise ValueError("Invalid number of dimensions: {}".format(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.is_contiguous():
            return F.layer_norm(
                x.permute(*self.permute_in), self.normalized_shape, self.weight, self.bias, self.eps).permute(*self.permute_in)
        else:
            s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
            x = (x - u) * torch.rsqrt(s + self.eps)
            x = x * self.weight[:, None, None] + self.bias[:, None, None]
            return x
        

class LayerNorm1d(LayerNormNd):
    def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs):
        super().__init__(normalized_shape, dim=1, **kwargs)

class LayerNorm2d(LayerNormNd):
    def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs):
        super().__init__(normalized_shape, dim=2, **kwargs)

class LayerNorm3d(LayerNormNd):
    def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs):
        super().__init__(normalized_shape, dim=3, **kwargs)


In [43]:
aaa = LayerNorm1d(32)

In [17]:
x = torch.rand(1,3,4)

In [19]:
x.is_contiguous()

True

In [48]:
class ConvNeXtBlockNd(torch.nn.Module):
    r"""https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
    
    ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, channels, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(channels, channels, kernel_size=7, padding=3, groups=channels) # depthwise conv
        self.norm = torch.nn.LayerNorm(channels, eps=1e-6)
        self.pwconv1 = nn.Linear(channels, 4 * channels) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * channels, channels)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((channels)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = timm.models.layers.DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
            print(f"x = self.gamma * x: \n\t{x.shape}\n")
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x
    



In [None]:
class ConvNeXtBlockNd(torch.nn.Module):
    r"""https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
    
    ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, channels, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(channels, channels, kernel_size=7, padding=3, groups=channels) # depthwise conv
        self.norm = torch.nn.LayerNorm(channels, eps=1e-6)
        self.pwconv1 = nn.Linear(channels, 4 * channels) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * channels, channels)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((channels)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = timm.models.layers.DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
            print(f"x = self.gamma * x: \n\t{x.shape}\n")
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x


In [151]:
"""https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
from timm.models.registry import register_model

class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x

class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf
    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """
    def __init__(self, in_chans=3, num_classes=1000, 
                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 
                 layer_scale_init_value=1e-6, head_init_scale=1.,
                 ):
        super().__init__()

        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                    LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                    nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 
                layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
        self.head = nn.Linear(dims[-1], num_classes)

        self.apply(self._init_weights)
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


model_urls = {
    "convnext_tiny_1k":    "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
    "convnext_small_1k":   "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
    "convnext_base_1k":    "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
    "convnext_large_1k":   "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
    "convnext_tiny_22k":   "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
    "convnext_small_22k":  "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
    "convnext_base_22k":   "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
    "convnext_large_22k":  "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
    "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
}

@register_model
def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
    model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
    if pretrained:
        url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"])
    return model

@register_model
def convnext_small(pretrained=False,in_22k=False, **kwargs):
    model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
    if pretrained:
        url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
    return model

@register_model
def convnext_base(pretrained=False, in_22k=False, **kwargs):
    model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
    if pretrained:
        url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
    return model

@register_model
def convnext_large(pretrained=False, in_22k=False, **kwargs):
    model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
    if pretrained:
        url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
    return model

@register_model
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
    model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
    if pretrained:
        assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
        url = model_urls['convnext_xlarge_22k']
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
    return model

In [160]:
model = ConvNeXt()
model

ConvNeXt(
  (downsample_layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm()
    )
    (1): Sequential(
      (0): LayerNorm()
      (1): Conv2d(96, 192, kernel_size=(2, 2), stride=(2, 2))
    )
    (2): Sequential(
      (0): LayerNorm()
      (1): Conv2d(192, 384, kernel_size=(2, 2), stride=(2, 2))
    )
    (3): Sequential(
      (0): LayerNorm()
      (1): Conv2d(384, 768, kernel_size=(2, 2), stride=(2, 2))
    )
  )
  (stages): ModuleList(
    (0): Sequential(
      (0): Block(
        (dwconv): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
        (norm): LayerNorm()
        (pwconv1): Linear(in_features=96, out_features=384, bias=True)
        (act): GELU()
        (pwconv2): Linear(in_features=384, out_features=96, bias=True)
        (drop_path): Identity()
      )
      (1): Block(
        (dwconv): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
