In [9]:
%%capture
import sys, os
import torch, math, os
import sys
sys.path.append("..")

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    print("Running in Colab!")
    from google.colab import drive

    drive.mount('/content/drive', force_remount=False)
else:
    print("Not running in Colab.")

def resolve_path_gdrive(relativePath):
    if os.path.exists('/content/drive'):
        return '/content/drive/MyDrive/work/gdrive-workspaces/git/nn_catalyst/' + relativePath
    else:
        from utils import get_project_root
        return get_project_root() + "/../.." + relativePath

print(f"Root project folder is at {resolve_path_gdrive('.')}")

CHECKPOINTS_FOLDER_BASE = "/checkpoints/stn_r3_f849_tlast29/stack=False-scaleY=True"
CHECKPOINTS_FOLDER = resolve_path_gdrive(CHECKPOINTS_FOLDER_BASE) #f'd:/temp{CHECKPOINTS_FOLDER_BASE}'
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.set_float32_matmul_precision("medium")  # to make lightning happy

In [10]:

import pandas as pd
import torch
import numpy as np
from sklearn.preprocessing import StandardScaler
from torch import nn
import torch.nn.functional as F

# Load the data
df = pd.read_csv('few_merged_data_f849_tlast29_reordered_byR2.csv')
X = df.iloc[:, :849]  # First 849 columns are features
y = df.iloc[:, 849:]  # Last 29 columns are targets
from pl.model_impl import *

In [11]:
import joblib
from sklearn.preprocessing import StandardScaler

# Define the path to the pickle file
pickle_file_path = f'{CHECKPOINTS_FOLDER}/scaler_X.pkl'

# Load the StandardScaler from the pickle file
with open(pickle_file_path, 'rb') as file:
    scaler = joblib.load(file)

# Now you can use the scaler
# Example: scaler.transform(data)

# Standardize features
X_scaled = scaler.transform(X)
X_tensor = torch.tensor(X_scaled, dtype=torch.float32)


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [14]:
from pathlib import Path

def eval_model(x_data, target_num):
        checkpoint_path=f'{CHECKPOINTS_FOLDER}/{target_num}'
        pathlist = Path(checkpoint_path).glob('**/*.ckpt')
        for path in pathlist:
            # because path is object not string
            model = SingleTargetNet.load_from_checkpoint(str(path))
            model.eval()
            model.cpu()
            with torch.no_grad():
                  y_pred = model(x_data)
            return y_pred.detach().numpy()
        
# Load models and make predictions
predictions = []
for target_index in range(df.shape[1] - 849):
    print(f"Predicting target {target_index + 1}...")
    predictions.append(eval_model(X_tensor, target_index+1))
    
# Stack predictions into array
predictions = np.hstack(predictions)

# Create a DataFrame for predictions
#predictions_df = pd.DataFrame(predictions, columns=[f'Prediction_{i}' for i in range(predictions.shape[1])])

# Save predictions to a CSV file
#predictions_df.to_csv('predictions.csv', index=False)

print(f"Predictions shape: {predictions.shape}")

Predicting target 1...
Predicting target 2...
Predicting target 3...
Predicting target 4...
Predicting target 5...
Predicting target 6...
Predicting target 7...
Predicting target 8...
Predicting target 9...
Predicting target 10...
Predicting target 11...
Predicting target 12...
Predicting target 13...
Predicting target 14...
Predicting target 15...
Predicting target 16...
Predicting target 17...
Predicting target 18...
Predicting target 19...
Predicting target 20...
Predicting target 21...
Predicting target 22...
Predicting target 23...
Predicting target 24...
Predicting target 25...
Predicting target 26...
Predicting target 27...
Predicting target 28...
Predicting target 29...
Predictions shape: (8, 29)


In [15]:
predictions

array([[ 0.38041747,  0.37435395,  0.38301158,  0.32179022,  0.40278733,
         0.35184574, -1.5712007 , -1.3984663 ,  1.5064232 ,  2.4844203 ,
         1.7189146 ,  2.7789712 ,  1.8699226 ,  1.3264762 , -1.1213297 ,
        -1.8684686 ,  1.3493798 ,  1.7378752 , -0.16221514, -1.998511  ,
        -1.7195978 , -0.73782766,  1.1627886 , -1.9192525 ,  2.836595  ,
         0.8197746 ,  0.3535876 , -0.40010872, -0.70075166],
       [ 0.25580525,  0.2529648 ,  0.2503891 ,  0.28139862,  0.30164066,
         0.28221887, -0.6387827 , -0.69388795,  0.50233024,  0.11119656,
        -0.03374834,  2.4412274 ,  2.7798343 ,  0.27203688, -0.5255129 ,
        -0.8286061 ,  0.28920034,  2.1548796 ,  0.62466604, -1.1175022 ,
        -1.1320281 ,  0.4210174 , -0.75317067, -1.1652302 ,  0.01175366,
         1.3295193 , -0.41899648,  0.39197096,  1.3291682 ],
       [ 0.81967187,  0.8394152 ,  0.8660329 ,  0.8326798 ,  0.8342287 ,
         0.82582533, -1.8965763 , -2.0378969 ,  1.4113448 ,  0.43807143,
  