In [1]:
import os
import pandas as pd
import requests
from io import BytesIO
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
import boto3
from botocore import UNSIGNED
from botocore.config import Config
import numpy as np
from tqdm import tqdm

from PIL import Image

from rdkit import Chem
from rdkit.Chem import AllChem
import pubchempy as pcp

from torchvision.transforms import Normalize


In [2]:
s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))

def download_image_from_s3(image_url, download_path):
    # Extract the bucket name and key from the image_url
    bucket_name = image_url.split('/')[2]
    key = '/'.join(image_url.split('/')[3:])
    
    # Download the image
    s3_client.download_file(bucket_name, key, download_path)
    # print(f"Image downloaded to {download_path}")


In [3]:
def download_image(s3_client, path, filename):
    bucket_name = path.split('/')[2]
    key = '/'.join(path.split('/')[3:]) + filename
    response = s3_client.get_object(Bucket=bucket_name, Key=key)
    return mpimg.imread(BytesIO(response["Body"].read()), format="tiff")

def normalize(channel):
    return (channel - np.min(channel)) / (np.max(channel) - np.min(channel))

def get_image_from_well(s3_client, load_data_well):
    channel_paths = {
        "AGP": load_data_well['PathName_OrigAGP'].values[0],
        "DNA": load_data_well['PathName_OrigDNA'].values[0],
        "ER": load_data_well['PathName_OrigER'].values[0],
        "Mito": load_data_well['PathName_OrigMito'].values[0],
        "RNA": load_data_well['PathName_OrigRNA'].values[0]
    }

    channel_files = {
        "AGP": load_data_well['FileName_OrigAGP'].values[0],
        "DNA": load_data_well['FileName_OrigDNA'].values[0],
        "ER": load_data_well['FileName_OrigER'].values[0],
        "Mito": load_data_well['FileName_OrigMito'].values[0],
        "RNA": load_data_well['FileName_OrigRNA'].values[0]
    }

    # Download and normalize each channel
    images = {}
    for channel, path in channel_paths.items():
        filename = channel_files[channel]
        images[channel] = download_image(s3_client, path, filename)
        # Normalize the image to [0, 1]
        images[channel] = normalize(images[channel])

    # Combine the channels into an RGB image
    combined_image = np.zeros((images["DNA"].shape[0], images["DNA"].shape[1], 3))

    # Map each channel to RGB (example mapping)
    combined_image[..., 0] = images["DNA"]  # Red channel
    combined_image[..., 1] = images["AGP"]  # Green channel
    combined_image[..., 2] = images["ER"]  # Blue channel

    return combined_image


In [4]:
profile_formatter = (
    "s3://cellpainting-gallery/cpg0016-jump/"
    "{Metadata_Source}/workspace/profiles/"
    "{Metadata_Batch}/{Metadata_Plate}/{Metadata_Plate}.parquet"
)

loaddata_formatter = (
    "s3://cellpainting-gallery/cpg0016-jump/"
    "{Metadata_Source}/workspace/load_data_csv/"
    "{Metadata_Batch}/{Metadata_Plate}/load_data_with_illum.parquet"
)

GIT_CLONE_DIR = "./"

plates = pd.read_csv(os.path.join(GIT_CLONE_DIR, "metadata/plate.csv.gz"))
wells = pd.read_csv(os.path.join(GIT_CLONE_DIR, "metadata/well.csv.gz"))
compounds = pd.read_csv(os.path.join(GIT_CLONE_DIR, "metadata/compound.csv.gz"))

In [5]:
# Whether to sample a certain number of wells per chemical
sample_from_group = True
sample_number = 10

# Whether to use all chemicals or specific ones 
all_chemicals = False

# Total number of sites for each chemical
site_number = 10

In [6]:
# Get plates with compounds
compound_plates = plates[plates['Metadata_PlateType'] == 'COMPOUND']
compound_plates_with_wells = compound_plates.merge(wells, on=['Metadata_Source', 'Metadata_Plate'])
wells_with_chems = compound_plates_with_wells.merge(compounds, on="Metadata_JCP2022")

# Get rid of NA 
selected_chemicals = wells_with_chems['Metadata_InChIKey'].dropna().unique()
wells_with_chems = wells_with_chems[wells_with_chems['Metadata_InChIKey'].isin(selected_chemicals)]

# Get chemical counts and dataframe with only > counts
wells_per_chemical = wells_with_chems['Metadata_InChIKey'].value_counts().reset_index()
    # CRITERIA FOR WHICH CHEMICALS TO USE
if all_chemicals:
    valid_chemicals = selected_chemicals
else:
    valid_chemicals = wells_per_chemical[(wells_per_chemical['count'] > 17) & (wells_per_chemical['count'] < 19)]['Metadata_InChIKey']

print("Number of valid chemicals : ", len(valid_chemicals))
valid_wells = wells_with_chems[wells_with_chems['Metadata_InChIKey'].isin(valid_chemicals)]

if sample_from_group:
    valid_wells = valid_wells.groupby(['Metadata_InChIKey'], group_keys=False).apply(lambda x: x.sample(sample_number))

Number of valid chemicals :  20


  valid_wells = valid_wells.groupby(['Metadata_InChIKey'], group_keys=False).apply(lambda x: x.sample(sample_number))


In [7]:
# Get plates we need to load : 
relevant_plates = valid_wells['Metadata_Plate'].unique()
print("Number of plates : ", len(relevant_plates))

# Load data for plates
load_data = []
for plate in tqdm(relevant_plates):
    # print(plate)
    rows = compound_plates[compound_plates['Metadata_Plate'] == plate]
    row = compound_plates[compound_plates['Metadata_Plate'] == plate].iloc[0]
    s3_path = loaddata_formatter.format(**row.to_dict())
    load_data.append(pd.read_parquet(s3_path, storage_options={"anon": True}))
load_data = pd.concat(load_data)

Number of plates :  152


  0%|          | 0/152 [00:00<?, ?it/s]

100%|██████████| 152/152 [06:15<00:00,  2.47s/it]


In [8]:
list_chemicals = valid_chemicals.to_list()

In [31]:
dataset_mean = np.array([1773.33, 1943.87, 2308.27, 1423.18, 565.90])
# [1773.33367106, 1942.86980054, 2308.26890061 1423.18477434  565.90003101]
# dataset_std = np.array([47.1314, 40.8138, 53.7692, 46.2656, 28.7243])
dataset_std = np.array([1773.33, 1943.87, 2308.27, 1423.18, 565.90])/5


def crop_stack(combined_image):
    height, width, channels = combined_image.shape

    start_x = (width - 696) // 2
    end_x = start_x + 696
    start_y = (height - 520) // 2
    end_y = start_y + 520

    cropped_image = combined_image[start_y:end_y, start_x:end_x, :]

    return cropped_image


def get_chemical_info(chemid):
    try: 
        compound = pcp.get_compounds(chemid, 'inchikey')[0]
        smiles = compound.canonical_smiles
        iupac_name = compound.iupac_name if 'iupac_name' in compound.to_dict() else 'No description available'

        return {"INCHIKEY": chemid, 
                "SMILES": smiles, 
                "CPD_NAME": iupac_name}
    except Exception as e:
        return None

def illumination_threshold(arr, perc=0.0028):
    """ Return threshold value to not display a percentage of highest pixels"""
    perc/=100
    total_pixels = arr.size
    n_pixels = int(np.around(total_pixels * perc))
    flat_inds = np.argpartition(arr, -n_pixels, axis=None)[-n_pixels:]
    threshold = arr.flat[flat_inds].min()
    return threshold

def sixteen_to_eight_bit(arr, display_max, display_min=0):
    threshold_image = ((arr.astype(float) - display_min) * (arr > display_min))
    scaled_image = np.clip(threshold_image * (255. / (display_max - display_min)), 0, 255)
    return scaled_image.astype(np.uint8)


def process_image(arr):
    arr=arr.astype(np.float32)
    threshold = illumination_threshold(arr)

    processed_image = sixteen_to_eight_bit(arr, threshold)

    return processed_image

def normalize_image(image, mean, std):
    normalized_image = (image - mean) / std
    return (normalized_image * 255).astype(np.uint8)

def normalize(channel):
    return (channel - np.min(channel)) / (np.max(channel) - np.min(channel))

def process_site(df, chemical_info, num_channels=5):
    channel_paths = {
        "Ph_golgi": df['PathName_OrigAGP'],
        "Hoecsht": df['PathName_OrigDNA'],
        "ERSyto": df['PathName_OrigER'],
        "Mito": df['PathName_OrigMito'],
        "ERSystoBleed": df['PathName_OrigRNA']
    }

    channel_files = {
        "Ph_golgi": df['FileName_OrigAGP'],
        "Hoecsht": df['FileName_OrigDNA'],
        "ERSyto": df['FileName_OrigER'],
        "Mito": df['FileName_OrigMito'],
        "ERSystoBleed": df['FileName_OrigRNA']
    }

    # Download and normalize each channel
    images = {}
    for channel, path in channel_paths.items():
        filename = channel_files[channel]
        images[channel] = download_image(s3_client, path, filename)
        # images[channel] = normalize(images[channel])

    # Mito – Mito
    # AGP – Ph_golgi
    # ER – ERSyto
    # DNA – Hoecsht
    # RNA – ERSystoBleed

    channels = {'0': 'Mito', '1': 'ERSyto', '2': 'ERSytoBleed', '3': 'Ph_golgi', '4': 'Hoechst'}
    combined_image = np.stack([images["Mito"], images["ERSyto"], images["ERSystoBleed"], images["Ph_golgi"], images["Hoecsht"]], axis=-1)
    combined_image = crop_stack(combined_image)

    processed_channels = [process_image(combined_image[:, :, i]) for i in range(combined_image.shape[-1])]

    processed_image = np.stack(processed_channels, axis=-1)

    normalized_channels = [normalize_image(processed_image[:, :, i], dataset_mean[i], dataset_std[i]) for i in range(processed_image.shape[-1])]

    combined_image = np.stack(normalized_channels, axis=-1)

    id =  f"{df['Metadata_Plate']}-{df['Metadata_Well']}-{df['Metadata_Site']}"

    if num_channels == 5:
        
        save_dir = "preprocessed_data/5_channels"
        np_id = f"{id}.npz"
        np_path = os.path.join(save_dir, np_id)
        np.savez(np_path, sample=combined_image, channels=channels)


    if num_channels == 3:
        R = combined_image[:, :, 4] 
        G = combined_image[:, :, 3]  
        B = combined_image[:, :, 1] 

        rgb_image = np.stack([R, G, B], axis=-1)

        save_dir = "preprocessed_data/3_channels"
        
        png_id = f"{id}.png"
        save_path = os.path.join(save_dir, png_id)

        image = Image.fromarray(rgb_image)
        image.save(save_path)


    metadata = {
        "SAMPLE_KEY": id,
        "PLATE_ID": df["Metadata_Plate"],
        "WELL_POSITION": df["Metadata_Well"],
        "SITE": row["Metadata_Site"],
        "SMILES": chemical_info["SMILES"],
        "INCHIKEY": chemical_info["INCHIKEY"],
        "CPD_NAME": chemical_info["CPD_NAME"]
    }

    return metadata


In [32]:
metadata = []
no_data = 0
for chemical in tqdm(valid_chemicals[0:10]):
    chemical_info = get_chemical_info(chemical)
    if not chemical_info:
        no_data += 1
        continue

    print("Chemical : ", chemical_info["CPD_NAME"])

    valid_well = valid_wells[valid_wells['Metadata_InChIKey'] == chemical]

    merged_df = pd.merge(
    valid_well,
    load_data,
    on=['Metadata_Source', 'Metadata_Batch', 'Metadata_Plate', 'Metadata_Well'],
    how='inner'
    ).sample(site_number)

    i = 0

    print(merged_df)

    for iter, row in merged_df.iterrows():
        print(f"Processing row {i}")
        i += 1
        row_metadata = process_site(row, chemical_info, num_channels=3)
        row_metadata = process_site(row, chemical_info, num_channels=5)
        metadata.append(row_metadata)
        break

metadata_df = pd.DataFrame(metadata)
metadata_df.to_csv("preprocessed_data/metadata.csv")

print("Number of compounds that weren't found : ", no_data)

  0%|          | 0/10 [00:00<?, ?it/s]

Chemical :  6-fluoro-1-(4-fluorophenyl)-4-oxo-7-piperazin-1-ylquinoline-3-carboxylic acid
   Metadata_Source                  Metadata_Batch Metadata_Plate  \
62        source_6  p211012CPU2OS48hw384exp033JUMP   110000295611   
77        source_3               CP_36_all_Phenix1       BAY5872b   
42        source_6  p211012CPU2OS48hw384exp033JUMP   110000295616   
60        source_6  p211012CPU2OS48hw384exp033JUMP   110000295611   
33        source_7                   20210803_Run5     CP5-SC1-19   
18        source_8                              J2       A1170461   
40        source_6  p211012CPU2OS48hw384exp033JUMP   110000295616   
68       source_11                          Batch3       EC000166   
28        source_2                20210816_Batch_9     1086292433   
65        source_6  p211012CPU2OS48hw384exp033JUMP   110000295611   

   Metadata_PlateType Metadata_Well Metadata_JCP2022  \
62           COMPOUND           L19   JCP2022_102548   
77           COMPOUND           L19   

 20%|██        | 2/10 [00:27<01:31, 11.39s/it]

Chemical :  2-[4-[2-[(4-chlorobenzoyl)amino]ethyl]phenoxy]-2-methylpropanoic acid
   Metadata_Source                  Metadata_Batch Metadata_Plate  \
9         source_6  p211012CPU2OS48hw384exp033JUMP   110000295605   
74        source_1                 Batch6_20221102       UL000579   
75        source_1                 Batch6_20221102       UL000579   
38        source_7                   20210719_Run1     CP1-SC1-06   
15        source_6  p211012CPU2OS48hw384exp033JUMP   110000295605   
17        source_6  p211012CPU2OS48hw384exp033JUMP   110000295605   
42        source_7                   20210719_Run1     CP1-SC1-06   
37        source_7                   20210719_Run1     CP1-SC1-06   
54        source_8                              J2       A1170455   
7         source_3                            CP60      BR5873d3W   

   Metadata_PlateType Metadata_Well Metadata_JCP2022  \
9            COMPOUND           F07   JCP2022_035200   
74           COMPOUND           R09   JCP2022_

 20%|██        | 2/10 [00:28<01:52, 14.07s/it]


KeyboardInterrupt: 

In [26]:
# Get mean/standard deviation
def means_stds(df):
    channel_sums = {}
    channel_squared_sums = {}
    total_pixels = 0

    channels = ['Mito', 'ERSyto', 'ERSytoBleed', 'Ph_golgi', 'Hoechst']
    for channel in channels:
        channel_sums[channel] = 0
        channel_squared_sums[channel] = 0

    for index, row in df.iterrows():
        # print(f"Row {index} of {len(df)}")
        channel_paths = {
            "Ph_golgi": row['PathName_OrigAGP'],
            "Hoecsht": row['PathName_OrigDNA'],
            "ERSyto": row['PathName_OrigER'],
            "Mito": row['PathName_OrigMito'],
            "ERSystoBleed": row['PathName_OrigRNA']
        }

        channel_files = {
            "Ph_golgi": row['FileName_OrigAGP'],
            "Hoecsht": row['FileName_OrigDNA'],
            "ERSyto": row['FileName_OrigER'],
            "Mito": row['FileName_OrigMito'],
            "ERSystoBleed": row['FileName_OrigRNA']
        }

        images = {}
        for channel, path in channel_paths.items():
            filename = channel_files[channel]
            images[channel] = download_image(s3_client, path, filename)

        combined_image = np.stack(
            [images["Mito"], images["ERSyto"], images["ERSystoBleed"], images["Ph_golgi"], images["Hoecsht"]], axis=-1)
        combined_image = crop_stack(combined_image)

        total_pixels += combined_image.shape[0] * combined_image.shape[1]

        for i, channel in enumerate(channels):
            channel_data = combined_image[:, :, i]
            channel_sums[channel] += np.sum(channel_data)
            channel_squared_sums[channel] += np.sum(channel_data ** 2)

    means = {channel: channel_sums[channel] / total_pixels for channel in channels}
    stds = {
        channel: np.sqrt(channel_squared_sums[channel] / total_pixels - (means[channel] ** 2))
        for channel in channels
    }

    return means, stds

selected_valid_rows = valid_wells.sample(n=15, random_state=42)

# Now get everythign ! 
print("Number of wells : ", len(valid_wells))
merged_df = pd.merge(
    selected_valid_rows,
    load_data,
    on=['Metadata_Source', 'Metadata_Batch', 'Metadata_Plate', 'Metadata_Well'],
    how='inner'
    )

print("Getting means")

means, stds = means_stds(merged_df)

average_means = [means[channel] for channel in means]
average_stds = [stds[channel] for channel in stds]

print("Average means: ", average_means)
print("Average stds: ", average_stds)

# Output dataset mean and std arrays
dataset_mean = np.array(average_means)
dataset_std = np.array(average_stds)

print("Dataset mean: ", dataset_mean)
print("Dataset std: ", dataset_std)

Number of wells :  200
Getting means
Average means:  [1773.3336710639962, 1942.8698005442168, 2308.2689006113487, 1423.184774336256, 565.9000310074664]
Average stds:  [nan, nan, nan, nan, nan]
Dataset mean:  [1773.33367106 1942.86980054 2308.26890061 1423.18477434  565.90003101]
Dataset std:  [nan nan nan nan nan]


  channel: np.sqrt(channel_squared_sums[channel] / total_pixels - (means[channel] ** 2))
