# Core

> Core functions

In [1]:
#| default_exp layers.squeeze

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
from fastai.vision.all import nn

In [4]:
from torch import randn as torch_randn
from fastai.vision.all import test_eq

In [5]:
#| export
def squeeze_spatial(input, factor=2):
    assert factor >= 1 and isinstance(factor, int), "Factor must be a positive integer."
    
    if factor == 1: return input
    
    dim = input.dim()
    assert dim >= 3 and dim <= 5, "Input tensor must have 3D, 4D, or 5D dimensions."
    
    sizes = list(input.size())
    batch_size, channels = sizes[:2]
    spatial_dims = sizes[2:]
    
    for dim_idx in range(len(spatial_dims)):
        assert spatial_dims[dim_idx] % factor == 0, "Spatial dimension {} is not divisible by the factor.".format(dim_idx)
        sizes[dim_idx + 2] //= factor
    
    input = input.view(batch_size, channels, *spatial_dims)  # Ensure contiguous memory
    
    new_channels = channels * (factor ** (dim_idx + 1))
    new_spatial_dims = sizes[2:]

    return  input.view(batch_size, new_channels, *new_spatial_dims)


In [6]:
x = torch_randn(16,1,64,64)
test_eq(squeeze_spatial(x).shape, [16,4,32,32])

x = torch_randn(16,1,64,64,32)
test_eq(squeeze_spatial(x).shape, [16,8,32,32,16])

In [7]:
#| export
def unsqueeze_spatial(input, factor=2):
    assert factor >= 1 and isinstance(factor, int), "Factor must be a positive integer."
    
    if factor == 1: return input
    
    dim = input.dim()
    assert dim >= 3 and dim <= 5, "Input tensor must have 3D, 4D, or 5D dimensions."
    
    sizes = list(input.size())
    batch_size, channels = sizes[:2]
    spatial_dims = sizes[2:]
    
    sizes = [sizes[0], sizes[1]]
    for dim_idx in range(len(spatial_dims)):
        sizes.append(spatial_dims[dim_idx] * factor)
            
    input = input.view(batch_size, channels, *spatial_dims)  # Ensure contiguous memory
    
    new_channels = channels // (factor ** (len(spatial_dims)))
    new_spatial_dims = sizes[2:]
    
    return  input.view(batch_size, new_channels, *new_spatial_dims)


In [8]:
x = torch_randn(16,4,32,32)
test_eq(unsqueeze_spatial(x).shape, [16,1,64,64])

x = torch_randn(16,8,32,32,16)
test_eq(unsqueeze_spatial(x).shape, [16,1,64,64,32])

In [9]:
#| export
class SqueezeLayer(nn.Module):
	def __init__(self, factor, level, name='squeeze'):
		super(SqueezeLayer, self).__init__()
		self.factor = factor
		self.name = name
		self.level = level

	def _inverse(self, z, **kwargs):
		output = unsqueeze_spatial(z, self.factor)
		return output

	def _forward_and_log_det_jacobian(self, x, **kwargs):
		output = squeeze_spatial(x, self.factor)
		return output, 0


In [10]:
x = torch_randn(16,1,64,64)

tst = SqueezeLayer(2, 1)
z, _ = tst._forward_and_log_det_jacobian(x)
test_eq(z.shape, [16,4,32,32])

In [11]:
#| hide
import nbdev; nbdev.nbdev_export()