# Motifs analysis - Part 1: Motifs extraction

Notebook to perform the discriminative motifs analysis. It requires a trained model but it is an independant analysis from the analysis of feature space and from the prototypes analysis.

Motifs are extracted on the base of Class-Activation Maps (CAMs) which display the saliency of a class in a given input according to a model. CAMs towards any class can be computed regardless of the actual class of the input. This means that one can look for discriminative motifs of class B in an input of class A. However, for the sake of motif extraction, we don't use this feature of CAMs. Instead we produce CAMs towards the actual class of the input.

The motif extraction procedure is as follow:
1. Select trajectories from which to extract motifs.
2. Compute CAM for each trajectory (saliency towards its own class).
3. Binarize each time point into 'relevant' and 'non-relevant' to recognize input class.
4. Optional but recommended, extend the 'relevant' regions to capture more context around the motifs and connect smaller adjacents motifs into a bigger one. Also filter for motif length.
5. Extract the longest 'relevant' stretches of time-points. These are the final motifs.

In order to visualize these motifs, we propose to cluster them afterwards as follow:
1. Build a distance matrix between the motifs with dynamic time warping (dtw)
2. Cluster with hierarchical clustering.
3. Visualize dynamics captured by each cluster.

This clustering can be run in 2 modes: either patterns from every class are pooled together, either a separate clustering is run indepently for each class. In the 1st case, this will reflect the diversity of patterns at the dataset level and can reveal dynamics overlap between classes. In the second case, the emphasis is put on the diversity of dynamics induced by each class.


This notebook covers only the motif extraction part. It ends with the export of the motifs to a csv file. Go to the next one for computing DTW and clustering!


## Import libraries

In [1]:
# Standard libraries
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import pandas as pd
from skimage.filters import threshold_li, threshold_mean
import os
from itertools import chain
from tqdm import tqdm
import sys

# Custom functions/classes
path_to_module = '../source'  # Path where all the .py files are, relative to the notebook folder
sys.path.append(path_to_module)
from load_data import DataProcesser
from results_model import top_confidence_perclass, least_correlated_set
from pattern_utils import extend_segments, create_cam, longest_segments, extract_pattern
from class_dataset import myDataset, ToTensor, RandomCrop

# For reproducibility
myseed = 7
torch.manual_seed(myseed)
torch.cuda.manual_seed(myseed)
np.random.seed(myseed)

cuda_available = torch.cuda.is_available()

## Parameters

Parameters for the motifs extraction:
- selected_set: str one of ['all', 'training', 'validation', 'test'], from which set of trajectories should motifs be extracted? For this purprose, extracting from training data also makes sense.
- n_series_perclass: int, maximum number of series, per class, on which motif extraction is attempted.
- n_pattern_perseries: int, maximum number of motifs to extract out of a single trajectory.
- mode_series_selection: str one of ['top_confidence', 'least_correlated']. Mode to select the trajectories from which to extract the motifs (see Prototype analysis). If top confidence, the motifs might be heavily biased towards a representative subpopulation of the class. Hence, the output might not reflect the whole diversity of motifs induced by the class.
- extend_patt: int, by how many points to extend motifs? After binarization into 'relevant' and 'non-relevant time points', the motifs are usually fragmented because a few points in their middle are improperly classified as 'non-relevant'. This parameter allows to extend each fragment by a number of time points (in both time directions) before extracting the actual patterns.
- min_len_patt/max_len_patt: int, set minimum/maximum size of a motif. **/!\ The size is given in number of time-points. This means that if the input has more than one channel, the actual length of the motifs will be divided across them.** For example, a motif that spans over 2 channels for 10 time points will be considered of length 20.

Parameters for the groups of motifs:
- export_perClass: bool, whether to run the motif clustering class per class.
- export_allPooled: bool, whether to pool all motifs across classes for clustering.

In [2]:
selected_set = 'all'
n_series_perclass = 75
n_pattern_perseries = 1
mode_series_selection = 'top_confidence'
# mode_series_selection = 'least_correlated'
thresh_confidence = 0.5  # used in least_correlated mode to choose set of series with minimal classification confidence
extend_patt = 5
min_len_patt = 0
max_len_patt = 200 # length to divide by nchannel

export_perClass = False
export_allPooled = True

assert selected_set in ['all', 'training', 'validation', 'test']
assert mode_series_selection in ['top_confidence', 'least_correlated']

## Load model and data

- Pay attention to the order of 'meas_var', should be the same as for training the model!
- Pay attention to trajectories preprocessing.
- Set batch_size as high as memory allows for speed up.

In [3]:
data_file = '../../../p53MCF10AstimuliCNN_Dhana/data_forCNN/p53fc_24h_26classes_imba.zip'
model_file = '../../../p53MCF10AstimuliCNN_Dhana/models/2021-02-08-15__35__59_p53fc_24h_26classes_imba.pytorch'

out_dir = 'auto'  # If 'auto' will automatically create a directory to save motifs tables

meas_var = None  # Set to None for auto detection
start_time = None  # Set to None for auto detection
end_time = None  # Set to None for auto detection

batch_size = 32  # Set as high as memory allows for speed up
is_cuda = torch.cuda.is_available()
device = torch.device('cuda' if is_cuda else 'cpu')
model = torch.load(model_file) if cuda_available else torch.load(model_file, map_location='cpu')
model.eval()
model.double()
model.batch_size = batch_size
model = model.to(device)

Pay attention that **data.process() is already centering the data**, so don't do a second time when loading the data in the DataLoader. The **random crop** should be performed before passing the trajectories to the model to ensure that the same crop is used as input and for extracting the patterns.

In [4]:
# Transformations to perform when loading data into the model
ls_transforms = transforms.Compose([RandomCrop(output_size=model.length, ignore_na_tails=True),
                                                            ToTensor()])
# Loading and PREPROCESSING
data = DataProcesser(data_file)
meas_var = data.detect_groups_times()['groups'] if meas_var is None else meas_var
start_time = data.detect_groups_times()['times'][0] if start_time is None else start_time
end_time = data.detect_groups_times()['times'][1] if end_time is None else end_time
# Path where to export tables with motifs
if out_dir == 'auto':
    out_dir = 'output/' + '_'.join(meas_var) + '/local_motifs/'
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

data.subset(sel_groups=meas_var, start_time=start_time, end_time=end_time)
cols_to_check=data.dataset.columns.values[data.dataset.columns.str.startswith('FGF')]
cols_dict={k:'float64' for k in cols_to_check}
data.dataset=data.dataset.astype(cols_dict)
data.get_stats()
data.process(method='center_train', independent_groups=True)  # do here and not in loader so can use in df
data.crop_random(model.length, ignore_na_tails=True)
data.split_sets(which='dataset')
classes = tuple(data.classes[data.col_classname])
dict_classes = data.classes[data.col_classname]

# Random crop before to keep the same in df as the ones passed in the model
if selected_set == 'validation':
    selected_data = myDataset(dataset=data.validation_set, transform=ls_transforms)
    df = data.validation_set
elif selected_set == 'training':
    selected_data = myDataset(dataset=data.train_set, transform=ls_transforms)
    df = data.train_set
elif selected_set == 'test':
    selected_data = myDataset(dataset=data.test_set, transform=ls_transforms)
    df = data.train_set
elif selected_set == 'all':
    try:
        selected_data = myDataset(dataset=data.dataset_cropped, transform=ls_transforms)
        df = data.dataset_cropped
    except:
        selected_data = myDataset(dataset=data.dataset, transform=ls_transforms)
        df = data.dataset

if batch_size > len(selected_data):
    raise ValueError('Batch size ({}) must be smaller than the number of trajectories in the selected set ({}).'.format(batch_size, len(selected_data)))        

data_loader = DataLoader(dataset=selected_data,
                         batch_size=batch_size,
                         shuffle=True,
                         num_workers=4)
# Dataframe used for retrieving trajectories. wide_to_long() instead of melt() because can do melting per group of columns
df = pd.wide_to_long(df, stubnames=meas_var, i=[data.col_id, data.col_class], j='Time', sep='_', suffix='\d+')
df = df.reset_index()  # wide_to_long creates a multi-level Index, reset index to retrieve indexes in columns
df.rename(columns={data.col_id: 'ID', data.col_class: 'Class'}, inplace=True)
df['ID'] = df['ID'].astype('U32')
del data  # free memory



## Select trajectories from which to extract patterns

In [5]:
if mode_series_selection == 'least_correlated':
    set_trajectories = least_correlated_set(model, data_loader, threshold_confidence=thresh_confidence, device=device,
                                            n=n_series_perclass, labels_classes=dict_classes)
elif mode_series_selection == 'top_confidence':
    set_trajectories = top_confidence_perclass(model, data_loader, device=device, n=n_series_perclass,
                                               labels_classes=dict_classes)

# free some memory by keeping only relevant series
selected_trajectories = set_trajectories['ID']
df = df[df['ID'].isin(selected_trajectories)]
# Make sure that class is an integer (especially when 0 or 1, could be read as boolean)
df['Class'] = df['Class'].astype('int32')

 98%|█████████▊| 377/384 [00:02<00:00, 178.59it/s]

## Extract patterns

### Extract, extend and filter patterns. 

Outputs a report of how many trajectories were filtered out by size.

In [6]:
# Initialize dict to store the patterns and set progress bar
store_patts = {i:[] for i in classes}
model.batch_size = 1  # Leave it to 1!
report_filter = {'Total number of patterns': 0,
                 'Number of patterns above maximum length': 0,
                 'Number of patterns below minimum length': 0}
pbar = tqdm(total=len(selected_trajectories))

for id_trajectory in selected_trajectories:
    # Read and format the trajectories to numpy
    series_numpy = np.array(df.loc[df['ID'] == id_trajectory][meas_var]).astype('float').squeeze()
    # Row: measurement; Col: time
    if len(meas_var) >= 2:
        series_numpy = series_numpy.transpose()
    series_tensor = torch.tensor(series_numpy)
    class_trajectory = df.loc[df['ID']==id_trajectory]['Class'].iloc[0]  # repeated value through all series
    class_label = classes[class_trajectory]
    
    # Create and process the CAM for the trajectory
    cam = create_cam(model, array_series=series_tensor, feature_layer='features',
                         device=device, clip=0, target_class=class_trajectory)
    thresh = threshold_li(cam)
    bincam = np.where(cam >= thresh, 1, 0)
    bincam_ext = extend_segments(array=bincam, max_ext=extend_patt)
    patterns = longest_segments(array=bincam_ext, k=n_pattern_perseries)
    
    # Filter short/long patterns
    report_filter['Total number of patterns'] += len(patterns)
    report_filter['Number of patterns above maximum length'] += len([k for k in patterns.keys() if patterns[k] > max_len_patt])
    report_filter['Number of patterns below minimum length'] += len([k for k in patterns.keys() if patterns[k] < min_len_patt])
    patterns = {k: patterns[k] for k in patterns.keys() if (patterns[k] >= min_len_patt and
                                                            patterns[k] <= max_len_patt)}
    if len(patterns) > 0:
        for pattern_position in list(patterns.keys()):
            store_patts[class_label].append(extract_pattern(series_numpy, pattern_position, NA_fill=False))
    pbar.update(1)

print(report_filter)


  0%|          | 0/1854 [00:00<?, ?it/s][A
  0%|          | 4/1854 [00:00<00:56, 32.95it/s][A
  0%|          | 8/1854 [00:00<00:56, 32.41it/s][A
  1%|          | 12/1854 [00:00<00:56, 32.46it/s][A
  1%|          | 16/1854 [00:00<00:56, 32.52it/s][A
  1%|          | 20/1854 [00:00<00:56, 32.43it/s][A
  1%|▏         | 24/1854 [00:00<00:56, 32.61it/s][A
  2%|▏         | 28/1854 [00:00<00:55, 32.76it/s][A
  2%|▏         | 32/1854 [00:00<00:55, 32.58it/s][A
  2%|▏         | 36/1854 [00:01<00:56, 32.11it/s][A
  2%|▏         | 40/1854 [00:01<00:57, 31.45it/s][A
  2%|▏         | 44/1854 [00:01<00:58, 31.20it/s][A
  3%|▎         | 48/1854 [00:01<00:57, 31.26it/s][A
  3%|▎         | 52/1854 [00:01<00:58, 30.97it/s][A
  3%|▎         | 56/1854 [00:01<00:57, 31.21it/s][A
  3%|▎         | 60/1854 [00:01<00:57, 31.45it/s][A
  3%|▎         | 64/1854 [00:02<00:56, 31.74it/s][A
  4%|▎         | 68/1854 [00:02<00:59, 30.23it/s][A
  4%|▍         | 72/1854 [00:02<00:57, 31.15it/s][A
  4

 27%|██▋       | 496/1854 [00:18<00:57, 23.51it/s][A
 27%|██▋       | 499/1854 [00:18<00:57, 23.49it/s][A
 27%|██▋       | 502/1854 [00:18<00:57, 23.71it/s][A
 27%|██▋       | 505/1854 [00:18<00:56, 23.95it/s][A
 27%|██▋       | 508/1854 [00:18<00:56, 23.99it/s][A
 28%|██▊       | 511/1854 [00:19<00:56, 23.98it/s][A
 28%|██▊       | 514/1854 [00:19<00:55, 23.96it/s][A
 28%|██▊       | 517/1854 [00:19<00:55, 23.89it/s][A
 28%|██▊       | 520/1854 [00:19<00:55, 23.95it/s][A
 28%|██▊       | 523/1854 [00:19<00:55, 23.89it/s][A
 28%|██▊       | 526/1854 [00:19<00:55, 23.86it/s][A
 29%|██▊       | 529/1854 [00:19<00:55, 24.00it/s][A
 29%|██▊       | 532/1854 [00:19<00:54, 24.25it/s][A
 29%|██▉       | 535/1854 [00:20<00:54, 24.21it/s][A
 29%|██▉       | 538/1854 [00:20<00:54, 24.01it/s][A
 29%|██▉       | 541/1854 [00:20<00:54, 24.03it/s][A
 29%|██▉       | 544/1854 [00:20<00:54, 24.06it/s][A
 30%|██▉       | 547/1854 [00:20<00:54, 23.96it/s][A
 30%|██▉       | 550/1854 [0

 49%|████▊     | 902/1854 [00:39<00:57, 16.56it/s][A
 49%|████▉     | 904/1854 [00:39<00:55, 17.10it/s][A
 49%|████▉     | 906/1854 [00:39<00:54, 17.28it/s][A
 49%|████▉     | 908/1854 [00:40<00:54, 17.32it/s][A
 49%|████▉     | 910/1854 [00:40<00:53, 17.71it/s][A
 49%|████▉     | 912/1854 [00:40<00:51, 18.17it/s][A
 49%|████▉     | 914/1854 [00:40<00:51, 18.41it/s][A
 49%|████▉     | 916/1854 [00:40<00:50, 18.60it/s][A
 50%|████▉     | 918/1854 [00:40<00:53, 17.34it/s][A
 50%|████▉     | 920/1854 [00:40<01:02, 14.86it/s][A
 50%|████▉     | 922/1854 [00:40<01:00, 15.47it/s][A
 50%|████▉     | 924/1854 [00:41<00:56, 16.43it/s][A
 50%|████▉     | 926/1854 [00:41<01:08, 13.58it/s][A
 50%|█████     | 928/1854 [00:41<01:01, 14.97it/s][A
 50%|█████     | 930/1854 [00:41<00:57, 16.08it/s][A
 50%|█████     | 932/1854 [00:41<00:54, 16.99it/s][A
 50%|█████     | 934/1854 [00:41<00:52, 17.65it/s][A
 50%|█████     | 936/1854 [00:41<00:50, 18.08it/s][A
 51%|█████     | 938/1854 [0

 65%|██████▍   | 1200/1854 [00:56<00:37, 17.67it/s][A
 65%|██████▍   | 1202/1854 [00:56<00:37, 17.62it/s][A
 65%|██████▍   | 1204/1854 [00:56<00:36, 17.64it/s][A
 65%|██████▌   | 1206/1854 [00:56<00:36, 17.78it/s][A
 65%|██████▌   | 1208/1854 [00:56<00:36, 17.72it/s][A
 65%|██████▌   | 1210/1854 [00:56<00:37, 17.03it/s][A
 65%|██████▌   | 1212/1854 [00:56<00:38, 16.80it/s][A
 65%|██████▌   | 1214/1854 [00:57<00:38, 16.68it/s][A
 66%|██████▌   | 1216/1854 [00:57<00:38, 16.72it/s][A
 66%|██████▌   | 1218/1854 [00:57<00:37, 16.77it/s][A
 66%|██████▌   | 1220/1854 [00:57<00:38, 16.61it/s][A
 66%|██████▌   | 1222/1854 [00:57<00:38, 16.45it/s][A
 66%|██████▌   | 1224/1854 [00:57<00:37, 16.68it/s][A
 66%|██████▌   | 1226/1854 [00:57<00:36, 17.04it/s][A
 66%|██████▌   | 1228/1854 [00:57<00:36, 17.30it/s][A
 66%|██████▋   | 1230/1854 [00:57<00:35, 17.41it/s][A
 66%|██████▋   | 1232/1854 [00:58<00:35, 17.46it/s][A
 67%|██████▋   | 1234/1854 [00:58<00:35, 17.65it/s][A
 67%|█████

 81%|████████  | 1496/1854 [01:17<00:27, 13.25it/s][A
 81%|████████  | 1498/1854 [01:17<00:26, 13.52it/s][A
 81%|████████  | 1500/1854 [01:17<00:27, 12.81it/s][A
 81%|████████  | 1502/1854 [01:18<00:26, 13.05it/s][A
 81%|████████  | 1504/1854 [01:18<00:28, 12.16it/s][A
 81%|████████  | 1506/1854 [01:18<00:27, 12.84it/s][A
 81%|████████▏ | 1508/1854 [01:18<00:41,  8.42it/s][A
 81%|████████▏ | 1510/1854 [01:19<00:45,  7.49it/s][A
 82%|████████▏ | 1512/1854 [01:19<00:38,  8.89it/s][A
 82%|████████▏ | 1514/1854 [01:19<00:33, 10.13it/s][A
 82%|████████▏ | 1516/1854 [01:19<00:31, 10.84it/s][A
 82%|████████▏ | 1518/1854 [01:19<00:28, 11.96it/s][A
 82%|████████▏ | 1520/1854 [01:19<00:26, 12.69it/s][A
 82%|████████▏ | 1522/1854 [01:20<00:26, 12.52it/s][A
 82%|████████▏ | 1524/1854 [01:20<00:27, 12.19it/s][A
 82%|████████▏ | 1526/1854 [01:20<00:29, 11.22it/s][A
 82%|████████▏ | 1528/1854 [01:20<00:34,  9.43it/s][A
 83%|████████▎ | 1530/1854 [01:20<00:30, 10.58it/s][A
 83%|█████

 95%|█████████▍| 1754/1854 [02:34<00:10,  9.97it/s][A
 95%|█████████▍| 1756/1854 [02:34<00:10,  9.18it/s][A
 95%|█████████▍| 1758/1854 [02:34<00:11,  8.70it/s][A
 95%|█████████▍| 1760/1854 [02:34<00:10,  9.40it/s][A
 95%|█████████▌| 1762/1854 [02:35<00:09,  9.91it/s][A
 95%|█████████▌| 1764/1854 [02:35<00:08, 10.34it/s][A
 95%|█████████▌| 1766/1854 [02:35<00:08, 10.64it/s][A
 95%|█████████▌| 1768/1854 [02:35<00:08, 10.46it/s][A
 95%|█████████▌| 1770/1854 [02:35<00:07, 10.58it/s][A
 96%|█████████▌| 1772/1854 [02:36<00:07, 10.57it/s][A
 96%|█████████▌| 1774/1854 [02:36<00:07, 10.78it/s][A
 96%|█████████▌| 1776/1854 [02:36<00:07, 10.95it/s][A
 96%|█████████▌| 1778/1854 [02:36<00:06, 10.86it/s][A
 96%|█████████▌| 1780/1854 [02:36<00:07, 10.55it/s][A
 96%|█████████▌| 1782/1854 [02:36<00:06, 10.56it/s][A
 96%|█████████▌| 1784/1854 [02:37<00:06, 10.82it/s][A
 96%|█████████▋| 1786/1854 [02:37<00:06, 10.92it/s][A
 96%|█████████▋| 1788/1854 [02:37<00:06, 10.65it/s][A
 97%|█████

{'Total number of patterns': 1854, 'Number of patterns above maximum length': 0, 'Number of patterns below minimum length': 0}


### Dump patterns into csv

In [7]:
if export_allPooled:
    concat_patts_allPooled = np.full((sum(map(len, store_patts.values())), len(meas_var) * max_len_patt), np.nan)
    irow = 0
for classe in classes:
    concat_patts = np.full((len(store_patts[classe]), len(meas_var) * max_len_patt), np.nan)
    for i, patt in enumerate(store_patts[classe]):
        if len(meas_var) == 1:
            len_patt = len(patt)
            concat_patts[i, 0:len_patt] = patt
        if len(meas_var) >= 2:
            len_patt = patt.shape[1]
            for j in range(len(meas_var)):
                offset = j*max_len_patt
                concat_patts[i, (0+offset):(len_patt+offset)] = patt[j, :]
    if len(meas_var) == 1:
        headers = ','.join([meas_var[0] + '_' + str(k) for k in range(max_len_patt)])
        fout_patt = out_dir + 'motif_{}.csv.gz'.format(classe)
        if export_perClass:
            np.savetxt(fout_patt, concat_patts,
                       delimiter=',', header=headers, comments='')
    elif len(meas_var) >= 2:
        headers = ','.join([meas + '_' + str(k) for meas in meas_var for k in range(max_len_patt)])
        fout_patt = out_dir + 'motif_{}.csv.gz'.format(classe)
        if export_perClass:
            np.savetxt(fout_patt, concat_patts,
                       delimiter=',', header=headers, comments='')
    if export_allPooled:
        concat_patts_allPooled[irow:(irow+concat_patts.shape[0]), :] = concat_patts
        irow += concat_patts.shape[0]

if export_allPooled:
    concat_patts_allPooled = pd.DataFrame(concat_patts_allPooled)
    concat_patts_allPooled.columns = headers.split(',')
    pattID_col = [[classe] * len(store_patts[classe]) for classe in classes]
    concat_patts_allPooled['pattID'] = [j+'_'+str(i) for i,j in enumerate(list(chain.from_iterable(pattID_col)))]
    concat_patts_allPooled.set_index('pattID', inplace = True)
    fout_patt = out_dir + 'motif_allPooled.csv.gz'.format(classe)
    concat_patts_allPooled.to_csv(fout_patt, header=True, index=True, compression='gzip')


100%|██████████| 1854/1854 [02:58<00:00, 10.26it/s][A

### Build distance matrix between patterns with DTW

This is done in R with the implementation of the *parallelDist* package. It is very efficient and has support for multivariate cases.

Check next notebook.