# Deep-learning-based identification of protein nanobarcodes

This notebook contains the code necessary for using our trained deep neural network for identification of protein nanobarcodes from confocal images. 

Current setup in the code loads hyperparameters and network weights from the ```../network_params``` folder. These files will be downloaded from an ftp server the first time the notebook is run.

If you wish, you can use this notebook for training of the network with new data.

This notebook depends on the included Python package ```deep_nanobarcode```. For instruction on how to set up the library, see the ```README.md``` file or visit the github page at https://github.com/noegroup/deep_nanobarcode.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
from tqdm.notebook import tqdm

import os, glob, urllib, json, pickle
import torch,tifffile

### Importing the deep_nanobarcode library, which initially does a check on PyTorch installation and CUDA availability:

In [None]:
from deep_nanobarcode import dataset_handler
from deep_nanobarcode import model
from deep_nanobarcode import image_processing
from deep_nanobarcode import network_components

### Setting up IO paths, as well as specifying which protein expression/imaging condition among the available training data is to be used:

In [None]:
series_number = 15

output_folder = f"../output/series_{series_number}"
network_params_folder = "../network_params"
training_data_folder = "../training_data"
example_folder = "../examples"

dataset_expression_time = "24h"
dataset_scaling = "whitened"
dataset_gain = "550"

ext_str = f"{dataset_expression_time}_{dataset_scaling}_gain_{dataset_gain}"

saved_net_filename = f"weights_series_{series_number}_trained_on_{ext_str}.pth"
dataset_filename = f"kpca_input_{ext_str}.pkl"

for _data_folder, _data_filename in zip([network_params_folder, training_data_folder],
                                        [saved_net_filename, dataset_filename]):
    
    if not glob.glob(f"{_data_folder}/{_data_filename}"):

        print(f"Downloading {_data_filename}...")

        urllib.request.urlretrieve(f"https://ftp.mi.fu-berlin.de/pub/cmb-data/deep_nanobarcode/{_data_filename}",
                           f"{_data_folder}/{_data_filename}")
    else:
        
        print(f"{_data_filename} is already present!")

### The dataset object manages input training data, its shuffling, train/test split, and data augmentation:

In [None]:
dataset = dataset_handler.NanobarcodeDataset(brightness_data_file_name=f"{training_data_folder}/{dataset_filename}",
                                             train_val_test_split_frac=(0.8, 0.1, 0.1),
                                             do_brightness_augmentation=True,
                                             brightness_augmentation_factor=0.9,
                                             verbose=True)

### Getting the names of labeled proteins from the dataset:

In [None]:
print(dataset.id_to_protein_name)

### Checking the training data for a specified channel and protein (feel free to change): 

In [None]:
channel_ind = 8
protein_ind = 3

In [None]:
plt.figure(figsize=(7, 7))

for _pr in dataset.protein_names:
    
    data = dataset.brightness_data[_pr]["data"][channel_ind, :]
    
    hist, bin_edges = np.histogram(data, bins=100)
    hist = hist / np.sum(hist)
    plt.plot(bin_edges[:-1], hist, label=_pr, linewidth=3)

plt.title(f"channel ind = {channel_ind}, $\lambda = $ {dataset_handler.channel_wavelength[channel_ind]} nm")
plt.ylabel('Probability')
plt.xlabel('Brightness')

plt.legend()

In [None]:
protein_name = dataset.id_to_protein_name[protein_ind]

plt.figure(figsize=(10, 7))

for i in range(dataset_handler.n_channels):
    
    data = dataset.brightness_data[protein_name]["data"][i]

    hist, bin_edges = np.histogram(data, bins=50)
    hist = hist / np.sum(hist)
    
    plt.plot(bin_edges[:-1], hist, label='channel - ' + str(i + 1), linewidth=3)

plt.title(protein_name)
plt.ylabel('Probability')
plt.xlabel('Brightness')

plt.legend()

### Boolean flags controlling the behavior of the rest of the notebook:

In [None]:
save_net = False
load_net = True
train_net = False

### Loading network hyperparameters from file:

In [None]:
with open(f"{network_params_folder}/model_config_series_{series_number}.json", 'r') as fp:
    network_parameters = json.load(fp)
    
print(network_parameters)

### Setting up the network using the ```model_factory```:

In [None]:
net, train_data_loader, val_data_loader, test_data_loader, optimizer =\
    model.model_factory(dataset, *model.separate_args(network_parameters))

print(net)

In [None]:
if load_net:
    net.load_state_dict(torch.load(f"{network_params_folder}/{saved_net_filename}",
                                   map_location=network_components.nn_device))

### If the flag ```train_net``` is set to True, the following code will result in the retraining of the network

In [None]:
if train_net:
    losses, accuracies, epoch = model.training_loop(net, train_data_loader, val_data_loader, optimizer,
                                                    starting_epoch=0, num_epochs=120,
                                                    losses=None, accuracies=None,
                                                    save_checkpoint=save_net,
                                                    starting_max_accuracy=max_accuracy,
                                                    save_net_file_name=saved_net_filename)

### Checking the prediction accuracy on the test set:

In [None]:
test_metric = model.calc_metrics(net, test_data_loader)

print("Test accuracy = {:5.2f}%".format(test_metric['overall accuracy'] * 100.0))

print()

print("{:12s}\t{}\t{}\t\t{}\t{}\t{}".format("Protein name", "Precision", "Recall", "F1-score",
                                            "False positive", "False negative"))
print("-" * 100)

for protein_name in dataset.protein_names:
    
    protein_id = dataset.brightness_data[protein_name]["ID"]
    
    print("{:12s}\t{:5.2f}\t\t{:5.2f}\t\t{:5.2f}\t\t{:5.2f}\t\t{:5.2f}"
          .format(protein_name,
                  test_metric["precision"][protein_id] * 100.0,
                  test_metric["recall"][protein_id] * 100.0,
                  test_metric["F1-score"][protein_id] * 100.0,
                  test_metric["percent false positive"][protein_id] * 100.0,
                  test_metric["percent false negative"][protein_id] * 100.0))

### Predicting nanobarcodes from example confocal images:

In [None]:
example_ind = 0

# number of entropy minimization iterations used by the "ContrastModifierNet"
n_optim_iter = 10

example_files = ['GalNact_NL1b_CS2_P5_CH11_Gain550',
                 'ENDO_NL1ab_CS2_P2_CH11_Gain550',
                 'KDEL_NL4_CS1_P3_CH11_Gain550',
                 'TOM70_NRXN3Beta-_CS1_P1_CH11_Gain550']

fig, ax = plt.subplots(nrows=len(example_files), ncols=2, figsize=(10, 5 * len(example_files)))

for i, _file in enumerate(example_files):

    image_file_name = f"{example_folder}/{_file}.lsm"

    result = model.predict_from_image_file(file_name=image_file_name, net=net,
                                           dataset=dataset, n_optim_iter=n_optim_iter,
                                           brightness_scaling_method=dataset_scaling)

    ax[i, 0].imshow(result["unprocessed false-color stack"][0, :, :] / 65535.0)
    ax[i, 0].set_title(f"Unprocessed false color image -- {result['protein name']}", fontsize=12)
    ax[i, 0].axis('off')

    ax[i, 1].imshow(result["cell-halo false-color stack"][0, :, :] / 65535.0)
    ax[i, 1].set_title(f"False color + cell halos (network prediction) -- {result['protein name']}", fontsize=12)
    ax[i, 1].axis('off')