In [1]:
from typing import Tuple

import torch
import numpy as np

import minitorch.autodiff.tensor_functions as tf
from minitorch import operators
from minitorch.module import LinearTensorLayer, Parameter
from minitorch.autodiff import Context, Tensor, topological_sort
from minitorch.autodiff.tensor_ops import SimpleBackend

In [None]:
input_dim, output_dim = 2, 1
weights = LinearTensorLayer._initialise_parameter(input_dim, output_dim).value
bias = LinearTensorLayer._initialise_parameter(output_dim).value

# Generate some input data
n_samples = 10
inputs = tf.rand((n_samples, input_dim))
targets = tf.tensor([1, 1, 1, 0, 0, 0, 1, 0, 1, 0])

# Forward
inputs = inputs.view(*inputs.shape, 1)
_weights = weights.view(1, *weights.shape)

out = (inputs * _weights).sum(dim=1)
predictions = out.view(inputs.shape[0], bias.size) + bias
predictions = predictions.view(targets.size).sigmoid()

predictions_ = (predictions * targets) + (predictions - 1.0) * (targets - 1.0)
predictions_sum = predictions_.sum()
predictions_sum.backward()

# Compute loss
# probas = (predictions * targets) + (predictions - 1.0) * (targets - 1.0)
# loss = ((-probas.log()) / targets.size).sum()
# loss.backward()


In [None]:
# Test loss
out = tf.rand((10, ), requires_grad=True)
predictions = out.sigmoid()
targets = tf.tensor([1, 1, 1, 0, 0, 0, 1, 0, 1, 0])

predictions_ = (predictions * targets) + (predictions - 1.0) * (targets - 1.0)
predictions_.sum().backward()

out.grad

In [None]:
## Compare to torch
torch_out = torch.tensor(out.data.storage, requires_grad=True)
torch_predictions = torch_out.sigmoid()
torch_targets = torch.tensor(targets.data.storage)

torch_predictions_ = (torch_predictions * torch_targets) + (torch_predictions - 1.0) * (torch_targets - 1.0)
torch_predictions_.sum().backward()

torch_out.grad

In [2]:
# Test loss without sigmoid
predictions = tf.rand((10, ), requires_grad=True)
targets = tf.tensor([1, 1, 1, 0, 0, 0, 1, 0, 1, 0])

predictions_ = (predictions * targets) + (predictions - 1.0) * (targets - 1.0)
predictions_.sum().backward()

predictions.grad


[1.00000 1.00000 1.00000 -1.00000 -1.00000 -1.00000 1.00000 -1.00000 1.00000 -1.00000]

In [3]:
## Compare to torch
torch_predictions = torch.tensor(predictions.data.storage, requires_grad=True)
torch_targets = torch.tensor(targets.data.storage)

torch_predictions_ = (torch_predictions * torch_targets) + (torch_predictions - 1.0) * (torch_targets - 1.0)
torch_predictions_.sum().backward()

torch_predictions.grad

tensor([ 1.,  1.,  1., -1., -1., -1.,  1., -1.,  1., -1.], dtype=torch.float64)

### Compare to torch implementation

In [None]:
import torch

In [None]:
torch_predictions = torch.tensor(predictions.data.storage, requires_grad=True)
torch_targets = torch.tensor(targets.data.storage)

torch_predictions_ = (torch_predictions * torch_targets) + (torch_predictions - 1.0) * (torch_targets - 1.0)
torch_predictions_.sum().backward()

In [None]:
torch_predictions.grad

In [None]:
torch_out = torch.tensor(out.data.storage, requires_grad=True)
torch_out_flatten = torch_out.view(10)
torch_predictions = torch_out_flatten.sigmoid()
torch_targets = torch.tensor(targets.data.storage)

torch_predictions_ = (torch_predictions * torch_targets) + (torch_predictions - 1.0) * (torch_targets - 1.0)
torch_predictions_.sum().backward()

In [None]:
torch_out.grad

In [None]:
torch_weights = torch.tensor(weights.data.storage, requires_grad=True)
torch_bias = torch.tensor(bias.data.storage, requires_grad=True)

torch_weights_ = torch_weights.view((input_dim, output_dim))
torch_bias_ = torch_bias.view((output_dim, ))

torch_inputs = torch.tensor(inputs.data.storage).view((n_samples, input_dim))
torch_targets = torch.tensor(targets.data.storage).view((n_samples, ))

# Forward
torch_out = torch_inputs @ torch_weights_ + torch_bias_
torch_predictions = torch_out.sigmoid().view((n_samples, ))

# Backward
torch_predictions_ = (torch_predictions * torch_targets) + (torch_predictions - 1.0) * (torch_targets - 1.0)
torch_predictions_.sum().backward()

In [None]:
torch_weights.grad

In [None]:
torch_bias.grad