## RevIN연산 사용법
- Instance Normalization

In [47]:
import torch
import torch.nn as nn
from revin.revin_torch import RevIN

import pdb

In [25]:
x = torch.reshape(torch.arange(0, 24), shape=(4, 3, 2))/24
x.shape

torch.Size([4, 3, 2])

In [51]:
class MyRevIN(nn.Module):
    def __init__(self, num_features: int, eps=1e-5, affine=True):
        """Reversible Instance Normalization for Accurate Time-Series Forecasting
           against Distribution Shift, ICLR2021.

    Parameters
    ----------
    num_features: int, the number of features or channels.
    eps: float, a value added for numerical stability, default 1e-5.
    affine: bool, if True(default), RevIN has learnable affine parameters.
        """
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        if self.affine:
            self._init_params()

    def forward(self, x, mode:str):
        if mode == 'norm':
            self._get_statistics(x)
            x = self._normalize(x)
        elif mode == 'denorm':
            x = self._denormalize(x)
        else:
            raise NotImplementedError('Only modes norm and denorm are supported.')
        return x

    def _init_params(self):
        self.affine_weight = nn.Parameter(torch.ones(self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))

    def _get_statistics(self, x):
        # pdb.set_trace()
        dim2reduce = tuple(range(1, x.ndim - 1))
        # detach()는 mean, stdev는 학습할 인자가 아니라는 말(연산그래프에서 제외, Freeze)
        self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
        print('Mean: ',self.mean)
        print('Std: ', self.stdev)
        print('Alpha: ', self.affine_weight)
        print('beta: ', self.affine_bias)
        print('\n')

    def _normalize(self, x):
        x = x - self.mean
        x = x / self.stdev
        if self.affine:
            x = x * self.affine_weight
            x = x + self.affine_bias
        return x

    def _denormalize(self, x):
        if self.affine:
            x = x - self.affine_bias
            x = x / (self.affine_weight + self.eps*self.eps)
        x = x * self.stdev
        x = x + self.mean
        return x

### Instance Normalization 테스트

In [52]:
# bs=2, 길이=3, 특징수=2라 가정하면,
# 이 때, instance Norm은 각 샘플별, 각 특징별로 길이방향의 값을 모아서
# 평균과 분산을 구해 정규화함
# (예) 첫 샘플은 [[0.0000, 0.0833], [0.1667, 0.2500], [0.3333, 0.4167]]인데, 
#      첫번째 특징에 대해 길이 방향으로 값 3개를 모으면,
#      [0.0, 0.1667, 0.333]이고, 평균이 0.1667, std가 0.1361임.
#      따라서 (0.0-0.1667)/0.1361=-1.2244, (0.1667-0.1667)/0.1361=0, (0.3333-0.1667)/0.1361=1.2240
#      여기에 나중에 학습할 스케일 alpha, shift beta값을 정의해 놓음
x = torch.reshape(torch.arange(0, 12), shape=(2, 3, 2))/12


layer = MyRevIN(2)
y = layer(x, mode='norm')
z = layer(y, mode='denorm')

print(x)
print(y)
print(z)

Mean:  tensor([[[0.1667, 0.2500]],

        [[0.6667, 0.7500]]])
Std:  tensor([[[0.1361, 0.1361]],

        [[0.1361, 0.1361]]])
Alpha:  Parameter containing:
tensor([1., 1.], requires_grad=True)
beta:  Parameter containing:
tensor([0., 0.], requires_grad=True)


tensor([[[0.0000, 0.0833],
         [0.1667, 0.2500],
         [0.3333, 0.4167]],

        [[0.5000, 0.5833],
         [0.6667, 0.7500],
         [0.8333, 0.9167]]])
tensor([[[-1.2244, -1.2244],
         [ 0.0000,  0.0000],
         [ 1.2244,  1.2244]],

        [[-1.2244, -1.2244],
         [ 0.0000,  0.0000],
         [ 1.2244,  1.2244]]], grad_fn=<AddBackward0>)
tensor([[[0.0000, 0.0833],
         [0.1667, 0.2500],
         [0.3333, 0.4167]],

        [[0.5000, 0.5833],
         [0.6667, 0.7500],
         [0.8333, 0.9167]]], grad_fn=<AddBackward0>)


In [28]:
revinlayer = RevIN(2)
revinlayer

RevIN()

### Net에 넣어서 사용할 때

In [31]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.revinlayer = RevIN(num_features=2)
        self.conv1d = nn.Conv1d(in_channels=2, out_channels=2, kernel_size=1)

    def forward(self, x):
        x = self.revinlayer(x, mode='norm')
        #pdb.set_trace()
        x = self.conv1d(x)
        x = nn.ReLU(x)
        x = self.revinlayer(x, mode='denorm')
        return x
    
    
model = Net()

In [33]:
#res = model(x)
#res.shape

In [None]:
x.shape