In [None]:
import torch
import re
import pandas as pd
import cv2
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import io
import os
import copy

from google.colab import drive
drive.mount('/content/drive', force_remount=True)
os.chdir('/content/drive/Shareddrives/Strawberries/Text experiment/')

In [None]:
conventional = pd.read_csv("rrBLUPextractedTraitCorrelations.csv")
end2end = pd.read_csv("end2end_correlations.csv")
end2end = pd.melt(end2end,
                       id_vars=[col for col in end2end.columns if col not in ['Extracted_Length', 'Extracted_Redness']],
                       value_vars=['Extracted_Length', 'Extracted_Redness'],
                       var_name='trait',
                       value_name='end2end_test_r2')
end2end = end2end[end2end['Type'] == "Original"]
end2end = end2end.rename(columns={'Random_Seed': 'seed'})
merged_df = pd.merge(conventional, end2end, on=['seed', 'trait'])
# merged_df.groupby('trait')[['end2end_test_r2', 'test_r2', 'h2']].agg(['mean', 'std']).round(3)

# drop columns ['Unnamed: 0', 'h2', 'train_r2', 'Type']
merged_df = merged_df.drop(columns=['Unnamed: 0', 'h2', 'train_r2', 'Type'])

merged_df = merged_df.melt(id_vars=['seed', 'trait'], var_name='method', value_name='r2[predicted, known]')

correlation_differences = merged_df.drop(columns=['method']).groupby(['seed', 'trait']).apply(lambda x: x.iloc[1] - x.iloc[0], include_groups=False).reset_index()
correlation_differences['trait'] = correlation_differences['trait'].replace({'Extracted_Length': 'Length', 'Extracted_Redness': 'Redness'})

fig = px.box(correlation_differences, x="trait", y="r2[predicted, known]",
             labels={"trait": "Trait", "r2[predicted, known]": "R2 Prediction Accuracy", "method": "Prediction Method"},
             boxmode="group")

fig = fig.add_shape(type='line', x0=-1, x1=2, y0=0, y1=0, line=dict(color='red', width=2))

fig.update_layout(
    yaxis=dict(range=[-1, 1]),
    yaxis_title="End-to-end r² - conventional genomic prediction r²",
    xaxis_title="Trait",
    legend_title="Prediction Method"
)

fig.show()