# Import

In [None]:
import os
import torch
import pytorch_lightning as pl
from time import time_ns
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from rnaquanet.network.graph_regression_network import GraphRegressionNetwork
from rnaquanet.network.grn_data_module import GRNDataModule
from rnaquanet.utils.rnaquanet_config import RnaquanetConfig
from pytorch_lightning.loggers import TensorBoardLogger
import h5py
from rnaquanet.data.preprocessing.hdf5_utils import load_data_from_hdf5
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Wstępna konfiguracja

In [None]:
torch.set_float32_matmul_precision('high')

# konfiguracje można zmienić w locie przy definiowaniu
config = RnaquanetConfig('/app/configs/config_ares1.yml', override = {
    'network': {
        'num_workers': 4
    }
})
config.network.max_epochs = 50

data = GRNDataModule(config)
data.prepare_data()

# Eksperyment

In [None]:
def train_model(config, data, filename):
    model = GraphRegressionNetwork(config)
    checkpoint = ModelCheckpoint(
        monitor='val_loss',
        filename=filename,
        dirpath='/app/models',
        save_top_k=1,
        mode='min',
    )
    trainer = pl.Trainer(max_epochs=config.network.max_epochs, log_every_n_steps=1, callbacks=[
        EarlyStopping('val_loss', patience=4),
        checkpoint
    ]) 
    trainer.fit(model, data)
    return GraphRegressionNetwork.load_from_checkpoint(checkpoint.best_model_path, config=config)

In [None]:
device = torch.device('cuda:0')
def get_scores(config, model):
    scores = {
        # @structure_name: {
        #   @loss: []
        #   @dataset: 'train'|'val'|'test'
        #   @nucleotides: number
        # }
    }
    for dataset in ['train', 'val', 'test']:
        file_path = os.path.join(config.data.path, config.name, f'{dataset}.h5')
        with h5py.File(file_path, 'r') as file:
            for key, value in file.items(): # key is structure name
                if isinstance(value, h5py.Group):
                    key = key[0:4]
                    if not key in scores:
                        scores[key] = {
                            'loss': [],
                            'dataset': dataset,
                            'nucleotides': torch.tensor(value['x'][()]).shape[0]
                        }

                    model.eval()
                    for sample in DataLoader([Data(
                        x=torch.tensor(value['x'][()]),
                        edge_index=torch.tensor(value['edge_index'][()]),
                        edge_attr=torch.tensor(value['edge_attr'][()]),
                        y=torch.tensor(value['y'][()]) if value.get('y') is not None else None
                    )], batch_size=1):
                        scores[key]['loss'].append(F.mse_loss(model(
                            sample.x.to(device), 
                            sample.edge_index.to(device), 
                            sample.edge_attr.to(device), 
                            sample.batch.to(device)
                        ).cpu(), sample.y.cpu().view(-1, 1)).item())
    return scores

In [None]:
def plot_scores(scores, title = 'Mean loss'):
    colors = {
        'train': 'pink',
        'val': 'lightblue',
        'test': 'lightgreen'
    }
    plt.bar(
        x=list(map(lambda key: f'{key} ({scores[key]["nucleotides"]})',scores.keys())), 
        height=list(map(lambda score: np.mean(score['loss']), scores.values())), 
        color=list(map(lambda score: colors[score['dataset']], scores.values()))
    )

    plt.title(title)
    plt.xticks(rotation=90)
    legend_labels = {v: k for k, v in colors.items()}
    legend_handles = [plt.Rectangle((0, 0), 1, 1, color=c) for c in colors.values()]
    plt.legend(legend_handles, legend_labels.values())

    plt.xlabel('Structure')
    plt.ylabel('Mean loss')
    plt.savefig(title.lower().replace(' ', '_') + '.png', bbox_inches='tight')
    plt.show()

In [None]:
output = {
    'batch_size': {
        10: [],
        100: [],
        1000: []
    },
    'hidden_dim': {
        64: [],
        128: [],
        256: []
    },
    'layer_type': {
        1: [],
        3: []
    },
    'batch_norm': {
        True: [],
        False: []
    },
    'num_of_layers': {
        4: [],
        16: []
    }
}

for batch_size in [10, 100, 1000]:
    for hidden_dim in [64, 128, 256]:
        for layer_type in [1, 3]:
            for batch_norm in [True, False]:
                for num_of_layers in [4, 16]:
                    data.batch_size = batch_size
                    config.hidden_dim = hidden_dim
                    config.layer_type = layer_type
                    config.batch_norm = batch_norm
                    config.num_of_layers = num_of_layers
                    torch.manual_seed(2137)
                    try:
                        model = train_model(config, data, f'{batch_size}_{hidden_dim}_{layer_type}_{batch_norm}_{num_of_layers}')
                        scores = get_scores(config, model)
                        plot_scores(scores, title=f'{batch_size}, {hidden_dim}, {layer_type}, {batch_norm}, {num_of_layers}')
                        val_mean = np.mean(np.concatenate([score['loss'] for score in scores.values() if score['dataset'] == 'val']))
                        output['batch_size'][batch_size].append(val_mean)
                        output['hidden_dim'][hidden_dim].append(val_mean)
                        output['layer_type'][layer_type].append(val_mean)
                        output['batch_norm'][batch_norm].append(val_mean)
                        output['num_of_layers'][num_of_layers].append(val_mean)
                    except:
                        print(f'Failed at {batch_size}, {hidden_dim}, {layer_type}, {batch_norm}, {num_of_layers}')

for key, value in output.items():
    plt.bar(
        x=list(map(lambda x: str(x), value.keys())),
        height=list(map(lambda val: np.mean(val), value.values()))
    )

    plt.title(f'Mean val loss {key}')
    plt.xlabel(key)
    plt.ylabel('Mean loss')
    plt.savefig('val_' + key + '.png', bbox_inches='tight')
    plt.show()