In [1]:
#Prints **all** console output, not just last item in cell 
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

**Notebook author:** emeinhardt@ucsd.edu

<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Overview" data-toc-modified-id="Overview-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Overview</a></span><ul class="toc-item"><li><span><a href="#Requirements" data-toc-modified-id="Requirements-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Requirements</a></span></li><li><span><a href="#Usage" data-toc-modified-id="Usage-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Usage</a></span></li></ul></li><li><span><a href="#Parameters" data-toc-modified-id="Parameters-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Parameters</a></span></li><li><span><a href="#Imports-/-load-data" data-toc-modified-id="Imports-/-load-data-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Imports / load data</a></span><ul class="toc-item"><li><span><a href="#Load-sanity-checking-metadata" data-toc-modified-id="Load-sanity-checking-metadata-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Load sanity-checking metadata</a></span></li><li><span><a href="#Load-segmental-sequence-channel-matrices-$p(Y_0^f|X_0^k)$" data-toc-modified-id="Load-segmental-sequence-channel-matrices-$p(Y_0^f|X_0^k)$-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Load segmental sequence channel matrices $p(Y_0^f|X_0^k)$</a></span></li><li><span><a href="#Load-contextual-distribution-on-segmental-wordforms-$p(W|C)$" data-toc-modified-id="Load-contextual-distribution-on-segmental-wordforms-$p(W|C)$-3.3"><span class="toc-item-num">3.3&nbsp;&nbsp;</span>Load contextual distribution on segmental wordforms $p(W|C)$</a></span></li><li><span><a href="#Load-lexicon-metadata" data-toc-modified-id="Load-lexicon-metadata-3.4"><span class="toc-item-num">3.4&nbsp;&nbsp;</span>Load lexicon metadata</a></span></li></ul></li><li><span><a href="#Calculation" data-toc-modified-id="Calculation-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Calculation</a></span></li></ul></div>

# Overview

Given a choice of parameters $\epsilon$ and $n$, and given
 - wordform channel matrices $p(Y_0^f|X_0^f)$
 - a contextual distribution on segmental wordforms $p(X_0^f|C)$
 - segmental lexicon metadata pre-calculating $k$-cousins/$k$-spheres up to $k=4$
 
Calculate

$$\hat{p}(\hat{X}_0^f = x_0^{'f}|X_0^f = x_0^{*f}, c) = \frac{1}{n} \sum\limits_{y_0^f \in S} p(\hat{X}_0^f = x_0^{'f}|y_0^f, c)$$
 where 
  - edit distance $d(x_0^{'f}, x_0^{*f}) \leq 4$
  - $S = $ a set of $n$ samples from $p(Y_0^f|x_0^{*f})$. In practice an $n \approx 50$ seems to result in estimates that are within $10^{-6}$ of the true estimate. 
  - $p(\hat{X}_0^f = x_0^{'f}|Y_0^f = y_0^f, c) = \frac{p(y_0^f|x_0^{'f})p(x_0^{'f}|c)}{p(y_0^f | c)}$
  - $p(y_0^f| c) = \sum\limits_{v', x_0^{''f}} p(y_0^f|x_0^{''f})p(x_0^{''f}|v')p(v'|c) = \sum\limits_{x_0^{''f}} p(y_0^f|x_0^{''f})p(x_0^{''f}|c)$

## Requirements

#FIXME

## Usage

#FIXME

# Parameters

In [2]:
from os import getcwd, chdir, listdir, path, mkdir, makedirs

In [3]:
from boilerplate import *

In [4]:
repo_dir = getcwd(); repo_dir

'/mnt/cube/home/AD/emeinhar/wr'

In [5]:
# Parameters

#p(Y_0^f|X_0^k)
c = ''
c = 'CM_AmE_destressed_aligned_w_LTR_Buckeye_pseudocount0.01/LTR_Buckeye_aligned_CM_filtered_LM_filtered_CMs_by_length_by_prefix_index.pickle'

#p(X_0^f|C)
w = ''
w = 'LD_Fisher_vocab_in_Buckeye_contexts/LTR_Buckeye_aligned_CM_filtered_LM_filtered_in_buckeye_contexts.pW_C.npy'

# LTR metadata directory
m = ''
m = 'LTR_Buckeye_aligned_w_GD_AmE_destressed'

# output filepath prefix for pW_WC
o = ''
o = 'LD_Fisher_vocab_in_Buckeye_contexts/LTR_Buckeye_aligned_CM_filtered_LM_filtered_in_buckeye_contexts.pW_WC'

n = ''
n = '50'

k = ''
k = '2'

r = ''
r = 'False' #if 'False', only calculate p(\hat{W}|W = w, c), i.e. don't calculate p(\hat{W}|P = p, c)

s = ''
s = 'True' #if r='False' and s='True', only calculate p(\hat{W} = w*| W = w*, c) ∀w ∈ W

In [6]:
output_dir = path.dirname(o)
if not path.exists(output_dir):
    print('Making output path {0}'.format(output_dir))
    makedirs(output_dir)

In [7]:
if n == '':
    n = 50
else:
    n = int(n)

In [8]:
if k == '':
    k = 2
else:
    k = int(k)

In [9]:
if r == '' or r == 'False':
    r = False
elif r == 'True':
    r = True
else:
    raise Exception(f"r must be one of {'','True','False'}, got '{r}' instead")

In [10]:
if not r and (s == '' or s == 'True'):
    s = True
elif r and (s == '' or s == 'True'):
    raise Exception("s can only be True if r is False")
elif s == 'False':
    s = False
else:
    raise Exception(f"s must be one of {'','True','False'}, got '{s}'")

# Imports / load data

In [11]:
import pickle

In [12]:
import numpy as np
import torch
import sparse

In [None]:
# from probdist import *

In [None]:
# from tqdm import tqdm

In [None]:
# from joblib import Parallel, delayed

# J = -1
# BACKEND = 'multiprocessing'
# # BACKEND = 'loky'
# V = 10
# PREFER = 'processes'
# # PREFER = 'threads'

# def identity(x):
#     return x

# def par(gen_expr):
#     return Parallel(n_jobs=J, backend=BACKEND, verbose=V, prefer=PREFER)(gen_expr)

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')
    print(torch.cuda.get_device_name(1))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(1)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(1)/1024**3,1), 'GB')

Using device: cuda

GeForce RTX 2070
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB
GeForce RTX 2070
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [14]:
gpu = torch.device('cuda')
cpu = torch.device('cpu')

my_device = cpu

In [15]:
cuda_ft = torch.cuda.FloatTensor
cuda_dt = torch.cuda.DoubleTensor

ft = torch.FloatTensor
dt = torch.DoubleTensor

my_ft = ft
my_dt = dt

torch.set_default_tensor_type(my_ft)

## Load sanity-checking metadata

We want to be able to make queries, spot checks, and sanity checks. That means we want to be able to reference
 1. the set of strings constituting segmental wordforms and prefixes
 2. the source and channel alphabets
 3. the channel distribution's conditioning triphones ∩ the triphones in the lexicon = the triphones in the lexicon
 4.  

Segmental wordforms were necessary for 
 - the lexicon metadata calculation (step 4b)
 - the contextual distribution on segmental wordforms (step 4c)
 - the definition of the segmental sequence channel matrices (step 4d)
 
What does each use as input? (Pass that to this notebook.)

Each notebook uses ...pW_V.json (or something slightly downstream of that)

In [None]:
#load the ...pW_V.json file
# extract W

# construct Ws_t

# extract P

# construct Ps_t

# extract source alphabet

# extract triphones

In [None]:
#load the triphone channel distribution used to construct the segment sequence channel matrices

# extract channel alphabet

In [None]:
# construct what you need to convert to/from one-hot representations

# look at segment sequence channel matrix notebook

In [None]:
# FIXME update this cell once you add in cells for doing calculations the old-fashioned way 
# and verifying/demonstrating that the calculation cell works correctly

## Load segmental sequence channel matrices $p(Y_0^f|X_0^k)$

In [None]:
#FIXME load stuff per the segmental sequence channel matrices so you can spot check/sanity test things
# probably pW_V

In [17]:
pY0f_X0ks = pickle.load(open(c, 'rb'))
len(pY0f_X0ks)

17

In [21]:
pY0f_X0ks[0].shape
pY0f_X0ks[1].shape
pY0f_X0ks[2].shape
pY0f_X0ks[3].shape

(6403, 38, 1)

(6403, 38, 2)

(6403, 38, 3)

(6403, 38, 4)

In [59]:
pY0f_X0ks[0][0]

array([[0.65264711],
       [0.00700978],
       [0.02545656],
       [0.02545656],
       [0.00682531],
       [0.06235012],
       [0.00682531],
       [0.00700978],
       [0.00682531],
       [0.00700978],
       [0.0051651 ],
       [0.00682531],
       [0.00700978],
       [0.00700978],
       [0.00700978],
       [0.00700978],
       [0.00700978],
       [0.00682531],
       [0.00700978],
       [0.00700978],
       [0.00700978],
       [0.00682531],
       [0.00664084],
       [0.00682531],
       [0.00700978],
       [0.00682531],
       [0.00700978],
       [0.00700978],
       [0.00700978],
       [0.00682531],
       [0.00700978],
       [0.00700978],
       [0.00700978],
       [0.00682531],
       [0.00700978],
       [0.00700978],
       [0.00682531],
       [0.00700978]])

In [22]:
pY0f_X0ks[3][0]

array([[0.65264711, 0.02141693, 0.00537084, 0.01037752],
       [0.00700978, 0.00855777, 0.00537084, 0.0054456 ],
       [0.02545656, 0.00855777, 0.01023504, 0.0054456 ],
       [0.02545656, 0.58122414, 0.0052997 , 0.00537347],
       [0.00682531, 0.00821919, 0.00515835, 0.00523015],
       [0.06235012, 0.00855777, 0.00537084, 0.0054456 ],
       [0.00682531, 0.00833257, 0.0052295 , 0.0053023 ],
       [0.00700978, 0.02552273, 0.00537084, 0.0054456 ],
       [0.00682531, 0.00844442, 0.0052997 , 0.00537347],
       [0.00700978, 0.00855777, 0.79005246, 0.0054456 ],
       [0.0051651 , 0.00734595, 0.0046103 , 0.00467448],
       [0.00682531, 0.00844442, 0.0052997 , 0.00537347],
       [0.00700978, 0.00855777, 0.00537084, 0.0054456 ],
       [0.00700978, 0.00844442, 0.0052997 , 0.00537347],
       [0.00700978, 0.01630829, 0.00537084, 0.0054456 ],
       [0.00700978, 0.00855777, 0.00537084, 0.0054456 ],
       [0.00700978, 0.00844442, 0.0052997 , 0.00537347],
       [0.00682531, 0.00844442,

## Load contextual distribution on segmental wordforms $p(W|C)$

In [24]:
pW_C = np.load(w)
pW_C.shape
pW_C.dtype
pW_C.nbytes / 1e9

(6403, 16443)

dtype('float64')

0.842276232

## Load lexicon metadata

In [33]:
cousin_fn_map = {i:'{0}cousinsOf.npz'.format(i) 
                 for i in range(5)}
sphere_fn_map = {i:'{0}spheresOf.npz'.format(i) 
                 for i in range(5)}

In [35]:
cousin_fn_map

{0: '0cousinsOf.npz',
 1: '1cousinsOf.npz',
 2: '2cousinsOf.npz',
 3: '3cousinsOf.npz',
 4: '4cousinsOf.npz'}

In [36]:
sphere_fn_map

{0: '0spheresOf.npz',
 1: '1spheresOf.npz',
 2: '2spheresOf.npz',
 3: '3spheresOf.npz',
 4: '4spheresOf.npz'}

In [34]:
assert all(fn in listdir(m) for fn in cousin_fn_map.values())
assert all(fn in listdir(m) for fn in sphere_fn_map.values())

In [42]:
chdir(m)

In [43]:
cousin_mats = mapValues(sparse.load_npz, cousin_fn_map)

In [44]:
sphere_mats = mapValues(sparse.load_npz, sphere_fn_map)

In [45]:
chdir(repo_dir)

In [48]:
cousin_mats
mapValues(lambda m: m.nbytes / 1e9,
          cousin_mats)

{0: <COO: shape=(21475, 6403), dtype=uint8, nnz=49429, fill_value=0>,
 1: <COO: shape=(21475, 6403), dtype=uint8, nnz=590534, fill_value=0>,
 2: <COO: shape=(21475, 6403), dtype=uint8, nnz=4878514, fill_value=0>,
 3: <COO: shape=(21475, 6403), dtype=uint8, nnz=17429734, fill_value=0>,
 4: <COO: shape=(21475, 6403), dtype=uint8, nnz=26080551, fill_value=0>}

{0: 0.000840293,
 1: 0.010039078,
 2: 0.082934738,
 3: 0.296305478,
 4: 0.443369367}

In [49]:
sphere_mats
mapValues(lambda m: m.nbytes / 1e9,
          sphere_mats)

{0: <COO: shape=(21475, 6403), dtype=uint8, nnz=6403, fill_value=0>,
 1: <COO: shape=(21475, 6403), dtype=uint8, nnz=14910, fill_value=0>,
 2: <COO: shape=(21475, 6403), dtype=uint8, nnz=173130, fill_value=0>,
 3: <COO: shape=(21475, 6403), dtype=uint8, nnz=914686, fill_value=0>,
 4: <COO: shape=(21475, 6403), dtype=uint8, nnz=1476146, fill_value=0>}

{0: 0.000108851, 1: 0.00025347, 2: 0.00294321, 3: 0.015549662, 4: 0.025094482}

In [50]:
m

'LTR_Buckeye_aligned_w_GD_AmE_destressed'

In [52]:
segmental_wordforms = importSeqs(path.join(m, 'LTR_Buckeye_aligned_CM_filtered_LM_filtered.pW_V_Transcriptions.txt'))
len(segmental_wordforms)

6403

In [55]:
list(segmental_wordforms)[0]

'd.aʊ.n.t.aʊ.n'

In [89]:
Ws_t = tuple(sorted(list(map(padInputSequenceWithBoundaries,
                             segmental_wordforms))))
len(Ws_t)

6403

In [58]:
Ps = union(map(lambda w: getPrefixes(padInputSequenceWithBoundaries(w)), segmental_wordforms))
len(Ps)

21475

In [86]:
Ps_t = tuple(sorted(list(Ps)))
len(Ps_t)

21475

# Calculation

In [84]:
from random import choice

In [94]:
random_wordform = choice(Ws_t)
random_wordform_length = len(ds2t(random_wordform))
random_wordform_idx = Ws_t.index(random_wordform)
random_wordform, random_wordform_idx, random_wordform_length

('⋊.s.p.ɑ.n.s.ɚ.d.⋉', 3801, 9)

In [81]:
type(pW_C)
pW_C.shape
pW_C[:,0].shape

numpy.ndarray

(6403, 16443)

(6403,)

In [83]:
pX0f = pW_C[:,0]
pX0f_torch = torch.from_numpy(pX0f)

In [77]:
if not r:
    CMsByLengthByWordformIndex = pY0f_X0ks
    CMsByLengthByWordformIndex_torch = [torch.from_numpy(each) for each in CMsByLengthByWordformIndex]
else:
#     xCMsByLengthByWordformIndex_torch
    pass

In [104]:
random_wordform
ds2t(random_wordform)
random_wordform_idx
random_wordform_length
random_wordform_CM = CMsByLengthByWordformIndex_torch[random_wordform_length-3][random_wordform_idx]
CMsByLengthByWordformIndex_torch[random_wordform_length-1].shape
random_wordform_CM

'⋊.s.p.ɑ.n.s.ɚ.d.⋉'

('⋊', 's', 'p', 'ɑ', 'n', 's', 'ɚ', 'd', '⋉')

3801

9

torch.Size([6403, 38, 9])

tensor([[0.0111, 0.0094, 0.1207, 0.0090, 0.0091, 0.0071, 0.0070],
        [0.0111, 0.0094, 0.0076, 0.0090, 0.0091, 0.0071, 0.0070],
        [0.0111, 0.0094, 0.0076, 0.0090, 0.0091, 0.0071, 0.0070],
        [0.0111, 0.0093, 0.0075, 0.0089, 0.0090, 0.0070, 0.5602],
        [0.0108, 0.0090, 0.0073, 0.0168, 0.0088, 0.0068, 0.0066],
        [0.0111, 0.0094, 0.0076, 0.0090, 0.0091, 0.0071, 0.0070],
        [0.0108, 0.0091, 0.0074, 0.0088, 0.0089, 0.0069, 0.0068],
        [0.0111, 0.0094, 0.0076, 0.0090, 0.0091, 0.0071, 0.0624],
        [0.0108, 0.0413, 0.0075, 0.0089, 0.0173, 0.0070, 0.0070],
        [0.0111, 0.0094, 0.0076, 0.0090, 0.0091, 0.0071, 0.0070],
        [0.0082, 0.0080, 0.0065, 0.0077, 0.0078, 0.0061, 0.0070],
        [0.0108, 0.0093, 0.0075, 0.0089, 0.0090, 0.0070, 0.0624],
        [0.0111, 0.0094, 0.0076, 0.0090, 0.0091, 0.0071, 0.0070],
        [0.0111, 0.0093, 0.0075, 0.0089, 0.0090, 0.0070, 0.0068],
        [0.0111, 0.0094, 0.0076, 0.6556, 0.0091, 0.0071, 0.0070],
        [0

In [78]:
if not r:
    def depthSampler2a_t(CM, m=1):
        stack = torch.zeros((m, CM.shape[1], CM.shape[0]))
        for eachStack in torch.arange(m):
            for i in torch.arange(CM.shape[1]):
                stack[eachStack, i] = torch.distributions.Multinomial(1, CM[:,i]).sample()
        return stack
else:
    def depthSampler2a_t(xCM, m=1):
        stack = torch.zeros((m, xCM.shape[1], xCM.shape[0]))
        for eachStack in torch.arange(m):
            for i in torch.arange(xCM.shape[1]):
                stack[eachStack, i] = torch.distributions.Multinomial(1, xCM[:,i]).sample()
        return stack

In [None]:
#FIXME copy what you need from the segmental sequence notebook to interpret these samples

In [105]:
depthSampler2a_t(random_wordform_CM)

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       

In [None]:
if not r:
    def pXhat0fX0k_pxt(xhat0f_idx, x0k_CM, m = 50):
    #     xhat0f_idx = Wmap[xhat0f]

    #     l = len(ds2t(x0k))
    #     x0k_CM = CMsByPrefixIndex[prefixMap[x0k]]

    #     my_Q_l = CMsByLengthByWordformIndex[l - 2]

        Y_prime = depthSampler2a_t(x0k_CM, m)#.float()
        l = Y_prime.shape[1] + 1

        my_Q_l = CMsByLengthByWordformIndex_torch[l - 3]#.float()

        # NORMALIZATION
        V_prime = torch.einsum('mli,kil->mkl', Y_prime, my_Q_l)  # :: (m,n,l)
        M_prime = torch.prod(V_prime, 2) # :: (m,n)
        N_prime = torch.matmul(M_prime, pX0f_torch) # :: (m, 1) <- prior probabilities of each of the m sampled channel prefixes
        Z_prime = 1.0 / N_prime # :: (m, 1)

        # NUMERATOR
        L_w = my_Q_l[xhat0f_idx]#.float()
        V_prime_w = torch.einsum('mij,ji->mi',Y_prime, L_w)
        O_w = torch.prod(V_prime_w, 1) # :: (m,1) likelihoods of each of the m sampled channel prefixes
        U_w = pX0f_torch[xhat0f_idx] * O_w ## :: (m,1) joint probabilities of xhat0f with each of the m sampled channel prefixes

        return torch.dot(Z_prime, U_w) / m

    #     return torch.dot( 1.0 / torch.matmul(torch.prod(torch.einsum('mli,kil->mkl', Y_prime, my_Q_l), 2), pX0f_torch) , 
    #        pX0f_torch[xhat0f_idx] * torch.prod(torch.einsum('mij,ji->mi',Y_prime, L_w), 1) ) / m
else:
    def pXhat0fX0k_pxt(xhat0f_idx, x0k_xCM, m = 50):
    #     xhat0f_idx = Wmap[xhat0f]

    #     l = len(ds2t(x0k))
    #     x0k_xCM = xCMsByPrefixIndex[prefixMap[x0k]]

    #     my_Q_l = xCMsByLengthByWordformIndex[l - 2]

        Y_prime = depthSampler2a_t(x0k_xCM, m)#.float()
        l = Y_prime.shape[1] + 1

        my_Q_l = xCMsByLengthByWordformIndex_torch[l - 3]#.float()

        # NORMALIZATION
        V_prime = torch.einsum('mli,kil->mkl', Y_prime, my_Q_l)  # :: (m,n,l)
        M_prime = torch.prod(V_prime, 2) # :: (m,n)
        N_prime = torch.matmul(M_prime, pX0f_torch) # :: (m, 1) <- prior probabilities of each of the m sampled channel prefixes
        Z_prime = 1.0 / N_prime # :: (m, 1)

        # NUMERATOR
        L_w = my_Q_l[xhat0f_idx]#.float()
        V_prime_w = torch.einsum('mij,ji->mi',Y_prime, L_w)
        O_w = torch.prod(V_prime_w, 1) # :: (m,1) likelihoods of each of the m sampled channel prefixes
        U_w = pX0f_torch[xhat0f_idx] * O_w ## :: (m,1) joint probabilities of xhat0f with each of the m sampled channel prefixes

        return torch.dot(Z_prime, U_w) / m

    #     return torch.dot( 1.0 / torch.matmul(torch.prod(torch.einsum('mli,kil->mkl', Y_prime, my_Q_l), 2), pX0f_torch) , 
    #        pX0f_torch[xhat0f_idx] * torch.prod(torch.einsum('mij,ji->mi',Y_prime, L_w), 1) ) / m