# 3 Diagonal Augment

Given a tensor and $n$, arrange the tensor so that the width is increased by $n$, creating a triangle of padded zeros in the upper left and lower right corners.

[[1,2,3,]

  [4,5,6],

  [7,8,9]]

n = 2 â†’

[[0,0,1,2,3],

 [0,4,5,6,0],

 [7,8,9,0,0]]

<div class="alert alert-block alert-warning">
<b>Note:</b> This is not a very good question statement. Needs to be reworked.
</div>

## Implement the Function

In [60]:
import torch

def diagonal_augment(t: torch.Tensor, n: int) -> torch.Tensor:
    pass

## Check Test Cases

In [None]:
test_func = diagonal_augment
test_cases = {
    # Given example
    (torch.tensor([[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]]), 2):
        torch.tensor([[0, 0, 1, 2, 3],
                      [0, 4, 5, 6, 0],
                      [7, 8, 9, 0, 0]]),

    # n = 0 (no change)
    (torch.tensor([[1, 2],
                   [3, 4]]), 0):
        torch.tensor([[1, 2],
                      [3, 4]]),

    # 1x1, n = 3
    (torch.tensor([[7]]), 3):
        torch.tensor([[0, 0, 0, 7]]),

    # Single row, n = 2
    (torch.tensor([[1, 2, 3, 4]]), 2):
        torch.tensor([[0, 0, 1, 2, 3, 4]]),

    # Single column, n = 3
    (torch.tensor([[1],
                   [2],
                   [3],
                   [4]]), 3):
        torch.tensor([[0, 0, 0, 1],
                      [0, 0, 2, 0],
                      [0, 3, 0, 0],
                      [4, 0, 0, 0]]),

    # (torch.tensor([[1],
    #                [2],
    #                [3],
    #                [4]]), 2):
    #     torch.tensor([[0, 0, 1],
    #                   [0, 2, 0],
    #                   [0, 3, 0],
    #                   [4, 0, 0]]),

    # Rectangular 2x3, n = 1
    (torch.tensor([[1, 2, 3],
                   [4, 5, 6]]), 1):
        torch.tensor([[0, 1, 2, 3],
                      [4, 5, 6, 0]]),

    # Rectangular 3x2, n = 2
    (torch.tensor([[1, 2],
                   [3, 4],
                   [5, 6]]), 2):
        torch.tensor([[0, 0, 1, 2],
                      [0, 3, 4, 0],
                      [5, 6, 0, 0]]),

    # n larger than height-1 (still just shifts by row index, total width + n)
    (torch.tensor([[1, 2],
                   [3, 4]]), 5):
        torch.tensor([[0, 0, 0, 0, 0, 1, 2],
                      [0, 0, 0, 0, 3, 4, 0]]),

    # Includes negatives, n = 2
    (torch.tensor([[-1, -2],
                   [-3, -4],
                   [-5, -6]]), 2):
        torch.tensor([[0, 0, -1, -2],
                      [0, -3, -4, 0],
                      [-5, -6, 0, 0]]),

    # Includes zeros in input (should remain, distinct from padding), n = 1
    (torch.tensor([[0, 1, 0],
                   [2, 0, 3]]), 1):
        torch.tensor([[0, 0, 1, 0],
                      [2, 0, 3, 0]]),

    # Larger square, n = 3
    (torch.tensor([[1,  2,  3,  4],
                   [5,  6,  7,  8],
                   [9, 10, 11, 12],
                   [13,14, 15,16]]), 3):
        torch.tensor([[0, 0, 0, 1,  2,  3,  4],
                      [0, 0, 5, 6,  7,  8,  0],
                      [0, 9,10,11, 12, 0,  0],
                      [13,14,15,16, 0, 0,  0]]),
}

error_list = []
correct = 0
num = len(test_cases.keys())
for k, v in test_cases.items():
    result = test_func(*k)
    if torch.all(result == v):
        correct += 1
    else:
        error_list.append(f'Received {result} on input {k}. Expected {v}.')

print(f'Test cases passed: {correct}/{num}')

if correct == num:
    print('Success!')
else:
    [print(e) for e in error_list]

Test cases passed: 11/11
Success!


## Solutions

```python
def diagonal_augment(t: torch.Tensor, n: int) -> torch.Tensor:

    if n == 0:
        return t
    
    t = torch.nn.functional.pad(t, (n, 0, 0, 0), 'constant', 0)
    for i in range(t.shape[0]):
        t[i] = torch.roll(t[i], shifts=-i, dims=0)
    return t
```

This is what ChatGPT came up with:

```python
def diagonal_augment(t: torch.Tensor, n: int) -> torch.Tensor:

    H, W = t.shape
    out = torch.zeros((H, W + n), dtype=t.dtype, device=t.device)

    for i in range(H):
        start = max(0, n - i)
        out[i, start:start + W] = t[i]

    return out
```