In [None]:
# start coding here

import pandas as pd
import torch

In [None]:
query_variants = pd.read_csv(snakemake.input.query_variants)
text_embeddings = torch.load(
    snakemake.input.text_embeddings, map_location=torch.device("cpu")
)

In [None]:
query_variants

In [None]:
text_embeddings.shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

# Example data
# Perform PCA
pca = PCA(n_components=2)
principal_components = pca.fit_transform(text_embeddings.cpu().numpy())
explained_variance = pca.explained_variance_ratio_


# Plotting
plt.figure(figsize=(5, 5))
unique_keys = np.unique(query_variants["query"])
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_keys)))

for key, color in zip(unique_keys, colors):
    indices = query_variants["query"] == key
    plt.scatter(
        principal_components[indices, 0],
        principal_components[indices, 1],
        c=[color],
        label=key,
        alpha=0.7,
    )

plt.xlabel(f"Principal Component 1 ({explained_variance[0]:.2%} variance)")
plt.ylabel(f"Principal Component 2 ({explained_variance[1]:.2%} variance)")

plt.title("PCA of Dataset")
# plt.legend()
plt.grid(True)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import umap  # Import UMAP from umap-learn

# Example data
# Perform UMAP
umap_model = umap.UMAP(n_components=2, random_state=42, densmap=True)
umap_embedding = umap_model.fit_transform(text_embeddings.cpu().numpy())

In [None]:
# Plotting
plt.figure(figsize=(5, 5))
unique_keys = np.unique(query_variants["query"])
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_keys)))

for key, color in zip(unique_keys, colors):
    indices = query_variants["query"] == key
    plt.scatter(
        umap_embedding[indices, 0],
        umap_embedding[indices, 1],
        c=[color],
        label=key,
        alpha=0.7,
    )

plt.xlabel("UMAP 1")
plt.ylabel("UMAP 2")
plt.title("densMAP of Dataset")
plt.legend()
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

plt.grid(True)
plt.savefig(snakemake.output.plot)

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Example data preparation
# Assume text_embeddings is a numpy array and query_variants["query"] is a pandas Series
# Group data by keys
grouped_data = {}
unique_keys = np.unique(query_variants["query"])

for key in unique_keys:
    indices = query_variants["query"] == key
    grouped_data[key] = text_embeddings.cpu().numpy()[indices]

# Calculate correlation matrix
correlation_matrix = np.zeros((len(unique_keys), len(unique_keys)))

for i, key1 in enumerate(unique_keys):
    for j, key2 in enumerate(unique_keys):
        # Compute correlation between the mean embeddings of each group
        mean1 = np.mean(grouped_data[key1], axis=0)
        mean2 = np.mean(grouped_data[key2], axis=0)
        correlation_matrix[i, j] = np.corrcoef(mean1, mean2)[0, 1]

# Convert to DataFrame for seaborn
correlation_df = pd.DataFrame(
    correlation_matrix, index=unique_keys, columns=unique_keys
)

# Plot heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(correlation_df, annot=True, cmap="coolwarm", fmt=".1f")
plt.title("Correlation Heatmap by Keys")
plt.xlabel("Keys")
plt.ylabel("Keys")
plt.tight_layout()