In [None]:
'''
PyTorch's convolutional layers, such as nn.Conv2d, expect input tensors in the NCHW format. This format represents the dimensions in the following order:
N: (Batch size): The number of independent samples or images in a batch.
C: (Channels): The number of color channels in the image (e.g., 3 for RGB, 1 for grayscale).
H: (Height): The height of the image in pixels.
W: (Width): The width of the image in pixels.
Therefore, an input image for a 2D convolution in PyTorch should have the shape (N, C, H, W). For example, a batch of 10 RGB images with a height of 256 pixels and a width of 256 pixels would have a shape of (10, 3, 256, 256).
'''
import torch
import torchvision.transforms.functional as TF

print("\n\n----------------------------------------------------------")
print("TOPIC 1: DEMONSTRATING WHAT tensor_a.shape[2:] DOES")
print("----------------------------------------------------------")
# Create a tensor
tensor_a = torch.randn(3, 2, 16, 16) 
# Get the shape using .shape
print("Tensor a")
shape_tuple = tensor_a.shape
print(f"Shape using .shape: {shape_tuple}") #output: Shape using .shape: torch.Size([2, 3, 4])

# Get the shape using .size()
size_tuple = tensor_a.size()
print(f"Shape using .size(): {size_tuple}") #output: Shape using .size(): torch.Size([2, 3, 4])

print(f"tensor_a.shape[2:] prints the last two dimensions of the tensor: {tensor_a.shape[2:]}")


print("\n\n----------------------------------------------------------")
print("TOPIC 2: RESIZING A BATCH OF TENSORS THAT ARE IMAGES ")
print("----------------------------------------------------------")
print("torchvision.transforms.functional.resize is specifically for image tensors (there could be other types of resizes in pytorch). \
it is expected to have […, H, W] shape, where … means an arbitrary number of leading dimensions. \
i.e the last two dimensions are assumed to be H,W and resizing occurs along these dimensions")
# Tensor b
tensor_b = torch.randn(3, 5, 10, 10) 
print("\nTensor b: before resizing ")
print(tensor_b.shape)
tensor_b= TF.resize(tensor_b,  size=tensor_a.shape[2:])
print("\nTensor b: after resizing along the last two dimensions tensor_b.shape[2:]")
print(tensor_b.shape)


print("\n\n----------------------------------------------------------")
print("TOPIC 3: CONCATENATING ALONG VARIOUS DIMENSIONS")
print("----------------------------------------------------------")
print("When concatenating along a particular dimension. That dimension can be different. But all other dimensions should be the same. \
For example if you are concatenating along dim1 = channels. Number of channels can be different. \
But dim0=number of batch samples should be the same. dim2,dim3= (H,W) should also be the same")

# Tensor A,B and C1
# Pytorch tensors are of the format : (N,C, H, W) = (dim0, dim1, dim2, dim3). 
# Concate along the number of channels = dimension 1 : 5+5 = 10
tensor_c1 = torch.cat((tensor_a, tensor_b), dim=1)
print("\nTensor c1: concatenate along num of channels(dimension 2). Example for concatenating skip connections")
print(f"tensor_a.shape: {tensor_a.shape}")
print(f"tensor_b.shape: {tensor_b.shape}")
print(f"tensor_c1.shape: {tensor_c1.shape}")

# Tensor A,D and C2
# Pytorch tensors are of the format : (N,C, H, W) = (dim0, dim1, dim2, dim3). 
# Concate along the height = dimension 2 : 10+10 = 20
print("\nTensor c2: concatenate along height(dimension 3). You will never need this LOL. Why will you want to increase the size of the image")
tensor_d = torch.randn(3, 2, 30, 16) 
tensor_c2 = torch.cat((tensor_a, tensor_d), dim=2)
print(f"tensor_a.shape:  {tensor_a.shape}")
print(f"tensor_d.shape:  {tensor_d.shape}")
print(f"tensor_c2.shape: {tensor_c2.shape}")

# Tensor A,E and C3
# Pytorch tensors are of the format : (N,C, H, W) = (dim0, dim1, dim2, dim3). 
# Concate along the height = dimension 2 : 10+10 = 20
print("\nTensor c3: concatenate along batche samples(dimension 0). Example when u are increasing the samples in a batch")
tensor_e = torch.randn(105, 2, 16, 16) 
tensor_c3 = torch.cat((tensor_a, tensor_e), dim=0)
print(f"tensor_a.shape:  {tensor_a.shape}")
print(f"tensor_e.shape:  {tensor_e.shape}")
print(f"tensor_c3.shape: {tensor_c3.shape}")