In [12]:
import os
import sys
import pandas as pd
pd.set_option('display.max_rows', 200)

current = os.getcwd()
parent = os.path.dirname(current)
sys.path.append(parent)

from swp.utils.setup import seed_everything, set_device
from swp.datasets.phonemes import get_phoneme_to_id

seed_everything()
device = set_device()
phoneme_to_id = get_phoneme_to_id()

Using MPS device


In [None]:
import torch 

def cache_lstm_weights(layer):
    """Cache LSTM weights and biases for later restoration"""
    weights = {
        "weight_ih_l0": layer.weight_ih_l0.clone(),
        "weight_hh_l0": layer.weight_hh_l0.clone(),
        "bias_ih_l0": layer.bias_ih_l0.clone(),
        "bias_hh_l0": layer.bias_hh_l0.clone(),
    }
    return weights


def restore_lstm_weights(layer, weights):
    """Restore LSTM weights and biases from cache"""
    with torch.no_grad():
        layer.weight_ih_l0.copy_(weights["weight_ih_l0"])
        layer.weight_hh_l0.copy_(weights["weight_hh_l0"])
        layer.bias_ih_l0.copy_(weights["bias_ih_l0"])
        layer.bias_hh_l0.copy_(weights["bias_hh_l0"])


def ablate_lstm_neuron(layer, neuron_idx, num_neurons):
    """Zero out all four gates of a single LSTM neuron"""
    with torch.no_grad():
        # Compute row indices for all four gates
        gate_indices = torch.tensor(
            [
                neuron_idx,
                neuron_idx + num_neurons,
                neuron_idx + num_neurons * 2,
                neuron_idx + num_neurons * 3,
            ]
        )
        # Zero out corresponding rows in weights and biases
        layer.weight_ih_l0[gate_indices] = 0
        layer.weight_hh_l0[gate_indices] = 0
        layer.bias_ih_l0[gate_indices] = 0
        layer.bias_hh_l0[gate_indices] = 0





In [None]:
import copy
from swp.utils.models import get_model, load_weights

model_name = "Ua_LSTM_h256_l1_v42_d0.5_t0.1_s1"
train_name = "b1024_l0.001_f0_sn"

model = get_model(model_name)
load_weights(model, model_name, train_name, "50", device)

for name, param in model.named_parameters():
    print(name, param.shape)

print(model.encoder.recurrent.bias_ih_l0.shape)

print(model.named_parameters())


encoder.embedding.weight torch.Size([42, 256])
encoder.recurrent.weight_ih_l0 torch.Size([1024, 256])
encoder.recurrent.weight_hh_l0 torch.Size([1024, 256])
encoder.recurrent.bias_ih_l0 torch.Size([1024])
encoder.recurrent.bias_hh_l0 torch.Size([1024])
decoder.recurrent.weight_ih_l0 torch.Size([1024, 256])
decoder.recurrent.weight_hh_l0 torch.Size([1024, 256])
decoder.recurrent.bias_ih_l0 torch.Size([1024])
decoder.recurrent.bias_hh_l0 torch.Size([1024])
torch.Size([1024])
<generator object Module.named_parameters at 0x140808cf0>


In [2]:
from swp.utils.datasets import get_train_data, get_phoneme_statistics
from swp.utils.plots import enrich_for_plotting
from ast import literal_eval

# train_df = get_train_data()
# phoneme_statistics = get_phoneme_statistics(train_df)

test_df = pd.read_csv('../results/gridsearch/test/Ua_LSTM_h256_l1_v42_d0.5_t0.1_s1~b1024_l0.001_f0_sn/50.csv')
test_df["No Stress"] = test_df["No Stress"].apply(literal_eval)
test_df["Prediction"] = test_df["Prediction"].apply(literal_eval)

df = enrich_for_plotting(test_df, phoneme_to_id, False)
df

Unnamed: 0.1,Unnamed: 0,Word,Size,Length,Frequency,Zipf Frequency,Morphology,Lexicality,Part of Speech,Phonemes,No Stress,Prediction,Edit Distance,Insertions,Deletions,Substitutions,Sequence Length,Error Indices,Bigram Frequency
0,0,bathmat,long,7,low,1.55,complex,real,NOUN,"['B', 'AE1', 'TH', 'M', 'AH0', 'T']","[B, AE, TH, M, AH, T]","[B, AE, TH, M, AH, T]",0,0,0,0,6,[],0.003116
1,1,decoder,long,7,low,2.84,complex,real,NOUN,"['D', 'IH0', 'K', 'OW1', 'D', 'ER0']","[D, IH, K, OW, D, ER]","[D, IH, K, OW, D, ER]",0,0,0,0,6,[],0.002864
2,2,defiant,long,7,low,3.21,complex,real,ADJ,"['D', 'IH0', 'F', 'AY1', 'AH0', 'N', 'T']","[D, IH, F, AY, AH, N, T]","[D, IH, F, AY, AH, N, T]",0,0,0,0,7,[],0.009445
3,3,padlock,long,7,low,2.68,complex,real,NOUN,"['P', 'AE1', 'D', 'L', 'AA2', 'K']","[P, AE, D, L, AA, K]","[P, AE, D, L, AA, K]",0,0,0,0,6,[],0.000953
4,4,immoral,long,7,low,3.46,complex,real,ADJ,"['IH0', 'M', 'AO1', 'R', 'AH0', 'L']","[IH, M, AO, R, AH, L]","[IH, M, AO, R, AH, L]",0,0,0,0,6,[],0.007904
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1649,340,unrich,short,6,,,complex,pseudo,,"['AH1', 'N', 'R', 'IH0', 'CH']","[AH, N, R, IH, CH]","[AH, N, R, IH, CH]",0,0,0,0,5,[],0.010701
1650,341,upholt,short,6,,,complex,pseudo,,"['AH0', 'P', 'OW1', 'L', 'T']","[AH, P, OW, L, T]","[AH, P, OW, L, T]",0,0,0,0,5,[],0.001672
1651,342,warels,short,6,,,complex,pseudo,,"['W', 'EH1', 'R', 'AH0', 'L', 'Z']","[W, EH, R, AH, L, Z]","[W, EH, R, AH, L, Z]",0,0,0,0,6,[],0.006488
1652,343,wately,short,6,,,complex,pseudo,,"['W', 'EY1', 'T', 'L', 'IY0']","[W, EY, T, L, IY]","[W, EY, T, L, IY]",0,0,0,0,5,[],0.003228


In [3]:
import torch

# Example preds and target tensors
preds = torch.tensor([[1, 0, 2, 0],   # Matches
                      [0, 1, 1, 1],   # One mismatch
                      [2, 2, 0, 0]])  # Mismatches but padded at the end
target = torch.tensor([[1, 0, 2, -100],  # Matches (ignore -100)
                       [0, 1, 0, -100],  # Mismatch at position 2
                       [2, 2, -100, -100]])  # All mismatches are ignored

# Mask for valid positions
mask = target != -100

print(preds.shape, target.shape)

# Element-wise comparison
comparison = (preds != target) & mask

# Check per sequence if there's any mismatch
per_sequence_error = torch.any(comparison, dim=1)

# Count the number of sequences with errors
num_errors = per_sequence_error.sum().item()

print(f"Per-sequence errors: {per_sequence_error}")  # tensor([False,  True, False])
print(f"Number of sequences with errors: {num_errors}")  # Output: 1

torch.Size([3, 4]) torch.Size([3, 4])
Per-sequence errors: tensor([False,  True, False])
Number of sequences with errors: 1
