In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from copy import deepcopy
from tqdm.auto import tqdm
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix

from transformer import Transformer

In [2]:
dataset = load_dataset('GleghornLab/SS3')
dataset

README.md:   0%|          | 0.00/505 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/3.58M [00:00<?, ?B/s]

data/valid-00000-of-00001.parquet:   0%|          | 0.00/180k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/19.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10792 [00:00<?, ? examples/s]

Generating valid split:   0%|          | 0/626 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/50 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['seqs', 'labels'],
        num_rows: 10792
    })
    valid: Dataset({
        features: ['seqs', 'labels'],
        num_rows: 626
    })
    test: Dataset({
        features: ['seqs', 'labels'],
        num_rows: 50
    })
})

First let's try one hot encoding as our feature inputs

In [3]:
seq = dataset['train'][0]['seqs']
label = dataset['train'][0]['labels']
print(seq)
print(label)


MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKRQTLGQHDFSAGEGLYTHMKALRPDEDRLSPLHSVYVDQWDWERVMGDGERQFSTLKSTVEAIWAGIKATEAAVSEEFGLAPFLPDQIHFVHSQELLSRYPDLDAKGRERAIAKDLGAVFLVGIGGKLSDGHRHDVRAPDYDDWSTPSELGHAGLNGDILVWNPVLEDAFELSSMGIRVDADTLKHQLALTGDEDRLELEWHQALLRGEMPQTIGGGIGQSRLTMLLLQLPHIGQVQAGVWPAAVRESVPSLL
DDDCHHHHHHHHHHHHHHHHHHHHHHHCEEECCCCCEEECCCCCCCCCCCCCCCCEECCCCCCCCCEEECCCCCCHHHHHHHHCCCCCCCEEEEEEEEECCCCCCCCCCCCCEEEEEEEEEECCCCCCCHHHHHHHHHHHHHHHHHHHHHHHHHCCCCCCCCCCCEEEEHHHHHHHCCCCCHHHHHHHHHHHHCEEEEECCCCCCCCCCCCCCCCCCCECCCCECCCCCECCEEEEEEEECCCCEEEEEEEEEEECCHHHHHHHHHHHCCCCHHHCHHHHHHHCCCCCCEEEEEEEHHHHHHHHHCCCCHHHCCCCCCCHHHHHHCCCCC


In [4]:
label_dict = {label: i for i, label in enumerate(sorted(set(label)))}
label_dict

{'C': 0, 'D': 1, 'E': 2, 'H': 3}

In [5]:
# Here's one way to make label vectors
dataset = dataset.map(lambda x: {'label_vector': [label_dict[y] for y in x['labels']]})

Map:   0%|          | 0/10792 [00:00<?, ? examples/s]

Map:   0%|          | 0/626 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

In [6]:
# This way looks at all examples so nothing is missed
# # Collect sequence alphabet
all_seqs = []
all_labels = []
for split in ["train", "valid", "test"]:
    all_seqs += dataset[split]["seqs"]
    # dataset uses key 'labels'; if 'label' also exists, prefer 'labels'
    if "labels" in dataset[split].features:
        all_labels += dataset[split]["labels"]
    else:
        all_labels += [ex["label"] for ex in dataset[split]]

seq_vocab = sorted(list(set("".join(all_seqs))))
label_vocab = sorted(list(set("".join(all_labels))))
char2idx = {ch: i for i, ch in enumerate(seq_vocab)}
label2idx = {ch: i for i, ch in enumerate(label_vocab)}
idx2label = {i: ch for ch, i in label2idx.items()}

vocab_size = len(seq_vocab)
num_labels = len(label_vocab)

print({"vocab_size": vocab_size, "num_labels": num_labels})

{'vocab_size': 24, 'num_labels': 4}


In [7]:
# Dataloaders with padding and collate
PAD = "<PAD>"
if PAD not in seq_vocab:
    seq_vocab = [PAD] + seq_vocab
    char2idx = {ch: i for i, ch in enumerate(seq_vocab)}
    vocab_size = len(seq_vocab)
    print("Added PAD token to seq vocab.")


pad_idx = char2idx[PAD]
ignore_index = -100


def to_indices(s: str, mapper: dict[str, int]) -> list[int]:
    return [mapper[c] for c in s]


class HFDatasetWrapper(Dataset):
    def __init__(self, hf_ds):
        self.ds = hf_ds

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        seq = item["seqs"]
        labs = item["labels"] if "labels" in item else item["label"]
        return seq, labs


def collate_fn(batch):
    seqs, labs = zip(*batch)
    seq_ids = [torch.tensor(to_indices(s, char2idx), dtype=torch.long) for s in seqs]
    lab_ids = [torch.tensor(to_indices(s, label2idx), dtype=torch.long) for s in labs]
    max_len = max(x.size(0) for x in seq_ids)

    input_ids = torch.full((len(seqs), max_len), fill_value=pad_idx, dtype=torch.long)
    labels = torch.full((len(seqs), max_len), fill_value=ignore_index, dtype=torch.long)
    attn_mask = torch.zeros((len(seqs), max_len), dtype=torch.long)

    for i, (s_ids, l_ids) in enumerate(zip(seq_ids, lab_ids)):
        L = s_ids.size(0)
        input_ids[i, :L] = s_ids
        labels[i, :L] = l_ids
        attn_mask[i, :L] = 1  # 1=valid, 0=pad

    # Convert to boolean mask where True means MASKED (padding) for the attention module
    attn_mask_bool = attn_mask == 0
    return input_ids, labels, attn_mask_bool


train_ds = HFDatasetWrapper(dataset["train"]) 
valid_ds = HFDatasetWrapper(dataset["valid"]) 
test_ds  = HFDatasetWrapper(dataset["test"]) 

batch_size = 4
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0)
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
len(train_loader), len(valid_loader), len(test_loader)

Added PAD token to seq vocab.


(2698, 157, 13)

In [8]:
# Model wrapper around provided Transformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class SeqLabelTransformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        num_labels: int,
        hidden_size: int = 128,
        n_heads: int = 2,
        n_layers: int = 1,
        expansion_ratio: float = 2.0,
        dropout: float = 0.1,
        rotary: bool = True
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_proj = nn.Linear(vocab_size, hidden_size)
        self.backbone = Transformer(hidden_size=hidden_size, n_heads=n_heads, n_layers=n_layers, expansion_ratio=expansion_ratio, dropout=dropout, rotary=rotary)
        self.norm = nn.LayerNorm(hidden_size)
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
        # input_ids: (B, L) of token indices
        # one-hot -> project
        x = F.one_hot(input_ids, num_classes=vocab_size).float()  # (B, L, V)
        x = self.input_proj(x)  # (B, L, H)
        x = self.backbone(x, attention_mask=attention_mask)  # (B, L, H)
        x = self.norm(x)
        logits = self.classifier(x)  # (B, L, C)
        return logits


model = SeqLabelTransformer(
    vocab_size=vocab_size,
    num_labels=num_labels,
    hidden_size=128,
    n_heads=2,
    n_layers=1,
    expansion_ratio=2.0,
    dropout=0.1,
    rotary=True
).to(device)
model


SeqLabelTransformer(
  (input_proj): Linear(in_features=25, out_features=128, bias=True)
  (backbone): Transformer(
    (layers): ModuleList(
      (0): TransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=128, out_features=384, bias=False)
          )
          (out_proj): Linear(in_features=128, out_features=128, bias=False)
          (q_ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (k_ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (rotary): RotaryEmbedding()
        )
        (ffn): Sequential(
          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=128, out_features=512, bias=False)
          (2): SwiGLU()
          (3): Dropout(p=0.1, inplace=False)
          (4): Linear(in_features=256, out_features=128, bias=False)
        )
      )
    )
  )
 

In [9]:
# Training & evaluation with early stopping on validation F1
criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)


def run_epoch(dl, train: bool):
    if train:
        model.train()
    else:
        model.eval()
    total_loss = 0.0
    all_preds = []
    all_trues = []
    for input_ids, labels, attn_mask in tqdm(dl, desc=f"{'Training' if train else 'Evaluating'}", leave=False):
        input_ids = input_ids.to(device)
        labels = labels.to(device)
        attn_mask = attn_mask.to(device)

        with torch.set_grad_enabled(train):
            logits = model(input_ids, attention_mask=attn_mask)
            # reshape for token-level CE
            B, L, C = logits.shape
            loss = criterion(logits.view(B*L, C), labels.view(B*L))

        if train:
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        total_loss += loss.item()

        # collect predictions ignoring pad positions
        preds = logits.argmax(dim=-1)
        mask = labels != ignore_index
        all_preds.extend(preds[mask].detach().cpu().tolist())
        all_trues.extend(labels[mask].detach().cpu().tolist())

    avg_loss = total_loss / max(1, len(dl))
    f1 = f1_score(all_trues, all_preds, average='weighted')
    acc = accuracy_score(all_trues, all_preds)
    conf_mat = confusion_matrix(all_trues, all_preds)
    print(f"F1: {f1:.4f}, Accuracy: {acc:.4f}")
    return avg_loss, f1, acc, conf_mat


best_state = None
best_val_f1 = -1.0
patience = 3
stale = 0
max_epochs = 20

for epoch in range(1, max_epochs+1):
    train_loss, train_f1, train_acc, train_conf_mat = run_epoch(train_loader, train=True)
    val_loss, val_f1, val_acc, val_conf_mat = run_epoch(valid_loader, train=False)
    scheduler.step()
    print({"epoch": epoch, "train_loss": round(train_loss, 4), "train_f1": round(train_f1, 4), "val_loss": round(val_loss, 4), "val_f1": round(val_f1, 4)})

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_state = deepcopy(model.state_dict())
        stale = 0
    else:
        stale += 1
        if stale >= patience:
            print(f"Early stopping at epoch {epoch} (best val_f1={best_val_f1:.4f})")
            break

# Load best and evaluate on test
assert best_state is not None, "No best state captured."
model.load_state_dict(best_state)
test_loss, test_f1, test_acc, test_conf_mat = run_epoch(test_loader, train=False)
print(f"Test Loss: {test_loss:.4f}, Test F1: {test_f1:.4f}, Test Accuracy: {test_acc:.4f}")
print(f"Test Confusion Matrix:\n{test_conf_mat}")
print(f"Best Validation F1: {best_val_f1:.4f}")


Training:   0%|          | 0/2698 [00:00<?, ?it/s]

F1: 0.4206, Accuracy: 0.4523


Evaluating:   0%|          | 0/157 [00:00<?, ?it/s]

F1: 0.4452, Accuracy: 0.4697
{'epoch': 1, 'train_loss': 1.193, 'train_f1': 0.4206, 'val_loss': 1.1383, 'val_f1': 0.4452}


Training:   0%|          | 0/2698 [00:00<?, ?it/s]

F1: 0.4261, Accuracy: 0.4564


Evaluating:   0%|          | 0/157 [00:00<?, ?it/s]

F1: 0.4395, Accuracy: 0.4684
{'epoch': 2, 'train_loss': 1.1871, 'train_f1': 0.4261, 'val_loss': 1.1365, 'val_f1': 0.4395}


Training:   0%|          | 0/2698 [00:00<?, ?it/s]

F1: 0.4261, Accuracy: 0.4570


Evaluating:   0%|          | 0/157 [00:00<?, ?it/s]

F1: 0.4553, Accuracy: 0.4685
{'epoch': 3, 'train_loss': 1.1857, 'train_f1': 0.4261, 'val_loss': 1.1352, 'val_f1': 0.4553}


Training:   0%|          | 0/2698 [00:00<?, ?it/s]

F1: 0.4274, Accuracy: 0.4571


Evaluating:   0%|          | 0/157 [00:00<?, ?it/s]

F1: 0.4496, Accuracy: 0.4708
{'epoch': 4, 'train_loss': 1.1851, 'train_f1': 0.4274, 'val_loss': 1.1361, 'val_f1': 0.4496}


Training:   0%|          | 0/2698 [00:00<?, ?it/s]

F1: 0.4262, Accuracy: 0.4574


Evaluating:   0%|          | 0/157 [00:00<?, ?it/s]

F1: 0.4584, Accuracy: 0.4740
{'epoch': 5, 'train_loss': 1.185, 'train_f1': 0.4262, 'val_loss': 1.1365, 'val_f1': 0.4584}


Training:   0%|          | 0/2698 [00:00<?, ?it/s]

F1: 0.4262, Accuracy: 0.4578


Evaluating:   0%|          | 0/157 [00:00<?, ?it/s]

F1: 0.4411, Accuracy: 0.4732
{'epoch': 6, 'train_loss': 1.1851, 'train_f1': 0.4262, 'val_loss': 1.1363, 'val_f1': 0.4411}


Training:   0%|          | 0/2698 [00:00<?, ?it/s]

F1: 0.4247, Accuracy: 0.4577


Evaluating:   0%|          | 0/157 [00:00<?, ?it/s]

F1: 0.4495, Accuracy: 0.4706
{'epoch': 7, 'train_loss': 1.1839, 'train_f1': 0.4247, 'val_loss': 1.1396, 'val_f1': 0.4495}


Training:   0%|          | 0/2698 [00:00<?, ?it/s]

F1: 0.4240, Accuracy: 0.4580


Evaluating:   0%|          | 0/157 [00:00<?, ?it/s]

F1: 0.4411, Accuracy: 0.4732
{'epoch': 8, 'train_loss': 1.1842, 'train_f1': 0.424, 'val_loss': 1.1336, 'val_f1': 0.4411}
Early stopping at epoch 8 (best val_f1=0.4584)


Evaluating:   0%|          | 0/13 [00:00<?, ?it/s]

F1: 0.3895, Accuracy: 0.4217
Test Loss: 1.2608, Test F1: 0.3895, Test Accuracy: 0.4217
Test Confusion Matrix:
[[2202    0  418 1551]
 [ 626    0  132  516]
 [ 715    0  538  984]
 [ 946    0  529 1940]]
Best Validation F1: 0.4584


Second, let's try protein language model embeddings as our input

In [10]:
from transformers import AutoModel, AutoTokenizer

model_path = 'Synthyra/ESM2-8M'
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).eval()
model = model.to(device)
tokenizer = model.tokenizer

Some weights of FastEsmModel were not initialized from the model checkpoint at Synthyra/ESM2-8M and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
all_seqs = list(set(list(dataset['train']['seqs']) + list(dataset['valid']['seqs']) + list(dataset['test']['seqs'])))
all_seqs = sorted(all_seqs, key=len)

embedding_dict = model.embed_dataset(
    sequences=all_seqs,
    tokenizer=tokenizer,
    batch_size=4, # adjust for your GPU memory
    max_len=512, # adjust for your needs
    full_embeddings=True, # if True, no pooling is performed
    embed_dtype=torch.float32, # cast to what dtype you want
    #pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
    num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
    sql=False, # if True, embeddings will be stored in SQLite database
    sql_db_path='embeddings.db',
    save=True, # if True, embeddings will be saved as a .pth file
    save_path='embeddings.pth',
)
# embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
print(len(embedding_dict))

Embedding 11380 new sequences


Embedding batches:   0%|          | 0/2845 [00:00<?, ?it/s]

11380


In [16]:
# Embedding-based DataLoaders (use per-residue embeddings instead of one-hot)
assert 'embedding_dict' in globals(), "Run the embedding cell to create embedding_dict first."

class EmbeddingDataset(Dataset):
    def __init__(self, hf_ds, embedding_lookup: dict):
        self.ds = hf_ds
        self.lookup = embedding_lookup

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        seq = item['seqs']
        labs = item['labels'] if 'labels' in item else item['label']
        emb = self.lookup[seq]
        if not torch.is_tensor(emb):
            emb = torch.tensor(emb)
        # Handle potential special tokens (e.g., BOS/EOS) in embeddings
        if emb.dim() == 2 and emb.size(0) == len(labs) + 2:
            emb = emb[1:-1]
        # Align lengths conservatively
        K = min(emb.size(0), len(labs))
        return emb[:K].float(), labs[:K]


def collate_fn_emb(batch):
    embs, lab_strs = zip(*batch)
    lengths = [e.size(0) for e in embs]
    max_len = max(lengths)
    embed_dim = embs[0].size(1)

    inputs = torch.zeros((len(embs), max_len, embed_dim), dtype=torch.float32)
    labels = torch.full((len(embs), max_len), fill_value=ignore_index, dtype=torch.long)
    attn_mask = torch.zeros((len(embs), max_len), dtype=torch.long)

    for i, (e, labs) in enumerate(zip(embs, lab_strs)):
        L = e.size(0)
        inputs[i, :L] = e
        labels[i, :L] = torch.tensor([label2idx[c] for c in labs], dtype=torch.long)
        attn_mask[i, :L] = 1

    attn_mask_bool = attn_mask == 0
    return inputs, labels, attn_mask_bool

train_ds_emb = EmbeddingDataset(dataset['train'], embedding_dict)
valid_ds_emb = EmbeddingDataset(dataset['valid'], embedding_dict)
test_ds_emb  = EmbeddingDataset(dataset['test'],  embedding_dict)

batch_size_emb = 4
train_loader_emb = DataLoader(train_ds_emb, batch_size=batch_size_emb, shuffle=True,  collate_fn=collate_fn_emb, num_workers=0)
valid_loader_emb = DataLoader(valid_ds_emb, batch_size=batch_size_emb, shuffle=False, collate_fn=collate_fn_emb, num_workers=0)
test_loader_emb  = DataLoader(test_ds_emb,  batch_size=batch_size_emb, shuffle=False, collate_fn=collate_fn_emb, num_workers=0)

len(train_loader_emb), len(valid_loader_emb), len(test_loader_emb)


(2698, 157, 13)

In [17]:
# Model that consumes embeddings directly
class EmbSeqLabelTransformer(nn.Module):
    def __init__(self, embed_dim: int, num_labels: int, hidden_size: int = 128, n_heads: int = 2, n_layers: int = 1, expansion_ratio: float = 2.0, dropout: float = 0.1, rotary: bool = True):
        super().__init__()
        self.input_proj = nn.Linear(embed_dim, hidden_size, bias=False)
        self.backbone = Transformer(hidden_size=hidden_size, n_heads=n_heads, n_layers=n_layers, expansion_ratio=expansion_ratio, dropout=dropout, rotary=rotary)
        self.norm = nn.LayerNorm(hidden_size)
        self.classifier = nn.Linear(hidden_size, num_labels, bias=True)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
        x = self.input_proj(inputs)
        x = self.backbone(x, attention_mask=attention_mask)
        x = self.norm(x)
        return self.classifier(x)

# Infer embed_dim from one sample
tmp_e, _ = train_ds_emb[0]
embed_dim = tmp_e.size(1)

emb_model = EmbSeqLabelTransformer(
    embed_dim=embed_dim,
    num_labels=num_labels,
    hidden_size=128,
    n_heads=2,
    n_layers=1,
    expansion_ratio=2.0,
    dropout=0.1,
    rotary=True
).to(device)
emb_model


EmbSeqLabelTransformer(
  (input_proj): Linear(in_features=320, out_features=128, bias=False)
  (backbone): Transformer(
    (layers): ModuleList(
      (0): TransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=128, out_features=384, bias=False)
          )
          (out_proj): Linear(in_features=128, out_features=128, bias=False)
          (q_ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (k_ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (rotary): RotaryEmbedding()
        )
        (ffn): Sequential(
          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=128, out_features=512, bias=False)
          (2): SwiGLU()
          (3): Dropout(p=0.1, inplace=False)
          (4): Linear(in_features=256, out_features=128, bias=False)
        )
      )
    )


In [18]:
# Training with embeddings + early stopping on validation F1
emb_criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
emb_optimizer = torch.optim.AdamW(emb_model.parameters(), lr=1e-4, weight_decay=0.01)
emb_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(emb_optimizer, T_max=10)


def run_epoch_emb(dl, train: bool):
    if train:
        emb_model.train()
    else:
        emb_model.eval()
    total_loss = 0.0
    all_preds, all_trues = [], []
    for inputs, labels, attn_mask in tqdm(dl, desc=f"{'Training' if train else 'Evaluating'} (emb)", leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)
        attn_mask = attn_mask.to(device)
        with torch.set_grad_enabled(train):
            logits = emb_model(inputs, attention_mask=attn_mask)
            B, L, C = logits.shape
            loss = emb_criterion(logits.view(B*L, C), labels.view(B*L))
        if train:
            emb_optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(emb_model.parameters(), 1.0)
            emb_optimizer.step()
        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        mask = labels != ignore_index
        all_preds.extend(preds[mask].detach().cpu().tolist())
        all_trues.extend(labels[mask].detach().cpu().tolist())
    avg_loss = total_loss / max(1, len(dl))
    f1 = f1_score(all_trues, all_preds, average='weighted')
    return avg_loss, f1

best_state_emb = None
best_val_f1_emb = -1.0
patience_emb = 3
stale_emb = 0
max_epochs_emb = 20

for epoch in range(1, max_epochs_emb+1):
    train_loss, train_f1 = run_epoch_emb(train_loader_emb, train=True)
    val_loss, val_f1 = run_epoch_emb(valid_loader_emb, train=False)
    emb_scheduler.step()
    print({"epoch": epoch, "train_loss": round(train_loss, 4), "train_f1": round(train_f1, 4), "val_loss": round(val_loss, 4), "val_f1": round(val_f1, 4)})
    if val_f1 > best_val_f1_emb:
        best_val_f1_emb = val_f1
        best_state_emb = deepcopy(emb_model.state_dict())
        stale_emb = 0
    else:
        stale_emb += 1
        if stale_emb >= patience_emb:
            print(f"Early stopping (emb) at epoch {epoch} (best val_f1={best_val_f1_emb:.4f})")
            break

# Load best and evaluate on test (embeddings)
assert best_state_emb is not None, "No best embedding model state captured."
emb_model.load_state_dict(best_state_emb)
_, test_f1_emb = run_epoch_emb(test_loader_emb, train=False)
print({"best_val_f1_emb": round(best_val_f1_emb, 4), "test_f1_emb": round(test_f1_emb, 4)})


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 1, 'train_loss': 0.6829, 'train_f1': 0.7295, 'val_loss': 0.6592, 'val_f1': 0.7364}


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 2, 'train_loss': 0.6431, 'train_f1': 0.7443, 'val_loss': 0.651, 'val_f1': 0.7403}


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 3, 'train_loss': 0.6329, 'train_f1': 0.7479, 'val_loss': 0.6438, 'val_f1': 0.7423}


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 4, 'train_loss': 0.6262, 'train_f1': 0.7502, 'val_loss': 0.6398, 'val_f1': 0.7439}


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 5, 'train_loss': 0.6225, 'train_f1': 0.7518, 'val_loss': 0.6407, 'val_f1': 0.7437}


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 6, 'train_loss': 0.6181, 'train_f1': 0.753, 'val_loss': 0.6371, 'val_f1': 0.745}


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 7, 'train_loss': 0.6155, 'train_f1': 0.754, 'val_loss': 0.6381, 'val_f1': 0.745}


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 8, 'train_loss': 0.6144, 'train_f1': 0.7547, 'val_loss': 0.636, 'val_f1': 0.7454}


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 9, 'train_loss': 0.6122, 'train_f1': 0.7553, 'val_loss': 0.6378, 'val_f1': 0.7453}


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 10, 'train_loss': 0.6122, 'train_f1': 0.7558, 'val_loss': 0.6361, 'val_f1': 0.7453}


Training (emb):   0%|          | 0/2698 [00:00<?, ?it/s]

Evaluating (emb):   0%|          | 0/157 [00:00<?, ?it/s]

{'epoch': 11, 'train_loss': 0.6115, 'train_f1': 0.7557, 'val_loss': 0.6361, 'val_f1': 0.7453}
Early stopping (emb) at epoch 11 (best val_f1=0.7454)


Evaluating (emb):   0%|          | 0/13 [00:00<?, ?it/s]

{'best_val_f1_emb': 0.7454, 'test_f1_emb': 0.6645}
