# Module 2.3: Combining Tensors

Learn how to combine multiple tensors together.

In [None]:
import torch

print(f"PyTorch version: {torch.__version__}")

## 1. torch.cat() - Concatenate along existing dimension

In [None]:
a = torch.tensor([[1, 2],
                  [3, 4]])
b = torch.tensor([[5, 6],
                  [7, 8]])

print(f"Tensor a:\n{a}")
print(f"\nTensor b:\n{b}")

In [None]:
# Concatenate along dimension 0 (rows)
cat_dim0 = torch.cat([a, b], dim=0)
print(f"Concatenate dim=0 (stack vertically):\n{cat_dim0}")
print(f"Shape: {cat_dim0.shape}")

In [None]:
# Concatenate along dimension 1 (columns)
cat_dim1 = torch.cat([a, b], dim=1)
print(f"Concatenate dim=1 (stack horizontally):\n{cat_dim1}")
print(f"Shape: {cat_dim1.shape}")

In [None]:
# Can concatenate more than 2 tensors
c = torch.tensor([[9, 10],
                  [11, 12]])
cat_three = torch.cat([a, b, c], dim=0)
print(f"Concatenate three tensors:\n{cat_three}")

## 2. torch.stack() - Stack along NEW dimension

In [None]:
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.tensor([7, 8, 9])

print(f"x: {x}")
print(f"y: {y}")
print(f"z: {z}")

In [None]:
# Stack creates a NEW dimension
stacked_0 = torch.stack([x, y, z], dim=0)
print(f"Stack dim=0:\n{stacked_0}")
print(f"Shape: {stacked_0.shape}")  # (3, 3)

stacked_1 = torch.stack([x, y, z], dim=1)
print(f"\nStack dim=1:\n{stacked_1}")
print(f"Shape: {stacked_1.shape}")  # (3, 3)

### Difference between cat and stack:

- **cat**: joins along EXISTING dimension (no new dimension)
- **stack**: creates a NEW dimension

Example with 1D tensors `[1,2,3]` and `[4,5,6]`:
- `cat(dim=0)` -> `[1,2,3,4,5,6]` shape: `(6,)`
- `stack(dim=0)` -> `[[1,2,3],[4,5,6]]` shape: `(2, 3)`

## 3. Practical example: Creating batches

In [None]:
# Simulating 4 individual images (grayscale, 28x28)
image1 = torch.randn(1, 28, 28)  # [channels, height, width]
image2 = torch.randn(1, 28, 28)
image3 = torch.randn(1, 28, 28)
image4 = torch.randn(1, 28, 28)

# Stack to create a batch
batch = torch.stack([image1, image2, image3, image4], dim=0)
print(f"Single image shape: {image1.shape}")
print(f"Batch shape: {batch.shape}")  # [batch, channels, height, width]

## 4. torch.chunk() and torch.split() - Divide tensors

In [None]:
# Create a tensor to split
data = torch.arange(12).reshape(4, 3)
print(f"Original data:\n{data}")

In [None]:
# chunk: split into equal parts
chunks = torch.chunk(data, chunks=2, dim=0)
print(f"Chunk into 2 parts (dim=0):")
for i, chunk in enumerate(chunks):
    print(f"  Chunk {i}:\n{chunk}")

In [None]:
# split: split by specific sizes
splits = torch.split(data, split_size_or_sections=[1, 3], dim=0)
print(f"Split [1, 3] (dim=0):")
for i, split in enumerate(splits):
    print(f"  Split {i}:\n{split}")

## 5. Convenience functions: hstack, vstack

In [None]:
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

print(f"a: {a}")
print(f"b: {b}")

# hstack: horizontal stack (along columns)
print(f"hstack: {torch.hstack([a, b])}")

# vstack: vertical stack (along rows)
print(f"vstack:\n{torch.vstack([a, b])}")

In [None]:
# For 2D tensors
m1 = torch.tensor([[1, 2], [3, 4]])
m2 = torch.tensor([[5, 6], [7, 8]])

print(f"Matrix m1:\n{m1}")
print(f"Matrix m2:\n{m2}")
print(f"\nhstack:\n{torch.hstack([m1, m2])}")
print(f"\nvstack:\n{torch.vstack([m1, m2])}")

## 6. Broadcasting (Automatic size matching)

In [None]:
# Broadcasting automatically expands dimensions
a = torch.tensor([[1], [2], [3]])  # Shape: (3, 1)
b = torch.tensor([10, 20, 30, 40])  # Shape: (4,)

print(f"a (shape {a.shape}):\n{a}")
print(f"b (shape {b.shape}): {b}")

# PyTorch broadcasts automatically!
result = a + b  # (3, 1) + (4,) -> (3, 4)
print(f"\na + b (shape {result.shape}):\n{result}")

### Broadcasting rules:

1. Align shapes from the right
2. Dimensions must be equal or one of them must be 1
3. Missing dimensions are treated as 1

**Example:** `(3, 1) + (4,)`
- Step 1: Align -> `(3, 1)` and `(1, 4)`
- Step 2: Expand -> `(3, 4)` and `(3, 4)`
- Step 3: Add element-wise

## Summary

### Combining
| Method | Description |
|--------|-------------|
| `torch.cat([a, b], dim)` | Join along existing dimension |
| `torch.stack([a, b], dim)` | Join along NEW dimension |
| `torch.hstack([a, b])` | Horizontal stack |
| `torch.vstack([a, b])` | Vertical stack |

### Splitting
| Method | Description |
|--------|-------------|
| `torch.chunk(x, n, dim)` | Split into n equal parts |
| `torch.split(x, sizes, dim)` | Split by specific sizes |

### Broadcasting
- Automatic size matching for operations
- Shapes aligned from right
- Size-1 dims are expanded

### Key Insight
- Use **STACK** to create batches from individual samples
- Use **CAT** to combine batches into larger batches

---
**Next:** Move to Module 3 to learn about autograd!