Skip to content

Commit

Permalink
fix some typos and add new operators
Browse files Browse the repository at this point in the history
  • Loading branch information
m0saan committed Jun 7, 2023
1 parent ef83684 commit 4ada44c
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 55 deletions.
9 changes: 5 additions & 4 deletions minima/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def parameters(self) -> List[Parameter]:
"""
return _unpack_params(self.__dict__)

def _children(self) -> List[Parameter]:
def _children(self) -> List["Module"]:
"""
Returns a list of all child `Module` instances in the module.
This is done by unpacking the modules from the module's dictionary.
Expand Down Expand Up @@ -266,6 +266,7 @@ def forward(self, X: Tensor) -> Tensor:

out = X @ self.weight
out = out + self.bias.broadcast_to(out.shape) if self.bias else out
return out

# %% ../nbs/03_nn.ipynb 11
class Flatten(Module):
Expand All @@ -290,7 +291,7 @@ def forward(self, X: Tensor) -> Tensor:
Returns:
Tensor: The output tensor, which is a 2D tensor with the same number of elements as the input tensor.
"""
return X.reshape(X.shape[0], -1)
return X.reshape((X.shape[0], -1))


# %% ../nbs/03_nn.ipynb 12
Expand Down Expand Up @@ -461,7 +462,7 @@ def update_stats(self, x: Tensor) -> Tuple[Tensor, Tensor]:
x_centered = x - mean.broadcast_to(x.shape)
std = ((x_centered ** 2).sum(axes=axes) / bs)
self.running_mean = self.momentum * mean.data + (1 - self.momentum) * self.running_mean
self.running_var = self.momentum * std.data + (1 - self.momentum) * self.running_var
self.running_std = self.momentum * std.data + (1 - self.momentum) * self.running_std
return mean,std

def forward(self, x: Tensor) -> Tensor:
Expand All @@ -479,7 +480,7 @@ def __init__(self, p = 0.5):
self.p = p

def forward(self, x: Tensor) -> Tensor:
binary_mask = np.random.binomial(n=1, p=self.p, size=x.shape)
binary_mask = init.randb(*x.shape, p=self.p)
if self.training:
return (binary_mask * x) / (1 - self.p)
return x
Expand Down
61 changes: 59 additions & 2 deletions minima/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,15 +1021,72 @@ def broadcast_to(a: Tensor, shape: Tuple[int, ...]) -> Tensor:

# %% ../nbs/01_operators.ipynb 101
class LogSumExp(TensorOp):
"""
A Tensor operation class for performing LogSumExp computation.
Attributes
----------
axes : tuple, optional
The axes along which the operation is performed. If not specified, the operation is
performed along all axes.
Methods
-------
compute(Z):
Computes the LogSumExp operation on the input tensor Z.
gradient(out_grad, node):
Computes the gradient of the LogSumExp operation with respect to its input.
"""

def __init__(self, axes: Optional[tuple] = None):
"""
Initializes the LogSumExp operation with the specified axes.
Parameters
----------
axes : tuple, optional
The axes along which the operation is performed.
"""

self.axes = axes

def compute(self, Z):
"""
Computes the LogSumExp operation on the input tensor Z.
Parameters
----------
Z : Tensor
The input tensor on which the operation is to be performed.
Returns
-------
out : Tensor
The result of the LogSumExp operation on the input tensor.
"""

max_z = ARRAY_API.max(Z, axis=self.axes, keepdims=True)
self.out = ARRAY_API.squeeze(ARRAY_API.log(ARRAY_API.sum(ARRAY_API.exp(Z - max_z), axis=self.axes, keepdims=True)) + max_z)
return self.out

def gradient(self, out_grad, node):
"""
Computes the gradient of the LogSumExp operation with respect to its input.
Parameters
----------
out_grad : Tensor
The gradient of the output of the LogSumExp operation.
node : Tensor
The input tensor of the LogSumExp operation.
Returns
-------
grad : tuple of Tensor
The gradient of the LogSumExp operation with respect to its input.
"""
new_shape = list(node.inputs[0].shape)

# If axes were specified, set those dimensions to 1 in the new shape
Expand All @@ -1047,6 +1104,6 @@ def gradient(self, out_grad, node):
broadcasted_out = broadcast_to(reshaped_out, node.inputs[0].shape)
return (exp(node.inputs[0] - broadcasted_out) * broadcasted_grad, )
return (exp(node.inputs[0] - self.out) * out_grad, )

def logsumexp(a, axes=None):
return LogSumExp(axes=axes)(a)
63 changes: 60 additions & 3 deletions nbs/01_operators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2668,21 +2668,78 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f45b842f-49bb-4450-b030-f03e79f23905",
"id": "73b751a1-b637-41d8-8661-ea9f9ac95615",
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"class LogSumExp(TensorOp):\n",
" \"\"\"\n",
" A Tensor operation class for performing LogSumExp computation.\n",
"\n",
" Attributes\n",
" ----------\n",
" axes : tuple, optional\n",
" The axes along which the operation is performed. If not specified, the operation is \n",
" performed along all axes.\n",
"\n",
" Methods\n",
" -------\n",
" compute(Z):\n",
" Computes the LogSumExp operation on the input tensor Z.\n",
"\n",
" gradient(out_grad, node):\n",
" Computes the gradient of the LogSumExp operation with respect to its input.\n",
" \"\"\"\n",
" \n",
" def __init__(self, axes: Optional[tuple] = None):\n",
" \"\"\"\n",
" Initializes the LogSumExp operation with the specified axes.\n",
" \n",
" Parameters\n",
" ----------\n",
" axes : tuple, optional\n",
" The axes along which the operation is performed.\n",
" \"\"\"\n",
" \n",
" self.axes = axes\n",
"\n",
" def compute(self, Z):\n",
" \"\"\"\n",
" Computes the LogSumExp operation on the input tensor Z.\n",
" \n",
" Parameters\n",
" ----------\n",
" Z : Tensor\n",
" The input tensor on which the operation is to be performed.\n",
" \n",
" Returns\n",
" -------\n",
" out : Tensor\n",
" The result of the LogSumExp operation on the input tensor.\n",
" \"\"\"\n",
" \n",
" max_z = ARRAY_API.max(Z, axis=self.axes, keepdims=True)\n",
" self.out = ARRAY_API.squeeze(ARRAY_API.log(ARRAY_API.sum(ARRAY_API.exp(Z - max_z), axis=self.axes, keepdims=True)) + max_z)\n",
" return self.out\n",
"\n",
" \n",
" def gradient(self, out_grad, node):\n",
" \"\"\"\n",
" Computes the gradient of the LogSumExp operation with respect to its input.\n",
" \n",
" Parameters\n",
" ----------\n",
" out_grad : Tensor\n",
" The gradient of the output of the LogSumExp operation.\n",
" \n",
" node : Tensor\n",
" The input tensor of the LogSumExp operation.\n",
" \n",
" Returns\n",
" -------\n",
" grad : tuple of Tensor\n",
" The gradient of the LogSumExp operation with respect to its input.\n",
" \"\"\"\n",
" new_shape = list(node.inputs[0].shape)\n",
"\n",
" # If axes were specified, set those dimensions to 1 in the new shape\n",
Expand All @@ -2700,7 +2757,7 @@
" broadcasted_out = broadcast_to(reshaped_out, node.inputs[0].shape)\n",
" return (exp(node.inputs[0] - broadcasted_out) * broadcasted_grad, )\n",
" return (exp(node.inputs[0] - self.out) * out_grad, )\n",
"\n",
" \n",
"def logsumexp(a, axes=None): \n",
" return LogSumExp(axes=axes)(a)"
]
Expand Down
113 changes: 107 additions & 6 deletions nbs/02_init.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,28 @@
"execution_count": null,
"id": "93b68976-7696-43bf-b0ca-f7935eda9331",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"minima.Tensor(\n",
"[[0.724743 0.236047 0.065364 0.855749 0.91864 ]\n",
" [0.027286 0.800982 0.64705 0.968984 0.489134]\n",
" [0.931952 0.758258 0.887548 0.867378 0.108525]\n",
" [0.03819 0.143269 0.210976 0.892318 0.069396]\n",
" [0.257424 0.554249 0.235325 0.064803 0.843057]\n",
" [0.696038 0.812699 0.54037 0.754445 0.385663]\n",
" [0.461943 0.538387 0.582451 0.802216 0.6077 ]\n",
" [0.045212 0.726626 0.886866 0.190699 0.00549 ]\n",
" [0.685753 0.342417 0.554111 0.813416 0.375196]\n",
" [0.170601 0.631679 0.474656 0.363225 0.162466]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rand(10,5)"
]
Expand All @@ -107,7 +128,18 @@
"execution_count": null,
"id": "0fc7790c-bd97-454b-8bb0-a1fdc94f7144",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"(dtype('float32'), minima.cpu(), False)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t.dtype, t.device, t.requires_grad"
]
Expand Down Expand Up @@ -173,7 +205,23 @@
"execution_count": null,
"id": "69baef10-0dfe-43d8-b999-9e58ac9d619f",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"minima.Tensor(\n",
"[[ 0.439627 -0.765065 -0.323372 -1.518769 -0.563107]\n",
" [-0.255756 0.96025 1.512503 -0.662302 -1.201184]\n",
" [ 0.650412 0.263193 1.310423 1.383127 1.237785]\n",
" [-0.008076 0.028429 1.874965 0.977454 -0.068408]\n",
" [-1.75604 0.546302 0.359429 0.864159 1.347796]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t"
]
Expand All @@ -183,7 +231,18 @@
"execution_count": null,
"id": "010d6791-37d8-4b0f-addc-8947697af949",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"((5, 5), dtype('float32'), minima.cpu(), True)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t.shape, t.dtype, t.device, t.requires_grad"
]
Expand Down Expand Up @@ -541,7 +600,28 @@
"execution_count": null,
"id": "f43d2073-07e0-4e88-8f9c-2ff7563c1f67",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"minima.Tensor(\n",
"[[-0.362476 0.425931 -0.604767 -0.585322 -0.469232]\n",
" [-0.207212 -0.106902 -0.349397 -0.082876 0.145293]\n",
" [ 0.274891 -0.136795 -0.210589 0.275138 0.481523]\n",
" [ 0.258619 0.307068 0.458597 0.236569 0.462849]\n",
" [ 0.029899 -0.559897 0.514509 0.32062 0.208706]\n",
" [-0.488839 -0.433476 0.089545 0.466521 -0.407354]\n",
" [-0.243207 0.266691 0.27616 0.263078 -0.267017]\n",
" [-0.61643 -0.143201 0.083898 0.366265 0.022065]\n",
" [ 0.149755 -0.155406 -0.494278 0.481983 -0.509169]\n",
" [ 0.159343 0.597055 -0.36376 0.376093 -0.399417]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"W"
]
Expand All @@ -561,7 +641,28 @@
"execution_count": null,
"id": "5572ebce-3093-4421-ad12-e37c8260dcc0",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"minima.Tensor(\n",
"[[ 0.176985 0.475322 -0.472049 -0.197204 -0.124735]\n",
" [-0.021712 0.164783 -0.76619 -0.109438 0.25406 ]\n",
" [ 0.747072 0.355753 -0.609391 0.107684 -0.27346 ]\n",
" [ 0.418525 -0.260188 -0.422199 -0.179245 -0.042915]\n",
" [ 0.167997 -0.004962 0.17853 -0.520827 0.349572]\n",
" [-0.019444 -0.406027 0.332068 -0.591041 -0.408733]\n",
" [-0.535821 -0.790127 -0.098206 0.25483 0.509668]\n",
" [-0.401008 -0.051625 0.376056 0.1105 -1.083598]\n",
" [ 0.06975 0.191631 0.233941 0.101632 0.235047]\n",
" [-0.034112 -0.465086 0.053479 0.444051 -0.654049]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"W"
]
Expand Down
Loading

0 comments on commit 4ada44c

Please sign in to comment.