In [None]:
import pandas as pd
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm, trange

from data import load, Dream
from models import BaselineCNN, SimpleCNN, fix_seeds

In [None]:
n_epochs = 10
batch_size = 1024
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
seed = 0

tr_cached = "train_dev.pt" if device.type == "cpu" else "train.pt"
tr = load("train_sequences.txt", tr_cached, Dream, path="../data/dream")

tr_loader = DataLoader(tr, batch_size=batch_size, shuffle=True, drop_last=True)

te_cached = "test_dev.pt" if device == "cpu" else "test.pt"
te = load("test_sequences.txt", te_cached, Dream, path="../data/dream")

In [None]:
fix_seeds(seed)
tr_losses = []
te_losses = []

net = SimpleCNN().to(device)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

for epoch in range(n_epochs):
    with tqdm(
        tr_loader, total=int(len(tr_loader) / batch_size), unit="batch"
    ) as tepoch:
        for seq, rc, y in tepoch:

            net.train()

            seq, rc, y = seq.to(device), rc.to(device), y.to(device)

            optimizer.zero_grad()
            y_pred = net(seq)
            tr_loss = criterion(y_pred, y)
            tr_loss.backward()
            optimizer.step()

            tr_losses.append(tr_loss.item())

            y = y.cpu().detach().numpy()
            y_pred = y_pred.cpu().detach().numpy()

            tepoch.set_postfix(
                tr_loss=tr_loss.item(),
                r=pearsonr(y.flatten(), y_pred.flatten())[0],
                rho=spearmanr(y, y_pred)[0],
            )

        net.eval()

        with torch.no_grad():
            te_pred = net(te.sequences.to(device)).cpu()
            te_loss = criterion(te_pred, te.expression[None, :].T)
            te_losses.append(te_loss.item())

        net_name = net.__class__.__name__
        torch.save(net.state_dict(), f"../results/models/{net_name}.pt")
        torch.save(
            {"train_loss": tr_losses, "val_loss": te_losses, "val_pred": te_pred},
            f"../results/models/{net_name}_stats.pt",
        )

In [None]:
with torch.no_grad():
    y_pred = net(tr.sequences[:50000]).cpu().detach().numpy().T[0]
y = tr.expression[:50000].detach().numpy().T

sns.scatterplot(x=y, y=y_pred)

In [None]:
sns.lineplot([i for i in range(len(tr_losses))], tr_losses)

In [None]:
sns.lineplot([i for i in range(len(te_losses))], te_losses)

In [None]:
torch.Tensor(tr_losses)