In [None]:
import pandas as pd
import seaborn as sns
from sklearn.metrics import r2_score
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import transforms
from tqdm.notebook import tqdm, trange

from data import Vaishnav, ReverseComplement
from models import OneStrandCNN
from models.utils import fix_seeds

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(4, 16, 5, padding="same"),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(16, 32, 5, padding="same"),
            nn.ReLU(),
            nn.MaxPool1d(2),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(640, 256),
            nn.ReLU(),
            nn.Linear(256, 96),
            nn.ReLU(),
            nn.Linear(96, 1),
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)
        return x

In [None]:
n_epochs = 20
batch_size = 1024
device = "cpu"
seed = 0

tr = torch.load("../data/dream/train.pt")
tr_loader = DataLoader(tr, batch_size=batch_size, shuffle=True, drop_last=True)
te = torch.load("../data/dream/test.pt")

In [None]:
fix_seeds(seed)
tr_losses = []
te_losses = []

net = Model().to(device)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

for epoch in range(n_epochs):
    with tqdm(tr_loader, total=int(len(tr_loader) / batch_size), unit="batch") as tepoch:
        for seq, y in tepoch:

            net.train()

            seq, y = seq.to(device), y.to(device)

            optimizer.zero_grad()
            y_pred = net(seq)
            tr_loss = criterion(y_pred, y)
            tr_loss.backward()
            optimizer.step()

            tr_losses.append(tr_loss.item())

            tepoch.set_postfix(tr_loss=tr_loss.item(), 
                               r2=r2_score(y.detach().numpy(), 
                                           y_pred.detach().numpy()))
            
        net.eval()

        te_pred = net(te.sequences)
        te_loss = criterion(te_pred, te.expression[None,:].T)
        te_losses.append(te_loss.item())

In [None]:
y.detach().numpy()

In [None]:
d

In [None]:
r2_score(y.detach().numpy(),
         y_pred.detach().numpy())

In [None]:
y_pred = net(te.sequences).detach().numpy().T[0]
y = te.expression.detach().numpy().T

sns.scatterplot(x=y, y=y_pred)

In [None]:
r2_score(y, y_pred)

In [None]:
sns.lineplot([i for i in range(len(tr_losses))], tr_losses)

In [None]:
sns.lineplot([i for i in range(len(te_losses))], te_losses)