# Lazy Tensor
- High level overview https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/

### This collab is HEAVILY inspired from this 
https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md

In [1]:
import torch

# Conditional logic hard for hardware vendors to account for

def add_two_maybe(t: torch.Tensor, maybe: torch.Tensor):
    if maybe:
        return t + 2
    return t

In [2]:
# Lets use our existing tracing system
t = torch.ones(1)
maybe_false = torch.BoolTensor([0])
good_inputs = (t, maybe_false)
jit = torch.jit.trace(add_two_maybe, good_inputs)
# let's check that the results match with eager
assert jit(*good_inputs) == add_two_maybe(*good_inputs)

  if maybe:


In [3]:
maybe_true = torch.BoolTensor([1])
assert jit(t, maybe_true) == add_two_maybe(t, maybe_true)

AssertionError: 

In [4]:
# print jit graph:
print(torch.jit.last_executed_optimized_graph())

# No if statemnt in graph just return the else path

graph(%t : Tensor,
      %maybe : Tensor):
  return (%t)



In [5]:
# Lazy tensors to the rescue
# lazy device remebers aten ops called with what inputs as opposed to 
# eargly executing them
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR


# Lazy imports
import torch._lazy
import torch._lazy.ts_backend
import torch._lazy.metrics
torch._lazy.ts_backend.init()

In [6]:
# A virtual "lazy" device 
dev = "lazy"
t_lazy = torch.ones(1).to(dev)
maybe_false_lazy = torch.BoolTensor([0]).to(dev)
lazy_result = add_two_maybe(t_lazy, maybe_false_lazy)

In [7]:
# Printing triggers execution of the op
print(lazy_result)
assert lazy_result.cpu() == add_two_maybe(t, maybe_false)

tensor([1.], device='lazy:0')


In [8]:
# Now for the case that Jit couldn't handle:
maybe_true_lazy = torch.BoolTensor([1]).to(dev)
lazy_result = add_two_maybe(t_lazy, maybe_true_lazy)
assert lazy_result.cpu() == add_two_maybe(t, maybe_true)

## Downsides
- Overhead for backends to translate aten ops to lower level for hardware
- Depends on model amount of dynamicism. The less dynamic the greater the reward for generating a trace and compiling but if super dynamic then there will be non trivial amount of time re-tracing and re-compiling

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [28]:
def train(log_interval, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad(set_to_none=True)
        # Forward traced
        output = model(data)
        loss = F.nll_loss(output, target)
        # Backward traced
        loss.backward()
        optimizer.step()
        # Regular Training loop execpt for this func call which
        # instructs Lazy Tensor to break up the current trace 
        # and start executing it asynchronously.
        torch._lazy.mark_step()
        
        # if batch_idx %2 ==0:
        #    torch._lazy.mark_step() 
        # can do this but don't need to capture
        # multiple forward backward passes in one trace

        if batch_idx % log_interval == 0:
            #  Print is a blocking call that will cause execution to pause
            # so that the loss can be computed and printed
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [None]:
bsz = 64
device = 'lazy'
epochs = 14
log_interval = 10
lr = 1
gamma = 0.7
train_kwargs = {'batch_size': bsz}


transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('./data', train=True, download=True,
                    transform=transform)

# my computer is fast but not that fast sowe shorten data
dataset1.data = dataset1.data[:6000,:,:]
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)



# Move the model to the lazy device
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

for epoch in range(1, epochs + 1):
    train(log_interval, model, device, train_loader, optimizer, epoch)
    scheduler.step()

## Caveats
- Not full op coverage covers the top 100 (However the current trace will break, it will concretize all the inputs to the current un-supported op and run on a suppported device and then shifts everything back to lazy device.


In [38]:
torch._lazy.metrics.reset()
train(log_interval, model, device, train_loader, optimizer, 1)

# Any op with aten prefix is not supported with on lazy tensor device
from pprint import pprint
pprint(torch._lazy.metrics.counter_names())

['CachedCompile',
 'CreateLtcTensor',
 'DestroyLtcTensor',
 'DeviceDataCacheMiss',
 'MarkStep',
 'UncachedCompile',
 'aten::_local_scalar_dense',
 'lazy::_copy_from',
 'lazy::_log_softmax',
 'lazy::_log_softmax_backward_data',
 'lazy::_to_copy',
 'lazy::add',
 'lazy::addcmul',
 'lazy::addmm',
 'lazy::convolution',
 'lazy::convolution_backward',
 'lazy::div',
 'lazy::fill_',
 'lazy::max_pool2d_with_indices',
 'lazy::max_pool2d_with_indices_backward',
 'lazy::mm',
 'lazy::mul',
 'lazy::native_dropout',
 'lazy::native_dropout_backward',
 'lazy::nll_loss_backward',
 'lazy::nll_loss_forward',
 'lazy::relu',
 'lazy::sqrt',
 'lazy::sum',
 'lazy::t',
 'lazy::threshold_backward',
 'lazy::view',
 'lazy::zero_functional']


In [39]:
# while mark_step executes op asynchrosnly wait_device_ops() is a blocking op
torch._lazy.wait_device_ops()

### Blog on implemntation with torch xla
https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/