# Import

In [None]:
import os
import torch
import pytorch_lightning as pl
from time import time_ns
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from rnaquanet.network.graph_regression_network import GraphRegressionNetwork
from rnaquanet.network.grn_data_module import GRNDataModule
from rnaquanet.utils.rnaquanet_config import RnaquanetConfig
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
from torch.optim import Adam
from torch.nn import (
    BatchNorm1d,
    Identity,
    ReLU,
    LeakyReLU,
    Linear,
    MSELoss
)
import torch.nn.functional as F
from torch_geometric.nn import (
    GATConv,
    GCNConv,
    Sequential,
    global_mean_pool,
    BatchNorm,
)
from torch_geometric.loader import DataLoader
import numpy as np
from rnaquanet.data.preprocessing.hdf5_utils import load_data_from_hdf5
from IPython.display import clear_output
from tqdm import tqdm
import matplotlib.pyplot as plt


In [None]:
train_data = load_data_from_hdf5('/app/data/ares/train.h5')
val_data = load_data_from_hdf5('/app/data/ares/val.h5')
test_data = load_data_from_hdf5('/app/data/rnaquadataset/train.h5')

In [None]:
torch.set_float32_matmul_precision('high')
torch.manual_seed(2137)
device = torch.device('cuda:0')

batch_size = 100
patience = 100

model = Sequential('x, edge_index, edge_attr, batch', [
    (GCNConv(in_channels=96, out_channels=8192), 'x, edge_index -> x'),
    (ReLU(), 'x -> x'),
    (GCNConv(in_channels=8192, out_channels=1), 'x, edge_index -> x'),
    (global_mean_pool, 'x, batch -> x'),
]).to(device)

loss_fn = MSELoss(reduction='mean').to(device)
mse_no_reduce = MSELoss(reduction='none').to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
train_losses_epoch = []
val_losses_epoch = []
for epoch in range(30):
    train_losses_step = []
    val_losses_step = []

    def plot_loss_chart():
        plt.plot(train_losses_epoch, label='Train Loss', color='blue', alpha=0.5)
        plt.plot(val_losses_epoch, label='Validation Loss', color='red', alpha=0.5)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.tight_layout()
        plt.show()

    def get_desc():
        desc = f'Epoch {epoch+1},'
        if train_losses_epoch:
            desc += f' Previous training Loss {train_losses_epoch[-1]:.2f}, Best training Loss {np.min(train_losses_epoch):.2f},'
        if val_losses_epoch:
            desc += f' Previous validation Epoch Loss {val_losses_epoch[-1]:.2f}, Best validation Loss {np.min(val_losses_epoch):.2f},'
        
        return desc.rstrip(',')

    with tqdm(total=len(train_data), desc=get_desc(), position=0, leave=True) as progressbar:
        for item in DataLoader(train_data, batch_size=batch_size, shuffle=True):
            item = item.to(device)
            y_pred = model(x=item.x, edge_index=item.edge_index, edge_attr=item.edge_attr, batch=item.batch).view(-1)
            loss = loss_fn(y_pred, item.y)
            train_losses_step.extend(mse_no_reduce(y_pred, item.y).cpu().tolist())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            progressbar.set_description(get_desc())
            progressbar.update(item.y.cpu().shape[0])
        train_losses_epoch.append(np.mean(train_losses_step))
    clear_output(wait=True)
    plot_loss_chart()

    with tqdm(total=len(val_data), desc=get_desc(), position=0, leave=True) as progressbar:
        with torch.no_grad():
            for item in DataLoader(val_data, batch_size=batch_size, shuffle=False):
                item = item.to(device)
                y_pred = model(x=item.x, edge_index=item.edge_index, edge_attr=item.edge_attr, batch=item.batch).view(-1)
                val_losses_step.extend(mse_no_reduce(y_pred, item.y).cpu().tolist())
                progressbar.set_description(get_desc())
                progressbar.update(item.y.cpu().shape[0])
        val_losses_epoch.append(np.mean(val_losses_step))

    index = np.argmin(val_losses_epoch)
    clear_output(wait=True)
    plot_loss_chart()
    if index <= (epoch-patience):
        print(get_desc())
        break

with tqdm(total=len(test_data), position=0, leave=True) as progressbar:
    with torch.no_grad():
        test_losses_step = []
        for item in DataLoader(test_data, batch_size=batch_size, shuffle=False):
            item = item.to(device)
            y_pred = model(x=item.x, edge_index=item.edge_index, edge_attr=item.edge_attr, batch=item.batch).view(-1)
            test_losses_step.extend(mse_no_reduce(y_pred, item.y).cpu().tolist())
            progressbar.set_description(get_desc())
            progressbar.update(item.y.cpu().shape[0])
        print(np.mean(test_losses_step))