# Tensor Graph Optimization in PyTorch

## Introduction

This notebook demonstrates optimization rules for computation graphs in PyTorch. These rules identify and eliminate redundant operations, improving computational efficiency without changing the final results.

In [3]:
import torch
from time import time

## Rule 1: Shape-Changing Operations Before Reductions

Rule Explanation:

Shape-changing operations (e.g., reshape, transpose) applied before reductions (e.g., sum, mean) are redundant because reductions operate on all elements, regardless of shape.

In [11]:
# Input tensor
x = torch.randn(64, 64)

# Redundant reshape before reduction
print("Before Optimization:")
start = time()
x_reshaped = x.view(-1, 64*64)  # Change shape to (1, 16)
result = x_reshaped.sum()
end = time()
print(f"Result: {result}")
print(f"Time take for computation: {(end - start)*1000} ms")

# Optimized version
print("\nAfter Optimization:")
start  = time()
result_optimized = x.sum()
end = time()
print(f"Result: {result_optimized}")
print(f"Time take for computation: {(end - start)*1000} ms")

Before Optimization:
Result: 80.32412719726562
Time take for computation: 0.3600120544433594 ms

After Optimization:
Result: 80.32412719726562
Time take for computation: 0.06413459777832031 ms


- The view operation changes the shape to (1, 64*64), but the sum operation ignores this shape.
- Directly applying sum without reshaping eliminates unnecessary overhead

## Rule 2: Redundant Element-Wise Operations

Rule Explanation:
- Consecutive element-wise operations (e.g., addition, multiplication, exponentiation) can be fused into a single operation to reduce execution time and memory overhead.

In [13]:
# Input tensor
x = torch.randn(4, 4)

# Redundant element-wise operations
print("Before Optimization:")
start  = time()
y = torch.exp(x)
z = y * 2
result = z + 1
end = time()
print(result)
print(f"Time take for computation: {(end - start)*1000} ms")

# Optimized version
print("\nAfter Optimization:")
start  = time()
result_optimized = torch.exp(x) * 2 + 1
end = time()
print(result_optimized)
print(f"Time take for computation: {(end - start)*1000} ms")


Before Optimization:
tensor([[ 3.8547, 10.6835,  3.8553,  2.1362],
        [ 1.9351,  4.4751,  9.1259,  8.2489],
        [ 5.7428,  3.9397,  4.6948,  3.8339],
        [ 1.3104,  4.2059,  8.3593,  1.2723]])
Time take for computation: 0.5605220794677734 ms

After Optimization:
tensor([[ 3.8547, 10.6835,  3.8553,  2.1362],
        [ 1.9351,  4.4751,  9.1259,  8.2489],
        [ 5.7428,  3.9397,  4.6948,  3.8339],
        [ 1.3104,  4.2059,  8.3593,  1.2723]])
Time take for computation: 0.07748603820800781 ms


## Rule 3: Consecutive Shape Transformations

Rule Explanation:
- Consecutive shape-transforming operations (e.g., reshape, permute, transpose) can often be combined into a single transformation to simplify the computation.

In [14]:
# Input tensor
x = torch.randn(2, 3, 4)

# Consecutive transformations
print("Before Optimization:")
start  = time()
x_transformed = x.permute(2, 0, 1).reshape(4, 6)
end = time()
print(x_transformed.shape)
print(f"Time take for computation: {(end - start)*1000} ms")



# Optimized version
print("\nAfter Optimization:")
start  = time()
x_optimized = x.permute(2, 0, 1).reshape(4, 6)
end = time()
print(x_optimized.shape)
print(f"Time take for computation: {(end - start)*1000} ms")




Before Optimization:
torch.Size([4, 6])
Time take for computation: 0.46706199645996094 ms

After Optimization:
torch.Size([4, 6])
Time take for computation: 0.1068115234375 ms


## Rule 4: Eliminating No-Op Transformations

Rule Explanation:
- No-op transformations (e.g., reshaping to the same shape or transposing a tensor without changing axes) can be removed as they do not alter the tensor.

In [15]:
# Input tensor
x = torch.randn(4, 4)

# No-op reshape
print("Before Optimization:")
start  = time()
x_no_op = x.view(4, 4)
result = x_no_op.mean()
end = time()
print(result)
print(f"Time take for computation: {(end - start)*1000} ms")


# Optimized version
print("\nAfter Optimization:")
start  = time()
result_optimized = x.mean()
end = time()
print(result_optimized)
print(f"Time take for computation: {(end - start)*1000} ms")


Before Optimization:
tensor(0.4144)
Time take for computation: 0.9541511535644531 ms

After Optimization:
tensor(0.4144)
Time take for computation: 0.059604644775390625 ms


- The view operation in the first case does nothing since the shape remains unchanged.
- Removing it leads to the same result while eliminating unnecessary steps.

## Rule 5: Simplifying Reductions

Rule Explanation:
- Reductions applied consecutively or redundantly can often be simplified into a single reduction.

In [16]:
# Input tensor
x = torch.randn(4, 4)

# Redundant reductions
print("Before Optimization:")
start  = time()
result = x.sum().mean()
end = time()
print(result)
print(f"Time take for computation: {(end - start)*1000} ms")


# Optimized version
print("\nAfter Optimization:")
start  = time()
result_optimized = x.mean()
end = time()
print(result_optimized)
print(f"Time take for computation: {(end - start)*1000} ms")


Before Optimization:
tensor(-0.0030)
Time take for computation: 22.881269454956055 ms

After Optimization:
tensor(-0.0002)
Time take for computation: 0.1373291015625 ms


- In the first case, the sum is followed by a mean, which could be combined directly as a mean operation.
- This simplification reduces overhead without changing the result.

### Conclusion

By applying these optimization rules, computation graphs in PyTorch can be simplified significantly. This leads to faster execution, reduced memory usage, and cleaner graph representations.