# Simple CNN

In this notebook, we will train a simple CNN (LeNet) end-to-end to predict one assay of compound activity.

## 1. Sample Images

For this model, we want to directly use 5-channel images. The images corresponding to one assay come from different plates (different files), so we want to have a nice function to extract those images.

In [78]:
import numpy as np
import pandas as pd
import re
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from torch.utils import data
from collections import OrderedDict
from glob import glob
from os.path import join, exists, basename
from json import load, dump
from shutil import copyfile, rmtree
from sklearn.utils import shuffle
from sklearn import metrics

### 1.1. Output Matrix

In the output_matrix, we have `compound_broad_id`. We can use it to map to individual `pid` and `wid`.

The structure is an array of `(pid, wid)` tuples corresponding to `compound_broad_id` in the output_matrix.

In [43]:
output_data = np.load('./resource/output_matrix_convert_collision.npz')
output_matrix = output_data['output_matrix']
compound_inchi = output_data['compound_inchi']
compound_broad_id = output_data['compound_broad_id']
assay = output_data['assay']
cleaned_output_bids = [i[:13] for i in compound_broad_id]

In [44]:
output_matrix.shape

(27241, 212)

In [34]:
df = pd.read_csv('./resource/merged_meta_table_406.csv')
df.head()

Unnamed: 0,pid,wid,bid,cleaned_bid
0,25855,a01,BRD-K14087339-001-01-6,BRD-K14087339
1,25855,a02,BRD-K53903148-001-01-7,BRD-K53903148
2,25855,a03,BRD-K37357048-001-01-8,BRD-K37357048
3,25855,a04,BRD-K25385069-001-01-7,BRD-K25385069
4,25855,a05,BRD-K63140065-001-01-3,BRD-K63140065


In [39]:
print(len(set(cleaned_table_bids)))
print(len(set(cleaned_output_bids).intersection(set(df['cleaned_bid']))))

30413
26939


Among 30413 imaged compounds, there are 26939 overlapping compounds in our 212 assays.

In [48]:
# Build a dictionary cleaned_bid => [(pid, wid)]

meta_bid_maps = {}
for i, r in df.iterrows():
    cur_bid = r['cleaned_bid']
    cur_pid = r['pid']
    cur_wid = r['wid']
    
    if cur_bid in meta_bid_maps:
        meta_bid_maps[cur_bid].append((cur_pid, cur_wid))
    else:
        meta_bid_maps[cur_bid] = [(cur_pid, cur_wid)]

In [52]:
pid_wids = [[] for i in range(output_matrix.shape[0])]

# Iterate through cmpounds in the output matrix
for i in range(output_matrix.shape[0]):
    cur_bid = cleaned_output_bids[i]
    pid_wids[i] = meta_bid_maps[cur_bid]

In [57]:
# Overwrite the old output matrix, so we don't need to extract pid_wids everytime
np.savez('./resource/output_matrix_convert_collision_.npz',
         output_matrix=output_matrix, compound_inchi=compound_inchi,
         compound_broad_id=compound_broad_id, assay=assay,
         cleaned_output_bids=cleaned_output_bids, pid_wids=pid_wids)

## 1.2. Extract Images and Labels

After getting the map from output compound to `(pid, wid)`, we can write a function to extract images and labels for one given assay.

In [19]:
assay = 192

# Load the output matrix and each row's corresponding pid, wid
output_data = np.load('./resource/output_matrix_convert_collision.npz')
output_matrix = output_data['output_matrix']
pid_wids = output_data['pid_wids']

In [10]:
# Find selected compounds in this assay
selected_index = output_matrix[:, assay] != -1
selected_labels = output_matrix[:, assay][selected_index]
selected_pid_wids = np.array(pid_wids)[selected_index]

In [17]:
# Flatten the selected pid_wids and group them by pid
# selected_wells has structure [(wid, pid, label)]
selected_wells = []

for i in range(len(selected_pid_wids)):
    cur_pid_wids = selected_pid_wids[i]
    cur_label = selected_labels[i]
    
    for pid_wid in cur_pid_wids:
        selected_wells.append((pid_wid[0], pid_wid[1], int(cur_label)))

# Group these wells by their pids
selected_well_dict = {}
for well in selected_wells:
    cur_pid, cur_wid, cur_label = well[0], well[1], well[2]
    
    if cur_pid in selected_well_dict:
        selected_well_dict[cur_pid].append((cur_wid, cur_label))
    else:
        selected_well_dict[cur_pid] = [(cur_wid, cur_label)]

In [19]:
raw_channels = ['ERSyto', 'ERSytoBleed', 'Hoechst', 'Mito', 'Ph_golgi']
raw_paths = ['./{}-{}/*.tif'.format('{}', c) for c in raw_channels]

In [39]:
def extract_instance(pid, wid, label, output_dir='./output'):
    
    paths = [p.format(pid) for p in raw_paths]

    # Dynamically count number of sids for this pid-wid
    sid_files = [f for f in glob(paths[0]) if
                        re.search(r'^.*_{}_s\d_.*\.tif$'.format(wid),
                                  basename(f))]
    sid_num = len(sid_files)

    for sid in range(1, sid_num + 1):
        # Each sid generates one instance
        image_names, images = [], []

        for p in paths:
            # Search current pid-wid-sid
            cur_file = [f for f in glob(p) if
                        re.search(r'^.*_{}_s{}_.*\.tif$'.format(wid, sid),
                                  basename(f))]

            # We should only see one result returned from the filter
            if len(cur_file) != 1:
                error = "Found more than one file for {}-{}-{}.".format(
                    pid, wid, sid
                )
                raise ValueError(error)

            image_names.append(cur_file[0])

        # Read 5 channels
        for n in image_names:
            images.append(cv2.imread(n, -1) * 16)

        # Store each image as a 5 channel 3d matrix
        image_instance = np.array(images)

        # Save the instance with its label
        np.savez(join(output_dir, 'img_{}_{}_{}_{}.npz'.format(
            pid, wid, sid, label
        )), img=image_instance)

In [40]:
pid = 24277

output_dir = './temp_1'

for wid_tuple in selected_well_dict[pid]:
    extract_instance(pid, wid_tuple[0], wid_tuple[1], output_dir)

In [38]:
def extract_plate(pid, selected_well_dict, output_dir='./output'):
    
    # Copy 5 zip files from gluster to the current directory
    for c in raw_channels:
        copyfile("/mnt/gluster/zwang688/{}-{}.zip".format(pid, c),
                 "./{}-{}.zip".format(pid, c))

        # Extract the zip file and remove it
        with zipfile.ZipFile("./{}-{}.zip".format(pid, c), 'r') as fp:
            fp.extractall('./')

        os.remove("./{}-{}.zip".format(pid, c))
        
    # Extract all instances from all selected wells in this plate
    for wid_tuple in selected_well_dict[pid]:
        extract_instance(pid, wid_tuple[0], wid_tuple[1], output_dir)
        
    # Clean up directories
    for c in raw_channels:
        rmtree("./{}-{}".format(pid, c))

(5, 520, 696)

## 2. LeNet

After having a nice function to extract 5-channel images from one given assay, we can start to implement LeNet using PyTorch.

### 2.1. DataLoader

Since we have a lot images, we don't want to load all of them into memory. Similarly to the `DataGenerator` in Keras, torch supports a runtime data loading mechanism. 

In [5]:
class Dataset(data.Dataset):
    """
    Define a dataset class so we can load data in the runtime.
    Trainning dataset, vali dataset and test dataset can all use this class.
    """
    
    def __init__(self, img_names):
        """
        Args:
            img_names([string]): a list of image names in this dataset. The
                name should be a relative path to a single image.
        """
        
        self.img_names = img_names
        
    def __len__(self):
        """
        Tell pytorch how many instances are in this dataset.
        """
        
        return len(self.img_names)
    
    def __getitem__(self, index):
        """
        Generate one image instance based on the given index.
        
        Args:
            index(int): the index of the current item
        
        Return:
            x(tensor): 5-channel 3d tensor encoding one cell image
            y(int): 0 - negative assay, 1 - positive assay
        """
        
        # Read the image matrix and convert to torch tensor
        cur_img_name = self.img_names[index]
        mat = np.load(cur_img_name)['img'].astype(dtype=np.float32)
        x = torch.from_numpy(mat)
        
        # Get the image label from its filename
        y = int(re.sub(r'img_\d+_.+_\d_(\d)\.npz', r'\1',
                       basename(cur_img_name)))
        
        return x, y
        

In [3]:
params = {
    'batch_size': 32,
    'shuffle': True,
    'num_workers': os.cpu_count()
}

training_dataset = Dataset(glob('./temp_2/*.npz'))
training_generator = data.DataLoader(training_dataset, **params)

In [43]:
for local_batch, local_labels in training_generator:
    print(local_batch.shape, local_labels.shape)

torch.Size([32, 5, 696, 696]) torch.Size([32])
torch.Size([32, 5, 696, 696]) torch.Size([32])
torch.Size([8, 5, 696, 696]) torch.Size([8])


### 2.2. LeNet

In this section, we use torch to implement a modified LeNet architecture which supports 5-channel inputs.

![](https://i.imgur.com/OZKLCxm.png)

In [120]:
class LeNet(nn.Module):
    """
    Modified LeNet architecture.
    """

    def __init__(self):
        """
        Create layers for the LeNet network.
        """
        
        super(LeNet, self).__init__()
        
        # C1: 5 channel -> 6 filters (5x5)
        self.conv1 = nn.Conv2d(5, 6, 5)
        # C2: 6 filters -> 16 filters (5x5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # FC1: CP2 -> 120
        self.fc1 = nn.Linear(171*171*16, 120)
        # FC2: FC1 -> 84
        self.fc2 = nn.Linear(120, 84)
        # Output: FC2 -> 2 (activated or not)
        self.output = nn.Linear(84, 2)

    def forward(self, x):
        """
        Pytorch forward() method for autogradient.
        """
        
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        
        # Flatten this layer to connect to FC lyaers
        # size(0) is the batch size
        out = out.view(out.size(0), -1)
        
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        # The original lenet5 doesnt use softmax.
        out = F.softmax(self.output(out), dim=1)
        
        return out

In [122]:
def train_one_epoch(model, device, training_generator, vali_generator,
                    optimizer, epoch, early_stopping=None):
    
    # Set lenet to training mode
    model.train()
    
    train_losses, y_predict_prob, y_true = [], [], []
    for i, (cur_batch, cur_labels) in enumerate(training_generator):
        
        # Transfer tensor to GPU if available
        cur_batch, cur_labels = cur_batch.to(device), cur_labels.to(device)

        # Clean the gradient
        optimizer.zero_grad()

        # Run the network forward
        output = model(cur_batch)

        # Compute the loss
        loss = criterion(output, cur_labels)
        train_losses.append(loss.detach().item())
        y_predict_prob.extend(output.detach().numpy())
        y_true.extend(cur_labels.numpy())

        if epoch % 5 == 0:
            print("Epoch {} - batch {}: loss = {}".format(epoch, i, loss))

        # Backpropogation and update weights
        loss.backward()
        optimizer.step()
    
    # Convert tensor to numpy array so we can use sklearn's metrics
    y_predict_prob = np.stack(y_predict_prob)
    y_predict = np.argmax(y_predict_prob, axis=1)
    y_true = np.array(y_true)
    
    # Average losses over different batches. Each loss corresponds to the mean
    # loss within that batch (reduction="mean").
    train_loss = np.mean(train_losses)
    train_acc = metrics.accuracy_score(y_true, y_predict)
        
    # After training for this epoch, we evaluate this current model on the
    # validation set
    model.eval()
    vali_losses = []
    
    with torch.no_grad():
        for cur_batch, cur_labels in vali_generator:
            cur_batch, cur_labels = cur_batch.to(device), cur_labels.to(device)
            output = model(cur_batch)
            
            loss = criterion(output, cur_labels)
            vali_losses.append(loss.detach().item())
    
    # Average losses over different batches. Each loss corresponds to the mean
    # loss within that batch (reduction="mean").
    vali_loss = np.mean(vali_losses)

    # Early stopping (the real stopping is outside of this function)
    if early_stopping:
        if vali_loss < early_stopping['best_loss']:
            early_stopping['best_loss'] = vali_loss
            early_stopping['wait'] = 0
        else:
            early_stopping['wait'] += 1
            
    return train_loss, train_acc, vali_loss

In [116]:
def test(model, device, test_generator):
    
    # Set model to evaluation mode
    model.eval()
    
    test_losses, y_predict_prob, y_true = [], [], []

    with torch.no_grad():
        for cur_batch, cur_labels in test_generator:
            # Even there is only forward() in testing phase, it is still faster
            # to do it on GPU
            cur_batch, cur_labels = cur_batch.to(device), cur_labels.to(device)
            
            output = model(cur_batch)
            loss = criterion(output, cur_labels)
            
            # Track the loss and prediction for each batch
            test_losses.append(loss.detach().item())
            y_predict_prob.extend(output.detach().numpy())
            y_true.extend(cur_labels.numpy())

    # Convert tensor to numpy array so we can use sklearn's metrics
    # sklearn loves 1d proba array of the activated class
    y_predict_prob = np.stack(y_predict_prob)[:, 1]
    y_predict = [1 if i >= 0.5 else 0 for i in y_predict_prob]
    y_true = np.array(y_true)
    
    # Take the average of batch loss means
    test_loss = np.mean(test_losses)
    test_acc = metrics.accuracy_score(y_true, y_predict)
    
    print("Testing on {} instances, the accuracy is {:.2f}.".format(
        len(test_generator), test_acc
    ))
    
    return test_loss, test_acc, y_predict_prob, y_true

Prepare for training a LeNet. We need to create data generators and an early stopping tracking dictionary.

In [121]:
params = {
    'batch_size': 32,
    'shuffle': True,
    'num_workers': os.cpu_count()
}

img_names = glob('./temp_1/*.npz')

# Randomly split img_names into three sets
img_names = shuffle(img_names)
quintile_len = len(img_names) // 5
vali_names = img_names[: quintile_len]
test_names = img_names[quintile_len: quintile_len * 2]
train_names = img_names[quintile_len * 2: ]

print("There are {} training samples, {} validation samples, and {} test samples.".\
      format(len(train_names), len(vali_names), len(test_names)))

# Create data generators
training_dataset = Dataset(train_names)
training_generator = data.DataLoader(training_dataset, **params)

vali_dataset = Dataset(vali_names)
vali_generator = data.DataLoader(vali_dataset, **params)

test_dataset = Dataset(test_names)
test_generator = data.DataLoader(test_dataset, **params)

There are 6 training samples, 2 validation samples, and 2 test samples.


In [123]:
# Run on GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Use cross-entropy as our loss funciton
lenet = LeNet()
lenet.to(device)

criterion = nn.CrossEntropyLoss(reduction='mean')
optimizer = optim.Adam(lenet.parameters(), lr=0.001)

# Init early stopping
early_stopping_dict = {
    'best_loss': np.inf,
    'wait': 0,
    'patience': 20
}

for e in range(1):
    train_loss, train_acc, vali_loss = train_one_epoch(
        lenet, device, training_generator, vali_generator,
        optimizer, e, early_stopping=early_stopping_dict
    )
    
    if early_stopping_dict['wait'] > early_stopping_dict['patience']:
        break

Epoch 0 - batch 0: loss = 1.31326162815094


In [117]:
test_loss, test_acc, y_predict_prob, y_true = test(lenet, device, test_generator)

[2, 6, 692, 692]
[2, 6, 346, 346]
[2, 16, 342, 342]
[2, 16, 171, 171]
[2, 467856]
[2, 2]
Testing on 1 instances, the accuracy is 0.50.


In [107]:
y_true

array([0, 0])