In [None]:
from __future__ import annotations

import time

import numpy as np
from implementations import new_mttkrp, old_mttkrp
from memory_profiler import memory_usage

%load_ext memory_profiler

In [None]:
def get_tens(n, R, shape, seed=123):
    szl = np.prod(shape[0:n])
    szr = np.prod(shape[n + 1 :])
    szn = shape[n]

    np_rng = np.random.default_rng(seed)
    Ul = np_rng.random((szr, R))
    Ur = np_rng.random((szl, R))
    Y = np_rng.random(shape)
    return Ul, Ur, Y, szl, szr, szn

In [None]:
def time_mttkrp(version, n, r, shape=[20, 30, 40, 3]):
    Ul, Ur, Y, szl, szr, szn = get_tens(n, r, shape)
    times = []
    result = None
    for _ in range(10):
        start = time.time()
        result = version(Ul, Y, Ur, szl, szr, szn, r)
        end = time.time()
        times.append(end - start)
    return times, result

In [None]:
def mem_mttkrp(version, n, r, shape=[20, 30, 40, 3]):
    Ul, Ur, Y, szl, szr, szn = get_tens(n, r, shape)
    mem_measurements = []
    result = None
    for _ in range(10):
        mem_usage = memory_usage(
            (version, (Ul, Y, Ur, szl, szr, szn, r)), max_usage=True
        )
        result = version(Ul, Y, Ur, szl, szr, szn, r)
        mem_measurements.append(mem_usage)
    return mem_measurements, result

## Timings

In [None]:
print("shape", [20, 30, 40, 50])
print("---------------------")
for n in [1, 2]:
    print(f"mode-{1} mttkrp")
    for r in [2, 10, 50, 100, 200]:
        print(f"rank-{r}")
        old_time, old_result = time_mttkrp(old_mttkrp, n, r)
        new_time, new_result = time_mttkrp(new_mttkrp, n, r)
        print("results equal:", np.all(old_result == new_result))
        print("old", np.array(old_time)[1:].mean())
        print("new", np.array(new_time)[1:].mean())
        print("speedup", np.array(old_time)[1:].mean() / np.array(new_time)[1:].mean())
        print("-----------------------------------------------------------------------")

## Memory

In [None]:
print("shape", [20, 30, 40, 50])
print("---------------------")
for n in [1, 2]:
    print(f"mode-{1} mttkrp")
    for r in [2, 10, 50, 100, 200]:
        print(f"rank-{r}")
        old_mem, old_result = mem_mttkrp(old_mttkrp, n, r)
        new_mem, new_result = mem_mttkrp(new_mttkrp, n, r)
        print("results equal:", np.all(old_result == new_result))
        print("old", np.array(old_mem)[1:].mean())
        print("new", np.array(new_mem)[1:].mean())
        print("ratio", np.array(old_mem)[1:].mean() / np.array(new_mem)[1:].mean())
        print("-----------------------------------------------------------------------")