In [None]:
from tarp.model.backbone.untrained.lstm import LstmEncoder
from tarp.model.backbone.untrained.hyena import HyenaEncoder
from tarp.model.backbone import Encoder
from tarp.model.backbone.untrained.transformer import TransformerEncoder
from tarp.model.finetuning.classification import ClassificationModel

from tarp.services.datasets.classification.multilabel import MultiLabelClassificationDataset
from tarp.services.tokenizers.pretrained.dnabert import Dnabert2Tokenizer
from tarp.services.datasource.sequence import TabularSequenceSource, CombinationSource, FastaSliceSource


from tarp.services.preprocessing.augmentation import (
    CombinationTechnique,
    RandomMutation,
    InsertionDeletion,
    ReverseComplement,
)

from tarp.services.datasets.metric.triplet import MultiLabelOfflineTripletDataset


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import polars as pl
import numpy as np

from sklearn.neighbors import NearestNeighbors

from pathlib import Path

In [None]:
label_columns = (
    pl.read_csv(Path("../temp/data/processed/labels.csv")).to_series().to_list()
)

In [None]:
dataset = MultiLabelClassificationDataset(
    CombinationSource(
        [
            TabularSequenceSource(
                source=Path("../temp/data/processed/card_amr.parquet")
            ),
            FastaSliceSource(
                directory=Path("../temp/data/external/sequences"),
                metadata=Path("../temp/data/processed/non_amr_genes_10000.parquet"),
                key_column="genomic_nucleotide_accession.version",
                start_column="start_position_on_the_genomic_accession",
                end_column="end_position_on_the_genomic_accession",
                orientation_column="orientation",
            ),
        ]
    ),
    Dnabert2Tokenizer(),
    sequence_column="sequence",
    label_columns=label_columns,
    maximum_sequence_length=512,
    augmentation=CombinationTechnique(
        [
            RandomMutation(),
            InsertionDeletion(),
            ReverseComplement(0.5),
        ]
    ),
)

metric_dataset = MultiLabelOfflineTripletDataset(
    base_dataset=dataset, label_cache=Path("../temp/data/cache/labels_cache.parquet")
)

In [None]:
from tarp.config import HyenaConfig, TransformerConfig

In [None]:
encoder = TransformerEncoder(
    vocabulary_size=dataset.tokenizer.vocab_size,
    embedding_dimension=TransformerConfig.embedding_dimension,
    hidden_dimension=TransformerConfig.hidden_dimension,
    padding_id=dataset.tokenizer.pad_token_id,
    number_of_layers=TransformerConfig.number_of_layers,
    number_of_heads=TransformerConfig.number_of_heads,
    dropout=TransformerConfig.dropout,
)

classification_model = ClassificationModel(
    encoder=encoder,
    number_of_classes=len(label_columns),
)

# Filename of the checkpoint to load
# Should be the latest checkpoint saved during training
from pathlib import Path

latest_checkpoint = max(Path("../temp/checkpoints/").glob("*.pt"), key=lambda p: p.stat().st_mtime)

classification_model.load_state_dict(torch.load(latest_checkpoint.as_posix()))

print(f"Loaded model from checkpoint: {latest_checkpoint}")

# Get the encoder part of the model
encoder = classification_model.encoder

In [None]:
# Apply the model to the dataset to get the embeddings
from tqdm.auto import tqdm

# Preallocate numpy array for all embeddings
num_samples = len(dataset)
batch_size = 32
embedding_dim = encoder.encoding_size  # Output dimension of encoder.encode
embeddings = np.empty((num_samples, embedding_dim), dtype=np.float32)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(DEVICE)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
start_idx = 0
encoder.eval()
with torch.no_grad():
    for batch in tqdm(dataloader):
        input_ids = batch["sequence"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        batch_embeddings = encoder.encode(input_ids, attention_mask)
        batch_size_actual = batch_embeddings.shape[0]
        embeddings[start_idx:start_idx + batch_size_actual] = batch_embeddings.cpu().numpy()
        start_idx += batch_size_actual

In [None]:
# Print the shape of the embeddings
print("Embeddings shape:", embeddings.shape)

In [None]:
from sklearn.model_selection import train_test_split
SEED = 69420
indices = list(range(len(dataset)))
train_indices, temp_indices = train_test_split(
    indices, test_size=0.2, random_state=SEED
)
valid_indices, test_indices = train_test_split(
    temp_indices, test_size=0.5, random_state=SEED
)

In [None]:
# Split the dataset into train and test sets
train_embeddings = embeddings[train_indices]
test_embeddings = embeddings[test_indices]

print("Train embeddings shape:", train_embeddings.shape)
print("Test embeddings shape:", test_embeddings.shape)


In [None]:
train_labels = np.array([
    dataset[i]['labels'].numpy() for i in train_indices
])
test_labels = np.array([
    dataset[i]['labels'].numpy() for i in test_indices
])

In [None]:
# KNN search
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import SGDClassifier

knn = KNeighborsClassifier(n_neighbors=5, metric='cosine', weights='distance')

knn.fit(train_embeddings, train_labels)


In [None]:
from sklearn.metrics import classification_report

print(classification_report(
    train_labels,
    knn.predict(train_embeddings),
    zero_division=0,
    target_names=label_columns
))

In [None]:
print(classification_report(
    test_labels,
    knn.predict(test_embeddings),
    zero_division=0,
    target_names=label_columns
))

In [None]:
# Every label which is not "non-AMR", extract the index in the
amr_label_indices = [
    idx for idx, label in enumerate(label_columns) if "non_amr" not in label
]

# Consolidate all the AMR class labels into a single "AMR" label and then Combine with non-AMR to get binary labels
train_amr_binary = (train_labels[:, amr_label_indices].sum(axis=1) > 0).astype(int)
test_amr_binary = (test_labels[:, amr_label_indices].sum(axis=1) > 0).astype(int)

In [None]:
binary_knn = KNeighborsClassifier(n_neighbors=5, metric='cosine', weights='distance')

binary_knn.fit(train_embeddings, train_amr_binary)

In [None]:
print(
    classification_report(
        train_amr_binary,
        binary_knn.predict(train_embeddings),
        zero_division=0,
        target_names=["non-AMR", "AMR"],
    )
)

In [None]:
print(
    classification_report(
        test_amr_binary,
        binary_knn.predict(test_embeddings),
        zero_division=0,
        target_names=["AMR", "non-AMR"],
        labels=[1, 0],
    )
)

In [None]:
# %%
import plotly.express as px

# Import UMAP and t-SNE
from umap import UMAP

# Run UMAP on the embeddings
umap = UMAP(n_components=2)
train_embeddings_2d = umap.fit_transform(train_embeddings)


In [None]:
rows = []
for i in range(len(train_embeddings)):
    active_labels = [label_columns[j] for j in np.where(train_labels[i] > 0)[0]]
    for label in active_labels:
        rows.append({
            "x": train_embeddings_2d[i, 0],
            "y": train_embeddings_2d[i, 1],
            "label": label,
            "sample_index": i,  # to identify duplicates
        })

df_vis = pl.DataFrame(rows)


In [None]:
fig = px.scatter(
    df_vis,
    x="x",
    y="y",
    color="label",
    hover_data=["sample_index"],
    title="UMAP visualization of Hyena embeddings (multi-label)",
    labels={"x": "UMAP dim 1", "y": "UMAP dim 2", "color": "Gene family"},
    opacity=0.7,
)

fig.update_layout(
    width=1024,
    height=768,
    legend_title="Label",
)
fig.show()

In [None]:
# Let's try t-SNE as well
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, random_state=SEED)

train_embeddings_2d = tsne.fit_transform(train_embeddings)

In [None]:
rows = []
for i in range(len(train_embeddings)):
    active_labels = [label_columns[j] for j in np.where(train_labels[i] > 0)[0]]
    for label in active_labels:
        rows.append({
            "x": train_embeddings_2d[i, 0],
            "y": train_embeddings_2d[i, 1],
            "label": label,
            "sample_index": i,  # to identify duplicates
        })
df_vis = pl.DataFrame(rows)
fig = px.scatter(
    df_vis,
    x="x",
    y="y",
    color="label",
    hover_data=["sample_index"],
    title="t-SNE visualization of Hyena embeddings (multi-label)",
    labels={"x": "t-SNE dim 1", "y": "t-SNE dim 2", "color": "Gene family"},
    opacity=0.7,
)
fig.update_layout(
    width=1024,
    height=768,
    legend_title="Label",
)
fig.show()

In [None]:
# Let's test ClassificationModel for test data
# Pytorch style classification with thresholds for multi-label classification

logits = []
labels = []

from torch.utils.data import DataLoader, Subset

test_dataloader = DataLoader(Subset(dataset=dataset, indices=test_indices), batch_size=32, shuffle=False)

classification_model.to(DEVICE)

for batch in tqdm(test_dataloader):
    input_ids = batch["sequence"].to(DEVICE)
    attention_mask = batch["attention_mask"].to(DEVICE)
    batch_logits = classification_model(input_ids, attention_mask)
    logits.append(batch_logits.detach().cpu())
    labels.append(batch["labels"].detach().cpu())

In [None]:
print(logits[0].shape)
print(labels[0].shape)

In [None]:
# Make sure both are 2D
predictions = torch.cat(logits, dim=0)         # [N, num_labels]
labels = torch.cat(labels, dim=0)             # [N, num_labels])
preds = (torch.sigmoid(predictions) >= 0.5).int().cpu().numpy()

from sklearn.metrics import classification_report
print(classification_report(labels, preds, zero_division=0))


In [None]:
import numpy as np
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
import plotly.graph_objects as go

thresholds = np.linspace(0.05, 0.95, 19)
f1_scores = []
accuracy_scores = []
precision_scores = []
recall_scores = []

for t in thresholds:
    preds = (predictions >= t).int().cpu().numpy()
    f1 = f1_score(labels, preds, average='micro')
    f1_scores.append(f1)
    accuracy = accuracy_score(labels, preds)
    accuracy_scores.append(accuracy)
    precision = precision_score(labels, preds, average='micro')
    precision_scores.append(precision)
    recall = recall_score(labels, preds, average='micro')
    recall_scores.append(recall)

best_idx = np.argmax(f1_scores)
best_thr = thresholds[best_idx]
best_f1 = f1_scores[best_idx]

fig = go.Figure()
fig.add_trace(go.Scatter(x=thresholds, y=f1_scores, mode='lines+markers', name='F1 Score'))
fig.add_trace(go.Scatter(x=thresholds, y=accuracy_scores, mode='lines+markers', name='Accuracy'))
fig.add_trace(go.Scatter(x=thresholds, y=precision_scores, mode='lines+markers', name='Precision'))
fig.add_trace(go.Scatter(x=thresholds, y=recall_scores, mode='lines+markers', name='Recall'))
fig.update_layout(
    title='Classification Metrics vs. Threshold',
    xaxis_title='Threshold',
    yaxis_title='Score',
    legend_title='Metrics',
    width=800,
    height=600
)
fig.show()
print(f"Best global threshold = {best_thr:.2f} (F1={best_f1:.4f})")

In [None]:
import numpy as np
print("Positive label ratio:", labels.mean(axis=0).mean())

In [None]:
probs = torch.sigmoid(predictions).cpu().numpy()
fig = px.histogram(
    probs.flatten(),
    nbins=50,
    title="Distribution of predicted probabilities",
    labels={"value": "Sigmoid output", "count": "Frequency"},
)
fig.show()

In [None]:
thresholds = np.linspace(0.05, 0.95, 19)
pred_counts = [(probs >= t).sum() for t in thresholds]
print(list(zip(thresholds, pred_counts)))
