for LocbalBins experiment
24-3-11

In [21]:
import torch
import torch.nn as nn

In [22]:
def log_binom(n, k, eps=1e-7):
    """ log(nCk) using stirling approximation """
    n = n + eps
    k = k + eps
    return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps)


class LogBinomial(nn.Module):
    def __init__(self, n_classes=256, act=torch.softmax):
        """Compute log binomial distribution for n_classes

        Args:
            n_classes (int, optional): number of output classes. Defaults to 256.
        """
        super().__init__()
        self.K = n_classes
        self.act = act
        self.register_buffer('k_idx', torch.arange(
            0, n_classes).view(1, -1, 1, 1))
        self.register_buffer('K_minus_1', torch.Tensor(
            [self.K-1]).view(1, -1, 1, 1))

    def forward(self, x, t=1., eps=1e-4):
        """Compute log binomial distribution for x

        Args:
            x (torch.Tensor - NCHW): probabilities
            t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1..
            eps (float, optional): Small number for numerical stability. Defaults to 1e-4.

        Returns:
            torch.Tensor -NCHW: log binomial distribution logbinomial(p;t)
        """
        if x.ndim == 3:
            x = x.unsqueeze(1)  # make it nchw

        one_minus_x = torch.clamp(1 - x, eps, 1)
        x = torch.clamp(x, eps, 1)
        y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \
            torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x)
        return self.act(y/t, dim=1)


class ConditionalLogBinomial(nn.Module):
    def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax):
        """Conditional Log Binomial distribution

        Args:
            in_features (int): number of input channels in main feature
            condition_dim (int): number of input channels in condition feature
            n_classes (int, optional): Number of classes. Defaults to 256.
            bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2.
            p_eps (float, optional): small eps value. Defaults to 1e-4.
            max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50.
            min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7.
        """
        super().__init__()
        self.p_eps = p_eps
        self.max_temp = max_temp
        self.min_temp = min_temp
        self.log_binomial_transform = LogBinomial(n_classes, act=act)
        bottleneck = (in_features + condition_dim) // bottleneck_factor
        self.mlp = nn.Sequential(
            nn.Conv2d(in_features + condition_dim, bottleneck,
                      kernel_size=1, stride=1, padding=0),
            nn.GELU(),
            # 2 for p linear norm, 2 for t linear norm
            nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0),
            nn.Softplus()
        )

    def forward(self, x, cond):
        """Forward pass

        Args:
            x (torch.Tensor - NCHW): Main feature
            cond (torch.Tensor - NCHW): condition feature

        Returns:
            torch.Tensor: Output log binomial distribution
        """
        pt = self.mlp(torch.concat((x, cond), dim=1))
        p, t = pt[:, :2, ...], pt[:, 2:, ...]

        p = p + self.p_eps
        p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...])

        t = t + self.p_eps
        t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...])
        t = t.unsqueeze(1)
        t = (self.max_temp - self.min_temp) * t + self.min_temp

        return self.log_binomial_transform(p, t)

In [23]:

@torch.jit.script
def exp_attractor(dx, alpha: float = 300, gamma: int = 2):
    """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor

    Args:
        dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
        alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
        gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.

    Returns:
        torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc
    """
    return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx)


@torch.jit.script
def inv_attractor(dx, alpha: float = 300, gamma: int = 2):
    """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center
    This is the default one according to the accompanying paper. 

    Args:
        dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
        alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
        gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.

    Returns:
        torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc
    """
    return dx.div(1+alpha*dx.pow(gamma))

In [24]:
class SeedBinRegressor(nn.Module):
    def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
        """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval.

        Args:
            in_features (int): input channels
            n_bins (int, optional): Number of bin centers. Defaults to 16.
            mlp_dim (int, optional): Hidden dimension. Defaults to 256.
            min_depth (float, optional): Min depth value. Defaults to 1e-3.
            max_depth (float, optional): Max depth value. Defaults to 10.
        """
        super().__init__()
        self.version = "1_1"
        self.min_depth = min_depth
        self.max_depth = max_depth

        self._net = nn.Sequential(
            nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        """
        Returns tensor of bin_width vectors (centers). One vector b for every pixel
        """
        B = self._net(x)
        eps = 1e-3
        B = B + eps
        B_widths_normed = B / B.sum(dim=1, keepdim=True)
        B_widths = (self.max_depth - self.min_depth) * \
            B_widths_normed  # .shape NCHW
        # pad has the form (left, right, top, bottom, front, back)
        B_widths = nn.functional.pad(
            B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth)
        B_edges = torch.cumsum(B_widths, dim=1)  # .shape NCHW

        B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...])
        return B_widths_normed, B_centers

In [25]:
class SeedBinRegressorUnnormed(nn.Module):
    def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
        """Bin center regressor network. Bin centers are unbounded

        Args:
            in_features (int): input channels
            n_bins (int, optional): Number of bin centers. Defaults to 16.
            mlp_dim (int, optional): Hidden dimension. Defaults to 256.
            min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
            max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
        """
        super().__init__()
        self.version = "1_1"
        self._net = nn.Sequential(
            nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
            nn.Softplus()
        )

    def forward(self, x):
        """
        Returns tensor of bin_width vectors (centers). One vector b for every pixel
        """
        B_centers = self._net(x)
        return B_centers, B_centers


In [26]:
class Projector(nn.Module):
    def __init__(self, in_features, out_features, mlp_dim=128):
        """Projector MLP

        Args:
            in_features (int): input channels
            out_features (int): output channels
            mlp_dim (int, optional): hidden dimension. Defaults to 128.
        """
        super().__init__()

        self._net = nn.Sequential(
            nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(mlp_dim, out_features, 1, 1, 0),
        )

    def forward(self, x):
        return self._net(x)

In [27]:
from pprojector import XcsProjector0 as XProjector
from pprojector import PositionWiseFeedForward as PFFProjector

In [28]:
class AttractorLayerUnnormed(nn.Module):
    def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
                 alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
        """
        Attractor layer for bin centers. Bin centers are unbounded
        """
        super().__init__()

        self.n_attractors = n_attractors
        self.n_bins = n_bins
        self.min_depth = min_depth
        self.max_depth = max_depth
        self.alpha = alpha
        self.gamma = gamma
        self.kind = kind
        self.attractor_type = attractor_type
        self.memory_efficient = memory_efficient

        self._net = nn.Sequential(
            nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0),
            nn.Softplus()
        )

    def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
        """
        Args:
            x (torch.Tensor) : feature block; shape - n, c, h, w
            b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
        
        Returns:
            tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version
        """
        if prev_b_embedding is not None:
            if interpolate:
                prev_b_embedding = nn.functional.interpolate(
                    prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
            x = x + prev_b_embedding

        A = self._net(x)
        n, c, h, w = A.shape

        b_prev = nn.functional.interpolate(
            b_prev, (h, w), mode='bilinear', align_corners=True)
        b_centers = b_prev

        if self.attractor_type == 'exp':
            dist = exp_attractor
        else:
            dist = inv_attractor

        if not self.memory_efficient:   # False
            func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
            # .shape N, nbins, h, w
            delta_c = func(
                dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1)
        else:
            delta_c = torch.zeros_like(b_centers, device=b_centers.device)
            for i in range(self.n_attractors):
                delta_c += dist(A[:, i, ...].unsqueeze(1) -
                                b_centers)  # .shape N, nbins, h, w

            if self.kind == 'mean':   #True
                delta_c = delta_c / self.n_attractors

        b_new_centers = b_centers + delta_c
        B_centers = b_new_centers

        return b_new_centers, B_centers

In [29]:
# 涂涂改改
class xcsAttractorLayerUnnormed(nn.Module):
    def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
                 alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
        """
        Attractor layer for bin centers. Bin centers are unbounded
        """
        super().__init__()

        self.n_attractors = n_attractors
        self.n_bins = n_bins
        self.min_depth = min_depth
        self.max_depth = max_depth
        self.alpha = alpha
        self.gamma = gamma
        self.kind = kind
        self.attractor_type = attractor_type
        self.memory_efficient = memory_efficient

        self._net = nn.Sequential(
            nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0),
            nn.Softplus()
        )

    def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
        """
        Args:
            x (torch.Tensor) : feature block; shape - n, c, h, w
            b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
        
        Returns:
            tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version
        """
        if prev_b_embedding is not None:
            if interpolate:
                prev_b_embedding = nn.functional.interpolate(
                    prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
            x = x + prev_b_embedding

        A = self._net(x)
        n, c, h, w = A.shape
        # 先做个投影罢

        b_prev = nn.functional.interpolate(
            b_prev, (h, w), mode='bilinear', align_corners=True)
        b_centers = b_prev  # 合着b centers 就是 b ？
        # 上次的b 做一个更新，先更新到当前层b的尺寸

        if self.attractor_type == 'exp':
            dist = exp_attractor
        else:
            dist = inv_attractor
            # inv_attractor : return dx.div(1+alpha*dx.pow(gamma))

        if not self.memory_efficient:   # False
            func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
            # .shape N, nbins, h, w
            delta_c = func(
                dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1)
        else:
            delta_c = torch.zeros_like(b_centers, device=b_centers.device)
            for i in range(self.n_attractors):
                delta_c += dist(A[:, i, ...].unsqueeze(1) -
                                b_centers)  # .shape N, nbins, h, w

            if self.kind == 'mean':   #True
                delta_c = delta_c / self.n_attractors

        b_new_centers = b_centers + delta_c
        B_centers = b_new_centers

        return b_new_centers, B_centers

In [30]:
''' 
Arguments and 
'''

core_output_channels = (512, 256, 128, 64, 64)
btlnck_features = 512
num_out_features = (256, 128, 64, 64)

bin_embedding_dim=128

n_bins = 64
min_depth = 1e-3
max_depth=10

attractor_alpha=300
attractor_gamma=2
attractor_kind='sum'
n_attractors=[16, 8, 4, 1]
attractor_type='exp'
N_MIDAS_OUT = 32

min_temp=5
max_temp=50

In [31]:
out = [
        torch.rand((1, 32, 384, 384),dtype=torch.float),    #outconv_activation
        torch.rand((1, 512, 12, 12),dtype=torch.float),     # btlnck 
        torch.rand((1, 256, 24, 24),dtype=torch.float),     # xbocks[0]
        torch.rand((1, 128, 48, 48),dtype=torch.float),     # xbocks[1]
        torch.rand((1, 64, 96, 96),dtype=torch.float),      # xbocks[2]
        torch.rand((1, 64, 192, 192),dtype=torch.float),    # xbocks[3]
       
    ]

In [32]:
'''
dummy core out put
'''

# outconv_ac:   torch.Size([1, 32, 384, 384])
# btlnck:       torch.Size([1, 512, 12, 12])
# x_blocks:     


outconv_activation = torch.rand((1, 32, 384, 384),dtype=torch.float)
btlnck  = torch.rand((1, 512, 12, 12),dtype=torch.float)
x_blocks = [
    torch.rand((1, 256, 24, 24),dtype=torch.float),
    torch.rand((1, 128, 48, 48),dtype=torch.float),
    torch.rand((1, 64, 96, 96),dtype=torch.float),
    torch.rand((1, 64, 192, 192),dtype=torch.float),
]
rel_depth = torch.rand((1,384,384),dtype=torch.float)

bin_centers_type = "softplus"



In [33]:
conv2 = nn.Conv2d(btlnck_features, btlnck_features,
                    kernel_size=1, stride=1, padding=0)  # btlnck conv means bottle neck conv

In [37]:
from pprojector import XcsSeedBinRegressorUnnormed
from pprojector import PositionWiseFeedForward1 as PFFProjector1
from pprojector import XcsAttractorLayerUnnormed

ImportError: cannot import name 'XcsAttractorLayerUnnormed' from 'pprojector' (d:\_2_workspace\day20icem\notes\pprojector.py)

In [36]:
SeedBinRegressorLayer = SeedBinRegressorUnnormed
Attractor = AttractorLayerUnnormed

seed_bin_regressor = SeedBinRegressorLayer(
            btlnck_features, 
            n_bins=n_bins, 
            min_depth=min_depth, 
            max_depth=max_depth)
        

seed_projector = Projector(btlnck_features, bin_embedding_dim)

projectors = nn.ModuleList([
            Projector(num_out, bin_embedding_dim)
            # PFFProjector1(num_out, bin_embedding_dim)

            # XProjector(num_out, bin_embedding_dim)
            # PFFProjector(num_out, bin_embedding_dim)


            for num_out in num_out_features
        ])

attractors = nn.ModuleList([
            Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth,
                      alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type)
            for i in range(len(num_out_features))
        ])
last_in = N_MIDAS_OUT + 1  # +1 for relative depth

conditional_log_binomial = ConditionalLogBinomial(
            last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp)


In [None]:
'''  
    rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True)


    outconv_activation = out[0] 
    btlnck = out[1]   #btlnck.shape = (1, 512, 12, 12)     
    x_blocks = out[2:]

    out.shape = [
        (1, 32, 384, 384)
        (1, 512, 12, 12)
        (1, 256, 24, 24)
        (1, 128, 48, 48)
        (1, 64, 96, 96)
        (1, 64, 192, 192) 
    ]

    rel_depth.shape = (1, 384, 384)
'''


outconv_activation = torch.rand((1, 32, 384, 384),dtype=torch.float)
btlnck  = torch.rand((1, 512, 12, 12),dtype=torch.float)
x_d0 = torch.rand((1, 512, 12, 12),dtype=torch.float)

x_blocks = [
    torch.rand((1, 256, 24, 24),dtype=torch.float),
    torch.rand((1, 128, 48, 48),dtype=torch.float),
    torch.rand((1, 64, 96, 96),dtype=torch.float),
    torch.rand((1, 64, 192, 192),dtype=torch.float),
]
rel_depth = torch.rand((1, 384, 384),dtype=torch.float)

bin_centers_type = "softplus"


'''
Dummy Forward:

'''
# SETTING
print('开始报告！！')

return_final_centers=False
denorm=False
return_probs=False

print(f'报告!! btlnck的shape是\t\t{btlnck.shape}')
x_d0 = conv2(btlnck)
print(f'报告!! x_d0的shape是\t\t{x_d0.shape}')

x = x_d0

_, seed_b_centers = seed_bin_regressor(x)
print(f'报告!! seed_b_centers的shape是\t\t{seed_b_centers.shape}')

if bin_centers_type == 'normed' or bin_centers_type == 'hybrid2':
    b_prev = (seed_b_centers - min_depth) / \
            (max_depth - min_depth)
else:
    b_prev = seed_b_centers
print(f'报告!! b_prev的shape是\t\t{b_prev.shape}')

prev_b_embedding = seed_projector(x)
print(f'报告!! prev_b_embedding的shape是\t\t{prev_b_embedding.shape}')

print('---------------开转!!!---------------')
for projector, attractor, x in zip(projectors, attractors, x_blocks):
    print('----------------------------------')
    print(f'报告!! x的shape是\t\t{x.shape}')
    b_embedding = projector(x)
    print(f'报告!! b_embedding的shape是\t{b_embedding.shape}')
    print(f'报告!! b_prev的shape是\t\t{b_prev.shape}')
    print(f'报告!! prev_b_embedding的shape是\t{prev_b_embedding.shape}')
    # print('----------')
    b, b_centers = attractor(
        b_embedding, b_prev, prev_b_embedding, interpolate=True)
    b_prev = b.clone()
    prev_b_embedding = b_embedding.clone()
    print(f'报告!! b的shape是\t\t{b.shape}')


last = outconv_activation

# last.shape = [1, 32, 384, 384]
rel_cond = rel_depth.unsqueeze(1)
# rel_cond.shape = [1, 1, 384, 384]

# last.shape[2:] = [384, 384]
rel_cond = nn.functional.interpolate(
            rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True)
last = torch.cat([last, rel_cond], dim=1)

b_embedding = nn.functional.interpolate(
            b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
x = conditional_log_binomial(last, b_embedding)

b_centers = nn.functional.interpolate(
        b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
out = torch.sum(x * b_centers, dim=1, keepdim=True)
output = dict(metric_depth=out)
if return_final_centers or return_probs:
    output['bin_centers'] = b_centers

if return_probs:
    output['probs'] = x



开始报告！！
报告!! btlnck的shape是		torch.Size([1, 512, 12, 12])
报告!! x_d0的shape是		torch.Size([1, 512, 12, 12])
报告!! seed_b_centers的shape是		torch.Size([1, 64, 12, 12])
报告!! b_prev的shape是		torch.Size([1, 64, 12, 12])
报告!! prev_b_embedding的shape是		torch.Size([1, 128, 12, 12])
---------------开转!!!---------------
----------------------------------
报告!! x的shape是		torch.Size([1, 256, 24, 24])
报告!! b_embedding的shape是	torch.Size([1, 128, 24, 24])
报告!! b_prev的shape是		torch.Size([1, 64, 12, 12])
报告!! prev_b_embedding的shape是	torch.Size([1, 128, 12, 12])
报告!! b的shape是		torch.Size([1, 64, 24, 24])
----------------------------------
报告!! x的shape是		torch.Size([1, 128, 48, 48])
报告!! b_embedding的shape是	torch.Size([1, 128, 48, 48])
报告!! b_prev的shape是		torch.Size([1, 64, 24, 24])
报告!! prev_b_embedding的shape是	torch.Size([1, 128, 24, 24])
报告!! b的shape是		torch.Size([1, 64, 48, 48])
----------------------------------
报告!! x的shape是		torch.Size([1, 64, 96, 96])
报告!! b_embedding的shape是	torch.Size([1, 128, 96, 96])
报告!! 

In [None]:
seed_b_centers.shape
rel_cond.shape
last.shape

torch.Size([1, 33, 384, 384])

In [None]:
out.shape

torch.Size([1, 1, 384, 384])