<a href="https://colab.research.google.com/github/el-eshaano/autograd/blob/main/Tensors.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [100]:
def get_size_of(arr):
    def size_of(a):
        if not type(a) == list:
            return []
        return [len(a)] + size_of(a[0])

    s = size_of(arr)
    return TensorSize(*s)


In [107]:
class BroadcastException(Exception):

    def __init__(self, size1, size2):
        self.message = f"{size1} cannot be broadcast to {size2}"
        super().__init__(self.message)

In [137]:
class TensorSize():

    def __init__(self, *args):
        self.shape = list(args)

    def __repr__(self):
        return f'TensorSize({self.shape})'

    def __getitem__(self, idx):
        return self.shape[idx]

    def __len__(self):
        return len(self.shape)

    def can_broadcast_to(self, new_size):

        # Get just the data
        new_size = TensorSize(*new_size) if not isinstance(new_size, TensorSize) else new_size

        # Cannot broadcast to lower dimensions
        if len(self) > len(new_size):
            return False

        my_size = self.shape.copy()[::-1]
        other_size = new_size.shape.copy()[::-1]

        for a, b in zip(my_size, other_size):
            if a == 1 or b == 1 or a == b:
                continue
            else:
                return False

        return True

    def get_broadcasted_size(self, new_size):

        my_size = self.shape.copy()[::-1]
        other_size = new_size.shape.copy()[::-1]

        while len(my_size) < len(other_size):
            my_size.append(1)

        target_size = []
        for a, b in zip(my_size, other_size):
            if a == 1:
                target_size.append(b)
            elif b == 1:
                target_size.append(a)
            else:
                target_size.append(a)

        return target_size[::-1]

    @staticmethod
    def are_broadcast_compatible(T1, T2):
        print(T1.can_broadcast_to(T2), T2.can_broadcast_to(T1))
        return T1.can_broadcast_to(T2) or T2.can_broadcast_to(T1)


In [138]:
class Tensor:

    def __init__(self, data):
        self.data = data
        self.size = get_size_of(data)

    def __repr__(self):
        return f"tensor({self.data})"

    def size(self):
        return self.size

    def broadcast_to(self, new_size):

        new_size = TensorSize(*new_size) if not isinstance(new_size, TensorSize) else new_size

        if not self.size.can_broadcast_to(new_size):
            raise BroadcastException(self.size, new_size)

        target_size = self.size.get_broadcasted_size(new_size)
        print(target_size)





In [139]:
a = Tensor([[3, 4, 5], [4, 5, 6], [7, 8, 9]])
B = Tensor([1, 2])
print(B.size)
B = B.broadcast_to([3, 2])

TensorSize([2])
[3, 2]


In [140]:
import torch
a = torch.ones(2, 2)
b = torch.broadcast_to(a, (2, 2, 2))
print(b)

tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]])
