In [None]:
import warnings

from git import Repo
import pandas as pd
import seaborn as sns
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm, trange

from data import load, Dream
from models import VaishnavCNN, SimpleCNN, fix_seeds
from models.utils import pearsonr, spearmanr


repo = Repo(search_parent_directories=True)

# save the current commit hash iff the repo has no un-committed changes
sha = None if repo.is_dirty() else repo.head.object.hexsha

if not sha:
    warnings.warn("Uncommitted changes. The model parameters won't be saved.")

# Model and hyperparameters

In [None]:
# model hyperparameters
n_epochs = 10
batch_size = 1024
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
seed = 0

fix_seeds(seed)

# model specification
net = SimpleCNN().to(device)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

# Data loading

In [None]:
# train and validation
val_size = 10000

tr_cached = "train_dev.pt" if device.type == "cpu" else "train.pt"
tr = load("train_sequences.txt", tr_cached, Dream, path="../data/dream")
tr, val = torch.utils.data.random_split(tr, [len(tr) - val_size, val_size])

tr_loader = DataLoader(tr, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val, batch_size=val_size)

# test (unlabelled)
te = load("test_sequences.txt", "test.pt", Dream, path="../data/dream")

In [None]:
# # create a smaller training set for development
# tr = load("train_sequences.txt", tr_cached, Dream, path="../data/dream")
# tr, val = torch.utils.data.random_split(tr, [len(tr) - val_size, val_size])

# d = Dream([''], [0])
# d.sequences, d.rc_sequences, d.expression = next(iter(DataLoader(tr, batch_size=100000 + val_size, shuffle=True)))
# d.expression = d.expression.flatten()

# torch.save(d, '../data/dream/train_dev.pt')

# Model training

In [None]:
tr_losses = []
val_losses = []

val_seqs, val_rc, val_expression = next(iter(val_loader))
# val_expression = val_expression[None, :].T

for epoch in range(n_epochs):

    with tqdm(tr_loader) as tepoch:
        for seq, rc, y in tepoch:

            net.train()

            seq, rc, y = seq.to(device), rc.to(device), y.to(device)

            optimizer.zero_grad()
            y_pred = net(seq)
            tr_loss = criterion(y_pred, y)
            tr_loss.backward()
            optimizer.step()

            tr_losses.append(tr_loss.item())

            tepoch.set_postfix(
                tr_loss=tr_loss.item(),
                r=pearsonr(y, y_pred),
                rho=spearmanr(y, y_pred),
            )

        net.eval()

        with torch.no_grad():
            val_pred = net(val_seqs.to(device)).cpu()
            val_loss = criterion(val_pred, val_expression)
            val_losses.append(val_loss.item())

        # store model iff on a cuda environment and if the repo is clean
        if device.type == "cuda" and sha:
            net_name = net.__class__.__name__
            torch.save(net.state_dict(), f"../results/models/{net_name}.pt")
            torch.save(
                {
                    "commit": sha,
                    "train_loss": tr_losses,
                    "val_loss": val_losses,
                    "val_r": pearsonr(val_expression, val_pred),
                    "val_rho": spearmanr(val_expression, val_pred),
                },
                f"../results/models/{net_name}_stats.pt",
            )

# Model analysis

In [None]:
sns.scatterplot(x=val_expression.detach().numpy().flatten(), y=val_pred.flatten())

In [None]:
sns.lineplot([i for i in range(len(tr_losses))], tr_losses)

In [None]:
sns.lineplot([i for i in range(len(val_losses))], val_losses)