## Expand vs Unsqueeze

Explanation:
- x: A 1D tensor of shape [4].
- x.unsqueeze(0): Adds a batch dimension at the beginning → shape [1, 4].
- x.expand((1, -1)): Virtually expands x to shape [1, 4]. Equivalent to unsqueeze(0).
- x.expand((2, -1)): Broadcasts the single row to 2 rows → shape [2, 4].\
- x.expand((3, -1)): Broadcasts the single row to 3 rows → shape [3, 4].\
\
Important: expand does not copy data; it's a view that uses broadcasting. All rows in x_expand2 and x_expand3 point to the same memory.

In [None]:
import torch

print('----------x-----------')
x = torch.tensor([1, 2, 3, 4])  
print(x.shape)
print(x)

print('----------x_unsqueezed-----------')
# Add a batch dimension at  the beginning (position 0)
x_unsqueezed = x.unsqueeze(0)   
print(x_unsqueezed.shape)
print(x_unsqueezed)

print('----------x_expand1-----------')
# This means make the number of dimensions = 0 in the 0th dimension i.e. the 3 of rows
# and keep the same dimension in the 1st dimension i.e. number of columns = -1 = 4 in this case
x_expand1 = x.expand((1,-1))
print(x_expand1.shape)
print(x_expand1)

print('----------x_expand2-----------')
x_expand2 = x.expand((2,-1))
print(x_expand2.shape)
print(x_expand2)

print('----------x_expand3-----------')
x_expand2 = x.expand((3,-1))
print(x_expand2.shape)
print(x_expand2)

## Case where unsqueeze != expand 
Notice that in the below unsqueeze and expand(-1,1) are not the same
One errors out.

What is happening here? \
Your tensor x is: 
```
tensor([[1, 2, 3],
        [4, 5, 6]])
And its shape is (2, 3).
```
Now, you try:
```
x_expand1 = x.expand((1, -1))
```
You're asking PyTorch to expand a (2, 3) tensor to (1, 3). That doesn’t make sense because: 
- expand() can only expand dimensions of size 1.
- The first dimension is 2, and you're trying to expand it to 1 — which is not allowed.

**Rule of expand():** \
You can only expand a dimension if:
- It's size 1 in the original tensor.
- You're expanding it to a larger size.
- Or you're keeping it the same with -1
So expand() can’t shrink dimensions or expand non-singleton (non-size-1) dimensions.

**Syntax:  to add a batch dimension using expand**
```
print('----------x_expand1-----------')

# Original x.shape = (2,3)
# New shape :
# 0th dimension expand to 1
# 1st dimension keep it same = -1 = 2
# 2nd dimension keep it same = -1 = 3

x_expand1 = x.expand((1,-1, -1))
print(x_expand1.shape)
print(x_expand1)
```

**output**
```
----------x_expand1-----------
torch.Size([1, 2, 3])
tensor([[[1, 2, 3],
         [4, 5, 6]]])
```


In [None]:
print('----------x-----------')
x = torch.tensor([[1,2,3], [4,5,6]])  
print(x.shape)
print(x)

print('----------x_unsqueezed-----------')
# Add a batch dimension at  the beginning (position 0)
x_unsqueezed = x.unsqueeze(0)   
print(x_unsqueezed.shape)
print(x_unsqueezed)

print('----------x_expand1-----------')
# in 0th dimension make it 1 i.e. number of rows should be 1. But number of rows is already 2
# in 1sr dimension keep it same i.e number of columns is 3
# You're asking PyTorch to expand a (2, 3) tensor to (1, 3). That doesn’t make sense
x_expand1 = x.expand((1,-1))
print(x_expand1.shape)
print(x_expand1)

print('----------x_expand2-----------')
x_expand2 = x.expand((2,-1))
print(x_expand2.shape)
print(x_expand2)

print('----------x_expand3-----------')
x_expand2 = x.expand((3,-1))
print(x_expand2.shape)
print(x_expand2)

In [None]:
print('----------x-----------')
x = torch.tensor([[1,2,3], [4,5,6]])  
print(x.shape)
print(x)

print('----------x_expand1-----------')
# Original shape = (2,3)
# New shape :
# 0th dimension expand to 1
# 1st dimension keep it same = -1 = 2
# 2nd dimension keep it same = -1 = 3
x_expand1 = x.expand((1,-1, -1))
print(x_expand1.shape)
print(x_expand1)


In [None]:
print('----------x-----------')
x = torch.tensor([[1,2,3], [4,5,6]])  
print(x.shape)
print(x)

print('----------x_unsqueezed-----------')
# Add a batch dimension at position 1
x_unsqueezed = x.unsqueeze(1)   
print(x_unsqueezed.shape)
print(x_unsqueezed)

In [None]:
print('----------x-----------')
x = torch.tensor([[1,2,3], [4,5,6]])  
print(x.shape)
print(x)

print('----------x_unsqueezed-----------')
# Add a batch dimension at the end
x_unsqueezed = x.unsqueeze(-1)   
print(x_unsqueezed.shape)
print(x_unsqueezed)