In [1]:
import os
import torch
import pytorch_lightning as pl
from time import time_ns
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from rnaquanet.network.graph_regression_network import GraphRegressionNetwork
from rnaquanet.network.grn_data_module import GRNDataModule
from rnaquanet.utils.rnaquanet_config import RnaquanetConfig
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
from torch.optim import Adam
from torch.nn import (
    BatchNorm1d,
    Identity,
    ReLU,
    LeakyReLU,
    Linear,
    MSELoss
)
import torch.nn.functional as F
from torch_geometric.nn import (
    GATConv,
    GCNConv,
    Sequential,
    global_mean_pool,
    BatchNorm,
    TransformerConv
)
from torch_geometric.nn.models import (
    GAT
)
from torch_geometric.loader import DataLoader
import numpy as np
from rnaquanet.data.preprocessing.hdf5_utils import load_data_from_hdf5
from IPython.display import clear_output
from tqdm import tqdm
import matplotlib.pyplot as plt
from rnaquanet.network.h5_graph_dataset import H5GraphDataset
import math
import pandas as pd

In [2]:
def get_empty_model(key: str) -> Sequential:
    conv = None
    first_dropout = None
    if key == 'ares' or key == 'seg1':
        conv = GATConv
        first_dropout = 0.5
    elif key == 'seg2' or key == 'transfer_seg2_ares':
        conv = TransformerConv
        first_dropout = 0.5
    elif key == 'seg3':
        conv = TransformerConv
        first_dropout = 0.8
    else:
        raise 'Unknown model key'


    return Sequential('x, edge_index, edge_attr, batch', [
        (conv(in_channels=99, out_channels=256, heads=4, dropout=first_dropout), f'x, edge_index{", edge_attr" if conv == GATConv else ""} -> x'),
        (BatchNorm(in_channels=256*4), 'x -> x'),
        (ReLU(), 'x -> x'),
        
        (conv(in_channels=256*4, out_channels=256, heads=8, dropout=0.5), f'x, edge_index{", edge_attr" if conv == GATConv else ""} -> x'),
        (BatchNorm(in_channels=256*8), 'x -> x'),
        (ReLU(), 'x -> x'),

        (GCNConv(in_channels=256*8, out_channels=256), 'x, edge_index -> x'),
        (global_mean_pool, 'x, batch -> x'),

        (Linear(in_features=256, out_features=64), 'x -> x'),
        (ReLU(), 'x -> x'),
        (Linear(in_features=64, out_features=1), 'x -> x'),
    ])

def get_model(key: str) -> Sequential:
    model = get_empty_model(key)
    model.load_state_dict(torch.load(os.path.join('..', 'results', 'model', f'{key}.pt')))

    if key == 'transfer_seg2_ares':
        # freeze all layers except MLP
        for child in model.children():
            if type(child) != Linear:
                for param in child.parameters():
                    param.requires_grad = False
                    
    return model

In [3]:
ares_test_data = H5GraphDataset('/app/data/ares/test.h5', return_key=True).__enter__()
rnaquadataset_test_data = H5GraphDataset('/app/data/rnaquadataset/test.h5', return_key=True).__enter__()

In [4]:
torch.set_float32_matmul_precision('high')
device = torch.device('cuda:0')

In [5]:
df = pd.DataFrame({
    'training_dataset': [], # zbiór na którym był trenowany model
    'evaluation_dataset': [], # zbiór na którym weryfikowane są wyniki
    'structure': [], # klucz z pliku H5, nazwa struktury ze zbioru ewaluacyjnego
    'real_rmsd': [], # prawdziwe RMSD
    'predicted_rmsd': [], # przewidziane RMSD przez model
})

batch_size = 64

for key in ['ares', 'seg1', 'seg2', 'seg3', 'transfer_seg2_ares']:
    print(f'Training dataset: {key}')
    model = get_model(key)
    model = model.to(device)
    with torch.no_grad():
        print('Evaluation dataset: ares test')
        for keys, item in DataLoader(ares_test_data, batch_size=batch_size, shuffle=False, num_workers=4):
            item = item.to(device)
            y_pred = model(x=item.x, edge_index=item.edge_index, edge_attr=item.edge_attr, batch=item.batch).view(-1)
            df = pd.concat([df, pd.DataFrame({
                'training_dataset': [key]*len(keys),
                'evaluation_dataset': ['ares test']*len(keys),
                'structure': keys,
                'real_rmsd': item.y.cpu().tolist(),
                'predicted_rmsd': y_pred.cpu().tolist()
            })])

        print('Evaluation dataset: rnaquadataset test')
        for keys, item in DataLoader(rnaquadataset_test_data, batch_size=batch_size, shuffle=False, num_workers=4):
            item = item.to(device)
            y_pred = model(x=item.x, edge_index=item.edge_index, edge_attr=item.edge_attr, batch=item.batch).view(-1)
            df = pd.concat([df, pd.DataFrame({
                'training_dataset': [key]*len(keys),
                'evaluation_dataset': ['rnaquadataset test']*len(keys),
                'structure': keys,
                'real_rmsd': item.y.cpu().tolist(),
                'predicted_rmsd': y_pred.cpu().tolist()
            })])

filename = f'eval{time_ns()}.csv'
df.to_csv(filename)
df['abs rmsd difference'] = np.abs(df['real_rmsd'] - df['predicted_rmsd'])
df['squared rmsd difference'] = df['abs rmsd difference'] ** 2
df.to_csv(filename)

display(df)


Training dataset: ares
Evaluation dataset: ares test


Evaluation dataset: rnaquadataset test
Training dataset: seg1
Evaluation dataset: ares test
Evaluation dataset: rnaquadataset test
Training dataset: seg2
Evaluation dataset: ares test
Evaluation dataset: rnaquadataset test
Training dataset: seg3
Evaluation dataset: ares test
Evaluation dataset: rnaquadataset test
Training dataset: transfer_seg2_ares
Evaluation dataset: ares test
Evaluation dataset: rnaquadataset test


Unnamed: 0,training_dataset,evaluation_dataset,structure,real_rmsd,predicted_rmsd,abs rmsd difference,squared rmsd difference
0,ares,ares test,1a4d_S_000002_minimize_001,7.213000,6.060639,1.152361,1.327936
1,ares,ares test,1a4d_S_000006_minimize_004,9.470000,6.819134,2.650866,7.027091
2,ares,ares test,1a4d_S_000009_minimize_005,14.677000,8.947495,5.729506,32.827234
3,ares,ares test,1a4d_S_000012_minimize_009,7.914000,6.040934,1.873066,3.508376
4,ares,ares test,1a4d_S_000019_minimize_004,6.239000,5.817907,0.421093,0.177319
...,...,...,...,...,...,...,...
17,transfer_seg2_ares,rnaquadataset test,8JOZ_1_B_4,16.907000,6.833779,10.073221,101.469776
18,transfer_seg2_ares,rnaquadataset test,8JOZ_1_B_5,22.042999,11.434573,10.608426,112.538704
19,transfer_seg2_ares,rnaquadataset test,8JOZ_1_B_6,18.382999,9.916877,8.466123,71.675232
20,transfer_seg2_ares,rnaquadataset test,8JOZ_1_B_7,22.982000,11.328849,11.653152,135.795940


In [6]:
# df = pd.read_csv('')

output = pd.DataFrame({
    'training_dataset': [],
    'evaluation_dataset': [],
    'mse': [],
    'mae': [],
})

t = time_ns()
for training_dataset in df['training_dataset'].unique():
    training_df = df[df['training_dataset'] == training_dataset]
    for evaluation_dataset in training_df['evaluation_dataset'].unique():
        evaluation_df = training_df[training_df['evaluation_dataset'] == evaluation_dataset]
        title = f'RMSD difference on {evaluation_dataset} - trained on {training_dataset}'
        plt.hist(evaluation_df["abs rmsd difference"])
        plt.title(title)
        plt.xlabel('RMSD')
        plt.savefig(os.path.join('histograms', f'{title.lower().replace(" ", "_")}{t}.png'))
        plt.close()

        output.loc[len(output)] = {
            'training_dataset': training_dataset,
            'evaluation_dataset': evaluation_dataset,
            'mse': evaluation_df['squared rmsd difference'].mean(),
            'mae': evaluation_df['abs rmsd difference'].mean()
        }
     
display(output)
output.to_csv(f'mse_mae{t}.csv')

Unnamed: 0,training_dataset,evaluation_dataset,mse,mae
0,ares,ares test,13.436916,2.689374
1,ares,rnaquadataset test,126.996648,8.592518
2,seg1,ares test,17.213387,2.921575
3,seg1,rnaquadataset test,135.947124,8.56409
4,seg2,ares test,130.343101,10.694094
5,seg2,rnaquadataset test,102.824252,7.385871
6,seg3,ares test,136.377121,10.933395
7,seg3,rnaquadataset test,56.948647,5.638476
8,transfer_seg2_ares,ares test,12.410662,2.786567
9,transfer_seg2_ares,rnaquadataset test,131.562791,9.55911
