Skip to content

Commit

Permalink
Merge pull request #81 from Kalsir/awd_lstm_fix
Browse files Browse the repository at this point in the history
AWD-LSTM fix
  • Loading branch information
jumelet committed Apr 13, 2021
2 parents 76d9fe1 + e22fb12 commit 0b4cc13
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions diagnnose/models/wrappers/awd_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

@staticmethod
def param_names(layer: int, rnn_name: str, **kwargs) -> Dict[str, str]:
return {
"weight_hh": f"{rnn_name}.{layer}.module.weight_hh_l0_raw",
"weight_ih": f"{rnn_name}.{layer}.module.weight_ih_l0",
"bias_hh": f"{rnn_name}.{layer}.module.bias_hh_l0",
"bias_ih": f"{rnn_name}.{layer}.module.bias_ih_l0",
}
def param_names(layer: int, rnn_name: str, no_suffix: bool = False, **kwargs) -> Dict[str, str]:
# The AWD-LSTM has no separate weight names for a single layer LSTM
if no_suffix:
return {
"weight_hh": "",
"weight_ih": "",
"bias_hh": "",
"bias_ih": "",
}
else:
return {
"weight_hh": f"{rnn_name}.{layer}.module.weight_hh_l0_raw",
"weight_ih": f"{rnn_name}.{layer}.module.weight_ih_l0",
"bias_hh": f"{rnn_name}.{layer}.module.bias_hh_l0",
"bias_ih": f"{rnn_name}.{layer}.module.bias_ih_l0",
}

0 comments on commit 0b4cc13

Please sign in to comment.