In [None]:
from random import random
import warnings

from git import Repo
import matplotlib.pyplot as plt
import numba
import pandas as pd
from pyprojroot import here
import seaborn as sns
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm, trange
from umap import UMAP

from data import load, Dream, ReverseComplement
from models import TransformerCNN, fix_seeds
from models.utils import pearsonr, spearmanr, numpify

plt.rcParams["figure.figsize"] = [15, 7.5]

# save the current commit hash iff the repo has no un-committed changes
repo = Repo(search_parent_directories=True)
sha = None if repo.is_dirty() else repo.head.object.hexsha

if not sha:
    warnings.warn("Uncommitted changes. The model parameters won't be saved.")

# Model and hyperparameters

In [None]:
# model hyperparameters
model_obj = TransformerCNN
cuda_device = 0
n_epochs = 10
batch_size = 1024
device = (
    torch.device(f"cuda:{cuda_device}")
    if torch.cuda.is_available()
    else torch.device("cpu")
)
seed = 0

# data transforms
rc_transform = False

fix_seeds(seed)

In [None]:
# model specification
model = model_obj().to(device)
model_name = model.__class__.__name__

## initialize last layer's bias to the average
model.fc[-1].bias = torch.nn.Parameter(torch.tensor(11.147))

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# Data loading

In [None]:
# train and validation
val_size = 10000

tr_cached = "train_dev.pt" if device.type == "cpu" else "train.pt"
tr = load("train_sequences.txt", tr_cached, Dream, path=here("data/dream"))
tr, val = torch.utils.data.random_split(tr, [len(tr) - val_size, val_size])

val_loader = DataLoader(val, batch_size=val_size)
val_seqs, val_rc, val_expression = next(iter(val_loader))

# test (unlabelled)
te = load("test_sequences.txt", "test.pt", Dream, path=here("data/dream"))

In [None]:
# data transformations
tf = []

if rc_transform:
    model_name += "_t=rc"

tr.transforms = transforms.Compose(tf)
tr_loader = DataLoader(tr, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
# # create a smaller training set for development
# tr = load("train_sequences.txt", tr_cached, Dream, path=here("data/dream"))
# tr, val = torch.utils.data.random_split(tr, [len(tr) - val_size, val_size])

# d = Dream([''], [0])
# d.sequences, d.rc_sequences, d.expression = next(iter(DataLoader(tr, batch_size=100000 + val_size, shuffle=True)))
# d.expression = d.expression.flatten()

# torch.save(d, here('data/dream/train_dev.pt'))

# Model training

In [None]:
tr_losses = []
val_losses = []
tr_pearson = 0
tr_spearman = 0

# scaler for mixed precision training
scaler = torch.cuda.amp.GradScaler()

for epoch in range(n_epochs):
    with tqdm(tr_loader) as tepoch:
        for seq, rc, y in tepoch:

            # forward
            model.train()

            seq, rc, y = seq.to(device), rc.to(device), y.to(device)

            if rc_transform and random() < 0.5:
                seq, rc = rc, seq

            with torch.autocast(device_type=device.type):  # mixed precision
                y_pred = model(seq, rc)
                tr_loss = criterion(y_pred, y)

            # backward (with mixed precision)
            optimizer.zero_grad()
            scaler.scale(tr_loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # evaluation
            model.eval()

            tr_losses.append(tr_loss.item())
            tr_pearson = 0.9 * tr_pearson + 0.1 * pearsonr(y, y_pred)
            tr_spearman = 0.9 * tr_spearman + 0.1 * spearmanr(y, y_pred)

            tepoch.set_postfix(
                tr_loss=tr_loss.item(),
                r=tr_pearson,
                rho=tr_spearman,
            )

        with torch.no_grad():
            val_pred = model(val_seqs.to(device), val_rc.to(device)).cpu()
            val_loss = criterion(val_pred, val_expression)
            val_losses.append(val_loss.item())
            val_pearson = pearsonr(val_expression, val_pred)
            val_spearman = spearmanr(val_expression, val_pred)

        # store model iff on a cuda environment and if the repo is clean
        if device.type == "cuda" and sha:
            torch.save(model.state_dict(), here(f"results/models/{model_name}.pt"))
            torch.save(
                {
                    "commit": sha,
                    "train_loss": tr_losses,
                    "train_pearson": tr_pearson,
                    "train_spearman": tr_spearman,
                    "val_loss": val_losses,
                    "val_pearson": val_pearson,
                    "val_spearman": val_spearman,
                },
                here(f"results/models/{model_name}_stats.pt"),
            )

In [None]:
if device.type == "cuda" and sha:
    with (open(here("results/models/summary.tsv"), "a")) as S:
        S.write(
            f"{model_name}\t{tr_pearson:.3f}\t{tr_spearman:.3f}\t{val_pearson:.3f}\t{val_spearman:.3f}\t{sha}\n"
        )

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 6))
sns.lineplot(x=list(range(len(tr_losses))), y=tr_losses)
sns.lineplot(
    x=[(i + 1) * len(tr_loader) for i in range(len(val_losses))],
    y=val_losses,
    color="orange",
)
ax.set(xlabel="Minibatch", ylabel="Loss")

# Model analysis

In [None]:
del model

model = model_obj().to(device)
model.load_state_dict(
    torch.load(here(f"results/models/{model_name}.pt"), map_location=device)
)
model.eval()


def hook_fn(module, input, output):
    global embedding
    embedding = output


hook = model.fc[-2].register_forward_hook(hook_fn)

with torch.no_grad():
    # training predictions
    tr_seqs, tr_rc, tr_expression = next(iter(tr_loader))
    tr_pred = model(tr_seqs.to(device), tr_rc.to(device)).cpu()
    tr_embedding = embedding

    # validation predictions
    val_pred = model(val_seqs.to(device), val_rc.to(device)).cpu()
    val_embedding = embedding

In [None]:
def umap(x, y, **sns_kwargs):

    x = numpify(x)
    y = numpify(y).flatten()
    x_emb = UMAP().fit_transform(x)

    sns.set_theme()
    g = sns.scatterplot(x=x_emb[:, 0], y=x_emb[:, 1], hue=y, **sns_kwargs)
    g.set(xlabel="UMAP 1", ylabel="UMAP 2")

    return g


plt.rcParams["figure.figsize"] = 15, 6

fig, ax = plt.subplots(1, 2)
g = umap(tr_embedding, tr_expression, ax=ax[0])
g.set_title("Train")
g = umap(val_embedding, val_expression, ax=ax[1])
g.set_title("Validation")

In [None]:
def plt_predictions(y, y_pred, split):

    y = numpify(y).flatten()
    y_pred = numpify(y_pred).flatten()
    axis = 0 if split == "Train" else 1

    g = sns.scatterplot(x=y, y=y_pred, ax=ax[axis])
    g.set(xlabel="y", ylabel="y_pred")
    g.set_title(split)
    ax[axis].axline([0, 0], [17, 17], color="red")


fig, ax = plt.subplots(1, 2)
plt_predictions(tr_expression, tr_pred, "Train")
plt_predictions(val_expression, val_pred, "Validation")