In [233]:
!pip install matplotlib tensorboard
import torch
import torch.nn.functional as F
import math
import torch.nn
import matplotlib.pyplot as plt
import numpy as np

Collecting tensorboard
  Downloading tensorboard-2.13.0-py3-none-any.whl (5.6 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0mm
Collecting tensorboard-data-server<0.8.0,>=0.7.0
  Using cached tensorboard_data_server-0.7.0-py3-none-any.whl (2.4 kB)
Collecting markdown>=2.6.8
  Using cached Markdown-3.4.3-py3-none-any.whl (93 kB)
Collecting absl-py>=0.4
  Using cached absl_py-1.4.0-py3-none-any.whl (126 kB)
Collecting protobuf>=3.19.6
  Downloading protobuf-4.23.0-cp37-abi3-macosx_10_9_universal2.whl (400 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m400.3/400.3 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m[31m5.2 MB/s[0m eta [36m0:00:01[0m
[?25hCollecting grpcio>=1.48.2
  Downloading grpcio-1.54.2-cp310-cp310-macosx_12_0_universal2.whl (8.6 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.6/8.6 

In [350]:
def plot_tensor(tensor):
    fig, ax = plt.subplots()
    im = ax.imshow(tensor, cmap='viridis') # You can use other colormaps like 'plasma', 'inferno', 'magma', etc.
    plt.show()

class Register(object):
    def __init__(self, name, size):
        self.name = name
        self.size = size
        self.offset = None

class Embedding(object):
    def __init__(self, tokens: list[str], registers: list[Register]):
        self.tokens = tokens
        self.token_map = { t: i for i,t in enumerate(tokens) }
        self.registers = registers
        self.register_map = {}
        self.register_size = 0
        
        if len(registers) == 0 or registers[0].name != 'pos':
            raise Exception("First register must be 'pos'") 
        
        offset = len(tokens)
        for reg in registers:
            reg.offset = offset
            offset += reg.size
            self.register_size += reg.size
            self.register_map[reg.name] = reg
            
        self.dim = len(tokens) + self.register_size

    def tokenize(self, string: str):
        return F.one_hot(torch.tensor([self.token_map[c] for c in string]), num_classes=len(self.tokens)).float()

    def embed(self, sequence):
        # We want to create additional space to store the registers
        extension_tensor = torch.zeros(*sequence.shape[:-1], self.register_size)

        # Encode position in the first extra embedding dimension
        for i in range(sequence.shape[0]):
            extension_tensor[i, 0] = math.sin(i*(2*math.pi)/100)
            extension_tensor[i, 1] = math.cos(i*(2*math.pi)/100)

        sequence = torch.cat((sequence, extension_tensor), dim=-1)

        return sequence
    
    def predict(self, sequence):
        return self.tokens[torch.argmax(sequence[-1,:len(self.tokens)])]

class AttentionLayer(torch.nn.Module):
    def __init__(self, instruction):
        super(AttentionLayer, self).__init__()
        self.instruction = instruction
        
        self.key = torch.nn.Parameter(instruction.key)
        self.value = torch.nn.Parameter(instruction.value)
        self.query = torch.nn.Parameter(instruction.query)
        
        self.softmax = torch.nn.Softmax(2)
        
    def forward(self, seq):
        batch_size, seq_length, dim = seq.shape
        
        query = seq @ self.query
        key = seq @ self.key
        value = seq @ self.value

        causal_mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1)*-1e10
        norm = np.sqrt(dim)
        
        kq = self.softmax(query @ key.transpose(-2, -1) / norm + causal_mask)
        
        s = kq @ value
        
        return (seq + s)
    
    def size(self):
        return torch.numel(self.instruction.query) + torch.numel(self.instruction.query) + torch.numel(self.instruction.query)
    
    def reset(self):
        torch.nn.init.xavier_uniform_(self.key)
        torch.nn.init.xavier_uniform_(self.query)
        torch.nn.init.xavier_uniform_(self.value)
    
class MLPLayer(torch.nn.Module):
    def __init__(self, instruction, debug=False):
        super(MLPLayer, self).__init__()
        self.instruction = instruction
        self.debug = debug
        
        self.first_weights = torch.nn.Parameter(instruction.first_weights)
        self.first_bias = torch.nn.Parameter(instruction.first_bias)
        self.second_weights = torch.nn.Parameter(instruction.second_weights)
        self.second_bias = torch.nn.Parameter(instruction.second_bias)
        
        self.gelu = torch.nn.GELU()
        
    def forward(self, seq):
        a = self.gelu(seq @ self.first_weights + self.first_bias)
        if self.debug:
            plot_tensor(a)
        b = (a @ self.second_weights)
        if self.debug:
            plot_tensor(b)
        x = b + self.second_bias
        if self.debug:
            plot_tensor(x)
        return seq + x
    
    def size(self):
        return (torch.numel(self.instruction.first_weights) + torch.numel(self.instruction.first_bias) 
                + torch.numel(self.instruction.second_weights) + torch.numel(self.instruction.second_bias))
    
    def reset(self):
        torch.nn.init.xavier_uniform_(self.first_weights)
        torch.nn.init.zeros_(self.first_bias)
        torch.nn.init.xavier_uniform_(self.second_weights)
        torch.nn.init.zeros_(self.second_bias)
    
tokens = list('0123456789+= \n')
pos = Register('pos', 2)
left_pos = Register('left_pos', 2)
right_pos = Register('right_pos', 2)
out_pos = Register('out_pos', 2)
left_digit = Register('left', len(tokens))
right_digit = Register('right', len(tokens))
out_digit = Register('out', len(tokens))
final_digit = Register('final', len(tokens))
carry = Register('carry', 1)
distance = Register('distance', 1)

embedding = Embedding(tokens, [pos, left_pos, right_pos, out_pos, left_digit, right_digit, out_digit, carry, distance, final_digit])
        
class FindAndStore(AttentionLayer):
    def __init__(self, embedding: Embedding, token: str, register: Register):
        pos_reg = embedding.register_map['pos']
        
        # No matter the current token, we attend to if the attended token is the given token
        token_select = torch.zeros(embedding.dim, embedding.dim) - 1e10
        token_select[:, int(embedding.token_map[token])] = 1e10

        position_select = torch.zeros(embedding.dim, embedding.dim)
        position_select[pos_reg.offset, register.offset] = 1.0
        position_select[pos_reg.offset + 1, register.offset + 1] = 1.0
        
        self.key = torch.eye(embedding.dim)
        self.query = token_select
        self.value = position_select        
        
        super(FindAndStore, self).__init__(self)

ex = embedding.embed(embedding.tokenize('10+10=2111')) 
        
class GetRelativeToken(AttentionLayer):
    def __init__(self, embedding: Embedding, pos_reg: Register, steps: int, out: Register):
        tpos_reg = embedding.register_map['pos']
        
        position_select = torch.zeros(embedding.dim, embedding.dim)
        position_select[tpos_reg.offset, tpos_reg.offset] = 1e10
        position_select[tpos_reg.offset + 1, tpos_reg.offset + 1] = 1e10

        i = -steps
        sin = math.sin(i*(2*math.pi)/100)*1
        cos = math.cos(i*(2*math.pi)/100)*1

        rotation = torch.zeros(embedding.dim, embedding.dim)
        rotation[pos_reg.offset, tpos_reg.offset] = cos
        rotation[pos_reg.offset + 1, tpos_reg.offset] = -sin
        rotation[pos_reg.offset, tpos_reg.offset + 1] = sin
        rotation[pos_reg.offset + 1, tpos_reg.offset + 1] = cos
        #plot_tensor(rotation)
        
        token_copy = torch.zeros(embedding.dim, embedding.dim)
        for i in range(len(embedding.tokens)):
            token_copy[i, i + out.offset] = 1.0
            
        self.query = rotation
        self.key = position_select
        self.value = token_copy
        
        super(GetRelativeToken, self).__init__(self)
        
class Multiply(MLPLayer):
    def __init__(self, embedding: Embedding, left_token: Register, right_token: Register, carry: Register, out_token: Register):
        width = 10*10*2

        self.first_weights = torch.zeros(embedding.dim, width)
        self.first_bias = torch.zeros(width)
        
        self.second_weights = torch.zeros(width, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)

        idx = 0
        for i in range(0, 10):
            for j in range(0, 10):
                self.first_weights[left_token.offset + i, idx] = 500
                self.first_weights[right_token.offset + j, idx] = 500
                self.first_weights[carry.offset, idx] = -1000
                self.first_bias[idx] = -900
                self.second_weights[idx, out_token.offset + (i + j) % 10] = 0.01
                if (i + j) >= 10:
                    self.second_weights[idx, carry.offset] = 0.01
                idx += 1
                
                self.first_weights[left_token.offset + i, idx] = 333
                self.first_weights[right_token.offset + j, idx] = 333
                self.first_weights[carry.offset, idx] = 333
                self.first_bias[idx] = -900
                self.second_weights[idx, out_token.offset + (i + j + 1) % 10] = 0.01 * (1.0/0.99)
                # If we need to carry, there was already a carry so we change nothing 
                if (i + j + 1) >= 10:
                    self.second_weights[idx, carry.offset] = 0.0
                # If we don't need to carry, we need to clear the carry bit,
                else:
                    self.second_weights[idx, carry.offset] = -0.01 * (1.0/0.99)
                idx += 1
                
        super(Multiply, self).__init__(self)

class Clear(MLPLayer):
    def __init__(self, embedding: Embedding, registers: list[Register]):
        self.first_weights = torch.zeros(embedding.dim, embedding.dim)
        self.first_bias = torch.zeros(embedding.dim)
        
        for reg in registers:
            for i in range(reg.size):
                self.first_weights[reg.offset + i, reg.offset + i] = 100.0
        
        self.second_weights = torch.zeros(embedding.dim, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        for reg in registers:
            for i in range(reg.size):
                self.second_weights[reg.offset + i, reg.offset + i] = -0.01
                
        super(Clear, self).__init__(self)
                
class DiffPos(MLPLayer):
    def __init__(self, embedding: Embedding, left_pos: Register, right_pos: Register, distance: Register):
        self.first_weights = torch.zeros(embedding.dim, 2)
        self.first_bias = torch.zeros(2)
        
        # Note: it's important that the x and y are multiplied by different numbers,
        # otherwise 1, 2 computes to be the same as 2, 1
        
        self.first_weights[left_pos.offset, 0] = 1e2
        self.first_weights[left_pos.offset + 1, 0] = 1e3
        self.first_weights[right_pos.offset, 0] = -1e2
        self.first_weights[right_pos.offset + 1, 0] = -1e3
        
        self.first_weights[left_pos.offset, 1] = -1e2
        self.first_weights[left_pos.offset + 1, 1] = -1e3
        self.first_weights[right_pos.offset, 1] = 1e2
        self.first_weights[right_pos.offset + 1, 1] = 1e3
        
        self.second_weights = torch.zeros(2, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        self.second_weights[0, distance.offset] = 1.0
        self.second_weights[1, distance.offset] = 1.0
        
        super(DiffPos, self).__init__(self)
        
class IsZero(MLPLayer):
    def __init__(self, embedding: Embedding, zero: Register):
        self.first_weights = torch.zeros(embedding.dim, 2)
        self.first_bias = torch.zeros(1)
        
        self.first_weights[zero.offset, 0] = -100.0
        self.first_weights[zero.offset, 1] = 1.0
        self.first_bias[0] = 10.0
        
        self.second_weights = torch.zeros(2, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        self.second_weights[0, zero.offset] = 0.1
        self.second_weights[1, zero.offset] = -1.0
        self.second_bias[zero.offset] = 10.0
        
        super(IsZero, self).__init__(self)

class CopyOnMatch(MLPLayer):
    def __init__(self, embedding: Embedding, match: Register, src: Register, dst: Register):
        self.first_weights = torch.zeros(embedding.dim, src.size)
        self.first_bias = torch.ones(src.size)
        
        for i in range(src.size):
            self.first_weights[match.offset, i] = 100
            self.first_weights[src.offset + i, i] = 100
            self.first_bias[i] = -190
        
        self.second_weights = torch.zeros(src.size, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        for i in range(src.size):
            self.second_weights[i, dst.offset + i] = 0.1
            
        super(CopyOnMatch, self).__init__(self)

class StepPosition(MLPLayer):
    def __init__(self, embedding: Embedding, positions: list[Register], offsets: list[Register]):

        rotation = torch.zeros(embedding.dim, embedding.dim)
        
        n = 0
        for reg in positions:
            i = offsets[n]
            sin = math.sin(-i*(2*math.pi)/100)*1
            cos = math.cos(-i*(2*math.pi)/100)*1
        
            rotation[reg.offset, reg.offset] = cos - 1
            rotation[reg.offset + 1, reg.offset] = -sin
            rotation[reg.offset, reg.offset + 1] = sin
            rotation[reg.offset + 1, reg.offset + 1] = cos - 1
            
            n += 1

        self.first_weights = rotation
        self.first_bias = torch.ones(embedding.dim)*1000
        
        self.second_weights = torch.eye(embedding.dim, embedding.dim)
        self.second_bias = torch.ones(embedding.dim)*-1000
        
        super(StepPosition, self).__init__(self)

In [351]:
import tqdm

class PaddedAddition(torch.nn.Module):
    def __init__(self):
        super(PaddedAddition, self).__init__()

        # Attention layer to search for the location of various tokens
        self.find_plus = FindAndStore(embedding, '+', left_pos)
        self.find_equal = FindAndStore(embedding, '=', right_pos)
        self.find_out = FindAndStore(embedding, '=', out_pos)

        # Fully connected layers to adjust positions by fixed amounts
        self.step_out = StepPosition(embedding, [out_pos], [-1])
        
        # Note: I've unrolled this because I was concerned the reuse of weights would do weird things with
        # backprop
        self.clear_lro_1 = Clear(embedding, [left_digit, right_digit, out_digit])
        self.step_pos_1 = StepPosition(embedding, [left_pos, right_pos, out_pos], [-1, -1, 1])
        self.read_left_1 = GetRelativeToken(embedding, left_pos, 0, left_digit)
        self.read_right_1 = GetRelativeToken(embedding, right_pos, 0, right_digit)
        self.multiply_1 = Multiply(embedding, left_digit, right_digit, carry, out_digit)
        self.diff_pos_1 = DiffPos(embedding, pos, out_pos, distance)
        self.is_zero_1 = IsZero(embedding, distance)
        self.copy_out_1 = CopyOnMatch(embedding, distance, out_digit, final_digit)
        
        self.clear_lro_2 = Clear(embedding, [left_digit, right_digit, out_digit])
        self.step_pos_2 = StepPosition(embedding, [left_pos, right_pos, out_pos], [-1, -1, 1])
        self.read_left_2 = GetRelativeToken(embedding, left_pos, 0, left_digit)
        self.read_right_2 = GetRelativeToken(embedding, right_pos, 0, right_digit)
        self.multiply_2 = Multiply(embedding, left_digit, right_digit, carry, out_digit)
        self.diff_pos_2 = DiffPos(embedding, pos, out_pos, distance)
        self.is_zero_2 = IsZero(embedding, distance)
        self.copy_out_2 = CopyOnMatch(embedding, distance, out_digit, final_digit)
        
        self.clear_lro_3 = Clear(embedding, [left_digit, right_digit, out_digit])
        self.step_pos_3 = StepPosition(embedding, [left_pos, right_pos, out_pos], [-1, -1, 1])
        self.read_left_3 = GetRelativeToken(embedding, left_pos, 0, left_digit)
        self.read_right_3 = GetRelativeToken(embedding, right_pos, 0, right_digit)
        self.multiply_3 = Multiply(embedding, left_digit, right_digit, carry, out_digit)
        self.diff_pos_3 = DiffPos(embedding, pos, out_pos, distance)
        self.is_zero_3 = IsZero(embedding, distance)
        self.copy_out_3 = CopyOnMatch(embedding, distance, out_digit, final_digit)
        
        self.clear_lro_4 = Clear(embedding, [left_digit, right_digit, out_digit])
        self.step_pos_4 = StepPosition(embedding, [left_pos, right_pos, out_pos], [-1, -1, 1])
        self.read_left_4 = GetRelativeToken(embedding, left_pos, 0, left_digit)
        self.read_right_4 = GetRelativeToken(embedding, right_pos, 0, right_digit)
        self.multiply_4 = Multiply(embedding, left_digit, right_digit, carry, out_digit)
        self.diff_pos_4 = DiffPos(embedding, pos, out_pos, distance)
        self.is_zero_4 = IsZero(embedding, distance)
        self.copy_out_4 = CopyOnMatch(embedding, distance, out_digit, final_digit)

        # Linear layer to move the final digit to the out digit
        self.final_projection = torch.zeros(embedding.dim, embedding.dim)
        for i in range(final_digit.size):
            self.final_projection[final_digit.offset + i, i] = 1.0e6
        self.final_projection = torch.nn.Parameter(self.final_projection)
            
        self.weights = 0
        self.weights += find_plus.size() + find_equal.size() + find_out.size() + step_out.size()
        self.weights += 4*(clear_lro.size() + step_pos.size() + read_left.size() + read_right.size() + multiply.size() + diff_pos.size() + is_zero.size() + copy_out.size())
        self.weights += torch.numel(final_projection)
        
        self.softmax = torch.nn.Softmax(1)

    def generate(self, input_string):
        # First we embed the original string
        x = embedding.embed(embedding.tokenize(input_string))
        x = self(torch.unsqueeze(x, 0))
        return input_string + embedding.predict(x[0])

    def forward(self, x):
        # Then we look for the various symbols that direct us to
        x = self.find_plus.forward(x)
        x = self.find_equal.forward(x)
        x = self.find_out.forward(x)
        x = self.step_out.forward(x)

        x = self.clear_lro_1.forward(x)
        x = self.step_pos_1.forward(x)
        x = self.read_left_1.forward(x)
        x = self.read_right_1.forward(x)
        x = self.multiply_1.forward(x)
        x = self.diff_pos_1.forward(x)
        x = self.is_zero_1.forward(x)
        x = self.copy_out_1.forward(x)
        
        x = self.clear_lro_2.forward(x)
        x = self.step_pos_2.forward(x)
        x = self.read_left_2.forward(x)
        x = self.read_right_2.forward(x)
        x = self.multiply_2.forward(x)
        x = self.diff_pos_2.forward(x)
        x = self.is_zero_2.forward(x)
        x = self.copy_out_2.forward(x)
        
        x = self.clear_lro_3.forward(x)
        x = self.step_pos_3.forward(x)
        x = self.read_left_3.forward(x)
        x = self.read_right_3.forward(x)
        x = self.multiply_3.forward(x)
        x = self.diff_pos_3.forward(x)
        x = self.is_zero_3.forward(x)
        x = self.copy_out_3.forward(x)
        
        x = self.clear_lro_4.forward(x)
        x = self.step_pos_4.forward(x)
        x = self.read_left_4.forward(x)
        x = self.read_right_4.forward(x)
        x = self.multiply_4.forward(x)
        x = self.diff_pos_4.forward(x)
        x = self.is_zero_4.forward(x)
        x = self.copy_out_4.forward(x)

        x =  x @ self.final_projection
        return x #self.softmax(x)
        #return (x)
    
    def reset(self):
        torch.nn.init.xavier_uniform_(self.final_projection)
        for child in self.children():
            if hasattr(child, 'reset'):
                child.reset()

def read_out(x):
    print('Out:', torch.argmax(x[-1,out_digit.offset:out_digit.offset + out_digit.size]))
    
model = PaddedAddition()

In [352]:
test = ' 00010+   0111='
for i in range(10):
    test = model.generate(test)
    print(test)

 00010+   0111=1
 00010+   0111=12
 00010+   0111=121
 00010+   0111=1210
 00010+   0111=12100
 00010+   0111=121000
 00010+   0111=1210000
 00010+   0111=12100000
 00010+   0111=121000000
 00010+   0111=1210000000


In [330]:
def random_addition():
    left = int(np.random.rand()*999)
    right = int(np.random.rand()*999)

    out = str(left + right).zfill(4)[::-1]
    left = str(left).zfill(4)
    right = str(right).zfill(4)
    
    # Note: I've made the strings of constant length to make batch SGD easier
    leftpadding = np.random.randint(0, 5)*' '
    rightpadding = (5 - len(leftpadding))*' ' # np.random.randint(0, 5)*' '
    return leftpadding + left + '+' + rightpadding + right + '=' + out
    
random_addition()

' 0299+    0061=0630'

In [187]:
model.reset()

In [319]:
for param in model.parameters():
    p = param.detach()
    if len(p.shape) == 1:
        p = torch.unsqueeze(p, 0)
    plot_tensor(p)

In [None]:
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import TensorDataset, DataLoader

model = PaddedAddition()
model.reset()

loss_fn = torch.nn.CrossEntropyLoss()
loss_fn.train()
optimizer = optim.SGD(model.parameters(), lr=1e-6)
writer = SummaryWriter()

batched_inputs = []
batched_outputs = []

for i in range(1000):
    # Generate a random addition
    trim = 3 #np.random.randint(0, 5)
    if trim:
        input_string = random_addition()[:-trim]
    else:
        input_string = random_addition()
    inputs = embedding.embed(embedding.tokenize(input_string))
    
    # The output is the last character with the position masked out
    outputs = inputs[-1,:]
    outputs[pos.offset] = 0.0
    outputs[pos.offset + 1] = 0.0
    outputs = outputs*1e6
    batched_outputs.append(outputs)
    
    # The input is the sequence with the last bit removed
    inputs = inputs[:-1,:]   
    batched_inputs.append(inputs)
    
dataset = TensorDataset(torch.stack(batched_inputs), torch.stack(batched_outputs))
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

step = 0
for epoch in range(5000):
    for batch_inputs, batch_outputs in dataloader:
        pred_output = model(batch_inputs)[:, -1,:]

        # Compute and print loss
        #print(pred_output.shape, batch_outputs.shape)
        loss = loss_fn(pred_output, batch_outputs)

        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1000.0)

        optimizer.step()

        writer.add_scalar('Loss/train', loss, step)
        step += 1

    if epoch % 500 == 0:
        print('Epoch {} loss: {}'.format(epoch, loss.item()))
        #print(loss, pred_output, outputs)
        for param in model.parameters():
            if False and param.grad is not None:
                print(f'Gradient norm: {param.grad.data.norm(2)}')
        

Epoch 0 loss: 139288368.0


In [348]:
print(list(model.parameters()))

[Parameter containing:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], requires_grad=True), Parameter containing:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], requires_grad=True), Parameter containing:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], requires_grad=True), 