In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pyprojroot import here
import seaborn as sns
import torch

from data import load, Dream
from models.utils import numpify

device = torch.device(f"cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# train and validation
tr_cached = "train_dev.pt" if device.type == "cpu" else "train.pt"
tr = load("train_sequences.txt", tr_cached, Dream, path=here("data/dream"))

# test (unlabelled)
te = load("test_sequences.txt", "test.pt", Dream, path=here("data/dream"))

In [None]:
def plot_nt_freq(obj, **plt_kwargs):

    df = pd.DataFrame(
        obj.sequences.sum(axis=0).T / len(obj), columns=["A", "C", "T", "G"]
    )
    df["Position"] = list(range(80))
    df = df.melt("Position", var_name="Base", value_name="Frequency")

    g = sns.lineplot(x="Position", y="Frequency", hue="Base", data=df, **plt_kwargs)
    g.axis(ymin=0.14, ymax=0.33, xmin=0, xmax=79)
    g.set_title("Test" if obj == te else "Train")

    return g


fig, ax = plt.subplots(1, 2, figsize=(15, 6))
fig.suptitle("Nucleotide frequencies", fontsize=16)
plot_nt_freq(tr, ax=ax[0])
plot_nt_freq(te, ax=ax[1])

In [None]:
def plot_na(obj, **plt_kwargs):
    missing = len(obj) - obj.sequences.sum(axis=(0, 1))

    g = sns.lineplot(x=list(range(80)), y=missing + 0.99, **plt_kwargs)
    g.set(yscale="log")
    g.set_title("Test" if obj == te else "Train")
    g.axis(ymin=0.9, ymax=2e6)
    g.set(xlabel="Position", ylabel="Number of samples")

    return g


plt.rcParams["figure.figsize"] = 15, 6
fig, ax = plt.subplots(1, 2)
fig.suptitle("Number of missing values", fontsize=16)
plot_na(tr, ax=ax[0])
plot_na(te, ax=ax[1])

In [None]:
df = pd.DataFrame(
    {
        "n_missing_vals": numpify(80 - tr.sequences.sum(axis=(1, 2))),
        "expression": numpify(tr.expression).flatten(),
    }
)

sns.scatterplot(x="n_missing_vals", y="expression", data=df)

In [None]:
g = sns.histplot(x=numpify(tr.expression).flatten())
g.set(yscale="log")
g.set(xlabel="Expression")
g.axis(xmin=-0.1, xmax=17.1)

In [None]:
print("Min expression:", tr.expression.min().numpy())
print("Mean expression:", tr.expression.mean().numpy())
print("Median expression:", tr.expression.median().numpy())
print("Max expression:", tr.expression.max().numpy())

In [None]:
no_expression = Dream([""], 0)
no_expression.sequences = tr.sequences[tr.expression == 0, :]
no_expression.expression = tr.expression[tr.expression == 0]

df = pd.DataFrame(
    tr.sequences[tr.expression == 0, :].sum(axis=0).T / sum(tr.expression == 0),
    columns=["A", "C", "T", "G"],
)
df["Position"] = list(range(80))
df = df.melt("Position", var_name="Base", value_name="Frequency")

g = sns.lineplot(x="Position", y="Frequency", hue="Base", data=df)
g.axis(xmin=0, xmax=79)
g.set_title(
    f"Base frequencies in {sum(tr.expression == 0)} 0-expression sequences", fontsize=16
)

In [None]:
decimals, _ = np.modf(numpify(tr.expression).flatten())
g = sns.histplot(x=decimals)
g.set(yscale="log")
g.set(xlabel="Decimals")
g.axis(xmin=-0.01, xmax=1.01)

In [None]:
# top nubbers
u, counts = np.unique(decimals[decimals != 0], return_counts=True)
counts_sort_ind = np.argsort(-counts)

u[counts_sort_ind][:10]