In [7]:
from tarp.model.backbone.untrained.lstm import LstmEncoder
from tarp.model.backbone.untrained.hyena import HyenaEncoder
from tarp.model.finetuning.classification import ClassificationModel

from tarp.services.datasets.classification.multilabel import MultiLabelClassificationDataset
from tarp.services.tokenizers.pretrained.dnabert import Dnabert2Tokenizer
from tarp.services.datasource.sequence import TabularSequenceSource, CombinationSource, FastaSliceSource


from tarp.services.preprocessing.augumentation import (
    CombinationTechnique,
    RandomMutation,
    InsertionDeletion,
    ReverseComplement,
)

from tarp.services.datasets.metric.triplet import MultilabelOfflineTripletDataset


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import polars as pl
import numpy as np

from sklearn.neighbors import NearestNeighbors

from pathlib import Path

In [8]:
label_columns = (
    pl.read_csv(Path("../temp/data/processed/labels.csv")).to_series().to_list()
)

In [9]:
dataset = MultiLabelClassificationDataset(
    CombinationSource(
        [
            TabularSequenceSource(
                source=Path("../temp/data/processed/card_amr.parquet")
            ),
            FastaSliceSource(
                directory=Path("../temp/data/external/sequences"),
                metadata=Path("../temp/data/processed/non_amr_genes_10000.parquet"),
                key_column="genomic_nucleotide_accession.version",
                start_column="start_position_on_the_genomic_accession",
                end_column="end_position_on_the_genomic_accession",
                orientation_column="orientation",
            ),
        ]
    ),
    Dnabert2Tokenizer(),
    sequence_column="sequence",
    label_columns=label_columns,
    maximum_sequence_length=512,
    augumentation=CombinationTechnique(
        [
            RandomMutation(),
            InsertionDeletion(),
            ReverseComplement(0.5),
        ]
    ),
)

metric_dataset = MultilabelOfflineTripletDataset(
    base_dataset=dataset, label_cache="../temp/data/interim/labels_cache.parquet"
)

[96m[DEBUG]	2025-10-13 22:49:45,118 - Checking for label cache at: ../temp/data/interim/labels_cache.parquet[0m
[93m[WARN]	2025-10-13 22:49:45,132 - Label cache mismatch — columns or size differ. (cache columns: ['disinfecting agents and antiseptics', 'glycylcycline', 'rifamycin antibiotic', 'macrolide antibiotic', 'streptogramin antibiotic', 'pyrazine antibiotic', 'tetracycline antibiotic', 'bicyclomycin-like antibiotic', 'isoniazid-like antibiotic', 'nitroimidazole antibiotic', 'orthosomycin antibiotic', 'nitrofuran antibiotic', 'carbapenem', 'pactamycin-like antibiotic', 'moenomycin antibiotic', 'cycloserine-like antibiotic', 'cephalosporin', 'diaminopyrimidine antibiotic', 'fluoroquinolone antibiotic', 'antibiotic without defined classification', 'phosphonic acid antibiotic', 'pleuromutilin antibiotic', 'elfamycin antibiotic', 'nucleoside antibiotic', 'peptide antibiotic', 'phenicol antibiotic', 'streptogramin A antibiotic', 'aminocoumarin antibiotic', 'streptogramin B antibioti

In [10]:
encoder = HyenaEncoder(
        vocabulary_size=dataset.tokenizer.vocab_size,
        embedding_dimension=128,
        hidden_dimension=256,
        padding_id=dataset.tokenizer.pad_token_id,
        number_of_layers=2,
        dropout=0.2,
    )

classification_model = ClassificationModel(
    encoder=encoder,
    number_of_classes=len(label_columns),
)

classification_model.load_state_dict(
    torch.load("../temp/checkpoints/HyenaEncoder_20251013_205824.pt")
)

# Get the encoder part of the model
encoder: HyenaEncoder = classification_model.encoder

In [11]:
# Apply the model to the dataset to get the embeddings
from tqdm.auto import tqdm

# Preallocate numpy array for all embeddings
num_samples = len(dataset)
batch_size = 32
embedding_dim = encoder.encoding_size  # Output dimension of encoder.encode
embeddings = np.empty((num_samples, embedding_dim), dtype=np.float32)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(DEVICE)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
start_idx = 0
encoder.eval()
with torch.no_grad():
    for batch in tqdm(dataloader):
        input_ids = batch["sequence"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        batch_embeddings = encoder.encode(input_ids, attention_mask)
        batch_size_actual = batch_embeddings.shape[0]
        embeddings[start_idx:start_idx + batch_size_actual] = batch_embeddings.cpu().numpy()
        start_idx += batch_size_actual

100%|██████████| 513/513 [00:22<00:00, 22.57it/s]


In [12]:
# Print the shape of the embeddings
print("Embeddings shape:", embeddings.shape)

Embeddings shape: (16392, 128)


In [13]:
# Split the dataset into train and test sets
train_size = int(0.8 * len(dataset))
train_embeddings = embeddings[:train_size]
test_embeddings = embeddings[train_size:]

print("Train embeddings shape:", train_embeddings.shape)
print("Test embeddings shape:", test_embeddings.shape)


Train embeddings shape: (13113, 128)
Test embeddings shape: (3279, 128)


In [14]:
# KNN search
from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors=5, metric='cosine')

knn.fit(train_embeddings, [
    dataset[i]['labels'].numpy() for i in range(train_size)
])

# Classify test set
predictions = knn.predict(test_embeddings)



In [15]:
# Classification report
from sklearn.metrics import classification_report

print(classification_report(
    [dataset[i + train_size]['labels'].numpy() for i in range(len(test_embeddings))],
    predictions,
    zero_division=0,
    target_names=label_columns
))

                                           precision    recall  f1-score   support

      disinfecting agents and antiseptics       0.00      0.00      0.00         0
                            glycylcycline       0.00      0.00      0.00         0
                     rifamycin antibiotic       0.00      0.00      0.00         0
                     macrolide antibiotic       0.00      0.00      0.00         0
                 streptogramin antibiotic       0.00      0.00      0.00         0
                      pyrazine antibiotic       0.00      0.00      0.00         0
                  tetracycline antibiotic       0.00      0.00      0.00         0
             bicyclomycin-like antibiotic       0.00      0.00      0.00         0
                isoniazid-like antibiotic       0.00      0.00      0.00         0
                nitroimidazole antibiotic       0.00      0.00      0.00         0
                  orthosomycin antibiotic       0.00      0.00      0.00         0
   

In [16]:
# %%
import plotly.express as px

from sklearn.manifold import TSNE

# Run t-SNE on the embeddings
tsne = TSNE(n_components=2, random_state=102, perplexity=30)
train_embeddings_2d = tsne.fit_transform(train_embeddings)


In [17]:
from torch.utils.data import DataLoader

# Collect labels for the training split
labels = []
dataloader_labels = DataLoader(dataset, batch_size=32, shuffle=False)

for i, batch in enumerate(dataloader_labels):
    if i * 32 >= train_size:  # stop after train split
        break
    labels.append(batch["labels"].numpy())

train_labels = np.vstack(labels)[:train_size]  # shape (N, num_labels)
print("Train labels shape:", train_labels.shape)

# Convert multilabel → single label for visualization
train_labels_simple = train_labels.argmax(axis=1)

Train labels shape: (13113, 47)


In [18]:
fig = px.scatter(
    x=train_embeddings_2d[:, 0],
    y=train_embeddings_2d[:, 1],
    color=[label_columns[i] for i in train_labels_simple],
    title="t-SNE visualization of Hyena embeddings",
    labels={"x": "t-SNE dim 1", "y": "t-SNE dim 2", "color": "Gene family"},
    opacity=0.7,
)
# Square figure
fig.update_layout(width=1024, height=768)
fig.show()

