In [None]:
import os
import torch
import pandas as pd

import kaleido #required
kaleido.__version__ #0.2.1
import plotly
plotly.__version__ #5.5.0
#now this works:
import plotly.graph_objects as go
import plotly.express as px

import plotly.io as pio   
pio.kaleido.scope.mathjax = None

import re
from monai.metrics import FIDMetric

In [2]:
from dataset.lsun import LSUNDatasetWrapper
from torchvision.utils import save_image

In [3]:
dataset = LSUNDatasetWrapper(classes=["church_outdoor_train"])
save_image(dataset[2], "church.png")

In [None]:
from dataset.cpg0000 import CPG0000
from monai.data import CacheDataset

from PIL import Image

cpg0000 = CPG0000(path_csv="/folder1/folder2/cpg0000-jump-pilot/csv/BR00116991.csv")
dataset = CacheDataset(cpg0000.csv_ds, transform=cpg0000.val_transforms, cache_rate=0, num_workers=0, copy_cache=False)

tensor = dataset[0]["IMAGE"][0]

# Define a color mapping for each of the 5 channels to RGB
# Each row corresponds to a channel, and each column to R, G, B
color_map = torch.tensor([
    [1.0, 0.0, 0.0],  # Channel 0 → Red
    [0.0, 1.0, 0.0],  # Channel 1 → Green
    [0.0, 0.0, 1.0],  # Channel 2 → Blue
    [0.5, 0.5, 0.0],  # Channel 3 → Red + Green
    [0.0, 0.5, 0.5],  # Channel 4 → Green + Blue
])

# Flatten spatial dimensions and apply color mapping
reshaped_tensor = tensor.view(5, -1)  # Shape: (5, 256*256)
rgb_tensor = torch.matmul(color_map.T, reshaped_tensor)  # Shape: (3, 256*256)

# Reshape back to (3, 256, 256)
rgb_tensor = rgb_tensor.view(3, 256, 256)

# Normalize to [0, 255] and convert to uint8
rgb_image = rgb_tensor.clamp(0, 1) * 255
rgb_image = rgb_image.byte().permute(1, 2, 0).numpy()  # Shape: (256, 256, 3)

# Save the image
image = Image.fromarray(rgb_image)
image.save("converted_image.png")

# Loading features

In [None]:
dict_metrics = {
        "kl_cpg": "/folder1/folder2/cpg0000-jump-pilot/metrics/kl",
        "mae_cpg": "/folder1/folder2/cpg0000-jump-pilot/metrics/mae",
        "mae_cpg_run2": "/folder1/folder2/cpg0000-jump-pilot-run2/metrics/mae",
        "ssim_cpg": "/folder1/folder2/cpg0000-jump-pilot/metrics/ssim",
        "emd_cpg": "/folder1/folder2/cpg0000-jump-pilot/metrics/emd",
        "real_inception_cpg": "/folder1/folder2/cpg0000-jump-pilot/metrics/real_inception",
        "fake_inception_cpg": "/folder1/folder2/cpg0000-jump-pilot/metrics/fake_inception",
        "kl_classroom": "/folder1/folder2/sd-vae-lsun-classroom/kl",
        "mae_classroom": "/folder1/folder2/sd-vae-lsun-classroom/mae",
        "ssim_classroom": "/folder1/folder2/sd-vae-lsun-classroom/ssim",
        "emd_classroom": "/folder1/folder2/sd-vae-lsun-classroom/emd",
        "real_inception_classroom": "/folder1/folder2/sd-vae-lsun-classroom/real_inception",
        "fake_inception_classroom": "/folder1/folder2/sd-vae-lsun-classroom/fake_inception",
        "kl_churches": "/folder1/folder2/sd-vae-lsun-church-outdoor/kl",
        "mae_churches": "/folder1/folder2/sd-vae-lsun-church-outdoor/mae",
        "ssim_churches": "/folder1/folder2/sd-vae-lsun-church-outdoor/ssim",
        "emd_churches": "/folder1/folder2/sd-vae-lsun-church-outdoor/emd",
        "real_inception_churches": "/folder1/folder2/sd-vae-lsun-church-outdoor/real_inception",
        "fake_inception_churches": "/folder1/folder2/sd-vae-lsun-church-outdoor/fake_inception"
    }

def load_metrics(dict_metrics):
    for key, value in dict_metrics.items():
        tensors = []
        filenames = sorted([f for f in os.listdir(value) if f.endswith('.pt')], key=lambda x: int(re.findall(r'\d+', x)[0]))
        for filename in filenames:
            file_path = os.path.join(value, filename)
            tensor = torch.load(file_path)
            tensors.append(tensor)
    
        dict_metrics[key] = torch.cat(tensors, dim=0)
    return dict_metrics

dict_metrics = load_metrics(dict_metrics)

In [None]:
path = "/folder1/folder2/cpg0000-jump-pilot/csv/"

list_plates_pt = sorted([f for f in os.listdir("/folder1/folder2/cpg0000-jump-pilot/metrics/real_openphenom") if f.endswith('.pt')], key=lambda x: int(re.findall(r'\d+', x)[0]))
csv_files = [x.split(".")[0] + ".csv" for x in list_plates_pt]

dfs = [pd.read_csv(path + file) for file in csv_files]
df = pd.concat(dfs, ignore_index=True)

df['row_letter'] = df['row'].apply(lambda x: chr(64 + x))
df['column_str'] = df['column'].apply(lambda x: f"{x:02d}")
df['well'] = df['row_letter'] + df['column_str']

df_platemap = pd.read_csv("cpg0000-jump-pilot/JUMP-Target-1_compound_platemap.txt", sep="	")
df_exp = pd.read_csv("cpg0000-jump-pilot/experiment-metadata.tsv", sep="	")
df_platemap['broad_sample'] = df_platemap['broad_sample'].fillna('DMSO')

df = pd.merge(df, df_platemap, left_on=['well'], right_on=['well_position'], how='inner')
df = pd.merge(df, df_exp, on="Assay_Plate_Barcode")

In [7]:
dict_metrics["kl_classroom"].shape, dict_metrics["kl_churches"].shape, dict_metrics["kl_cpg"].shape

(torch.Size([166419]), torch.Size([126200]), torch.Size([66048, 16, 2]))

In [8]:
(dict_metrics["mae_cpg"] - dict_metrics["mae_cpg_run2"]).mean(axis=(0,1))

tensor([-1.3583e-09, -7.6916e-10,  3.7223e-11,  9.5823e-10,  5.2113e-10])

In [9]:
dict_metrics["mae_cpg"].shape

torch.Size([66048, 16, 5])

## FID and KL

In [84]:
features_real = dict_metrics["real_inception_cpg"]
features_fake = dict_metrics["fake_inception_cpg"]
features_kl = dict_metrics["kl_cpg"]

fid = FIDMetric()

dataset_list = []
fid_list = []
kl_list = []
split_list = []

for name, grouped in df.groupby(["Cell_type","Time"]):
    indexes = grouped.index.to_list()
    real_features_group = features_real[indexes]
    fake_features_group = features_fake[indexes]
    features_kl_group = features_kl[indexes]
    
    for i in range(2):
        dataset_list.append(name[0] + "-" + str(name[1]) + "h")
        split_list.append(i)
        fid_list.append(fid(real_features_group[:,:,i,:].view(-1, 2048), fake_features_group[:,:,i,:].view(-1, 2048)).item())
        kl_list.append(features_kl_group[:,:,i].mean().item())

fid_list.append(fid(dict_metrics["real_inception_classroom"], dict_metrics["fake_inception_classroom"]).item())
kl_list.append(dict_metrics["kl_classroom"].mean().item())
split_list.append("RGB")
dataset_list.append("   Classroom")

fid_list.append(fid(dict_metrics["real_inception_churches"], dict_metrics["fake_inception_churches"]).item())
kl_list.append(dict_metrics["kl_churches"].mean().item())
split_list.append("RGB")
dataset_list.append("   Church")

In [85]:
data = {'Dataset': dataset_list,
        'KLD': kl_list,
        'FID': fid_list,
       'Channels': split_list}

In [86]:
df_plot = pd.DataFrame(data)
df_plot['Channels'] = df_plot['Channels'].map({0: 'Mito, AGP, RNA', 1: 'RNA, ER, DNA', "RGB": "RGB"})

In [40]:
fig_fid = px.bar(
    df_plot,
    x='Dataset',
    y='FID',
    color='Channels',
    barmode='group',
)

fig_fid.update_layout(
    font=dict(size=18, color='black'),
    showlegend=False,
    bargap=0.2,
    bargroupgap=0.1,
    xaxis=dict(showgrid=False),
    margin=dict(l=40, r=40, t=40, b=40),
    legend=dict(
        title='Channels',
        orientation='v',
        x=1,
        y=1,
        xanchor='right',
        yanchor='top',
        bgcolor='rgba(255,255,255,0.5)',
        bordercolor='black',
        borderwidth=1
    )
)

# Save figure to file
fig_fid.write_image("fid.pdf")

In [39]:
fig_kl = px.bar(df_plot, x='Dataset', y='KLD', color="Channels", barmode="group", log_x=False, log_y=False)
pio.full_figure_for_development(fig_kl, warn=False)
# Update layout and styling

fig_kl.update_layout(
    font=dict(size=18, color='black'),
    bargap=0.2,
    bargroupgap=0.1,
    xaxis=dict(showgrid=False),
    margin=dict(l=40, r=40, t=40, b=40),
    legend=dict(
        title='Channels',
        orientation='v',
        x=1,
        y=1,
        xanchor='right',
        yanchor='top',
        bgcolor='rgba(255,255,255,0.5)',
        bordercolor='black',
        borderwidth=1
    )
)

fig_kl.write_image("kl.pdf")

In [9]:
dict_metrics["kl_cpg"].std(), dict_metrics["kl_classroom"].std(), dict_metrics["kl_churches"].std()

(tensor(19137.9238), tensor(7933.4155), tensor(10359.3467))

## MAE, SSIM and EMD

In [10]:
df_classroom = pd.DataFrame(dict_metrics["mae_classroom"], columns=['R', 'G', 'B'])
df_classroom = df_classroom.melt(var_name='Channel', value_name='MAE')
df_classroom["Dataset"] = "LSUN"

df_church = pd.DataFrame(dict_metrics["mae_churches"], columns=['R', 'G', 'B'])
df_church = df_church.melt(var_name='Channel', value_name='MAE')
df_church["Dataset"] = "LSUN"

df_cpg = pd.DataFrame(dict_metrics["mae_cpg"].mean(1), columns=['Mito.', 'AGP', 'RNA', 'ER', 'DNA'])
df_cpg = df_cpg.melt(var_name='Channel', value_name='MAE')
df_cpg["Dataset"] = "CPJUMP1"

df_mae = pd.concat([df_classroom, df_church, df_cpg])

In [25]:
fig_mae = px.box(df_mae, x='Channel', y='MAE', color="Dataset", points=False)

fig_mae.update_layout(
    font=dict(size=18, color='black'),
    showlegend=False,
    bargap=0.2,
    bargroupgap=0.1,
    xaxis=dict(showgrid=False),
    margin=dict(l=40, r=40, t=40, b=40),
    legend=dict(
        title='Channels',
        orientation='v',
        x=1,
        y=1,
        xanchor='right',
        yanchor='top',
        bgcolor='rgba(255,255,255,0.5)',
        bordercolor='black',
        borderwidth=1
    )
)

fig_mae.write_image("mae.pdf")

In [22]:
df_classroom = pd.DataFrame(dict_metrics["ssim_classroom"], columns=['R', 'G', 'B'])
df_classroom = df_classroom.melt(var_name='Channel', value_name='SSIM')
df_classroom["Dataset"] = "LSUN"

df_church = pd.DataFrame(dict_metrics["ssim_churches"], columns=['R', 'G', 'B'])
df_church = df_church.melt(var_name='Channel', value_name='SSIM')
df_church["Dataset"] = "LSUN"

df_cpg = pd.DataFrame(dict_metrics["ssim_cpg"].mean(1), columns=['Mito.', 'AGP', 'RNA', 'ER', 'DNA'])
df_cpg = df_cpg.melt(var_name='Channel', value_name='SSIM')
df_cpg["Dataset"] = "CPJUMP1"

df_ssim = pd.concat([df_classroom, df_church, df_cpg])

fig_ssim = px.box(df_ssim, x='Channel', y='SSIM', color="Dataset", points=False)

fig_ssim.update_layout(
    font=dict(size=18, color='black'),
    bargap=0.2,
    bargroupgap=0.1,
    xaxis=dict(showgrid=False),
    margin=dict(l=40, r=40, t=40, b=40),
    legend=dict(
        title='Channels',
        orientation='v',
        x=1,
        y=1,
        xanchor='right',
        yanchor='top',
        bgcolor='rgba(255,255,255,0.5)',
        bordercolor='black',
        borderwidth=1
    )
)
fig_ssim.write_image("ssim.pdf")

In [24]:
df_classroom = pd.DataFrame(dict_metrics["emd_classroom"], columns=['R', 'G', 'B'])
df_classroom = df_classroom.melt(var_name='Channel', value_name='EMD')
df_classroom["Dataset"] = "LSUN"

df_church = pd.DataFrame(dict_metrics["emd_churches"], columns=['R', 'G', 'B'])
df_church = df_church.melt(var_name='Channel', value_name='EMD')
df_church["Dataset"] = "LSUN"

df_cpg = pd.DataFrame(dict_metrics["emd_cpg"].mean(1), columns=['Mito.', 'AGP', 'RNA', 'ER', 'DNA'])
df_cpg = df_cpg.melt(var_name='Channel', value_name='EMD')
df_cpg["Dataset"] = "CPJUMP1"

df_emd = pd.concat([df_classroom, df_church, df_cpg])

fig_emd = px.box(df_emd, x='Channel', y='EMD', color="Dataset", points=False)
fig_emd.update_layout(
    font=dict(size=18, color='black'),
    bargap=0.2,
    bargroupgap=0.1,
    xaxis=dict(showgrid=False),
    margin=dict(l=40, r=40, t=40, b=40),
    legend=dict(
        title='Channels',
        orientation='v',
        x=1,
        y=1,
        xanchor='right',
        yanchor='top',
        bgcolor='rgba(255,255,255,0.5)',
        bordercolor='black',
        borderwidth=1
    )
)
fig_emd.write_image("emd.pdf")