In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
"""
Real Phase Encoder with Mag with Power Law Compression
Inspired by Shetu, Shrishti Saha, et al. "Ultra Low Complexity Deep Learning Based Noise Suppression." arXiv preprint arXiv:2312.08132 (2023).
"""
# NOTE : torch.sign is not differentiable
# https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
class GradSign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        x = torch.sign(x)
        #ctx.save_for_backward(x)
        #return torch.tanh(x / epsilon)    
        return x

    @staticmethod
    def backward(ctx, grad_output):
        #x, = ctx.saved_tensors
        grad_input = grad_output
        return grad_input

# NOTE : torch.abs is not differentiable
# https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
class GradAbs(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        x = torch.abs(x)
        #ctx.save_for_backward(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        #x, = ctx.saved_tensors
        #grad_input = grad_output * GradSign(x) 
        grad_input = grad_output
        return grad_input

class PowerLawCompression(nn.Module):
    def __init__(self, alpha=0.3, **kwargs):
        super(PowerLawCompression, self).__init__()
        self.alpha = alpha

    def forward(self, X):
        """
        X.shape == (B,F,T,2)
        """
        # Power Law Compression
        # NOTE : PLC on compelx 
        print(GradSign.apply(X[:,:,:,:]))
        X[:,:,:,:] = GradSign.apply(X[:,:,:,:]) * torch.pow(GradAbs.apply(X[:,:,:,:]), self.alpha)
        return X

class PowerLawDecompression(nn.Module):
    def __init__(self, alpha=0.3, **kwargs):
        super(PowerLawDecompression, self).__init__()
        self.alpha = alpha
        #self.eps = 1e-7
        self.eps = 0

    def forward(self, X):
        """
        X.shape == (B,F,T,2)
        """
        X[:,:,:,:] = GradSign.apply(X[:,:,:,:]) * torch.pow(GradAbs.apply(X[:,:,:,:]), 1/self.alpha)

        return X


In [17]:
X = (torch.rand(1,2,3,2)-0.5)*10
X[0,0,0,0] = 0
print(X)



m1 = PowerLawCompression()
m2 = PowerLawDecompression()

Y = m1(X)
print(Y)

Z = m2(Y)
print(Z)

tensor([[[[ 0.0000,  2.9881],
          [-2.5757,  4.0735],
          [-3.9225, -3.0302]],

         [[ 0.5688, -0.3853],
          [ 0.7594,  2.4158],
          [-4.9838, -4.0396]]]])
tensor([[[[ 0.,  1.],
          [-1.,  1.],
          [-1., -1.]],

         [[ 1., -1.],
          [ 1.,  1.],
          [-1., -1.]]]])
tensor([[[[ 0.0000,  1.3887],
          [-1.3282,  1.5240],
          [-1.5068, -1.3946]],

         [[ 0.8443, -0.7512],
          [ 0.9207,  1.3029],
          [-1.6191, -1.5202]]]])
tensor([[[[ 0.0000,  2.9881],
          [-2.5757,  4.0735],
          [-3.9225, -3.0302]],

         [[ 0.5688, -0.3853],
          [ 0.7594,  2.4158],
          [-4.9838, -4.0396]]]])
