In [None]:
from statsmodels.nonparametric.smoothers_lowess import lowess

import io
from PIL import Image, ImageChops
import pickle
import numpy as np
import math
import re

import panel as pn
from PIL import Image, ImageOps
import holoviews as hv
import altair as alt
from sklearn.decomposition import PCA
alt.data_transformers.disable_max_rows()
hv.extension("plotly")
from PIL import Image, ImageDraw, ImageFont
pn.extension("plotly")
pn.config.theme = 'dark'
import plotly.io as pio
hv.renderer('plotly').theme = 'dark'
import plotly.express as px
from torch.utils.tensorboard import SummaryWriter
from matplotlib.ticker import FuncFormatter
from datetime import datetime
from functools import partial, reduce

import matplotlib.pyplot as plt
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import numpy as np
import os
import pandas as pd
import seaborn as sns
sns.set_style('darkgrid')
sns.set_palette('muted')
sns.set_context("notebook", font_scale=1.5,
                rc={"lines.linewidth": 2.5})
import torch
import functools
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.functional as F
import gc
from tqdm import tqdm, trange
import lpips
import lpips
import os
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision.transforms.v2 import PILToTensor,Compose
import torchvision
import sys

sys.path.append(os.path.abspath(os.path.join('..')))
from src.models.denoiser import DenoiserModelPipeline
from src.md import generate_gausian_data
from src.preprocessing import ChromoDataContainer, ChromoData


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# This notebook helps to perform the analysis of the model
# Various plots and visualizations are generated to understand the model performance and use plots in the paper

# Model name which will perform the analysing
trained_model_name = "full_labels_psize_4"

pipeline = DenoiserModelPipeline.load(trained_model_name)
pipeline.model.to(device)
pipeline.model.eval()

chromo_data = ChromoDataContainer.load_from_pkl()

In [None]:
# VISUALISZATION FOR NOIZE STEPS
# paper plot
# scattrplots of SD and noize on different steps

gc.collect()
torch.cuda.empty_cache()

# Modeling

T = 1000
# Forward diffusion calculation parameters
betas = torch.linspace(0.0001, 0.02, T)  # (T,)
alphas = 1 - betas  # (T,)
alphas_cumprod = torch.cumprod(alphas, dim=-1)  # Cumulative product of alpha_t (T,) [a1, a2, a3, ....] -> [a1, a1*a2, a1*a2*a3, .....]
alphas_cumprod_prev = torch.cat((torch.tensor([1.0]), alphas_cumprod[:-1]), dim=-1)  # Cumulative product of alpha_t-1 (T,), [1, a1, a1*a2, a1*a2*a3, .....]
variance = (1 - alphas) * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)  # Variance used for denoising (T,)

max_label = pipeline.config.model.parameters.label_num // 2
img_size = pipeline.config.model.parameters.img_size
perform_ablation = True
single_sample = True

def backward_denoise(model, x, y, variance, clamp=False, perform_ablation=False, start_step=0):
    
    if isinstance(x, np.ndarray):
        x = torch.tensor(x).to(device)
        
    if isinstance(y, np.ndarray):
        y = torch.tensor(y).to(device)
        
    if isinstance(variance, np.ndarray):
        variance = torch.tensor(variance).to(device)
        
    steps_noized = []
    steps=[]
    noize_preds=[]
    denoized_unclamp=[]
    denoized_clamp=[]
    ablation_data = []
    ablation_timesteps = range(0,1000)

    global alphas,alphas_cumprod

    x=x.to(device)
    alphas=alphas.to(device)
    alphas_cumprod=alphas_cumprod.to(device)
    variance=variance.to(device)
    y=y.to(device)
    
    model.eval()
    with torch.no_grad():
        for time in range(start_step, -1, -1):
            steps_noized.append(x.clone())
            
            t=torch.full((x.size(0),),time).to(device) 
            
            do_ablation_this_iter = time in ablation_timesteps and perform_ablation
    
            noise, ablation=model(x,t,y, do_ablation_this_iter)
            noize_preds.append(noise.clone())
            if do_ablation_this_iter:
                ablation_data.append(ablation)
            
            shape=(x.size(0),1,1,1) 
            mean=1/torch.sqrt(alphas[t].view(*shape))*  \
                (
                    x-
                    (1-alphas[t].view(*shape))/torch.sqrt(1-alphas_cumprod[t].view(*shape))*noise
                )
            if time!=0:
                x=mean+ \
                    torch.randn_like(x)* \
                    torch.sqrt(variance[t].view(*shape))
            else:
                x=mean
            denoized_unclamp.append(x.clone())
                
            if clamp:
                x=torch.clamp(x, -1.0, 1.0).detach()
            else:
                x = x.detach()
            
            denoized_clamp.append(x.clone())
        
            steps.append(x)
    return steps, steps_noized, noize_preds, denoized_unclamp, denoized_clamp, ablation_data

def generate_sample(score_model, y, init_frame=None, strength=1., clamp=False, img_size=28, perform_ablation=False, start_step_arg=T-1):
    variance_arg=variance.clone()
    if init_frame is None:
        x=torch.randn(size=(1,1,img_size,img_size))
        # x=torch.clamp(x,-1,1)
        start_step = start_step_arg-1
    else:
        x = init_frame * 2 - 1  # Scale to [-1, 1]
        x = x.to(device)
        
        # Calculate starting timestep based on strength
        start_step = int(start_step_arg * strength)-1
        
        # Add the appropriate amount of noise for the starting timestep
        noise = torch.randn_like(x)
        alpha_cumprod = alphas_cumprod[start_step-1]
        x = torch.sqrt(alpha_cumprod) * x + torch.sqrt(1 - alpha_cumprod) * noise
        
        # rescale to [-1, 1]
        x = torch.clamp(x, -1, 1)
        
    y = torch.tensor([y])
    steps, steps_noized, noize_preds, denoized_unclamp, denoized_clamp, ablation_data = backward_denoise(
        score_model,
        x,
        y, 
        variance_arg, 
        clamp=clamp, 
        perform_ablation=perform_ablation, 
        start_step=start_step
    )
    
    # this returns from -1;1 to 0;1
    final_img=(steps[-1][0].to('cpu')+1)/2
    
    
    return final_img, steps, steps_noized, noize_preds, denoized_unclamp, denoized_clamp, ablation_data


def recover_df_from_tensor(tensor):
    if tensor is None:
        return []
    # remove dim 1
    tensor = np.squeeze(tensor, axis=1)
    # convert to numpy
    tensor = tensor.cpu().detach().numpy()
    # convert to df
    dfs = []
    for i in range(tensor.shape[0]):
        gxx_vals = tensor[i]
        # flatten
        gxx_vals = gxx_vals.flatten()
        # create df
        df = pd.DataFrame({
            'POS_bin': np.arange(gxx_vals.shape[0]),
            'Gxx_ratio': gxx_vals,
        })
        dfs.append(df)
    return dfs

def generate_full_sequence(score_model, y_max=24, init_frame=None, strength=0.5, clamp=False, img_size=28, perform_ablation=False, single_label=False, hide_tqdm=False, start_step_arg=T):
    
    def recover_dfs_and_concat(tensors):
        dfs = []
        for tensor in tensors:
            dfs.extend(recover_df_from_tensor(tensor))
        df = pd.concat(dfs)
        df.reset_index(drop=True, inplace=True)
        return df
    
    all_gxxs = []
    all_steps = []
    all_steps_noized = []
    all_noize_preds = []
    all_denoized_unclamp = []
    all_denoized_clamp = []
    all_ablation_data = []
    if single_label:
        total_lbs = [y_max]
    else:
        total_lbs = range(y_max)
    for i in tqdm(total_lbs, disable=hide_tqdm):
        samples, steps, steps_noized, noize_preds, denoized_unclamp, denoized_clamp, ablation_data = generate_sample(score_model, y=i, init_frame=init_frame, strength=strength, clamp=clamp, img_size=img_size, perform_ablation=perform_ablation, start_step_arg=start_step_arg)
        all_gxxs.append(samples)
        all_steps.append(steps)
        all_steps_noized.append(steps_noized)
        all_noize_preds.append(noize_preds)
        all_denoized_unclamp.append(denoized_unclamp)
        all_denoized_clamp.append(denoized_clamp)
        all_ablation_data.append(ablation_data)
        
    
    df_gxxs = recover_dfs_and_concat(all_gxxs)
    return df_gxxs, all_steps, all_steps_noized, all_noize_preds, all_denoized_unclamp, all_denoized_clamp, all_ablation_data
      

chromo_dfs = []
for i in tqdm(range(0, 768, 28)):
    df_gxxs, all_steps, all_noized_steps, all_noize_preds, all_denoized_unclamp, all_denoized_clamp, all_ablation_data = generate_full_sequence(pipeline.model, y_max=i, clamp=False, img_size=img_size, single_label=True, hide_tqdm=True)
    chromo_dfs.append(df_gxxs)

sns.set_theme(style='white', font_scale=1.5)
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams.update({'font.size': 10})

# max 23
indexes_to_sample = [0]
# steps_to_show = [700, 750, 800, 850, 900, 950, 999]
# steps_to_show = [400, 600, 800, 900, 925, 950, 975, 985, 999]
steps_to_show = [400, 800, 925, 975, 999]

# sample chunk onformation
pred_noized_vals = [all_noized_steps[i] for i in indexes_to_sample]
pred_noize_vals = [all_noize_preds[i] for i in indexes_to_sample]
pred_unc_vals = [all_denoized_unclamp[i] for i in indexes_to_sample]
pred_clamp_vals = [all_denoized_clamp[i] for i in indexes_to_sample]

# sample steps from each chunk
pred_noized_vals = [[pred_noized_vals[i][j] for j in steps_to_show] for i in range(len(pred_noized_vals))]
pred_noize_vals = [[pred_noize_vals[i][j] for j in steps_to_show] for i in range(len(pred_noize_vals))]
pred_unc_vals = [[pred_unc_vals[i][j] for j in steps_to_show] for i in range(len(pred_unc_vals))]
pred_clamp_vals = [[pred_clamp_vals[i][j] for j in steps_to_show] for i in range(len(pred_clamp_vals))]

# convert tesnors on each step to df
pred_noized_vals = [[recover_df_from_tensor(pred_noized_vals[i][j]) for j in range(len(pred_noized_vals[i]))] for i in range(len(pred_noized_vals))]
pred_noize_vals = [[recover_df_from_tensor(pred_noize_vals[i][j]) for j in range(len(pred_noize_vals[i]))] for i in range(len(pred_noize_vals))]
pred_unc_vals = [[recover_df_from_tensor(pred_unc_vals[i][j]) for j in range(len(pred_unc_vals[i]))] for i in range(len(pred_unc_vals))]
pred_clamp_vals = [[recover_df_from_tensor(pred_clamp_vals[i][j]) for j in range(len(pred_clamp_vals[i]))] for i in range(len(pred_clamp_vals))]

# zip and join in [[pred_noize_vals, pred_unc_vals, pred_clamp_vals]]
pred_vals = list(zip(
    pred_noized_vals, 
    pred_noize_vals, 
    pred_unc_vals,
    pred_clamp_vals
))
len_charts = len(pred_vals[0])
len_charts = 2

# plot images
fig, axes = plt.subplots(len(pred_vals) * len_charts, len(steps_to_show), figsize=(15, 6))
type_name = ['Noised', 'Noise', 'Unclamp', 'Clamp']
type_colors = ['red', 'blue', 'green', 'purple']
# Add letters to mark rows
row_letters = ['A)', 'B)', 'C)', 'D)']

for i in range(len(pred_vals)):
    for j in range(len_charts):
        for k in range(len(steps_to_show)):
            dfs = pred_vals[i][j][k]
            if len (dfs) == 0:
                continue
            df = dfs[0]
            ax=axes[i*len_charts + j, k]
            
            # Add row letters to first column
            if k == 0:
                ax.text(-0.2, 0.91, row_letters[i*len_charts + j], 
                        transform=ax.transAxes,
                        fontsize=24,
                        fontweight='bold',
                        fontfamily='Times New Roman')
            
            
            # if k == 0:
            # ax.set_title(f"Chunk {indexes_to_sample[i]} Image {k} {type_name[j]}")
            ax.set_title(f"")
            sns.scatterplot(data=df, x='POS_bin', y='Gxx_ratio', color=type_colors[j], ax=ax, s=7)
            # hide y axis on all but first and last
            # if k>0 or k < len(steps_to_show):
                # ax.yaxis.set_visible(False)
            # else:
            ax.yaxis.set_visible(False)
            ax.xaxis.set_visible(False)
            # ax.set_yticks([-2, -1, 0, 1, 2])
            ax.set_yticks([],[])
            ax.set_xticks([],[])
            if j == len_charts - 1:
                ax.xaxis.set_visible(True)
                ax.set_xlabel(f"Step {steps_to_show[k]}")
                ax.set_ylabel(f"")
            else:
                ax.set_xticks([],[])
                ax.set_xlabel("")
                ax.set_ylabel(f"") 

# # put axes on the top
plt.tight_layout()
plt.show()

        

In [None]:
# VISUALISATION OF DENOISING PROCESS
# plots of SD and noize on different steps, with grayscale image visualisation for SD data
# scaled 50% of original paper size to fit here inline

pn.config.theme = 'default'
pn.extension()
sns.color_palette("tab10")
hv.extension("bokeh")
def get_paper_denoising_vis(label, index):

    df_gxxs, all_steps, all_noized_steps, all_noize_preds, all_denoized_unclamp, all_denoized_clamp, all_ablation_data = generate_full_sequence(
        pipeline.model, 
        y_max=label, 
        clamp=False, 
        img_size=img_size, 
        single_label=True, 
        hide_tqdm=True)

    # plot holoviews scatterplot
    scatter = hv.Scatter(df_gxxs, kdims=['POS_bin'], vdims=['Gxx_ratio'])
    # hide controls, make non interactive
    scatter.opts(
        width=150, 
        height=150,
        size=2,
        show_grid=False,
        show_legend=False,
        toolbar=None,
        xlabel='',  # hide x label
        ylabel='',  # hide y label
        xaxis=None, # hide x axis
        yaxis=None, # hide y axis
        bgcolor='white',
        show_frame=False,  # Hide frame
        tools=[],
        active_tools=[]
    )

    # take first 3, middle 3, last 3
    el_to_take = 3
    indexes_to_take = list(range(0, el_to_take)) + list(range(len(df_gxxs) // 2 - el_to_take, len(df_gxxs) // 2)) + list(range(len(df_gxxs) - el_to_take, len(df_gxxs)))
    values_of_indexes = df_gxxs.iloc[indexes_to_take]

    # create squares holovis with gray value of values
    squares = [pn.pane.HTML(width=15, height=15, styles={
        'background': f'rgb({int(255*val)},{int(255*val)},{int(255*val)})',
        'border': '1px solid black',
    }) for val in values_of_indexes['Gxx_ratio']]
    # split squares by groups of el_to_take
    squares_flatten = [pn.Row(*squares[i:i+el_to_take]) for i in range(0, len(squares), el_to_take)]
    
    # create arrow for denoising reshaper
    arrow_down = pn.widgets.ButtonIcon(icon='arrow-down', size='4em')
    
    # caption
    text_simple_visualization = pn.pane.Markdown("How process reshapes flat Chromo to 2d array", width=300, height=300)

    # take 5 denoising steps, evenly from all steps
    steps_to_take = [0, 900, 999]
    # take 5 denoising steps from all steps
    plot_images = []
    for i, step in enumerate(steps_to_take):
        val = all_denoized_unclamp[0][step]
        img = np.array(val.cpu().detach().numpy()).reshape(28, 28)
        image_plot = hv.Image(img, bounds=(0, 0, 28, 28))
        image_plot.opts(
            width=150, 
            height=150, 
            cmap='gray', 
            toolbar=None,
            xticks=None,  # Remove x-axis ticks
            yticks=None,  # Remove y-axis ticks
            xaxis=None,   # Remove x-axis line and labels
            yaxis=None    # Remove y-axis line and labels
        )
        text_image_step = pn.pane.Markdown(
             f"<span style='font-family: \"Times New Roman\"; font-size: 32px; width: 100%; text-align: center; display: block;'>Step {step}</span>",
            align='center')
        plot_images.append(pn.Column(
            image_plot, 
            text_image_step
        ))
        
    # for 3 rows do arrows
    # which indicates mu band
    arrow_left = pn.widgets.ButtonIcon(icon='arrow-left', size='4em')
    if label == 0:
        arrow_left_height = -25 # top
    elif label < 784 - 1:
        arrow_left_height = 25 # middle
    else:
        arrow_left_height = 75 # bottom
    arrow_left_height = arrow_left_height + 75 # respect margin
        
    label_style = {
                'position': 'absolute',
                'top': '-120px',
                'left': '-16px',
                'z-index': '1000',
                'color': 'black',
                'font-size': '48px',
                'font-family': 'Times New Roman',
                'font-weight': 'bold',
    }
    
    divider_style = {
        'width': '1px',
        'background-color': 'black',
        'position': 'fixed',
    }
        
    label_a = pn.pane.Markdown(
            f"**A{index})**", 
            styles=label_style
    )
    label_b = pn.pane.Markdown(
        f"**B{index})**",
        styles=label_style
    )
    label_c = pn.pane.Markdown(
       f"**C{index})**",
        styles=label_style
    )
        
    panel_vis = pn.panel(
        pn.Row(
            pn.Column(
                label_a,
                scatter,
                margin=(50,50,50,50)
            ),
            # column with vertical straight
            pn.Column(
                pn.pane.HTML(styles=divider_style, width=1, sizing_mode="stretch_height")
            ),
            pn.Column(
                label_b,
                pn.Spacer(height=50),
                pn.Row(*squares),
                pn.Row(arrow_down),
                pn.Column(*squares_flatten),
                # text_simple_visualization,
                pn.Spacer(height=50),
                width=350,
                height=350,
                # styles={'background': 'lightgrey'},
                margin=(50,50,50,50)
            ),
            pn.Column(
                pn.pane.HTML(styles=divider_style, width=1, sizing_mode="stretch_height")
            ),
            pn.Column(
                label_c,
                pn.Row(*plot_images),
                margin=(50,50,50,50)
            ),
            pn.Column(
                pn.Spacer(height=arrow_left_height),
                arrow_left,
            )
        )
    )
    # show panel instance in jupyter notebook
    return panel_vis
panel_all = pn.Column(
    get_paper_denoising_vis(0,1),
    get_paper_denoising_vis(784//2,2),
    get_paper_denoising_vis(784-1,3)
)
panel_all



In [None]:
# plot two chromos, him-4(u924) chromosome with experimentally detected misrepresented token adn sparse synthetic generated gausian chromosomical data
%matplotlib inline

gausians = chromo_data.gausians
not_gausians = chromo_data.not_gausians


sns.set_theme(style='white', font_scale=4.5)
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams.update({'font.size': 10})

# create 2 axes figure
fig, ax = plt.subplots(1, 2, figsize=(40, 10))
ax_g54 = ax[0]
ax_sd = ax[1]

#G54_x
g54_x = [g for g in gausians if g.g_number == 'G54' and g.chrom_label == 'X'][0]
chromo_data = g54_x.array
mu = g54_x.m_index
# set plot size
sns.scatterplot(data=chromo_data, x='POS', y='Gxx_ratio', color="dodgerblue", ax = ax_g54)
# put red line on index mu
ax_g54.axvline(x=chromo_data.iloc[mu]['POS'], color='red', linestyle='--', ymin=0.7, ymax=0.95, dashes=(5, 5), linewidth=4)
ax_g54.set_title('$\it{him-4(u924)}$')
ax_g54.set_xlabel('Physical Position (Mb)')
ax_g54.set_ylabel('Hawaiian/Total Ratio')
ax_g54.set_yticks([0, 0.25, 0.5, 0.75, 1])

# set bold A) on a first plot

ax_g54.text(-0.15, 1.1, 'A)',
            transform=ax_g54.transAxes,
            fontsize=80,
            fontweight='bold',
            fontfamily='Times New Roman')

# Set x-ticks quantization
def quantize_ticks(x, pos):
    return f'{x // 1e6:.0f}'

ax_g54.xaxis.set_major_formatter(FuncFormatter(quantize_ticks))

#SD 
sd_df= chromo_dfs[len(chromo_dfs)//2-1]
# clamp sd_df to 0, 1
sd_df['Gxx_ratio_clamp'] = sd_df['Gxx_ratio'].clip(0, 1)
sns.scatterplot(data=sd_df, x=sd_df.index, y='Gxx_ratio_clamp', color='dodgerblue', ax=ax_sd)
ax_sd.set_xlabel('Element Position')
ax_sd.set_ylabel('')
ax_sd.set_yticks([0, 0.25, 0.5, 0.75, 1])

# set bold B) on a second plot

ax_sd.text(-0.15, 1.1, 'B)',
            transform=ax_sd.transAxes,
            fontsize=80,
            fontweight='bold',
            fontfamily='Times New Roman')
    
plt.show()

In [None]:
hv.extension('plotly')
save_to_paper = False

g_54_x = [g for g in gausians if g.g_number == 'G54' and g.chrom_label == 'X'][0]
g_54_v = [g for g in not_gausians if g.g_number == 'G54' and g.chrom_label == 'V'][0]
chromo_data_points = g_54_v.array


def get_brooklyn_layer(grid_size, patch_count, scale_factor, height, rescale_side=100, debug=False):
    # Calculate patch size based on grid_size and patch_count
    patch_size = grid_size // patch_count
    
    if debug:
        print(f"Grid size: {grid_size}, Patch count: {patch_count}, Patch size: {patch_size}")

    # Create grid
    x, y = np.meshgrid(np.arange(grid_size), np.arange(grid_size))

    x_transformed = x.copy().astype(float)
    y_transformed = y.copy().astype(float)
    patches = []

    # Extract patches
    for i in range(patch_count):
        for j in range(patch_count):
            # Calculate patch boundaries
            x_start = i * patch_size
            x_end = (i + 1) * patch_size
            y_start = j * patch_size
            y_end = (j + 1) * patch_size
            
            # Calculate patch center
            center_x = (x_start + x_end) / 2
            center_y = (y_start + y_end) / 2
            
            if debug:
                print(f"Patch {i}, {j}")
                print(f"x_start: {x_start}, x_end: {x_end}, y_start: {y_start}, y_end: {y_end}")
                print(f"Center x: {center_x}, Center y: {center_y}")
                
            # Get patch points
            patch_x = x[y_start:y_end, x_start:x_end]
            patch_y = y[y_start:y_end, x_start:x_end]
            
            # Scale points toward center
            x_transformed[y_start:y_end, x_start:x_end] = center_x + (patch_x - center_x) * scale_factor
            y_transformed[y_start:y_end, x_start:x_end] = center_y + (patch_y - center_y) * scale_factor
            
            patches.append((center_x, center_y))

        # Create z-coordinates for the two layers
        z_transformed = np.ones_like(x.flatten()) * height
        
    # rescale x_transformed and y_transformed to [0, rescale_side]
    x_transformed = np.interp(x_transformed, (0, grid_size), (0, rescale_side))
    y_transformed = np.interp(y_transformed, (0, grid_size), (0, rescale_side))
        
    return x_transformed, y_transformed, z_transformed, patches

def get_brooklyn_3d_plot(
        chromo_data_points,  
        dot_from_x = 106,
        dot_from_y = 6,
        dot_from_z = 0.35,
        show_text_plane = False,
        misrepresented_token_position = None,
        show_bulbs = True,
    ):
    const_measurement_per_point_cloud = 200
    len_chromo_data = len(chromo_data_points['POS'])
    num_of_point_clouds = len_chromo_data // const_measurement_per_point_cloud
    side_min = math.ceil(math.sqrt(num_of_point_clouds))
    side_max = math.ceil(math.sqrt(len_chromo_data))
    points_in_section = (side_max//side_min)**2

    # adjust side_max so it is divisible by side_min
    side_max = side_min * (side_max // side_min)
    print(f"Side min: {side_min}, Side max: {side_max}, Num of point clouds: {num_of_point_clouds}, chromo data len: {len_chromo_data}, Points in section: {points_in_section}")

    layers = []
    # floor level
    floor = get_brooklyn_layer(grid_size=side_max, patch_count=side_min, scale_factor=0.8, height=1.0, rescale_side=side_max)
    # floor power
    floor_power_idx = np.linspace(0, len(chromo_data_points)-1, side_max**2, dtype=int)
    floor_power = chromo_data_points['Gxx_ratio'].iloc[floor_power_idx].to_numpy()
    floor_color = floor_power.copy()
    floor_color = floor_color / 2.2 # rescale to 0 - 0.25
    # modify floor_z slightly by multiplying by power each element by each element
    floor_z = floor[2] * floor_power * 0.3 + 0.2
    print(f"Floor z: {floor_z}")
    floor_x = floor[0]
    floor_y = floor[1]
    
    if misrepresented_token_position is not None:
        # get index of closes POS to misrepresented token
        closest_idx = np.argmin(np.abs(chromo_data_points['POS'] - misrepresented_token_position))
        # get its relative position
        relative_pos = closest_idx/len_chromo_data
        print(f"Closest idx: {closest_idx}, Relative pos: {relative_pos}")
        # detect in which point cloud is misrepresented token
        cell_idx = int(relative_pos * side_min**2) - 1 # not sure why -1
        point_cloud_idx = int(relative_pos * len(floor_z))  
        print(f"Closest idx: {closest_idx}, Cell idx: {cell_idx}, Point cloud idx: {point_cloud_idx}")
        cell_x = cell_idx % side_min
        cell_y = cell_idx // side_min
        print(f"Cell x: {cell_x}, Cell y: {cell_y}")
        cell_x_start = cell_x * (side_max // side_min)
        cell_y_start = cell_y * (side_max // side_min)
        cell_x_end = cell_x_start + (side_max // side_min)
        cell_y_end = cell_y_start + (side_max // side_min)
        # get exact position of misrepresented token in point cloud
        misrepresented_token_position_x = point_cloud_idx % points_in_section
        misrepresented_token_position_y = point_cloud_idx // points_in_section
        # power of misrepresented token x10
        floor_power[point_cloud_idx] = 2
        # set color to max
        floor_color[point_cloud_idx] = 1
        # set lines to point_cloud_idx point
        dot_from_x = floor_x.flatten()[point_cloud_idx]
        dot_from_y = floor_y.flatten()[point_cloud_idx]
        dot_from_z = floor_z[point_cloud_idx]
        # scale x,y,z around dot_from_x, dot_from_y, dot_from_z
        scale_Factor= 2
        for i in range(len(floor_x)):
            floor_x[i] = dot_from_x + (floor_x[i] - dot_from_x) * scale_Factor
            floor_y[i] = dot_from_y + (floor_y[i] - dot_from_y) * scale_Factor
            floor_z[i] = dot_from_z + (floor_z[i] - dot_from_z) * scale_Factor
        # put Z of it upper for 0.05
        floor_z[point_cloud_idx] = floor_z[point_cloud_idx] + 0.05
        print(f"Misrepresented token position x: {misrepresented_token_position_x}, Misrepresented token position y: {misrepresented_token_position_y}")
        # for each point in cell set z to 0.5
        # reshape floor_z to [side_max, side_max]
        floor_z = floor_z.reshape(side_max, side_max)
        
        ## visible neighborhood 
        # z_non_mtp_delta = np.max(floor_z) / 4
        # floor_z[0:cell_y_start] = floor_z[0:cell_y_start] - z_non_mtp_delta
        # floor_z[cell_y_end:] = floor_z[cell_y_end:] - z_non_mtp_delta
        
        # floor_z[:, 0:cell_x_start] = floor_z[:, 0:cell_x_start] - z_non_mtp_delta
        # floor_z[:, cell_x_end:] =  floor_z[:, cell_x_end:] - z_non_mtp_delta
        
        # not visible neighborhood
        z_non_mtp = 1000
        floor_z[0:cell_y_start] = z_non_mtp
        floor_z[cell_y_end:] = z_non_mtp
        
        floor_z[:, 0:cell_x_start] = z_non_mtp
        floor_z[:, cell_x_end:] = z_non_mtp
        
        print(f"Cell x start: {cell_x_start}, Cell x end: {cell_x_end}, Cell y start: {cell_y_start}, Cell y end: {cell_y_end}")
        floor_z = floor_z.flatten()
    
    layers.append((floor_x, floor_y, floor_z, floor_color, floor_power * 0.5))

    # bulb level
    bulbs = get_brooklyn_layer(grid_size=side_min, patch_count=side_min, scale_factor=0, height=2.0, rescale_side=side_max)
    bulbs_power_idx = np.linspace(0, len(chromo_data_points)-1, side_min**2, dtype=int)
    bulbs_power = chromo_data_points['Gxx_ratio'].iloc[bulbs_power_idx].to_numpy()
    bulbs_color = bulbs_power.copy()
    # rescale color t0 0.5 - 1

    bulbs_z = bulbs[2] * bulbs_power * 0.5 + 0.1
    bulbs_x = bulbs[0]
    bulbs_y = bulbs[1]
    bulbs_power = bulbs_power * 2
    bulbs_color = bulbs_color / 2 + 0.5 # rescale to 0.5 - 1

    if not show_bulbs:
        bulbs_z = bulbs_z * 10000
        
    layers.append((bulbs_x, bulbs_y, bulbs_z, bulbs_color, bulbs_power * 2))

    x = [x.flatten() for x, _, _, _, _ in layers]
    y = [y.flatten() for _, y, _ , _, _ in layers]
    z = [z for _, _, z, _, _ in layers]
    c = [c for _, _, _, c, _ in layers]
    s = [s for _, _, _, _, s in layers]

    x = np.concatenate(x)
    y = np.concatenate(y)
    z = np.concatenate(z)
    c = np.concatenate(c)
    s = np.concatenate(s)
    
    print(f"X: {x.shape}, Y: {y.shape}, Z: {z.shape}, C: {c.shape}, S: {s.shape}")
    df = pd.DataFrame({'x': x, 'y': y, 'z': z, 'color': c, 'size': s})
    print(f"Dataframe shapes: {df.shape}")
    # Create 3D scatter plot    
    fig = px.scatter_3d(
        df, 
        x='x', 
        y='y', 
        z='z', 
        color='color', 
        size='size',
    )
    fig.update_traces(
        marker=dict(
            line=dict(width = 1), 
            opacity = 1.0,
        )
    )  # Remove white outline

    if show_text_plane:
        # create 31x31 random ATGC letters array
        # each point cloud have 32x32=1024 points
        atgc_squared_size = 32
        letters = np.random.choice(['A', 'T', 'G', 'C'], (atgc_squared_size, atgc_squared_size))
        char_mat = {
            'A': 'A',
            'T': 'T',
            'G': 'G',
            'C': 'C',
            'M': '<b>A</b>'
        }
        colors_map = {
            'A': 'black',
            'T': 'black',
            'G': 'black',
            'C': 'black',
            'M': 'red'
        }
        size_map = {
            'A': 7,
            'T': 7,
            'G': 7,
            'C': 7,
            'M': 7
        }
        print(f"Letters shape: {letters.shape}")
        for i in range(len(letters)):
            row = letters[i]
            # put diamond on the middle of the array
            if misrepresented_token_position is not None and i == len(letters) // 2:
                row[len(row)//2] = 'M'
                
        # Create grid of points for text

        x_start_letters = side_max / 2
        y_start_letters = side_max / 2
        x_end_letters = side_max - 1
        y_end_letters = side_max - 1

        x_points = np.linspace(x_start_letters, x_end_letters, atgc_squared_size)
        y_points = np.linspace(y_start_letters, y_end_letters, atgc_squared_size)
        X, Y = np.meshgrid(x_points, y_points)

        text_x = X.flatten()
        text_y = Y.flatten()

        # lines connecting letters to research point
        line_width = 2
        line_1 = go.Scatter3d(
            x=[dot_from_x, x_start_letters],
            y=[dot_from_y, y_start_letters],
            z=[dot_from_z, 0.1],
            mode='lines',
            line=dict(color='black', width=line_width),
            showlegend=False
        )
        line_2 = go.Scatter3d(
            x=[dot_from_x, x_end_letters],
            y=[dot_from_y, y_start_letters],
            z=[dot_from_z, 0.1],
            mode='lines',
            line=dict(color='black', width=line_width),
            showlegend=False
        )
        line_3 = go.Scatter3d(
            x=[dot_from_x, x_start_letters],
            y=[dot_from_y, y_end_letters],
            z=[dot_from_z, 0.1],
            mode='lines',
            line=dict(color='black', width=line_width),
            showlegend=False
        )
        line_4 = go.Scatter3d(
            x=[dot_from_x, x_end_letters],
            y=[dot_from_y, y_end_letters],
            z=[dot_from_z, 0.1],
            mode='lines',
            line=dict(color='black', width=line_width),
            showlegend=False
        )


        text_trace = go.Scatter3d(
            x=text_x,
            y=text_y,
            z=np.full(X.size, 0.01),  # Slightly above surface
            mode='text',
            text=[char_mat[letter] for row in letters for letter in row],
            textfont=dict(
                size=[size_map[letter] for row in letters for letter in row],
                family='Courier New, monospace',
                color=[colors_map[letter] for row in letters for letter in row]
            ),
            showlegend=False
        )

        # fig.add_trace(surface_plane)
        fig.add_trace(text_trace)

        # Add lines
        fig.add_trace(line_1)
        fig.add_trace(line_2)
        fig.add_trace(line_3)
        fig.add_trace(line_4)
        

    fig.update_layout(
    scene=dict(
        camera=dict(
            eye=dict(x=1.85, y=0.99, z=1.1),
        ),
        # Ensure axes ranges are appropriate
    )
    )

    last_color = 'orange'
    if misrepresented_token_position is not None:
        last_color = 'red'

    fig.update_layout(
        scene=dict(
            zaxis=dict(range=[0, 2], showgrid=False, showbackground=False, showticklabels=False, title=''),
            xaxis=dict(range=[0, side_max], showgrid=False, showbackground=False, showticklabels=False, title=''),
            yaxis=dict(range=[0, side_max], showgrid=False, showbackground=False, showticklabels=False, title='')
        ),
        width=800, 
        height=800,
        coloraxis=dict(colorscale=[[0, 'black'], [0.25, 'black'],  [0.5, 'yellow'], [0.55, 'purple'], [0.85, 'purple'], [1, last_color]]),
        showlegend=False,
        coloraxis_showscale=False
    )
    return fig

def save_brooklyn_plot(fig):
    img_path = "./brooklyn_plot.png"
    pio.orca.config.use_xvfb = True
    fig.write_image(img_path, width=800, height=800, scale=4, format='png', engine="orca")
    # crop white edges
    im = Image.open(img_path)
    im_rgb = im.convert('RGB')
    bg = Image.new(im_rgb.mode, im.size, im.getpixel((0,0)))
    diff = ImageChops.difference(im_rgb, bg)
    diff = ImageChops.add(diff, diff, 2.0, -100)
    bbox = diff.getbbox()
    if bbox:
        print(f"Cropping image to bbox: {bbox}")
        img = im.crop(bbox)
    else:
        print("No bbox found")
    # remove old image
    os.remove(img_path)
    return img
    
# Add custom colorbar legend with parametrized dimensions
def add_colorbar_legend(fig, 
                       # Position parameters
                       x_pos=0,           # x position of the colorbar
                       y_min=0.05,        # bottom of colorbar
                       y_max=0.25,        # top of colorbar
                       width=0.015,       # width of the colorbar
                       label_offset=0.02,  # distance of labels from colorbar
                       # Style parameters
                       num_steps=1000,    # number of rectangles for gradient
                       border_width=1,    # width of the border
                       font_size=12):     # size of labels
    
    # Calculate steps for gradient
    y_steps = np.linspace(y_min, y_max, num_steps)
    step_height = (y_max - y_min) / num_steps
    
    # Create shapes list for all the rectangles
    shapes = []
    
    # Helper function to get color for each position in the gradient
    def get_color(pos):
        # Interpolate from black to yellow over the full range
        f = pos
        return f'rgb({int(255*f)},{int(255*f)},0)'
    
    # Create gradient rectangles
    for i in range(num_steps):
        pos = i / (num_steps - 1)
        shapes.append(dict(
            type='rect',
            x0=x_pos,
            x1=x_pos + width,
            y0=y_steps[i],
            y1=y_steps[i] + step_height,
            xref='paper',
            yref='paper',
            fillcolor=get_color(pos),
            line=dict(width=0),
        ))
    
    # Add border rectangle
    shapes.append(dict(
        type='rect',
        x0=x_pos,
        x1=x_pos + width,
        y0=y_min,
        y1=y_max,
        xref='paper',
        yref='paper',
        fillcolor='rgba(0,0,0,0)',
        line=dict(color='black', width=border_width),
    ))
    
    # Update the layout with the new shapes and annotations
    fig.update_layout(
        annotations=[
            dict(
                x=x_pos + width + label_offset,
                y=y_max,
                xanchor='left',
                yanchor='bottom',
                text='1.0',
                showarrow=False,
                font=dict(size=font_size, color='black')
            ),
            dict(
                x=x_pos + width + label_offset,
                y=y_min,
                xanchor='left',
                yanchor='bottom',
                text='0.0',
                showarrow=False,
                font=dict(size=font_size, color='black')
            )
        ],
        shapes=shapes
    )


fig_g_54_v = get_brooklyn_3d_plot(g_54_v.array, show_text_plane=False)
fig_g_54_x = get_brooklyn_3d_plot(g_54_x.array, show_text_plane=False)
fig_g_54_x_focused = get_brooklyn_3d_plot(g_54_x.array, misrepresented_token_position=9742980, show_bulbs=False, show_text_plane=True)

print("Adding colorbar legend to G54 X")
add_colorbar_legend(
        fig=fig_g_54_x,
        x_pos=0,
        y_min=0.05,
        y_max=0.45,
        width=0.025,
        label_offset=0.02,
        num_steps=10000,
        border_width=1,
        font_size=12
)
print("Added colorbar legend to G54 X")

plots = [
    save_brooklyn_plot(fig_g_54_x),
    save_brooklyn_plot(fig_g_54_x_focused),
    save_brooklyn_plot(fig_g_54_v)
]

# thumbnail images to min height
min_height = min([img.height for img in plots])
min_width = min([img.width for img in plots])
min_dim = min(min_height, min_width)
[img.thumbnail((min_dim, min_dim)) for img in plots]

# concatenate images
cell_width = plots[0].width
image_grid = Image.new('RGB', (min_dim * 3, min_dim), color='white')
for i, img in enumerate(plots):
    image_grid.paste(img, (i * cell_width, 0))
# add 50 px margin
image_grid = ImageOps.expand(image_grid, border=50, fill='white')
image_grid