<a href="https://colab.research.google.com/github/maxtost10/Playground-vs/blob/main/Linear_LSTM_and_Masking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import numpy as np

In [2]:
x_batch = torch.rand(1, 4, 2)
x_batch.shape

torch.Size([1, 4, 2])

In [3]:
torch.zeros(x_batch.shape[0], 3, x_batch.shape[2])

tensor([[[0., 0.],
         [0., 0.],
         [0., 0.]]])

In [8]:
ll = torch.nn.Linear(8, 1)

# Like this we can build the tensor but it will not be possible to use it with autograd

In [42]:
import torch

x_batch = torch.rand(1, 4, 2)

T = x_batch.shape[1]
results = []                    # a Python list

for n in range(T + 1):
    # copy and zero-out positions [n:] along dim=1
    masked = x_batch.clone()
    masked[:, n:, :] = 0

    # flatten and evaluate ll
    val = ll(masked.view(-1))
    print(val)

    results.append(val)

# stack into a 1D tensor
output = torch.tensor(results)
print("output shape:", output.shape)   # torch.Size([25])

tensor([0.0483], grad_fn=<ViewBackward0>)
tensor([-0.0707], grad_fn=<ViewBackward0>)
tensor([-0.0354], grad_fn=<ViewBackward0>)
tensor([0.0631], grad_fn=<ViewBackward0>)
tensor([0.1893], grad_fn=<ViewBackward0>)
output shape: torch.Size([5])


In [43]:
# Nach dem Ausführen des Codes
output.sum().backward()
print("x_batch.grad:", x_batch.grad)  # Gibt RuntimeError :(

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

# Crash course torch masking

In [19]:
import torch

# Beispiel-Daten: Batch von 3 Sequenzen je Länge 5
x = torch.arange(15).view(3,5)
# x = tensor([[ 0,  1,  2,  3,  4],
#             [ 5,  6,  7,  8,  9],
#             [10, 11, 12, 13, 14]])

# Maske: nur erste 3 Token jeder Sequenz behalten
mask = torch.arange(5) < 3
mask = mask.repeat(3, 1)
print(mask)


tensor([[ True,  True,  True, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True, False, False]])


In [23]:
# Ergebnis: maskiert (gefiltert) nur die True-Positionen,
# liefert flach die Werte [0,1,2, 5,6,7, 10,11,12]
selected = x[mask].view(3, -1)
print(selected)


tensor([[ 0,  1,  2],
        [ 5,  6,  7],
        [10, 11, 12]])

In [24]:
mask = mask.float()
print(x*mask)

tensor([[ 0.,  1.,  2.,  0.,  0.],
        [ 5.,  6.,  7.,  0.,  0.],
        [10., 11., 12.,  0.,  0.]])


In [34]:
mask = mask == 1
x.masked_fill(~mask, 111)

tensor([[  0,   1,   2, 111, 111],
        [  5,   6,   7, 111, 111],
        [ 10,  11,  12, 111, 111]])

# With these tools we can now build a differentiable tensor

In [40]:
import torch

# Beispiel-Batch mit requires_grad, damit x_batch.grad später nicht None ist
x_batch = torch.rand(1, 4, 2, requires_grad=True)

T = x_batch.shape[1]
results = []  # Liste für deine ll()-Rückgaben, jeweils 0-dim Tensoren

for n in range(T + 1):
    # Baue eine Maske: für jede Zeile ok, für Spalten ≥ n nullen
    # Das ist fully differentiable
    mask = (torch.arange(T)
        .unsqueeze(0)   # Form (1, T)
        .lt(n)          # True für Spalten < n, sonst False
    ).unsqueeze(-1).float()        # in FloatTensor: 1.0 oder 0.0
    masked = x_batch * mask  # Form (8, T, 3)

    # ll liefert einen 0-dim Tensor, der ins Gradient-Graph eingebunden ist
    val = ll(masked.view(-1))
    print(val)
    results.append(val)

# Hier verbinden wir alle 0-dim-Tensoren zu einem 1-dim-Tensor
output = torch.stack(results)  # shape: (T+1,)

# Jetzt kann man problemlos loss = output.sum(); loss.backward()
print("Output shape:", output.shape)


tensor([0.0483], grad_fn=<ViewBackward0>)
tensor([0.0205], grad_fn=<ViewBackward0>)
tensor([-0.0024], grad_fn=<ViewBackward0>)
tensor([0.1152], grad_fn=<ViewBackward0>)
tensor([0.3323], grad_fn=<ViewBackward0>)
Output shape: torch.Size([5, 1])


In [41]:
# Nach dem Ausführen des Codes
output.sum().backward()
print("x_batch.grad:", x_batch.grad)  # Sollte nicht None sein, wenn differenzierbar

x_batch.grad: tensor([[[-0.6544, -0.1412],
         [ 0.6882, -0.6431],
         [-0.1688,  0.5499],
         [ 0.0247,  0.2187]]])
