In [1]:
import os
import torch
import pytorch_lightning as pl
from time import time_ns
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from rnaquanet.network.graph_regression_network import GraphRegressionNetwork
from rnaquanet.network.grn_data_module import GRNDataModule
from rnaquanet.utils.rnaquanet_config import RnaquanetConfig
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
from torch.optim import Adam
from torch.nn import (
    BatchNorm1d,
    Identity,
    ReLU,
    LeakyReLU,
    Linear,
    MSELoss
)
import torch.nn.functional as F
from torch_geometric.nn import (
    GATConv,
    GCNConv,
    Sequential,
    global_mean_pool,
    BatchNorm,
    TransformerConv
)
from torch_geometric.nn.models import (
    GAT
)
from torch_geometric.loader import DataLoader
import numpy as np
from rnaquanet.data.preprocessing.hdf5_utils import load_data_from_hdf5
from IPython.display import clear_output
from tqdm import tqdm
import matplotlib.pyplot as plt
from rnaquanet.network.h5_graph_dataset import H5GraphDataset
import math
import pandas as pd

In [2]:
def get_empty_model(key: str) -> Sequential:
    conv = None
    first_dropout = None
    if key == 'ares' or key == 'seg1':
        conv = GATConv
        first_dropout = 0.5
    elif key == 'seg2' or key == 'transfer_seg2_ares':
        conv = TransformerConv
        first_dropout = 0.5
    elif key == 'seg3':
        conv = TransformerConv
        first_dropout = 0.8
    else:
        raise 'Unknown model key'


    return Sequential('x, edge_index, edge_attr, batch', [
        (conv(in_channels=99, out_channels=256, heads=4, dropout=first_dropout), f'x, edge_index{", edge_attr" if conv == GATConv else ""} -> x'),
        (BatchNorm(in_channels=256*4), 'x -> x'),
        (ReLU(), 'x -> x'),
        
        (conv(in_channels=256*4, out_channels=256, heads=8, dropout=0.5), f'x, edge_index{", edge_attr" if conv == GATConv else ""} -> x'),
        (BatchNorm(in_channels=256*8), 'x -> x'),
        (ReLU(), 'x -> x'),

        (GCNConv(in_channels=256*8, out_channels=256), 'x, edge_index -> x'),
        (global_mean_pool, 'x, batch -> x'),

        (Linear(in_features=256, out_features=64), 'x -> x'),
        (ReLU(), 'x -> x'),
        (Linear(in_features=64, out_features=1), 'x -> x'),
    ])

def get_model(key: str) -> Sequential:
    model = get_empty_model(key)
    model.load_state_dict(torch.load(os.path.join('..', 'results', 'model', f'{key}.pt')))

    if key == 'transfer_seg2_ares':
        # freeze all layers except MLP
        for child in model.children():
            if type(child) != Linear:
                for param in child.parameters():
                    param.requires_grad = False
                    
    return model

In [3]:
ares_test_data = H5GraphDataset('/app/data/ares/test.h5', return_key=True).__enter__()
rnaquadataset_test_data = H5GraphDataset('/app/data/rnaquadataset/test.h5', return_key=True).__enter__()
seg1_test_data = H5GraphDataset('/app/data/segments_1_coords/test.h5', return_key=True).__enter__()
seg2_test_data = H5GraphDataset('/app/data/segments_2_coords/test.h5', return_key=True).__enter__()
seg3_test_data = H5GraphDataset('/app/data/segments_3_coords/test.h5', return_key=True).__enter__()

ares_val_data = H5GraphDataset('/app/data/ares/val.h5', return_key=True).__enter__()
rnaquadataset_val_data = H5GraphDataset('/app/data/rnaquadataset/val.h5', return_key=True).__enter__()
seg1_val_data = H5GraphDataset('/app/data/segments_1_coords/val.h5', return_key=True).__enter__()
seg2_val_data = H5GraphDataset('/app/data/segments_2_coords/val.h5', return_key=True).__enter__()
seg3_val_data = H5GraphDataset('/app/data/segments_3_coords/val.h5', return_key=True).__enter__()

dfs1_data = H5GraphDataset('/app/data/dfs/dfs1.h5', return_key=True).__enter__()
dfs2_data = H5GraphDataset('/app/data/dfs/dfs2.h5', return_key=True).__enter__()
dfs3_data = H5GraphDataset('/app/data/dfs/dfs3.h5', return_key=True).__enter__()

In [4]:
torch.set_float32_matmul_precision('high')
device = torch.device('cuda:0')

In [5]:
df = pd.DataFrame({
    'training_dataset': [], # zbiór na którym był trenowany model
    'evaluation_dataset': [], # zbiór na którym weryfikowane są wyniki
    'structure': [], # klucz z pliku H5, nazwa struktury ze zbioru ewaluacyjnego
    'real_rmsd': [], # prawdziwe RMSD
    'predicted_rmsd': [], # przewidziane RMSD przez model
})

batch_size = 64

for key in ['ares', 'seg1', 'seg2', 'seg3', 'transfer_seg2_ares']:
    print(f'Training dataset: {key}')
    model = get_model(key)
    model = model.to(device)
    with torch.no_grad():
        for label, test_data in [
            ('ares val', ares_val_data),
            ('ares test', ares_test_data),
            # ('rnaquadataset val', rnaquadataset_val_data),
            # ('rnaquadataset test', rnaquadataset_test_data),
            ('seg1 val', seg1_val_data),
            ('seg1 test', seg1_test_data),
            ('seg2 val', seg2_val_data),
            ('seg2 test', seg2_test_data),
            ('seg3 val', seg3_val_data),
            ('seg3 test', seg3_test_data),
            # ('dfs1', dfs1_data),
            # ('dfs2', dfs2_data),
            # ('dfs3', dfs3_data),
        ]:
            print(f'Evaluation dataset: {label}')
            for keys, item in DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4):
                item = item.to(device)
                y_pred = model(x=item.x, edge_index=item.edge_index, edge_attr=item.edge_attr, batch=item.batch).view(-1)
                df = pd.concat([df, pd.DataFrame({
                    'training_dataset': [key]*len(keys),
                    'evaluation_dataset': [label]*len(keys),
                    'structure': keys,
                    'real_rmsd': item.y.cpu().tolist(),
                    'predicted_rmsd': y_pred.cpu().tolist()
                })])

filename = f'eval{time_ns()}.csv'
df.to_csv(filename)
df['squared rmsd difference'] = (df['real_rmsd'] - df['predicted_rmsd']) ** 2
df.to_csv(filename)

display(df)


Training dataset: ares
Evaluation dataset: ares val
Evaluation dataset: ares test
Evaluation dataset: seg1 val
Evaluation dataset: seg1 test
Evaluation dataset: seg2 val
Evaluation dataset: seg2 test
Evaluation dataset: seg3 val
Evaluation dataset: seg3 test
Training dataset: seg1
Evaluation dataset: ares val
Evaluation dataset: ares test
Evaluation dataset: seg1 val
Evaluation dataset: seg1 test
Evaluation dataset: seg2 val
Evaluation dataset: seg2 test
Evaluation dataset: seg3 val
Evaluation dataset: seg3 test
Training dataset: seg2
Evaluation dataset: ares val
Evaluation dataset: ares test
Evaluation dataset: seg1 val
Evaluation dataset: seg1 test
Evaluation dataset: seg2 val
Evaluation dataset: seg2 test
Evaluation dataset: seg3 val
Evaluation dataset: seg3 test
Training dataset: seg3
Evaluation dataset: ares val
Evaluation dataset: ares test
Evaluation dataset: seg1 val
Evaluation dataset: seg1 test
Evaluation dataset: seg2 val
Evaluation dataset: seg2 test
Evaluation dataset: seg

Unnamed: 0,training_dataset,evaluation_dataset,structure,real_rmsd,predicted_rmsd,squared rmsd difference
0,ares,ares val,157d_S_000004_minimize_001,6.463,6.936740,0.224430
1,ares,ares val,157d_S_000005_minimize_006,12.900,9.354218,12.572564
2,ares,ares val,157d_S_000010_minimize_007,5.600,6.185519,0.342833
3,ares,ares val,157d_S_000014_minimize_001,4.786,7.922315,9.836474
4,ares,ares val,157d_S_000014_minimize_006,8.363,7.841039,0.272444
...,...,...,...,...,...,...
60,transfer_seg2_ares,seg3 test,rs_8GLP_1_L7_A_9_C_8GLP_1_L7_5,1.244,9.900917,74.942214
61,transfer_seg2_ares,seg3 test,rs_8GLP_1_L7_A_9_C_8GLP_1_L7_6,1.033,12.863400,139.958375
62,transfer_seg2_ares,seg3 test,rs_8GLP_1_L7_A_9_C_8GLP_1_L7_7,1.296,11.193987,97.970144
63,transfer_seg2_ares,seg3 test,rs_8GLP_1_L7_A_9_C_8GLP_1_L7_8,1.028,16.551069,240.965679


In [8]:
ns = time_ns()
filename = f'loss{ns}.csv'
loss = df.groupby(['training_dataset', 'evaluation_dataset'])['squared rmsd difference'].mean().reset_index()
loss.to_csv(filename)
display(loss)

for training_dataset in loss['training_dataset'].unique():
    training_df = loss[loss['training_dataset'] == training_dataset]
    plt.figure(figsize=(10,6))
    rects = plt.bar(training_df["evaluation_dataset"], training_df["squared rmsd difference"])
    plt.bar_label(rects, fmt='%0.1f', padding=3)
    plt.ylabel('Average Loss (MSE)')
    plt.savefig(os.path.join('losses', f'{training_dataset}{ns}.png'))
    plt.close()


Unnamed: 0,training_dataset,evaluation_dataset,squared rmsd difference
0,ares,ares test,13.387998
1,ares,ares val,7.025
2,ares,seg1 test,21.18042
3,ares,seg1 val,19.365436
4,ares,seg2 test,35.531486
5,ares,seg2 val,34.786832
6,ares,seg3 test,65.203619
7,ares,seg3 val,60.689485
8,seg1,ares test,16.991443
9,seg1,ares val,12.044998


In [20]:
ns = time_ns()
grouped = df.groupby(['training_dataset', 'evaluation_dataset'])

for (training, evaluation), group in grouped:
    plt.boxplot([group['predicted_rmsd'], group['real_rmsd']], labels=['Predicted', 'Ground Truth'])
    plt.title(f"{evaluation.replace('ares', 'ARES').replace('val', 'validation')} set")
    plt.ylabel('RMSD score')
    plt.savefig(os.path.join('values', f'{training}_{evaluation.replace(" ", "_")}.png'))
    plt.close()