In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
influenza_metadata = pd.read_csv(
    snakemake.input.influenza_metadata,
    sep="\t",
    index_col="strain",
)

In [None]:
influenza_metadata.head()

In [None]:
influenza_metadata.shape

In [None]:
influenza_pca = pd.read_csv(
    snakemake.input.influenza_pca,
    index_col="strain",
).join(
    influenza_metadata,
)

In [None]:
influenza_pca.head()

In [None]:
influenza_pca.shape

In [None]:
influenza_mds = pd.read_csv(
    snakemake.input.influenza_mds,
    index_col="strain",
).join(
    influenza_metadata,
)

In [None]:
influenza_mds.head()

In [None]:
influenza_tsne = pd.read_csv(
    snakemake.input.influenza_tsne,
    index_col="strain",
).join(
    influenza_metadata,
)

In [None]:
influenza_tsne.head()

In [None]:
influenza_umap = pd.read_csv(
    snakemake.input.influenza_umap,
    index_col="strain",
).join(
    influenza_metadata,
)

In [None]:
influenza_umap.head()

In [None]:
coronavirus_metadata = pd.read_csv(
    snakemake.input.coronavirus_metadata,
    sep="\t",
    index_col="strain",
)

In [None]:
coronavirus_metadata.head()

In [None]:
coronavirus_pca = pd.read_csv(
    snakemake.input.coronavirus_pca,
    index_col="strain",
).join(
    coronavirus_metadata,
)

In [None]:
coronavirus_pca.head()

In [None]:
coronavirus_pca.shape

In [None]:
coronavirus_pca["is_recombinant"].sum()

In [None]:
coronavirus_mds = pd.read_csv(
    snakemake.input.coronavirus_mds,
    index_col="strain",
).join(
    coronavirus_metadata,
)

In [None]:
coronavirus_mds.head()

In [None]:
coronavirus_tsne = pd.read_csv(
    snakemake.input.coronavirus_tsne,
    index_col="strain",
).join(
    coronavirus_metadata,
)

In [None]:
coronavirus_tsne.head()

In [None]:
coronavirus_umap = pd.read_csv(
    snakemake.input.coronavirus_umap,
    index_col="strain",
).join(
    coronavirus_metadata,
)

In [None]:
coronavirus_umap.head()

In [None]:
fig, axes = plt.subplots(
    nrows=2,
    ncols=4,
    figsize=(8, 4),
    dpi=150,
    gridspec_kw={
        "wspace": 0.0,
        "hspace": 0.0
    },
    constrained_layout=False,
)

alpha = 0.25
size = 12
recombinant_size = 150

# Influenza

## PCA
axes[0][0] = sns.scatterplot(
    data=influenza_pca,
    x="pca1",
    y="pca2",
    hue="generation",
    ax=axes[0][0],
    alpha=alpha,
    s=size,
    legend=False,
)


## MDS
axes[0][1] = sns.scatterplot(
    data=influenza_mds,
    x="mds1",
    y="mds2",
    hue="generation",
    ax=axes[0][1],
    alpha=alpha,
    s=size,
    legend=False,
)

## t-SNE
axes[0][2] = sns.scatterplot(
    data=influenza_tsne,
    x="tsne_x",
    y="tsne_y",
    hue="generation",
    ax=axes[0][2],
    alpha=alpha,
    s=size,
    legend=False,
)

## UMAP
axes[0][3] = sns.scatterplot(
    data=influenza_umap,
    x="umap_x",
    y="umap_y",
    hue="generation",
    ax=axes[0][3],
    alpha=alpha,
    s=size,
    legend=False,
)

# Coronavirus-like (moderate recombination rate)

## PCA
axes[1][0] = sns.scatterplot(
    data=coronavirus_pca,
    x="pca1",
    y="pca2",
    hue="generation",
    ax=axes[1][0],
    alpha=alpha,
    s=size,
    legend=True,
)

coronavirus_handles, coronavirus_labels = axes[1][0].get_legend_handles_labels()
axes[1][0].legend(
    [coronavirus_handles[0], coronavirus_handles[-1]],
    ["early", "late"],
    loc="upper center",
    title="Generation",
    frameon=False,
    columnspacing=0.5,
    handletextpad=0.5,
    title_fontsize=10,
)

## MDS
axes[1][1] = sns.scatterplot(
    data=coronavirus_mds,
    x="mds1",
    y="mds2",
    hue="generation",
    ax=axes[1][1],
    alpha=alpha,
    s=size,
    legend=False,
)

## t-SNE
axes[1][2] = sns.scatterplot(
    data=coronavirus_tsne,
    x="tsne_x",
    y="tsne_y",
    hue="generation",
    ax=axes[1][2],
    alpha=alpha,
    s=size,
    legend=False,
)

## UMAP
axes[1][3] = sns.scatterplot(
    data=coronavirus_umap,
    x="umap_x",
    y="umap_y",
    hue="generation",
    ax=axes[1][3],
    alpha=alpha,
    s=size,
    legend=False,
)

for ax in axes.flatten():
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_xticks([])
    ax.set_yticks([])
    
axes[0][0].set_title("PCA")
axes[0][1].set_title("MDS")
axes[0][2].set_title("t-SNE")
axes[0][3].set_title("UMAP")

axes[0][0].set_ylabel("Influenza-like")
axes[1][0].set_ylabel("Coronavirus-like")

plt.tight_layout()
plt.savefig(snakemake.output.all_embeddings)

In [None]:
fig, axes = plt.subplots(
    nrows=2,
    ncols=2,
    figsize=(5, 4),
    dpi=150,
    constrained_layout=False,
)

alpha = 0.25
size = 12
recombinant_size = 150

# Influenza

## MDS 1 and 2
axes[0][0] = sns.scatterplot(
    data=influenza_mds,
    x="mds1",
    y="mds2",
    hue="generation",
    ax=axes[0][0],
    alpha=alpha,
    s=size,
    legend=False,
)
axes[0][0].set_xlabel("MDS 1")
axes[0][0].set_ylabel("Influenza-like MDS 2")

## MDS 2 and 3
axes[0][1] = sns.scatterplot(
    data=influenza_mds,
    x="mds2",
    y="mds3",
    hue="generation",
    ax=axes[0][1],
    alpha=alpha,
    s=size,
    legend=True,
)
influenza_mds_handles, influenza_mds_labels = axes[0][1].get_legend_handles_labels()
axes[0][1].legend(
    [influenza_mds_handles[0], influenza_mds_handles[-1]],
    ["early", "late"],
    bbox_to_anchor=(1.05, 1),
    loc="upper left",
    title="Generation",
    frameon=False,
    ncol=1,
    columnspacing=0.5,
)
axes[0][1].set_xlabel("MDS 2")
axes[0][1].set_ylabel("Influenza-like MDS 3")

# Coronavirus-like (moderate recombination rate)

## MDS 1 and 2
axes[1][0] = sns.scatterplot(
    data=coronavirus_mds,
    x="mds1",
    y="mds2",
    hue="generation",
    ax=axes[1][0],
    alpha=alpha,
    s=size,
    legend=False,
)
axes[1][0].set_xlabel("MDS 1")
axes[1][0].set_ylabel("Coronavirus-like MDS 2")

## MDS 2 and 3
axes[1][1] = sns.scatterplot(
    data=coronavirus_mds,
    x="mds2",
    y="mds3",
    hue="generation",
    ax=axes[1][1],
    alpha=alpha,
    s=size,
    legend=True,
)
coronavirus_mds_handles, coronavirus_mds_labels = axes[1][1].get_legend_handles_labels()
axes[1][1].legend(
    [coronavirus_mds_handles[0], coronavirus_mds_handles[-1]],
    ["early", "late"],
    bbox_to_anchor=(1.05, 1),
    loc="upper left",
    title="Generation",
    frameon=False,
    columnspacing=0.5,
)

axes[1][1].set_xlabel("MDS 2")
axes[1][1].set_ylabel("Coronavirus-like MDS 3")

for ax in axes.flatten():
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.savefig(snakemake.output.mds_embeddings)