In [2]:
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, mean_squared_error
import random

# Load the data
all_genes = pd.read_csv('extracted_data/GxG_filled.csv', sep=',', index_col=0)
ExE = pd.read_csv('extracted_data/ExE_imputed.csv', sep=',', index_col=0)

def evaluate_model(target_column, X, y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    lm = LinearRegression().fit(X_train, y_train)
    y_pred = lm.predict(X_test)
    r2 = lm.score(X_test, y_test)
    rmse = np.sqrt(mean_squared_error(y_test, y_pred))
    return r2, rmse

def iterate_over_proportion(ExE, proportion):
    num_columns = int(ExE.shape[1] * proportion)
    selected_columns = random.sample(ExE.columns.tolist(), num_columns)
    
    r2_scores = []
    rmse_scores = []
    
    for target_column in selected_columns:
        X = ExE.drop(columns=[target_column])
        y = ExE[target_column]
        
        r2, rmse = evaluate_model(target_column, X, y)
        r2_scores.append(r2)
        rmse_scores.append(rmse)
    
    print(f"R² scores: {r2_scores}")

    avg_r2 = np.mean(r2_scores)
    avg_rmse = np.mean(rmse_scores)
    
    return avg_r2, avg_rmse

proportion = 0.1
avg_r2, avg_rmse = iterate_over_proportion(ExE, proportion)

print(f"Average R²: {avg_r2}")
print(f"Average RMSE: {avg_rmse}")

R² scores: [-0.8479733487794889, -0.5843830413514235, -1.3516852720204193, -0.5347690483351277, -0.7702936693500977, -0.5751593217681035, -0.9678069380200158, -1.9484382777619147, -1.3163466511092397, -0.7415800713805007, -1.4540609473842, -3.7556222755178768, -1.316980018511197, -0.8231267001390432, -2.0256479876195916, -1.0631204431525427, -1.2184442471318704, -0.8741815423696295, -0.8558397368904149, -1.0591729225283797, -1.2375175716484428, -2.8036856838154094, -2.2497577282733086, -3.522648831133914, -2.337417698096883, -1.1231746981321473, -0.7718632689941574, -0.7661205306027903, -0.5543285999091201, -3.2197355416321853, -3.13993335028006, 0.17800744782295075, -1.8663527109998865, 0.04629311649956769, -0.6836056843931768, -3.159406030271507, -1.7486792752331666, -3.61663521169604, -1.4811140727228214, -0.9572646383073453, -1.2924053506236035, -2.5982787987520117, -1.6166605770046845, -0.4699204974708886, -3.795497542051681, -0.5333409933231414, -0.6667548561896885, -2.3334180856