# Imports

In [1]:
import torch as th
from torch.utils.serialization import load_lua

from _asm import AdaptiveLogSoftmaxWithLoss

# Settings

In [2]:
path    = "/tmp/AdaptiveSoftMax.t7"
nhid    = 32
batch   = 4096
cutoff  = [10, 50]
targets = 100

# Save lua module to disk

In [3]:
cutoff_str = ','.join(str(item) for item in cutoff + [targets])
!th save_asm.lua -nhid $nhid -batch $batch -cutoff $cutoff_str -path $path

# Load lua module

In [4]:
data = load_lua(path, unknown_classes=True)

asm_lua  = data['decoder']
crit_lua = data['criterion']
input    = data['input']
target   = data['target']
logprob  = data['logprob']

grad_input  = asm_lua['gradInput']
asm_output  = asm_lua['output']
crit_output = crit_lua['output']

# Transfer parameters from lua module to python module

In [5]:
asm_py  = AdaptiveLogSoftmaxWithLoss(in_features=nhid, n_classes=targets, cutoffs=cutoff).to(th.double)

mapping = {
    'head.weight':     asm_lua['head'].weight,
    'tail.0.0.weight': asm_lua['tail'][0].get(0).weight,
    'tail.0.1.weight': asm_lua['tail'][0].get(1).weight,
    'tail.1.0.weight': asm_lua['tail'][1].get(0).weight,
    'tail.1.1.weight': asm_lua['tail'][1].get(1).weight,
}

for name, py_param in asm_py.named_parameters():
    lua_param = mapping[name].data
    _ = py_param.data.copy_(lua_param)

# Make sure the computed loss is the same

In [6]:
py_loss = asm_py(input, target - 1).loss.item()

print(f'py_loss = {py_loss}, lua_loss = {crit_output}')
assert abs(py_loss - crit_output) < 1e-12

py_loss = 5.901809036515506, lua_loss = 5.901809036515515


# Make sure the computed logprobs are the same

In [7]:
py_logprob = asm_py.log_prob(input)
diff = th.sum(th.abs(py_logprob - logprob)).item()

print(f'Lua shape: {logprob.shape}, Py shape: {py_logprob.shape}')
print(f'Sum of absolute differences: {diff}')

Lua shape: torch.Size([4096, 100]), Py shape: torch.Size([4096, 100])
Sum of absolute differences: 8.469758228102364e-11


# Make sure the grads wrt weights are the same

In [8]:
mapping = {
    'head.weight':     asm_lua['head'].gradWeight,
    'tail.0.0.weight': asm_lua['tail'][0].get(0).gradWeight,
    'tail.0.1.weight': asm_lua['tail'][0].get(1).gradWeight,
    'tail.1.0.weight': asm_lua['tail'][1].get(0).gradWeight,
    'tail.1.1.weight': asm_lua['tail'][1].get(1).gradWeight,
}

input = input.detach().requires_grad_(True)
asm_py(input, target - 1).loss.backward()

diff = 0.
for name, param in asm_py.named_parameters():
    diff += th.sum(th.abs(param.grad.data - mapping[name])).item()
    
print(f'Sum of absolute gradient differences: {diff}')

Sum of absolute gradient differences: 6.161169427562011e-17


# Make sure the grads wrt inputs

In [9]:
diff = th.sum(th.abs(input.grad - grad_input)).item()

print(f'Sum of absolute gradient differences: {diff}')

Sum of absolute gradient differences: 2.0263020031955953e-16
