# Simple CNN

In this notebook, we will train a simple CNN (LeNet) end-to-end to predict one assay of compound activity.

## 1. Sample Images

For this model, we want to directly use 5-channel images. The images corresponding to one assay come from different plates (different files), so we want to have a nice function to extract those images.

In [41]:
import numpy as np
import pandas as pd
import re
import cv2
from glob import glob
from os.path import join, exists, basename
from json import load, dump
from shutil import copyfile, rmtree

### 1.1. Output Matrix

In the output_matrix, we have `compound_broad_id`. We can use it to map to individual `pid` and `wid`.

The structure is an array of `(pid, wid)` tuples corresponding to `compound_broad_id` in the output_matrix.

In [43]:
output_data = np.load('./resource/output_matrix_convert_collision.npz')
output_matrix = output_data['output_matrix']
compound_inchi = output_data['compound_inchi']
compound_broad_id = output_data['compound_broad_id']
assay = output_data['assay']
cleaned_output_bids = [i[:13] for i in compound_broad_id]

In [44]:
output_matrix.shape

(27241, 212)

In [34]:
df = pd.read_csv('./resource/merged_meta_table_406.csv')
df.head()

Unnamed: 0,pid,wid,bid,cleaned_bid
0,25855,a01,BRD-K14087339-001-01-6,BRD-K14087339
1,25855,a02,BRD-K53903148-001-01-7,BRD-K53903148
2,25855,a03,BRD-K37357048-001-01-8,BRD-K37357048
3,25855,a04,BRD-K25385069-001-01-7,BRD-K25385069
4,25855,a05,BRD-K63140065-001-01-3,BRD-K63140065


In [39]:
print(len(set(cleaned_table_bids)))
print(len(set(cleaned_output_bids).intersection(set(df['cleaned_bid']))))

30413
26939


Among 30413 imaged compounds, there are 26939 overlapping compounds in our 212 assays.

In [48]:
# Build a dictionary cleaned_bid => [(pid, wid)]

meta_bid_maps = {}
for i, r in df.iterrows():
    cur_bid = r['cleaned_bid']
    cur_pid = r['pid']
    cur_wid = r['wid']
    
    if cur_bid in meta_bid_maps:
        meta_bid_maps[cur_bid].append((cur_pid, cur_wid))
    else:
        meta_bid_maps[cur_bid] = [(cur_pid, cur_wid)]

In [52]:
pid_wids = [[] for i in range(output_matrix.shape[0])]

# Iterate through cmpounds in the output matrix
for i in range(output_matrix.shape[0]):
    cur_bid = cleaned_output_bids[i]
    pid_wids[i] = meta_bid_maps[cur_bid]

In [57]:
# Overwrite the old output matrix, so we don't need to extract pid_wids everytime
np.savez('./resource/output_matrix_convert_collision_.npz',
         output_matrix=output_matrix, compound_inchi=compound_inchi,
         compound_broad_id=compound_broad_id, assay=assay,
         cleaned_output_bids=cleaned_output_bids, pid_wids=pid_wids)

## 1.2. Extract Images and Labels

After getting the map from output compound to `(pid, wid)`, we can write a function to extract images and labels for one given assay.

In [2]:
assay = 192

# Load the output matrix and each row's corresponding pid, wid
output_data = np.load('./resource/output_matrix_convert_collision.npz')
output_matrix = output_data['output_matrix']
pid_wids = output_data['pid_wids']

In [6]:
# Find selected compounds in this assay
selected_index = output_matrix[:, assay] != -1
selected_labels = output_matrix[:, assay][selected_index]
selected_pid_wids = np.array(pid_wids)[selected_index]

In [17]:
# Flatten the selected pid_wids and group them by pid
# selected_wells has structure [(wid, pid, label)]
selected_wells = []

for i in range(len(selected_pid_wids)):
    cur_pid_wids = selected_pid_wids[i]
    cur_label = selected_labels[i]
    
    for pid_wid in cur_pid_wids:
        selected_wells.append((pid_wid[0], pid_wid[1], int(cur_label)))

# Group these wells by their pids
selected_well_dict = {}
for well in selected_wells:
    cur_pid, cur_wid, cur_label = well[0], well[1], well[2]
    
    if cur_pid in selected_well_dict:
        selected_well_dict[cur_pid].append((cur_wid, cur_label))
    else:
        selected_well_dict[cur_pid] = [(cur_wid, cur_label)]

In [19]:
raw_channels = ['ERSyto', 'ERSytoBleed', 'Hoechst', 'Mito', 'Ph_golgi']
raw_paths = ['./{}-{}/*.tif'.format('{}', c) for c in raw_channels]

In [39]:
def extract_instance(pid, wid, label, output_dir='./output'):
    
    paths = [p.format(pid) for p in raw_paths]

    # Dynamically count number of sids for this pid-wid
    sid_files = [f for f in glob(paths[0]) if
                        re.search(r'^.*_{}_s\d_.*\.tif$'.format(wid),
                                  basename(f))]
    sid_num = len(sid_files)

    for sid in range(1, sid_num + 1):
        # Each sid generates one instance
        image_names, images = [], []

        for p in paths:
            # Search current pid-wid-sid
            cur_file = [f for f in glob(p) if
                        re.search(r'^.*_{}_s{}_.*\.tif$'.format(wid, sid),
                                  basename(f))]

            # We should only see one result returned from the filter
            if len(cur_file) != 1:
                error = "Found more than one file for {}-{}-{}.".format(
                    pid, wid, sid
                )
                raise ValueError(error)

            image_names.append(cur_file[0])

        # Read 5 channels
        for n in image_names:
            images.append(cv2.imread(n, -1) * 16)

        # Store each image as a 5 channel 3d matrix
        image_instance = np.array(images)

        # Save the instance with its label
        np.savez(join(output_dir, 'img_{}_{}_{}_{}.npz'.format(
            pid, wid, sid, label
        )), img=image_instance)

In [40]:
pid = 24277

output_dir = './temp_1'

for wid_tuple in selected_well_dict[pid]:
    extract_instance(pid, wid_tuple[0], wid_tuple[1], output_dir)

In [38]:
def extract_plate(pid, selected_well_dict, output_dir='./output'):
    
    # Copy 5 zip files from gluster to the current directory
    for c in raw_channels:
        copyfile("/mnt/gluster/zwang688/{}-{}.zip".format(pid, c),
                 "./{}-{}.zip".format(pid, c))

        # Extract the zip file and remove it
        with zipfile.ZipFile("./{}-{}.zip".format(pid, c), 'r') as fp:
            fp.extractall('./')

        os.remove("./{}-{}.zip".format(pid, c))
        
    # Extract all instances from all selected wells in this plate
    for wid_tuple in selected_well_dict[pid]:
        extract_instance(pid, wid_tuple[0], wid_tuple[1], output_dir)
        
    # Clean up directories
    for c in raw_channels:
        rmtree("./{}-{}".format(pid, c))

(5, 520, 696)