In [2]:
from tarp.model.backbone.untrained.lstm import LstmEncoder
from tarp.model.backbone.untrained.hyena import HyenaEncoder
from tarp.model.finetuning.classification import ClassificationModel

from tarp.services.datasets.classification.multilabel import MultiLabelClassificationDataset
from tarp.services.tokenizers.pretrained.dnabert import Dnabert2Tokenizer
from tarp.services.datasource.sequence import TabularSequenceSource, CombinationSource, FastaSliceSource


from tarp.services.preprocessing.augmentation import (
    CombinationTechnique,
    RandomMutation,
    InsertionDeletion,
    ReverseComplement,
)

from tarp.services.datasets.metric.triplet import MultiLabelOfflineTripletDataset


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import polars as pl
import numpy as np

from sklearn.neighbors import NearestNeighbors

from pathlib import Path

In [3]:
label_columns = (
    pl.read_csv(Path("../temp/data/processed/labels.csv")).to_series().to_list()
)

In [4]:
dataset = MultiLabelClassificationDataset(
    CombinationSource(
        [
            TabularSequenceSource(
                source=Path("../temp/data/processed/card_amr.parquet")
            ),
            FastaSliceSource(
                directory=Path("../temp/data/external/sequences"),
                metadata=Path("../temp/data/processed/non_amr_genes_10000.parquet"),
                key_column="genomic_nucleotide_accession.version",
                start_column="start_position_on_the_genomic_accession",
                end_column="end_position_on_the_genomic_accession",
                orientation_column="orientation",
            ),
        ]
    ),
    Dnabert2Tokenizer(),
    sequence_column="sequence",
    label_columns=label_columns,
    maximum_sequence_length=512,
    augmentation=CombinationTechnique(
        [
            RandomMutation(),
            InsertionDeletion(),
            ReverseComplement(0.5),
        ]
    ),
)

metric_dataset = MultiLabelOfflineTripletDataset(
    base_dataset=dataset, label_cache=Path("../temp/data/cache/labels_cache.parquet")
)

[96m[DEBUG]	2025-10-27 14:47:08,766 - Checking for label cache at: ..\temp\data\cache\labels_cache.parquet[0m
[92m[INFO]	2025-10-27 14:47:08,776 - Loaded labels from cache (aligned to label_columns): ..\temp\data\cache\labels_cache.parquet[0m


In [5]:
from tarp.config import HyenaConfig

In [6]:
encoder = HyenaEncoder(
    vocabulary_size=dataset.tokenizer.vocab_size,
    embedding_dimension=HyenaConfig.embedding_dimension,
    hidden_dimension=HyenaConfig.hidden_dimension,
    padding_id=dataset.tokenizer.pad_token_id,
    number_of_layers=HyenaConfig.number_of_layers,
    dropout=HyenaConfig.dropout,
)

classification_model = ClassificationModel(
    encoder=encoder,
    number_of_classes=len(label_columns),
)

classification_model.load_state_dict(
    torch.load("../temp/checkpoints/HyenaEncoder_20251027_111621.pt")
)

# Get the encoder part of the model
encoder: HyenaEncoder = classification_model.encoder

In [7]:
# Apply the model to the dataset to get the embeddings
from tqdm.auto import tqdm

# Preallocate numpy array for all embeddings
num_samples = len(dataset)
batch_size = 32
embedding_dim = encoder.encoding_size  # Output dimension of encoder.encode
embeddings = np.empty((num_samples, embedding_dim), dtype=np.float32)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(DEVICE)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
start_idx = 0
encoder.eval()
with torch.no_grad():
    for batch in tqdm(dataloader):
        input_ids = batch["sequence"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        batch_embeddings = encoder.encode(input_ids, attention_mask)
        batch_size_actual = batch_embeddings.shape[0]
        embeddings[start_idx:start_idx + batch_size_actual] = batch_embeddings.cpu().numpy()
        start_idx += batch_size_actual

  0%|          | 0/825 [00:00<?, ?it/s]

In [8]:
# Print the shape of the embeddings
print("Embeddings shape:", embeddings.shape)

Embeddings shape: (26392, 256)


In [9]:
from sklearn.model_selection import train_test_split
SEED = 69420
indices = list(range(len(dataset)))
train_indices, temp_indices = train_test_split(
    indices, test_size=0.2, random_state=SEED
)
valid_indices, test_indices = train_test_split(
    temp_indices, test_size=0.5, random_state=SEED
)

In [10]:
# Split the dataset into train and test sets
train_embeddings = embeddings[train_indices]
test_embeddings = embeddings[test_indices]

print("Train embeddings shape:", train_embeddings.shape)
print("Test embeddings shape:", test_embeddings.shape)


Train embeddings shape: (21113, 256)
Test embeddings shape: (2640, 256)


In [11]:
train_labels = np.array([
    dataset[i]['labels'].numpy() for i in train_indices
])
test_labels = np.array([
    dataset[i]['labels'].numpy() for i in test_indices
])

In [12]:
# KNN search
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import SGDClassifier

knn = KNeighborsClassifier(n_neighbors=5, metric='cosine', weights='distance')

knn.fit(train_embeddings, train_labels)


0,1,2
,n_neighbors,5
,weights,'distance'
,algorithm,'auto'
,leaf_size,30
,p,2
,metric,'cosine'
,metric_params,
,n_jobs,


In [13]:
from sklearn.metrics import classification_report

print(classification_report(
    train_labels,
    knn.predict(train_embeddings),
    zero_division=0,
    target_names=label_columns
))

                                           precision    recall  f1-score   support

      disinfecting agents and antiseptics       1.00      1.00      1.00        67
                            glycylcycline       1.00      1.00      1.00        30
                     rifamycin antibiotic       1.00      1.00      1.00        54
                     macrolide antibiotic       1.00      1.00      1.00       179
                 streptogramin antibiotic       1.00      1.00      1.00        83
                      pyrazine antibiotic       1.00      1.00      1.00        11
                  tetracycline antibiotic       1.00      1.00      1.00       154
             bicyclomycin-like antibiotic       0.00      0.00      0.00         0
                isoniazid-like antibiotic       1.00      1.00      1.00        17
                nitroimidazole antibiotic       1.00      1.00      1.00        16
                  orthosomycin antibiotic       1.00      1.00      1.00         1
   

In [14]:
print(classification_report(
    test_labels,
    knn.predict(test_embeddings),
    zero_division=0,
    target_names=label_columns
))

                                           precision    recall  f1-score   support

      disinfecting agents and antiseptics       0.00      0.00      0.00         4
                            glycylcycline       0.00      0.00      0.00         3
                     rifamycin antibiotic       0.00      0.00      0.00         3
                     macrolide antibiotic       1.00      0.06      0.12        16
                 streptogramin antibiotic       0.00      0.00      0.00         7
                      pyrazine antibiotic       0.00      0.00      0.00         1
                  tetracycline antibiotic       0.00      0.00      0.00        18
             bicyclomycin-like antibiotic       0.00      0.00      0.00         1
                isoniazid-like antibiotic       0.00      0.00      0.00         2
                nitroimidazole antibiotic       0.00      0.00      0.00         1
                  orthosomycin antibiotic       0.00      0.00      0.00         0
   

In [15]:
# Every label which is not "non-AMR", extract the index in the
amr_label_indices = [
    idx for idx, label in enumerate(label_columns) if "non_amr" not in label
]

# Consolidate all the AMR class labels into a single "AMR" label and then Combine with non-AMR to get binary labels
train_amr_binary = (train_labels[:, amr_label_indices].sum(axis=1) > 0).astype(int)
test_amr_binary = (test_labels[:, amr_label_indices].sum(axis=1) > 0).astype(int)

In [16]:
binary_knn = KNeighborsClassifier(n_neighbors=5, metric='cosine', weights='distance')

binary_knn.fit(train_embeddings, train_amr_binary)

0,1,2
,n_neighbors,5
,weights,'distance'
,algorithm,'auto'
,leaf_size,30
,p,2
,metric,'cosine'
,metric_params,
,n_jobs,


In [17]:
print(
    classification_report(
        train_amr_binary,
        binary_knn.predict(train_embeddings),
        zero_division=0,
        target_names=["non-AMR", "AMR"],
    )
)

              precision    recall  f1-score   support

     non-AMR       1.00      1.00      1.00     15989
         AMR       1.00      1.00      1.00      5124

    accuracy                           1.00     21113
   macro avg       1.00      1.00      1.00     21113
weighted avg       1.00      1.00      1.00     21113



In [18]:
print(
    classification_report(
        test_amr_binary,
        binary_knn.predict(test_embeddings),
        zero_division=0,
        target_names=["AMR", "non-AMR"],
        labels=[1, 0],
    )
)

              precision    recall  f1-score   support

         AMR       0.91      0.78      0.84       611
     non-AMR       0.94      0.98      0.96      2029

    accuracy                           0.93      2640
   macro avg       0.92      0.88      0.90      2640
weighted avg       0.93      0.93      0.93      2640



In [19]:
# %%
import plotly.express as px

# Import UMAP and t-SNE
from umap import UMAP

# Run UMAP on the embeddings
umap = UMAP(n_components=2)
train_embeddings_2d = umap.fit_transform(train_embeddings)


In [20]:
rows = []
for i in range(len(train_embeddings)):
    active_labels = [label_columns[j] for j in np.where(train_labels[i] > 0)[0]]
    for label in active_labels:
        rows.append({
            "x": train_embeddings_2d[i, 0],
            "y": train_embeddings_2d[i, 1],
            "label": label,
            "sample_index": i,  # to identify duplicates
        })

df_vis = pl.DataFrame(rows)


In [21]:
fig = px.scatter(
    df_vis,
    x="x",
    y="y",
    color="label",
    hover_data=["sample_index"],
    title="UMAP visualization of Hyena embeddings (multi-label)",
    labels={"x": "UMAP dim 1", "y": "UMAP dim 2", "color": "Gene family"},
    opacity=0.7,
)

fig.update_layout(
    width=1024,
    height=768,
    legend_title="Label",
)
fig.show()

In [22]:
# Let's try t-SNE as well
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, random_state=SEED)

train_embeddings_2d = tsne.fit_transform(train_embeddings)

In [23]:
rows = []
for i in range(len(train_embeddings)):
    active_labels = [label_columns[j] for j in np.where(train_labels[i] > 0)[0]]
    for label in active_labels:
        rows.append({
            "x": train_embeddings_2d[i, 0],
            "y": train_embeddings_2d[i, 1],
            "label": label,
            "sample_index": i,  # to identify duplicates
        })
df_vis = pl.DataFrame(rows)
fig = px.scatter(
    df_vis,
    x="x",
    y="y",
    color="label",
    hover_data=["sample_index"],
    title="t-SNE visualization of Hyena embeddings (multi-label)",
    labels={"x": "t-SNE dim 1", "y": "t-SNE dim 2", "color": "Gene family"},
    opacity=0.7,
)
fig.update_layout(
    width=1024,
    height=768,
    legend_title="Label",
)
fig.show()