In [None]:
from datasets import load_from_disk
import torch
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd

In [None]:
dataset = load_from_disk("../tmp/data/pssm/pssm_dataset_0_only")
print(dataset)

for x in range(0, 18, 4):
    print(dataset[x]["name"])
    print(dataset[x]["sequence"])
    print(torch.tensor(dataset[x]["pssm_features"]).shape)
display(pd.DataFrame(dataset[x]["pssm_features"]))

print(dataset)

In [None]:
tokens = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
token2idx = dict(zip(tokens, range(len(tokens))))
idx2token = dict(zip(range(len(tokens)), tokens))

print(token2idx)
print(idx2token)

In [4]:
def plot_pssm_heatmap(dataset, sample_idx, figsize=(12, 8)):
    plt.figure(figsize=figsize)
    ax = sns.heatmap(torch.tensor(dataset[sample_idx]["pssm_features"]), cmap="YlOrRd")
    plt.title(f"PSSM Features Heatmap for Sample {sample_idx} ({dataset[sample_idx]['name']})")
    plt.xlabel("PSSM Position")
    plt.ylabel("Sequence Position")

    x_labels = [f"{i} ({idx2token[i]})" for i in range(20)]
    ax.set_xticks(range(20))
    ax.set_xticklabels(x_labels, rotation=45)

    plt.show()


# for x in torch.randint(0, len(dataset), (10,)).tolist():
#     plot_pssm_heatmap(dataset, sample_idx=x)

In [None]:
# Get sequence lengths
sequence_lengths = [len(sample["sequence"]) for sample in dataset]

plt.figure(figsize=(10, 6))
plt.hist(sequence_lengths, bins=50, edgecolor="black")
plt.title("Distribution of Sequence Lengths")
plt.xlabel("Sequence Length")
plt.ylabel("Count")
plt.grid(True, alpha=0.3)
plt.show()

# Print some statistics
print(f"Mean sequence length: {np.mean(sequence_lengths):.2f}")
print(f"Median sequence length: {np.median(sequence_lengths):.2f}")
print(f"Min sequence length: {min(sequence_lengths)}")
print(f"Max sequence length: {max(sequence_lengths)}")


In [None]:
# Calculate mode sequence length
from statistics import mode

mode_length = mode(sequence_lengths)
print(f"Mode sequence length: {mode_length}")


In [None]:
# Get all sequences of length 57
length_57_indices = [i for i, sample in enumerate(dataset) if len(sample["sequence"]) == 57]
length_57_pssms = [dataset[i]["pssm_features"] for i in length_57_indices]

# Convert to tensor and calculate mean
length_57_pssms_tensor = torch.stack([torch.tensor(pssm) for pssm in length_57_pssms])
mean_pssm = torch.mean(length_57_pssms_tensor, dim=0)

# Plot average PSSM
plt.figure(figsize=(10, 6))
ax = sns.heatmap(mean_pssm, cmap="YlOrRd")
plt.title(f"Average PSSM Features for Sequences of Length 57 (n={len(length_57_indices)})")
plt.xlabel("PSSM Position")
plt.ylabel("Sequence Position")

x_labels = [f"{i} ({idx2token[i]})" for i in range(20)]
ax.set_xticks(range(20))
ax.set_xticklabels(x_labels, rotation=45)

plt.show()
