In [None]:
import pandas as pd
import numpy as np
import scipy
import torch as T
from scipy.sparse import csr_matrix, dok_matrix
import sys
import cmath
import time
from datetime import datetime as dt
from datetime import date

In [None]:
mat_1 = np.random.choice(2, size=(200, 100000), p=[0.99, 0.01])
mat_2 = np.random.choice(2, size=(200, 100000), p=[0.99, 0.01])

In [None]:
def sparse_ohe(df, col):
    cats = df[col].unique()
    colarray = df[col].values
    # construct a sparse matrix of the appropriate size and an appropriate,
    # memory-efficient dtype
    spmtx = dok_matrix((df.shape[0], vals.shape[0]), dtype=np.uint8)
    # do the encoding
    spmtx[np.where(colaray.reshape(-1, 1) == vals.reshape(1, -1))] = 1

    # Construct a SparseDataFrame from the sparse matrix
    dfnew = pd.SparseDataFrame(spmtx, dtype=np.uint8, index=df.index,
                               columns=['cat' + '_' + str(x) for x in cats])
    dfnew.fillna(0, inplace=True)
    return dfnew

def get_cat_matches(df, id_col, cat_col):
    df = df[[id_col, cat_col]]
    
    tempdf = pd.get_dummies(df[cat_col], prefix='cat')
    df = pd.merge(
        left=df,
        right=tempdf,
        left_index=True,
        right_index=True
    )
    df.drop(columns=cat_col, inplace=True)
    cat_groups = [x for x in df.columns if 'cat' in x]
    
    df = df.groupby(id_col, as_index=False).sum()
    df[cat_groups] = df[cat_groups].astype(bool).astype(int)
    
    orig = df[df[id_col].isin(self.orig_match_ids)]
    other = df[df[id_col].isin(self.orig_match_ids) == False]

    ## M -> mxn ||| N -> nxp  ||| sim_matrix = matmul(M, N) mxp
    M = orig[cat_groups].to_numpy(dtype=np.int)
    N = other[cat_groups].to_numpy(dtype=np.int).transpose()
    
    sim_matrix = M@N
    return sim_matrix

def get_difference_matrix(mat_1, mat_2, power):
    M = mat_1
    N = -mat_2.transpose(1, 0)
    if power is None:
        M = np.exp(M)
        N = np.exp(N)
        return np.abs(np.log(np.matmul(M, N)))
    vlog = np.vectorize(cmath.log)
    M = np.float_power(power, M)
    N = np.float_power(power, N)
    return np.abs(vlog(M@N, power))

def get_inv_difference_matrix(mat_1, mat_2, power=1.01, eta=1.0):
    diff_mat = get_difference_matrix(mat_1, mat_2, power=power)
    inv_diff_mat = 1 / (diff_mat + eta)
    return inv_diff_mat

def get_topk(mat, k=5):
    indices = np.argsort(mat, axis=1)[:, -k:]
    topk = inv_diff_mat[np.array(k*[np.arange(inv_diff_mat.shape[0])])\
                        .transpose(), indices]
    return topk, indices

In [None]:
inv_diff_mat = get_inv_difference_matrix(mat_1, mat_2)
topk, indices = get_topk(inv_diff_mat)

In [None]:
date_1 = dt.now()
dt.timestamp(date_1) // (30*86400) ##months
date_2 = date.today()
dt(date_2.year, date_2.month, date_2.day).timestamp()

In [None]:
initial = time.time()
m1 = mat_1
m2 = mat_2.transpose()
m1@m2
time.time() - initial

In [None]:
initial = time.time()
m1 = csr_matrix(mat_1, dtype=np.int8)
m2 = csr_matrix(mat_2.transpose(), dtype=np.int8)
val = (m1@m2).toarray()
print(time.time() - initial)