# Supervised Dictionary Learning for Sentence Decomposition

We have had success with linear probes. We can now try to further decompose it into further atoms, which would show broader structure.

In order to align these atoms with interpretable properties, we train the dictionary with a classification task added to the reconstruction loss.


The goal is to check if we can linearly decompose the embedding back into words with the part of speech and dependencies.

In [1]:
!pip install stanza -q
!pip install nltk -q
!pip install transformers datasets


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Collecting transformers
  Downloading transformers-4.50.3-py3-none-any.whl.metadata (39 kB)
Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting huggingface-hub<1.0,>=0.26.0 (from transformers)
  Downloading huggingface_hub-0.29.3-py3-none-any.whl.metadata (13 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Downloading tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)

In [2]:

import nltk
nltk.download('brown')
import stanza
from nltk.corpus import brown
stanza.download('en')


def reconstruct_sentence(tokens):
    sentence = " ".join(tokens)
    sentence = sentence.replace('``', '').replace("''", "").replace(
        " ,", ",").replace(" .", ".").replace(" ?", "?").replace(" !", "!")
    return sentence

brown_sentences = [reconstruct_sentence(tokens) for tokens in brown.sents()]
brown_sentences = brown_sentences[:20000]



[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Unzipping corpora/brown.zip.


Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.10.0.json:   0%|  …

2025-03-28 22:22:07 INFO: Downloaded file to /root/stanza_resources/resources.json
2025-03-28 22:22:07 INFO: Downloading default packages for language: en (English) ...


Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.10.0/models/default.zip:   0%|          | …

2025-03-28 22:22:11 INFO: Downloaded file to /root/stanza_resources/en/default.zip
2025-03-28 22:22:14 INFO: Finished downloading models and saved to /root/stanza_resources


In [3]:
from transformers import AutoTokenizer, AutoModel
import torch
import stanza
import pandas as pd

# Load tokenizer and model
model_name = 'sentence-transformers/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 384, padding_idx=0)
    (position_embeddings): Embedding(512, 384)
    (token_type_embeddings): Embedding(2, 384)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-5): 6 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
    

In [3]:


# Load stanza for offset alignment
stanza.download('en')
nlp = stanza.Pipeline('en', processors='tokenize,pos,depparse,lemma')

def get_word_embeddings_aligned(sentence: str):
    """
    Given a sentence, aligns subword embeddings from MiniLM to words using char offsets from Stanza.
    Returns a list of dicts with word, embedding, POS, dependency, and position.
    """
    doc = nlp(sentence)
    word_spans = [(word.text, word.start_char, word.end_char, word.upos, word.deprel) 
                  for sent in doc.sentences for word in sent.words]

    # Tokenize with offset mapping, no special tokens
    encoding = tokenizer(
        sentence,
        return_offsets_mapping=True,
        return_tensors="pt",
        add_special_tokens=False
    )
    offsets = encoding["offset_mapping"][0].tolist()
    input_ids = encoding["input_ids"]

    # Get subword embeddings
    with torch.no_grad():
        output = model(**{k: v for k, v in encoding.items() if k != 'offset_mapping'})
        subword_embeddings = output.last_hidden_state.squeeze(0)  # [seq_len, dim]

    # Align subwords to words
    aligned_data = []
    for i, (word, w_start, w_end, upos, deprel) in enumerate(word_spans):
        matching_sub_idxs = [j for j, (s, e) in enumerate(offsets) if s < w_end and e > w_start and s != e]

        if matching_sub_idxs:
            embs = [subword_embeddings[j] for j in matching_sub_idxs]
            word_embedding = torch.stack(embs).mean(dim=0)
            aligned_data.append({
                "word": word,
                "embedding": word_embedding,
                "pos": upos,
                "dep": deprel,
                "position": i
            })

    return aligned_data

from tqdm import tqdm

all_rows = []
for i, sent in tqdm(enumerate(brown_sentences), total=len(brown_sentences), desc="Processing sentences"):
    try:
        aligned = get_word_embeddings_aligned(sent)
        for row in aligned:
            row["sentence_id"] = i
            row["sentence"] = sent
            all_rows.append(row)
    except: 
        continue

# Convert to DataFrame
df = pd.DataFrame(all_rows)



Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.10.0.json:   0%|  …

2025-03-27 11:15:17 INFO: Downloaded file to /root/stanza_resources/resources.json
2025-03-27 11:15:17 INFO: Downloading default packages for language: en (English) ...
2025-03-27 11:15:18 INFO: File exists: /root/stanza_resources/en/default.zip
2025-03-27 11:15:21 INFO: Finished downloading models and saved to /root/stanza_resources
2025-03-27 11:15:21 INFO: Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES


Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.10.0.json:   0%|  …

2025-03-27 11:15:21 INFO: Downloaded file to /root/stanza_resources/resources.json
2025-03-27 11:15:22 INFO: Loading these models for language: en (English):
| Processor | Package           |
---------------------------------
| tokenize  | combined          |
| mwt       | combined          |
| pos       | combined_charlm   |
| lemma     | combined_nocharlm |
| depparse  | combined_charlm   |

2025-03-27 11:15:22 INFO: Using device: cuda
2025-03-27 11:15:22 INFO: Loading: tokenize
  return self.fget.__get__(instance, owner)()
2025-03-27 11:15:22 INFO: Loading: mwt
2025-03-27 11:15:22 INFO: Loading: pos
2025-03-27 11:15:23 INFO: Loading: lemma
2025-03-27 11:15:24 INFO: Loading: depparse
2025-03-27 11:15:24 INFO: Done loading processors!
Processing sentences: 100%|██████████| 20000/20000 [31:05<00:00, 10.72it/s]


In [4]:
df.to_pickle("./dataset.pkl")  

In [4]:
%%capture
!pip install scikit-learn

In [5]:
import pandas as pd

# Load the DataFrame from a pickle file
df = pd.read_pickle("./dataset.pkl")

# Adaptative Softwax

In [6]:
len(df["word"].unique())

31484

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

# ---- Prepare Data ----
X = torch.stack(df['embedding'].tolist())

le_pos = LabelEncoder().fit(df["pos"])
le_dep = LabelEncoder().fit(df["dep"])
le_word = LabelEncoder().fit(df["word"])

y_pos = torch.tensor(le_pos.transform(df['pos'].values))
y_dep = torch.tensor(le_dep.transform(df['dep'].values))
y_word = torch.tensor(le_word.transform(df['word'].values))
y_position = torch.tensor(df['position'].values)

# ---- Define Probes ----
class LinearProbe(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

class AdaptiveSoftmaxProbe(nn.Module):
    def __init__(self, input_dim, n_classes):
        super().__init__()
        cutoffs = [1000, min(10000, n_classes - 2)] if n_classes > 10000 else [1000]
        self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
            in_features=input_dim,
            n_classes=n_classes,
            cutoffs=cutoffs,
            div_value=4.0
        )

    def forward(self, x, target=None):
        if target is not None:
            return self.adaptive_softmax(x, target)
        else:
            return self.adaptive_softmax.log_prob(x)

# ---- Training Functions ----
def train_linear_probe(X, y, num_classes, task_name="TASK", epochs=10):
    model = LinearProbe(X.shape[1], num_classes)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        for xb, yb in loader:
            optimizer.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"{task_name} - Epoch {epoch+1}, Loss: {total_loss:.4f}")

    return model

def train_adaptive_probe(X, y, num_classes, task_name="TASK", epochs=10):
    model = AdaptiveSoftmaxProbe(X.shape[1], num_classes)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        for xb, yb in loader:
            optimizer.zero_grad()
            out = model(xb, yb)
            loss = out.loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"{task_name} - Epoch {epoch+1}, Loss: {total_loss:.4f}")

    return model

# ---- Evaluation Functions ----
def evaluate_linear_probe(model, X, y):
    model.eval()
    with torch.no_grad():
        preds = model(X).argmax(dim=1)
        accuracy = (preds == y).float().mean().item()
    return accuracy

def evaluate_adaptive_probe(model, X, y, batch_size=64):
    model.eval()
    correct = 0
    total = 0
    dataset = DataLoader(TensorDataset(X, y), batch_size=batch_size)
    with torch.no_grad():
        for xb, yb in dataset:
            out = model(xb, yb)
            # Get predicted class by getting the max log probability from log_prob(x)
            # Even though log_prob is memory-intensive, we can do it in batches.
            log_probs = model.adaptive_softmax.log_prob(xb)
            preds = log_probs.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
    return correct / total


# ---- Train Probes ----
pos_model = train_linear_probe(X, y_pos, len(le_pos.classes_), task_name="POS")
dep_model = train_linear_probe(X, y_dep, len(le_dep.classes_), task_name="DEP")
position_model = train_linear_probe(X, y_position, y_position.max().item() + 1, task_name="POSITION")
word_model = train_adaptive_probe(X, y_word, len(le_word.classes_), task_name="WORD")

# ---- Evaluate All Probes ----
print("\n--- Evaluation ---")
print(f"POS Accuracy:       {evaluate_linear_probe(pos_model, X, y_pos):.2%}")
print(f"DEP Accuracy:       {evaluate_linear_probe(dep_model, X, y_dep):.2%}")
print(f"POSITION Accuracy:  {evaluate_linear_probe(position_model, X, y_position):.2%}")
print(f"WORD Accuracy:      {evaluate_adaptive_probe(word_model, X, y_word):.2%}")


POS - Epoch 1, Loss: 4730.9234
POS - Epoch 2, Loss: 3247.7045
POS - Epoch 3, Loss: 3069.8149
POS - Epoch 4, Loss: 2991.8502
POS - Epoch 5, Loss: 2948.3470
POS - Epoch 6, Loss: 2920.6060
POS - Epoch 7, Loss: 2901.9377
POS - Epoch 8, Loss: 2887.0535
POS - Epoch 9, Loss: 2876.1811
POS - Epoch 10, Loss: 2868.2626
DEP - Epoch 1, Loss: 9753.6313
DEP - Epoch 2, Loss: 7845.6147
DEP - Epoch 3, Loss: 7579.3977
DEP - Epoch 4, Loss: 7462.0545
DEP - Epoch 5, Loss: 7393.7948
DEP - Epoch 6, Loss: 7352.5016
DEP - Epoch 7, Loss: 7323.5441
DEP - Epoch 8, Loss: 7301.7198
DEP - Epoch 9, Loss: 7285.5186
DEP - Epoch 10, Loss: 7273.5649
POSITION - Epoch 1, Loss: 19297.2871
POSITION - Epoch 2, Loss: 16700.9510
POSITION - Epoch 3, Loss: 16192.9907
POSITION - Epoch 4, Loss: 15955.9640
POSITION - Epoch 5, Loss: 15813.7989
POSITION - Epoch 6, Loss: 15723.1971
POSITION - Epoch 7, Loss: 15657.1624
POSITION - Epoch 8, Loss: 15610.5736
POSITION - Epoch 9, Loss: 15572.2952
POSITION - Epoch 10, Loss: 15542.8545
WORD - 

# MODEL PROBE

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

X = torch.stack(df['embedding'].tolist())

le_pos = LabelEncoder().fit(df["pos"])
le_dep = LabelEncoder().fit(df["dep"])
le_word = LabelEncoder().fit(df["word"])

y_pos = torch.tensor(le_pos.transform(df['pos'].values))
y_dep = torch.tensor(le_dep.transform(df['dep'].values))
y_word = torch.tensor(le_word.transform(df['word'].values))
y_position = torch.tensor(df['position'].values)

class LinearProbe(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

def train_probe(X, y, num_classes, task_name="TASK", epochs=10):
    model = LinearProbe(X.shape[1], num_classes)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        for xb, yb in loader:
            optimizer.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"{task_name} - Epoch {epoch+1}, Loss: {total_loss:.4f}")

    return model

def evaluate_probe(model, X, y):
    model.eval()
    with torch.no_grad():
        preds = model(X).argmax(dim=1)
        accuracy = (preds == y).float().mean().item()
    return accuracy


pos_model = train_probe(X, y_pos, len(le_pos.classes_), task_name="POS")
dep_model = train_probe(X, y_dep, len(le_dep.classes_), task_name="DEP")
word_model = train_probe(X, y_word, len(le_word.classes_), task_name="WORD")
position_model = train_probe(X, y_position, y_position.max().item() + 1, task_name="POSITION")

print("\n--- Evaluation ---")
print(f"POS Accuracy:       {evaluate_probe(pos_model, X, y_pos):.2%}")
print(f"DEP Accuracy:       {evaluate_probe(dep_model, X, y_dep):.2%}")
print(f"WORD Accuracy:      {evaluate_probe(word_model, X, y_word):.2%}")
print(f"POSITION Accuracy:  {evaluate_probe(position_model, X, y_position):.2%}")


POS - Epoch 1, Loss: 4740.1016
POS - Epoch 2, Loss: 3248.5331
POS - Epoch 3, Loss: 3069.9618
POS - Epoch 4, Loss: 2992.9384
POS - Epoch 5, Loss: 2948.8804
POS - Epoch 6, Loss: 2920.8930
POS - Epoch 7, Loss: 2901.8045
POS - Epoch 8, Loss: 2887.5567
POS - Epoch 9, Loss: 2876.2220
POS - Epoch 10, Loss: 2868.3626
DEP - Epoch 1, Loss: 9744.6328
DEP - Epoch 2, Loss: 7843.7364
DEP - Epoch 3, Loss: 7578.9932
DEP - Epoch 4, Loss: 7461.9076
DEP - Epoch 5, Loss: 7394.7150
DEP - Epoch 6, Loss: 7351.7902
DEP - Epoch 7, Loss: 7324.2117
DEP - Epoch 8, Loss: 7300.6181
DEP - Epoch 9, Loss: 7285.8807
DEP - Epoch 10, Loss: 7273.1770
WORD - Epoch 1, Loss: 22264.9924


# Dict

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

X = torch.stack(df['embedding'].tolist())

le_pos = LabelEncoder().fit(df["pos"])
le_dep = LabelEncoder().fit(df["dep"])
le_word = LabelEncoder().fit(df["word"])

y_pos = torch.tensor(le_pos.transform(df['pos'].values))
y_dep = torch.tensor(le_dep.transform(df['dep'].values))
y_word = torch.tensor(le_word.transform(df['word'].values))
y_position = torch.tensor(df['position'].values)

class DictionaryLearner(nn.Module):
    def __init__(self, input_dim, dict_size):
        super().__init__()
        self.dictionary = nn.Parameter(torch.randn(dict_size, input_dim))

    def forward(self, x):
        codes = torch.matmul(x, self.dictionary.T)
        recon = torch.matmul(codes, self.dictionary)
        return codes, recon

class LinearProbe(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

def train_probe_with_dictionary(X, y, num_classes, task_name="TASK", dict_size=128, epochs=10):
    dict_learner = DictionaryLearner(X.shape[1], dict_size)
    probe = LinearProbe(dict_size, num_classes)

    recon_loss_fn = nn.MSELoss()
    clf_loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(list(dict_learner.parameters()) + list(probe.parameters()), lr=1e-3)

    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        for xb, yb in loader:
            codes, recon = dict_learner(xb)
            logits = probe(codes)

            loss_recon = recon_loss_fn(recon, xb)
            loss_clf = clf_loss_fn(logits, yb)
            loss = loss_recon + loss_clf

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"{task_name} - Epoch {epoch+1}, Loss: {total_loss:.4f}")

    return probe, dict_learner

# 5. Evaluation function
def evaluate_probe(model, dict_learner, X, y):
    model.eval()
    dict_learner.eval()
    with torch.no_grad():
        codes, _ = dict_learner(X)
        preds = model(codes).argmax(dim=1)
        accuracy = (preds == y).float().mean().item()
    return accuracy

# 6. Train dictionary-augmented probes
pos_model, pos_dict = train_probe_with_dictionary(X, y_pos, len(le_pos.classes_), task_name="POS")
dep_model, dep_dict = train_probe_with_dictionary(X, y_dep, len(le_dep.classes_), task_name="DEP")
word_model, word_dict = train_probe_with_dictionary(X, y_word, len(le_word.classes_), task_name="WORD")
position_model, posn_dict = train_probe_with_dictionary(X, y_position, y_position.max().item() + 1, task_name="POSITION")

# 7. Evaluate all
print("\n--- Evaluation ---")
print(f"POS Accuracy:       {evaluate_probe(pos_model, pos_dict, X, y_pos):.2%}")
print(f"DEP Accuracy:       {evaluate_probe(dep_model, dep_dict, X, y_dep):.2%}")
print(f"WORD Accuracy:      {evaluate_probe(word_model, word_dict, X, y_word):.2%}")
print(f"POSITION Accuracy:  {evaluate_probe(position_model, posn_dict, X, y_position):.2%}")


POS - Epoch 1, Loss: 5085862.2582


KeyboardInterrupt: 

# Shared Dictionary

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

X = torch.stack(df['embedding'].tolist())

le_pos = LabelEncoder().fit(df["pos"])
le_dep = LabelEncoder().fit(df["dep"])
le_word = LabelEncoder().fit(df["word"])

y_pos = torch.tensor(le_pos.transform(df['pos'].values))
y_dep = torch.tensor(le_dep.transform(df['dep'].values))
y_word = torch.tensor(le_word.transform(df['word'].values))
y_position = torch.tensor(df['position'].values)

class DictionaryLearner(nn.Module):
    def __init__(self, input_dim, dict_size):
        super().__init__()
        self.dictionary = nn.Parameter(torch.randn(dict_size, input_dim))

    def forward(self, x):
        codes = torch.matmul(x, self.dictionary.T)
        recon = torch.matmul(codes, self.dictionary)
        return codes, recon

class LinearProbe(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

def train_multitask_probes_with_shared_dictionary(X, y_dict, probes, dict_learner, epochs=10):
    loss_fn_cls = nn.CrossEntropyLoss()
    loss_fn_recon = nn.MSELoss()
    probe_params = []
    for probe in probes.values():
        probe_params += list(probe.parameters())

    optimizer = optim.Adam(probe_params + list(dict_learner.parameters()), lr=1e-3)
    dataset = TensorDataset(X, y_dict['pos'], y_dict['dep'], y_dict['word'], y_dict['position'])
    loader = DataLoader(dataset, batch_size=64, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        for xb, y_pos_b, y_dep_b, y_word_b, y_posn_b in loader:
            codes, recon = dict_learner(xb)

            loss_pos = loss_fn_cls(probes['pos'](codes), y_pos_b)
            loss_dep = loss_fn_cls(probes['dep'](codes), y_dep_b)
            loss_word = loss_fn_cls(probes['word'](codes), y_word_b)
            loss_posn = loss_fn_cls(probes['position'](codes), y_posn_b)
            loss_recon = loss_fn_recon(recon, xb)

            total = loss_pos + loss_dep + loss_word + loss_posn + loss_recon

            optimizer.zero_grad()
            total.backward()
            optimizer.step()

            total_loss += total.item()

        print(f"Epoch {epoch+1}, Total Loss: {total_loss:.4f}")

def evaluate_probe(model, dict_learner, X, y):
    model.eval()
    dict_learner.eval()
    with torch.no_grad():
        codes, _ = dict_learner(X)
        preds = model(codes).argmax(dim=1)
        accuracy = (preds == y).float().mean().item()
    return accuracy

dict_size = 128
shared_dict = DictionaryLearner(X.shape[1], dict_size)

probes = {
    'pos': LinearProbe(dict_size, len(le_pos.classes_)),
    'dep': LinearProbe(dict_size, len(le_dep.classes_)),
    'word': LinearProbe(dict_size, len(le_word.classes_)),
    'position': LinearProbe(dict_size, y_position.max().item() + 1),
}

y_dict = {
    'pos': y_pos,
    'dep': y_dep,
    'word': y_word,
    'position': y_position,
}

train_multitask_probes_with_shared_dictionary(X, y_dict, probes, shared_dict)

print("\n--- Evaluation with Shared Dictionary ---")
print(f"POS Accuracy:       {evaluate_probe(probes['pos'], shared_dict, X, y_pos):.2%}")
print(f"DEP Accuracy:       {evaluate_probe(probes['dep'], shared_dict, X, y_dep):.2%}")
print(f"WORD Accuracy:      {evaluate_probe(probes['word'], shared_dict, X, y_word):.2%}")
print(f"POSITION Accuracy:  {evaluate_probe(probes['position'], shared_dict, X, y_position):.2%}")

In [17]:
import pickle
probes_to_pickle = {
    'pos_model': pos_model, 
    'dep_model': dep_model, 
    'position_model': position_model,
    'word_model': word_model
}

with open('probes.pkl', 'wb') as f:
    pickle.dump(probes_to_pickle, f)