# Set up LogME for Strategies for Pre-training GNNs (Hu et al., ICLR 2020)

In [1]:
import os
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim

from loader import MoleculeDataset
from model import GNN, GNN_graphpred
from splitters import scaffold_split


seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else torch.device("cpu"))
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

In [2]:
def create_dataframe_save_to_csv(embeddings, labels, dataset_name, model_name, save_path):
#     ## create pandas dataframe to store: example id, embeddings, labels ## 
#     d = {'example_id': [i for i in range(len(embeddings))],
#             'embeddings': embeddings,
#             'labels': labels
#            }
#     df = pd.DataFrame(data=d)
#     # df.head(15)

#     if not os.path.exists(save_path):
#         os.makedirs(save_path)

    filename = '{}_{}.csv'.format(dataset_name, model_name)
#     print("dataset_name: {}".format(dataset_name))
#     df.to_csv(os.path.join(save_path, filename), index=False)
    emb_df = pd.DataFrame(np.array(embeddings))
    emb_df.columns = ['emb' + str(e+1) for e in range(emb_df.shape[1])]
    emb_df['label'] = labels
    print("emb_df: \n{}\n".format(emb_df))
    emb_df.to_csv(os.path.join(save_path, filename), sep='\t', index=False)

## Load Pre-trained GNN models (Hu et al., ICLR 2020)

In [3]:
num_tasks = 1

num_layer = 5 # default
emb_dim = 300 # default
JK = 'last' # default (how the node features across laysers are combined)
dropout_ratio = 0.5 # default
graph_pooling = 'mean' # default
gnn_type = 'gin' # default

##########################
input_model_file = './model_gin/supervised.pth'

gin_supervised_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_supervised_model.from_pretrained(input_model_file)

gin_supervised_model.to(device)
gin_supervised_model.eval()

###########################

input_model_file = './model_gin/supervised_infomax.pth'

gin_supervised_infomax_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_supervised_infomax_model.from_pretrained(input_model_file)

gin_supervised_infomax_model.to(device)
gin_supervised_infomax_model.eval()

###########################

input_model_file = './model_gin/supervised_edgepred.pth'

gin_supervised_edgepred_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_supervised_edgepred_model.from_pretrained(input_model_file)

gin_supervised_edgepred_model.to(device)
gin_supervised_edgepred_model.eval()

###########################

input_model_file = './model_gin/supervised_masking.pth'

gin_supervised_masking_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_supervised_masking_model.from_pretrained(input_model_file)

gin_supervised_masking_model.to(device)
gin_supervised_masking_model.eval()

###########################

input_model_file = './model_gin/supervised_contextpred.pth'

gin_supervised_contextpred_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_supervised_contextpred_model.from_pretrained(input_model_file)

gin_supervised_contextpred_model.to(device)
gin_supervised_contextpred_model.eval()

###########################

input_model_file = './model_gin/infomax.pth'

gin_infomax_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_infomax_model.from_pretrained(input_model_file)

gin_infomax_model.to(device)
gin_infomax_model.eval()

GNN_graphpred(
  (gnn): GNN(
    (x_embedding1): Embedding(120, 300)
    (x_embedding2): Embedding(3, 300)
    (gnns): ModuleList(
      (0): GINConv(
        (mlp): Sequential(
          (0): Linear(in_features=300, out_features=600, bias=True)
          (1): ReLU()
          (2): Linear(in_features=600, out_features=300, bias=True)
        )
        (edge_embedding1): Embedding(6, 300)
        (edge_embedding2): Embedding(3, 300)
      )
      (1): GINConv(
        (mlp): Sequential(
          (0): Linear(in_features=300, out_features=600, bias=True)
          (1): ReLU()
          (2): Linear(in_features=600, out_features=300, bias=True)
        )
        (edge_embedding1): Embedding(6, 300)
        (edge_embedding2): Embedding(3, 300)
      )
      (2): GINConv(
        (mlp): Sequential(
          (0): Linear(in_features=300, out_features=600, bias=True)
          (1): ReLU()
          (2): Linear(in_features=600, out_features=300, bias=True)
        )
        (edge_embedding1): E

In [4]:
def get_graph_features_labels(loader, model, seed):
    """Extract graph features.
    Args:
        loader: graph dataloader
        model: feature extractor
        seed: integer value for setting random seed
    Returns:
        all_graph_features: list of all graph features from the dataloader 
        all_graph_labels: list of all graph (true) labels from the dataloader
    """
    
    torch.manual_seed(seed)
    np.random.seed(seed)

    all_graph_features = []
    all_graph_labels = []
    for step, batch in enumerate(loader):
        batch = batch.to(device)
        x = batch.x
        edge_index = batch.edge_index
        edge_attr = batch.edge_attr
        y = batch.y
        batch = batch.batch

        # code from GNN_graphpred.forward() #
        node_representation = model.gnn(x, edge_index, edge_attr)
#         print("step {} | node_representation.shape: {}".format(step, node_representation.shape))

        graph_features = model.pool(node_representation, batch)
#         print("graph feature (shape): {}".format(graph_features.shape))
        all_graph_features.extend(graph_features.cpu().detach().numpy())
#         print("len(all_graph_features): {}".format(len(all_graph_features)))

        graph_preds = model.graph_pred_linear(graph_features)
    #     print("graph predictions: {}".format(graph_preds))
    #     print("graph predictions (shape): {}".format(graph_preds.shape))
    #     y_true = y.view(graph_preds.shape)
    #     print("y true: {}".format(y_true))
    #     print("y true (shape): {}".format(y_true.shape))
    #     print("y (shape): {}".format(y.shape))

        all_graph_labels.extend(y.cpu().detach().numpy())
#         print("len(all_graph_labels): {}".format(len(all_graph_labels)))
        
    return all_graph_features, all_graph_labels

## Extract features of BBBP training set using pre-trained GNN & obtain LogME score

In [5]:
dataset_name = "bbbp"
num_tasks = 1
batch_size = 32 # default
num_workers = 4 # default
save_results_to = '/mnt/sdc/course-projects/GRL-course-project/results'

dataset = MoleculeDataset("dataset/" + dataset_name, dataset=dataset_name)
print(dataset)

# scaffold split:
smiles_list = pd.read_csv('dataset/' + dataset_name + '/processed/smiles.csv', header=None)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
print(train_dataset[0])

## set shuffle to FALSE so we can compare the extracted features using different pre-trained networks 
bbbp_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers = num_workers)
# bbbp_val_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers = num_workers)
# bbbp_test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers = num_workers)


from LogME import LogME

## using GIN supervised model ('./model_gin/supervised.pth') ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_supervised_model, seed)

create_dataframe_save_to_csv(bbbp_graph_features, bbbp_graph_labels, dataset_name, 
                             model_name='gin_supervised', 
                             save_path=save_results_to)
# assert(False)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN supervised.pth): {}".format(score))

#############################

## using GIN supervised infomax model (supervised_infomax.pth) ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_supervised_infomax_model, seed)

create_dataframe_save_to_csv(bbbp_graph_features, bbbp_graph_labels, dataset_name, 
                             model_name='gin_supervised_infomax', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN supervised_infomax.pth): {}".format(score))

#############################

## using GIN supervised edgepred model (supervised_edgepred.pth) ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_supervised_edgepred_model, seed)

create_dataframe_save_to_csv(bbbp_graph_features, bbbp_graph_labels, dataset_name, 
                             model_name='gin_supervised_edgepred', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN supervised_edgepred.pth): {}".format(score))

#############################

## using GIN supervised masking model (supervised_masking.pth) ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_supervised_masking_model, seed)

create_dataframe_save_to_csv(bbbp_graph_features, bbbp_graph_labels, dataset_name, 
                             model_name='gin_supervised_masking', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN supervised_masking.pth): {}".format(score))

#############################

## using GIN supervised contextpred model (supervised_contextpred.pth) ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_supervised_contextpred_model, seed)

create_dataframe_save_to_csv(bbbp_graph_features, bbbp_graph_labels, dataset_name, 
                             model_name='gin_supervised_contextpred', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN supervised_contextpred.pth): {}".format(score))

#############################

## using GIN infomax model (infomax.pth) ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_infomax_model, seed)

create_dataframe_save_to_csv(bbbp_graph_features, bbbp_graph_labels, dataset_name, 
                             model_name='gin_infomax', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN infomax.pth): {}".format(score))

MoleculeDataset(2039)


[17:52:52] Conflicting single bond directions around double bond at index 1.
[17:52:52]   BondStereo set to STEREONONE and single bond directions set to NONE.


Data(edge_attr=[46, 2], edge_index=[2, 46], id=[1], x=[23, 2], y=[1])
emb_df: 
          emb1      emb2      emb3      emb4      emb5      emb6      emb7  \
0     0.187954  0.013389 -0.250219 -0.022584  0.086127 -0.006174 -0.109610   
1    -0.043880  0.003614 -0.058039  0.023666 -0.031128 -0.049535  0.085190   
2     0.065282  0.006262 -0.139585 -0.201065 -0.003127 -0.061996 -0.065256   
3    -0.070854  0.009151 -0.046193  0.016986 -0.002188 -0.034686  0.017967   
4    -0.002685  0.000574 -0.042318  0.043741  0.093193  0.027315 -0.036415   
...        ...       ...       ...       ...       ...       ...       ...   
1626  0.103538  0.009691  0.052273 -0.013520 -0.059032 -0.062254 -0.006430   
1627  0.142405  0.015304  0.013123 -0.097107 -0.086527 -0.128479  0.016773   
1628  0.153705  0.009646 -0.035269  0.018985  0.021211 -0.117989 -0.053596   
1629  0.132085  0.007505 -0.094936 -0.010299  0.063251 -0.106890  0.022236   
1630 -0.088223  0.027770  0.030527  0.013800  0.025591 -0.00188


logme score (GIN supervised_edgepred.pth): 3.970161485661006
emb_df: 
          emb1      emb2      emb3      emb4      emb5      emb6      emb7  \
0    -0.005460  0.051137 -0.114106 -0.029731 -0.142455 -0.043726  0.147545   
1     0.033536 -0.056563  0.036576  0.175933  0.134396  0.189540 -0.235593   
2     0.145353  0.026588  0.356795 -0.211925  0.133437 -0.174142  0.155910   
3     0.074380 -0.044443  0.029541  0.112440 -0.004850  0.177624 -0.220913   
4     0.181920 -0.129171 -0.027031  0.109574 -0.039084  0.092776 -0.116252   
...        ...       ...       ...       ...       ...       ...       ...   
1626 -0.062165 -0.111285  0.019402 -0.068983 -0.020654 -0.092148 -0.132049   
1627  0.017005  0.162836 -0.020215 -0.059957 -0.123588 -0.136370  0.006118   
1628 -0.130397  0.098640 -0.019299 -0.017283  0.127086 -0.068968  0.087288   
1629  0.064985  0.027036  0.053525 -0.010578 -0.052231 -0.007877  0.135360   
1630  0.209731  0.017934 -0.033849 -0.182742 -0.210122  0.126560 -0.138


logme score (GIN infomax.pth): 3.9504379329602375


In [6]:
## test ## 
for step, batch in enumerate(bbbp_train_loader):
    print("step: {}".format(step))
    print("batch: {}".format(batch))
    print("batch.batch: {}".format(batch.batch))
    print("batch.x: {}".format(batch.x))
    print("batch.edge_attr: {}".format(batch.edge_attr))
    break
## end of test ## 


step: 0
batch: Batch(batch=[500], edge_attr=[998, 2], edge_index=[2, 998], id=[32], x=[500, 2], y=[32])
batch.batch: tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,
         4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
         9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
   

In [7]:
## check to see if the data is loaded in the same order every time # 
for i, feature in enumerate(bbbp_graph_features):
#     print("{}-th feature: \n{}".format(i, feature))
    print("{}-th feature | mean: {} +- {}".format(i, np.mean(feature), np.std(feature)))
    if i > 50: 
        break

0-th feature | mean: -0.23916535079479218 +- 0.8351194262504578
1-th feature | mean: -0.19498318433761597 +- 0.8193398118019104
2-th feature | mean: -0.2217152863740921 +- 0.8073065876960754
3-th feature | mean: -0.2068459838628769 +- 0.9316534399986267
4-th feature | mean: -0.09914200752973557 +- 0.6641184687614441
5-th feature | mean: -0.22375477850437164 +- 0.7601909637451172
6-th feature | mean: -0.24704651534557343 +- 0.7580361366271973
7-th feature | mean: -0.17532235383987427 +- 0.9688711166381836
8-th feature | mean: -0.2324436455965042 +- 0.9376943707466125
9-th feature | mean: -0.2381032258272171 +- 0.866077721118927
10-th feature | mean: -0.23781614005565643 +- 0.9739698171615601
11-th feature | mean: -0.1704014390707016 +- 0.6410852670669556
12-th feature | mean: -0.24489985406398773 +- 0.9539623260498047
13-th feature | mean: -0.1664772629737854 +- 1.006030559539795
14-th feature | mean: -0.24989280104637146 +- 0.8938428163528442
15-th feature | mean: -0.23559331893920898 

## Extract features of BACE training set using pre-trained GNN & obtain LogME score

In [8]:
dataset_name = "bace"
num_tasks = 1
batch_size = 32 # default
num_workers = 4 # default
save_results_to = '/mnt/sdc/course-projects/GRL-course-project/results'

dataset = MoleculeDataset("dataset/" + dataset_name, dataset=dataset_name)
print(dataset)

# scaffold split:
smiles_list = pd.read_csv('dataset/' + dataset_name + '/processed/smiles.csv', header=None)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
print(train_dataset[0])

bace_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)
# bace_val_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers = num_workers)
# bace_test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers = num_workers)

###########################

## using GIN supervised model ('./model_gin/supervised.pth') ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_supervised_model, seed)

create_dataframe_save_to_csv(bace_graph_features, bace_graph_labels, dataset_name, 
                             model_name='gin_supervised', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN supervised.pth): {}".format(score))

#############################

## using GIN supervised infomax model (supervised_infomax.pth) ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_supervised_infomax_model, seed)

create_dataframe_save_to_csv(bace_graph_features, bace_graph_labels, dataset_name, 
                             model_name='gin_supervised_infomax', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN supervised_infomax.pth): {}".format(score))

#############################

## using GIN supervised edgepred model (supervised_edgepred.pth) ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_supervised_edgepred_model, seed)

create_dataframe_save_to_csv(bace_graph_features, bace_graph_labels, dataset_name, 
                             model_name='gin_supervised_edgepred', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN supervised_edgepred.pth): {}".format(score))

#############################

## using GIN supervised masking model (supervised_masking.pth) ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_supervised_masking_model, seed)

create_dataframe_save_to_csv(bace_graph_features, bace_graph_labels, dataset_name, 
                             model_name='gin_supervised_masking', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN supervised_masking.pth): {}".format(score))

#############################

## using GIN supervised contextpred model (supervised_contextpred.pth) ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_supervised_contextpred_model, seed)

create_dataframe_save_to_csv(bace_graph_features, bace_graph_labels, dataset_name, 
                             model_name='gin_supervised_contextpred', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN supervised_contextpred.pth): {}".format(score))

#############################

## using GIN infomax model (infomax.pth) ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_infomax_model, seed)

create_dataframe_save_to_csv(bace_graph_features, bace_graph_labels, dataset_name, 
                             model_name='gin_infomax', 
                             save_path=save_results_to)

logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN infomax.pth): {}".format(score))

MoleculeDataset(1513)
Data(edge_attr=[66, 2], edge_index=[2, 66], fold=[1], id=[1], x=[31, 2], y=[1])
emb_df: 
          emb1      emb2      emb3      emb4      emb5      emb6      emb7  \
0    -0.028116  0.005691 -0.137788  0.029680  0.016036 -0.003200  0.048668   
1     0.060740  0.016581 -0.228648 -0.045908 -0.003714  0.039868 -0.005419   
2     0.069248  0.010116 -0.066165 -0.031711  0.027497 -0.053612  0.022585   
3    -0.039441  0.009469  0.042803  0.003924  0.012666 -0.019713 -0.072263   
4     0.049737  0.019437 -0.179696 -0.120685 -0.054918 -0.076108  0.018942   
...        ...       ...       ...       ...       ...       ...       ...   
1205  0.038832  0.019011 -0.189996 -0.140201 -0.065311 -0.066743 -0.019755   
1206 -0.029755  0.013644 -0.033361  0.014733  0.068999 -0.040916 -0.023484   
1207  0.031380  0.009848 -0.047594  0.032608  0.020400 -0.009646  0.104117   
1208  0.028505  0.012025 -0.130497 -0.025724  0.051418 -0.084800 -0.040346   
1209  0.066476  0.011486 -0.054


logme score (GIN supervised_edgepred.pth): 3.800035703818971
emb_df: 
          emb1      emb2      emb3      emb4      emb5      emb6      emb7  \
0     0.018853 -0.006671  0.062913 -0.080721 -0.027570 -0.046811  0.076543   
1    -0.097652 -0.083461  0.022607 -0.149393 -0.027918  0.059638  0.079482   
2     0.051145 -0.105421  0.080007 -0.029230 -0.063557 -0.022076 -0.047856   
3    -0.066924  0.003461  0.025916 -0.122576 -0.028106 -0.090147 -0.041293   
4    -0.037802  0.054739  0.081782  0.034592  0.006986  0.079678  0.010378   
...        ...       ...       ...       ...       ...       ...       ...   
1205 -0.166366  0.035835  0.119096  0.055848 -0.027947  0.054147 -0.007061   
1206  0.214898 -0.002941  0.000593 -0.020288 -0.014719  0.081850 -0.043649   
1207 -0.213246  0.105717  0.117629 -0.131601 -0.106433 -0.179815  0.111670   
1208  0.018054  0.089764  0.038966 -0.121769  0.055908  0.026721 -0.182887   
1209  0.078399  0.064981 -0.022349 -0.001778  0.074161 -0.222432 -0.049


logme score (GIN infomax.pth): 3.776506281987087
