*Copyright 2023 Modular, Inc: Licensed under the Apache License v2.0 with LLVM Exceptions.*

In [3]:
from memory import memset_zero
from algorithm import vectorize, parallelize
from sys.intrinsics import strided_load
from math import trunc
from random import rand


struct Matrix[dtype: DType = DType.float32]:
    var dim0: Int
    var dim1: Int
    var _data: DTypePointer[dtype]
    alias simd_width: Int = simdwidthof[dtype]()

    fn __init__(inout self, *dims: Int):
        self.dim0 = dims[0]
        self.dim1 = dims[1]
        self._data = DTypePointer[dtype].alloc(dims[0] * dims[1])
        rand(self._data, dims[0] * dims[1])

    fn __copyinit__(inout self, other: Self):
        self._data = other._data
        self.dim0 = other.dim0
        self.dim1 = other.dim1

    fn _adjust_slice_(self, inout span: Slice, dim: Int):
        if span.start < 0:
            span.start = dim + span.start
        if not span._has_end():
            span.end = dim
        elif span.end < 0:
            span.end = dim + span.end
        if span.end > dim:
            span.end = dim
        if span.end < span.start:
            span.start = 0
            span.end = 0

    fn __getitem__(self, x: Int, y: Int) -> SIMD[dtype, 1]:
        return self._data.load[width=1](x * self.dim1 + y)

    fn __getitem__(self, owned row_slice: Slice, col: Int) -> Self:
        return self.__getitem__(row_slice, slice(col, col + 1))

    fn __getitem__(self, row: Int, owned col_slice: Slice) -> Self:
        return self.__getitem__(slice(row, row + 1), col_slice)

    fn __getitem__(self, owned row_slice: Slice, owned col_slice: Slice) -> Self:
        self._adjust_slice_(row_slice, self.dim0)
        self._adjust_slice_(col_slice, self.dim1)

        var src_ptr = self._data
        var sliced_mat = Self(row_slice.__len__(), col_slice.__len__())

        @parameter
        fn slice_column(idx_rows: Int):
            src_ptr = self._data.offset(row_slice[idx_rows] * self.dim1 + col_slice[0])

            @parameter
            fn slice_row[simd_width: Int](idx: Int) -> None:
                sliced_mat._data.store[width=simd_width](
                    idx + idx_rows * col_slice.__len__(),
                    strided_load[dtype, simd_width](src_ptr, col_slice.step),
                )
                src_ptr = src_ptr.offset(simd_width * col_slice.step)

            vectorize[slice_row, self.simd_width](col_slice.__len__())

        parallelize[slice_column](row_slice.__len__(), row_slice.__len__())
        return sliced_mat

    fn print(self, prec: Int = 4) -> None:
        var rank: Int = 2
        var dim0: Int = 0
        var dim1: Int = 0
        var val: Scalar[dtype] = 0.0
        if self.dim0 == 1:
            rank = 1
            dim0 = 1
            dim1 = self.dim1
        else:
            dim0 = self.dim0
            dim1 = self.dim1
        if dim0 > 0 and dim1 > 0:
            for j in range(dim0):
                if rank > 1:
                    if j == 0:
                        print("  [", end="")
                    else:
                        print("\n   ", end="")
                print("[", end="")
                for k in range(dim1):
                    if rank == 1:
                        val = self._data.load[width=1](k)
                    if rank == 2:
                        val = self[j, k]
                    var int_str: String
                    if val > 0 or val == 0:
                        int_str = String(trunc(val).cast[DType.int32]())
                    else:
                        int_str = "-" + String(trunc(val).cast[DType.int32]())
                        val = -val
                    var float_str: String
                    float_str = String(val % 1)
                    var s = int_str + "." + float_str[2 : prec + 2]
                    if k == 0:
                        print(s, end="")
                    else:
                        print("  ", s, end="")
                print("]", end="")
            if rank > 1:
                print("]", end="")
            print()
            if rank > 2:
                print("]")
        print("  Matrix:", self.dim0, "x", self.dim1, ",", "DType:", dtype.__str__())
        print()


fn main():
    var mat = Matrix(8, 5)
    mat.print()

    mat[2:4, -3:].print()
    mat[1:3, :].print()
    mat[0:3, 0:3].print()
    mat[1::2, ::2].print()
    mat[:, -1:2].print()
    mat[-1:2, :].print()


main()

  [[0.8018   0.1558   0.4714   0.2557   0.9939]
   [0.6987   0.8473   0.0693   0.1917   0.0081]
   [0.4396   0.3896   0.1356   0.2492   0.4613]
   [0.6983   0.2068   0.1822   0.3943   0.1424]
   [0.7156   0.4731   0.3617   0.0003   0.9721]
   [0.2425   0.6542   0.9402   0.5962   0.9038]
   [0.3315   0.3040   0.6753   0.4497   0.4090]
   [0.6183   0.7445   0.0473   0.6760   0.0181]]
  Matrix: 8 x 5 , DType: float32

  [[0.1356   0.2492   0.4613]
   [0.1822   0.3943   0.1424]]
  Matrix: 2 x 3 , DType: float32

  [[0.6987   0.8473   0.0693   0.1917   0.0081]
   [0.4396   0.3896   0.1356   0.2492   0.4613]]
  Matrix: 2 x 5 , DType: float32

  [[0.8018   0.1558   0.4714]
   [0.6987   0.8473   0.0693]
   [0.4396   0.3896   0.1356]]
  Matrix: 3 x 3 , DType: float32

  [[0.6987   0.0693   0.0081]
   [0.6983   0.1822   0.1424]
   [0.2425   0.9402   0.9038]
   [0.6183   0.0473   0.0181]]
  Matrix: 4 x 3 , DType: float32

  Matrix: 8 x 0 , DType: float32

  Matrix: 0 x 5 , DType: float32



In [4]:
var mat = Matrix(8, 5)
mat.print()

  [[0.9665   0.6154   0.4682   0.6210   0.1587]
   [0.1424   0.8650   0.0967   0.4542   0.3257]
   [0.5143   0.1336   0.0697   0.0295   0.9835]
   [0.9586   0.0020   0.3794   0.4057   0.2009]
   [0.3561   0.7346   0.6252   0.3567   0.1171]
   [0.6604   0.8526   0.1731   0.1798   0.6336]
   [0.4802   0.4312   0.7539   0.0289   0.0820]
   [0.6080   0.6512   0.8540   0.2766   0.6509]]
  Matrix: 8 x 5 , DType: float32



In [5]:
mat[2:4, -3:].print()

  [[0.0697   0.0295   0.9835]
   [0.3794   0.4057   0.2009]]
  Matrix: 2 x 3 , DType: float32



In [6]:
mat[1:3, :].print()

  [[0.1424   0.8650   0.0967   0.4542   0.3257]
   [0.5143   0.1336   0.0697   0.0295   0.9835]]
  Matrix: 2 x 5 , DType: float32



In [7]:
mat[0:3, 0:3].print()

  [[0.9665   0.6154   0.4682]
   [0.1424   0.8650   0.0967]
   [0.5143   0.1336   0.0697]]
  Matrix: 3 x 3 , DType: float32



In [8]:
mat[1::2, ::2].print()

  [[0.1424   0.0967   0.3257]
   [0.9586   0.3794   0.2009]
   [0.6604   0.1731   0.6336]
   [0.6080   0.8540   0.6509]]
  Matrix: 4 x 3 , DType: float32



In [9]:
mat[:, -1:2].print()
mat[-1:2, :].print()

  Matrix: 8 x 0 , DType: float32

  Matrix: 0 x 5 , DType: float32

