In [None]:
from random import random
import warnings

from git import Repo
import pandas as pd
from pyprojroot import here
import seaborn as sns
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm, trange

from data import load, Dream, ReverseComplement
from models import SimpleCNN_GELU, fix_seeds
from models.utils import pearsonr, spearmanr


repo = Repo(search_parent_directories=True)

# save the current commit hash iff the repo has no un-committed changes
sha = None if repo.is_dirty() else repo.head.object.hexsha

if not sha:
    warnings.warn("Uncommitted changes. The model parameters won't be saved.")

# Model and hyperparameters

In [None]:
# model hyperparameters
model = SimpleCNN_GELU
cuda_device = 0
n_epochs = 10
batch_size = 1024
device = (
    torch.device(f"cuda:{cuda_device}")
    if torch.cuda.is_available()
    else torch.device("cpu")
)
seed = 0

# data transforms
rc_transform = True

fix_seeds(seed)

In [None]:
# model specification
net = model().to(device)
net_name = net.__class__.__name__

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

# Data loading

In [None]:
# train and validation
val_size = 10000

tr_cached = "train_dev.pt" if device.type == "cpu" else "train.pt"
tr = load("train_sequences.txt", tr_cached, Dream, path=here("data/dream"))
tr, val = torch.utils.data.random_split(tr, [len(tr) - val_size, val_size])

val_loader = DataLoader(val, batch_size=val_size)
val_seqs, val_rc, val_expression = next(iter(val_loader))

# test (unlabelled)
te = load("test_sequences.txt", "test.pt", Dream, path=here("data/dream"))

In [None]:
# data transformations
tf = []

if rc_transform:
    net_name += "_t=rc"

tr.transforms = transforms.Compose(tf)
tr_loader = DataLoader(tr, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
# # create a smaller training set for development
# tr = load("train_sequences.txt", tr_cached, Dream, path=here("data/dream"))
# tr, val = torch.utils.data.random_split(tr, [len(tr) - val_size, val_size])

# d = Dream([''], [0])
# d.sequences, d.rc_sequences, d.expression = next(iter(DataLoader(tr, batch_size=100000 + val_size, shuffle=True)))
# d.expression = d.expression.flatten()

# torch.save(d, here('data/dream/train_dev.pt'))

# Model training

In [None]:
tr_losses = []
val_losses = []
tr_pearson = 0
tr_spearman = 0

# scaler for mixed precision training
scaler = torch.cuda.amp.GradScaler()

for epoch in range(n_epochs):
    with tqdm(tr_loader) as tepoch:
        for seq, rc, y in tepoch:

            # forward
            net.train()

            seq, rc, y = seq.to(device), rc.to(device), y.to(device)

            if rc_transform and random() < 0.5:
                seq, rc = rc, seq

            with torch.autocast(device_type=device.type):  # mixed precision
                y_pred = net(seq, rc)
                tr_loss = criterion(y_pred, y)

            # backward (with mixed precision)
            optimizer.zero_grad()
            scaler.scale(tr_loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # evaluation
            tr_losses.append(tr_loss.item())
            tr_pearson = 0.9 * tr_pearson + 0.1 * pearsonr(y, y_pred)
            tr_spearman = 0.9 * tr_spearman + 0.1 * spearmanr(y, y_pred)

            tepoch.set_postfix(
                tr_loss=tr_loss.item(),
                r=tr_pearson,
                rho=tr_spearman,
            )

        net.eval()

        with torch.no_grad():
            val_pred = net(val_seqs.to(device), val_rc.to(device)).cpu()
            val_loss = criterion(val_pred, val_expression)
            val_losses.append(val_loss.item())
            val_pearson = pearsonr(val_expression, val_pred)
            val_spearman = spearmanr(val_expression, val_pred)

        # store model iff on a cuda environment and if the repo is clean
        if device.type == "cuda" and sha:
            torch.save(net.state_dict(), here(f"results/models/{net_name}.pt"))
            torch.save(
                {
                    "commit": sha,
                    "train_loss": tr_losses,
                    "train_pearson": tr_pearson,
                    "train_spearman": tr_spearman,
                    "val_loss": val_losses,
                    "val_pearson": val_pearson,
                    "val_spearman": val_spearman,
                },
                here(f"results/models/{net_name}_stats.pt"),
            )

In [None]:
if device.type == "cuda" and sha:
    with (open(here("results/models/summary.tsv"), "a")) as S:
        S.write(
            f"{net_name}\t{tr_pearson}\t{tr_spearman}\t{val_pearson}\t{val_spearman}\t{sha}\n"
        )

In [None]:
sns.lineplot([i for i in range(len(tr_losses))], tr_losses).set(title="Train loss")

In [None]:
sns.lineplot([i for i in range(len(val_losses))], val_losses).set(
    title="Validation loss"
)

# Model analysis

In [None]:
del net

net = model().to(device)
net.load_state_dict(torch.load(here(f"results/models/{net_name}.pt")))
net.eval()

with torch.no_grad():
    # training predictions
    tr_seqs, tr_rc, tr_expression = next(iter(tr_loader))
    tr_pred = net(tr_seqs.to(device), tr_rc.to(device)).cpu()

    # validation predictions
    val_pred = net(val_seqs.to(device), val_rc.to(device)).cpu()

In [None]:
sns.scatterplot(
    x=tr_expression.detach().numpy().flatten(), y=tr_pred.detach().numpy().flatten()
).set(title="Train predictions")

In [None]:
sns.scatterplot(
    x=val_expression.detach().numpy().flatten(), y=val_pred.detach().numpy().flatten()
).set(title="Validation predictions")