In [1]:
import os
import torch
import pytorch_lightning as pl
from time import time_ns
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from rnaquanet.network.graph_regression_network import GraphRegressionNetwork
from rnaquanet.network.grn_data_module import GRNDataModule
from rnaquanet.utils.rnaquanet_config import RnaquanetConfig
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
from torch.optim import Adam
from torch.nn import (
    BatchNorm1d,
    Identity,
    ReLU,
    LeakyReLU,
    Linear,
    MSELoss
)
import torch.nn.functional as F
from torch_geometric.nn import (
    GATConv,
    GCNConv,
    Sequential,
    global_mean_pool,
    BatchNorm,
    TransformerConv
)
from torch_geometric.nn.models import (
    GAT
)
from torch_geometric.loader import DataLoader
import numpy as np
from rnaquanet.data.preprocessing.hdf5_utils import load_data_from_hdf5
from IPython.display import clear_output
from tqdm import tqdm
import matplotlib.pyplot as plt
from rnaquanet.network.h5_graph_dataset import H5GraphDataset
import math
import pandas as pd

In [2]:
def get_empty_model(key: str) -> Sequential:
    conv = None
    first_dropout = None
    if key == 'ares' or key == 'seg1':
        conv = GATConv
        first_dropout = 0.5
    elif key == 'seg2' or key == 'transfer_seg2_ares':
        conv = TransformerConv
        first_dropout = 0.5
    elif key == 'seg3':
        conv = TransformerConv
        first_dropout = 0.8
    else:
        raise 'Unknown model key'


    return Sequential('x, edge_index, edge_attr, batch', [
        (conv(in_channels=99, out_channels=256, heads=4, dropout=first_dropout), f'x, edge_index{", edge_attr" if conv == GATConv else ""} -> x'),
        (BatchNorm(in_channels=256*4), 'x -> x'),
        (ReLU(), 'x -> x'),
        
        (conv(in_channels=256*4, out_channels=256, heads=8, dropout=0.5), f'x, edge_index{", edge_attr" if conv == GATConv else ""} -> x'),
        (BatchNorm(in_channels=256*8), 'x -> x'),
        (ReLU(), 'x -> x'),

        (GCNConv(in_channels=256*8, out_channels=256), 'x, edge_index -> x'),
        (global_mean_pool, 'x, batch -> x'),

        (Linear(in_features=256, out_features=64), 'x -> x'),
        (ReLU(), 'x -> x'),
        (Linear(in_features=64, out_features=1), 'x -> x'),
    ])

def get_model(key: str) -> Sequential:
    model = get_empty_model(key)
    model.load_state_dict(torch.load(os.path.join('..', 'results', 'model', f'{key}.pt')))

    if key == 'transfer_seg2_ares':
        # freeze all layers except MLP
        for child in model.children():
            if type(child) != Linear:
                for param in child.parameters():
                    param.requires_grad = False
                    
    return model

In [3]:
ares_test_data = H5GraphDataset('/app/data/ares/test.h5', return_key=True).__enter__()
rnaquadataset_test_data = H5GraphDataset('/app/data/rnaquadataset/test.h5', return_key=True).__enter__()
seg1_test_data = H5GraphDataset('/app/data/segments_1_coords/test.h5', return_key=True).__enter__()
seg2_test_data = H5GraphDataset('/app/data/segments_2_coords/test.h5', return_key=True).__enter__()
seg3_test_data = H5GraphDataset('/app/data/segments_3_coords/test.h5', return_key=True).__enter__()

dfs1_data = H5GraphDataset('/app/data/dfs/dfs1.h5', return_key=True).__enter__()
dfs2_data = H5GraphDataset('/app/data/dfs/dfs2.h5', return_key=True).__enter__()
dfs3_data = H5GraphDataset('/app/data/dfs/dfs3.h5', return_key=True).__enter__()

In [4]:
torch.set_float32_matmul_precision('high')
device = torch.device('cuda:0')

In [5]:
df = pd.DataFrame({
    'training_dataset': [], # zbiór na którym był trenowany model
    'evaluation_dataset': [], # zbiór na którym weryfikowane są wyniki
    'structure': [], # klucz z pliku H5, nazwa struktury ze zbioru ewaluacyjnego
    'real_rmsd': [], # prawdziwe RMSD
    'predicted_rmsd': [], # przewidziane RMSD przez model
})

batch_size = 64

for key in ['ares', 'seg1', 'seg2', 'seg3', 'transfer_seg2_ares']:
    print(f'Training dataset: {key}')
    model = get_model(key)
    model = model.to(device)
    with torch.no_grad():
        for label, test_data in [
            # ('ares test', ares_test_data),
            # ('rnaquadataset test', rnaquadataset_test_data),
            # ('seg1 test', seg1_test_data),
            # ('seg2 test', seg2_test_data),
            # ('seg3 test', seg3_test_data),
            ('dfs1', dfs1_data),
            ('dfs2', dfs2_data),
            ('dfs3', dfs3_data),
        ]:
            print(f'Evaluation dataset: {label}')
            for keys, item in DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4):
                item = item.to(device)
                y_pred = model(x=item.x, edge_index=item.edge_index, edge_attr=item.edge_attr, batch=item.batch).view(-1)
                df = pd.concat([df, pd.DataFrame({
                    'training_dataset': [key]*len(keys),
                    'evaluation_dataset': [label]*len(keys),
                    'structure': keys,
                    'real_rmsd': item.y.cpu().tolist(),
                    'predicted_rmsd': y_pred.cpu().tolist()
                })])

filename = f'eval{time_ns()}.csv'
df.to_csv(filename)
df['abs rmsd difference'] = np.abs(df['real_rmsd'] - df['predicted_rmsd'])
df['squared rmsd difference'] = df['abs rmsd difference'] ** 2
df.to_csv(filename)

display(df)


Training dataset: ares
Evaluation dataset: dfs1
Evaluation dataset: dfs2
Evaluation dataset: dfs3
Training dataset: seg1
Evaluation dataset: dfs1
Evaluation dataset: dfs2
Evaluation dataset: dfs3
Training dataset: seg2
Evaluation dataset: dfs1
Evaluation dataset: dfs2
Evaluation dataset: dfs3
Training dataset: seg3
Evaluation dataset: dfs1
Evaluation dataset: dfs2
Evaluation dataset: dfs3
Training dataset: transfer_seg2_ares
Evaluation dataset: dfs1
Evaluation dataset: dfs2
Evaluation dataset: dfs3


Unnamed: 0,training_dataset,evaluation_dataset,structure,real_rmsd,predicted_rmsd,abs rmsd difference,squared rmsd difference
0,ares,dfs1,1A4D_1_A_B_A_85_G_1a4d_S_000163_minimize_001,6.649,8.385386,1.736386,3.015037
1,ares,dfs1,1A4D_1_A_B_A_85_G_1a4d_S_000262_minimize_009,4.669,7.660735,2.991735,8.950475
2,ares,dfs1,1A4D_1_A_B_A_85_G_1a4d_S_000319_minimize_005,5.080,6.307195,1.227195,1.506008
3,ares,dfs1,1A4D_1_A_B_A_85_G_1a4d_S_002845_minimize_003,3.167,8.741061,5.574061,31.070158
4,ares,dfs1,1A4D_1_A_B_A_86_G_1a4d_S_000174_minimize_004,3.677,7.288726,3.611726,13.044563
...,...,...,...,...,...,...,...
6,transfer_seg2_ares,dfs3,1I9X_1_A_B_B_6_A_1i9x_S_003526_minimize_009,10.076,7.939280,2.136720,4.565573
7,transfer_seg2_ares,dfs3,1I9X_1_A_B_B_6_A_1i9x_S_003543_minimize_001,8.317,8.757564,0.440563,0.194096
8,transfer_seg2_ares,dfs3,1I9X_1_A_B_B_6_A_1i9x_S_003842_minimize_002,9.094,8.284382,0.809618,0.655481
9,transfer_seg2_ares,dfs3,1I9X_1_A_B_B_6_A_1i9x_S_003863_minimize_004,8.446,7.710791,0.735209,0.540533


In [6]:
# df = pd.read_csv('')

output = pd.DataFrame({
    'training_dataset': [],
    'evaluation_dataset': [],
    'mse': [],
    'mae': [],
})

t = time_ns()
for training_dataset in df['training_dataset'].unique():
    training_df = df[df['training_dataset'] == training_dataset]
    for evaluation_dataset in training_df['evaluation_dataset'].unique():
        evaluation_df = training_df[training_df['evaluation_dataset'] == evaluation_dataset]
        title = f'RMSD difference on {evaluation_dataset} - trained on {training_dataset}'
        plt.hist(evaluation_df["abs rmsd difference"])
        plt.title(title)
        plt.xlabel('RMSD')
        plt.savefig(os.path.join('histograms', f'{title.lower().replace(" ", "_")}{t}.png'))
        plt.close()

        output.loc[len(output)] = {
            'training_dataset': training_dataset,
            'evaluation_dataset': evaluation_dataset,
            'mse': evaluation_df['squared rmsd difference'].mean(),
            'mae': evaluation_df['abs rmsd difference'].mean()
        }
     
display(output)
output.to_csv(f'mse_mae{t}.csv')

Unnamed: 0,training_dataset,evaluation_dataset,mse,mae
0,ares,dfs1,13.596899,2.977006
1,ares,dfs2,7.398915,2.114517
2,ares,dfs3,7.625786,2.295293
3,seg1,dfs1,3.273859,1.471078
4,seg1,dfs2,15.56874,3.076201
5,seg1,dfs3,13.719471,3.136061
6,seg2,dfs1,296.287489,6.953614
7,seg2,dfs2,197.740394,10.269148
8,seg2,dfs3,128.396517,11.099243
9,seg3,dfs1,260.326765,7.415658
