In [1]:
%reload_ext autoreload
%autoreload 2
%reload_ext notexbook
%texify

In [2]:
import ncolor

In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from skimage.measure import label, regionprops
from sklearn.metrics import confusion_matrix
from skimage import io
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch_geometric.nn as geom_nn
from torch_geometric.data import Data as geom_Data
from torch_geometric.loader import DataLoader as geom_DataLoader
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import add_self_loops, degree, to_networkx
from torch_geometric.nn import MessagePassing, GATConv, GraphConv, GATv2Conv
import networkx as nx
import fastremap
import random
import pickle
from pathlib import Path
from pprint import pprint

from narsil2.tracking.cell_simulator.cell import RodShapedCell, CoccoidCell, RodShapedCellStochastic
from narsil2.tracking.cell_simulator.channel import ChannelState

from narsil2.tracking.losses import TrackerLoss

%matplotlib qt

In [4]:
class channelStackGraph(geom_Data):
    
    def __init__(self, time_points=1, edge_features_size=64,
                images_dirpath=None, links_dirpath=None, datapoint_number=None):
        super().__init__()
        self.time_points = time_points
        self.edge_features_size = edge_features_size
        if type(images_dirpath) == str:
            self.images_dirpath = Path(images_dirpath)
        if type(links_dirpath) == str:   
            self.links_dirpath = Path(links_dirpath)
            
        self.datapoint_number = datapoint_number
        
        self.data = {}
        if datapoint_number is None:
            self._generate_data()
        else:
            self._load_data_from_file()
        
        self.nodes = {} # one key for each time step, 0 to t-1, for t timesteps 
        self.edges = {} # one key for each time step, 0 to t-1 for t timesteps
        self.edge_attributes = {}
        # for each timestep 
        self.edge_labels = {} 
        
        self._init_graph()
    
    def _generate_data(self):
        ecoli_cell = RodShapedCellStochastic(
                                length=46, width=30, position=[40, 40],
                            division_size=84, elongation_rate=4,
                        img_size=(2048, 80))
        #coccoid_cell = CoccoidCell()
        channel = ChannelState([ecoli_cell], img_size=(2048, 80), top_boundary=240, bottom_boundary=1600)
        images, links = channel.get_stack(time_points=self.time_points, plot=False)
        
        self.data['images']  = images
        self.data['links'] = links
        
    def _load_data_from_file(self):
        
        # read from file_path provided and load them appropriately 
        links_filename = str(self.datapoint_number) + '.pickle'
        links_filepath = self.links_dirpath / links_filename
        with open(links_filepath, 'rb') as f:
            links = pickle.load(f)
            
        # loop over images and accumulate
        images_dirname = self.images_dirpath / str(self.datapoint_number)
        images_filepaths = sorted(list(images_dirname.glob('*.tiff')), key= lambda name: int(name.stem))
        
        images_list = []
        number_images = len(images_filepaths)
        for i in range(number_images):
            image = io.imread(images_filepaths[i])
            images_list.append(image)
        
        self.time_points = number_images - 1
        self.data['images'] = np.asarray(images_list)
        self.data['links'] = links
    
    def __len__(self):
        return len(self.data['links'])
    
    def _construct_node_features(self, props):
        row, column, height, width = props['bbox'][0], props['bbox'][1], \
                                    props['bbox'][2] - props['bbox'][0], props['bbox'][3] - props['bbox'][1]
        area = props['area']
        eccentricity = props['eccentricity']
        return [row, column, height, width, area, eccentricity]
    
    def _init_graph(self):
        frames = self.__len__() + 1
        regionprops_stack = [regionprops(label(self.data['images'][i])) for i in range(frames)]
        
        # interate and build a graph
        
        for i in range(0, frames - 1):
            num_nodes = len(regionprops_stack[i]) + len(regionprops_stack[i+1])
            pooled_properties = regionprops_stack[i] + regionprops_stack[i+1]
            self.nodes[i] = []
            self.nodes[i].append(np.asarray([self._construct_node_features(properties) 
                                        for properties in regionprops_stack[i]]))
            self.nodes[i].append(np.asarray([self._construct_node_features(properties)
                                            for properties in regionprops_stack[i+1]]))
    
    def __getitem__(self, idx):
        
        node_features = self.nodes[idx]
        n_objects_1 = node_features[0].shape[0]
        n_objects_2 = node_features[1].shape[0]
        # create edge_index
        edge_index = [[], []]
        # basically add connections between nodes at t and t+1
        for i in range(n_objects_1):
            for j in range(n_objects_2):
                edge_index[0].append(i)
                edge_index[1].append(j+n_objects_1)
        
        #edge_index_repeat = [edge_index[0] + edge_index[1], edge_index[1] + edge_index[0]].copy()
        
        #edge_index_repeat = np.asarray(edge_index_repeat)
        edge_index_repeat = np.asarray(edge_index)
        # create edge_attribute
        num_edges = edge_index_repeat.shape[1]
        edge_attr = np.zeros((num_edges, self.edge_features_size))
        
        # create labels for the edge attributes to generate the affinity matrix
        
        x = np.vstack((node_features[0], node_features[1]))
        
        return {
            'images': self.data['images'][idx:idx+2],
            'links': torch.from_numpy(self.data['links'][idx][:n_objects_1, :n_objects_2].astype('int')),
            'x': torch.from_numpy(x.astype('float32')),
            'edge_index': torch.from_numpy(edge_index_repeat),
            'edge_attr': torch.from_numpy(edge_attr),
            'node_t': n_objects_1,
            'node_t1': n_objects_2
        }
    
    def plot_images(self, cmap='viridis', colors = {1: 'r', 2: 'g', 3: 'm'}):
        
        num_images, height, width = self.data['images'].shape
        full_img = np.zeros((height, num_images * width))
        properties = []
        for i in range(num_images):
            image = self.data['images'][i]
            full_img[:, i*width:(i+1)*width] = image
            image, _ = fastremap.renumber(image, in_place=True)
            properties.append(regionprops(image))
            
        links = self.data['links']
        
        plt.figure(figsize=(10, 6))
        plt.imshow(full_img, cmap=cmap)
        
        for i, frame_links in enumerate(links, 0):
            n_objects_1 = len(properties[i])
            n_objects_2 = len(properties[i+1])
            
            row, column = np.nonzero(frame_links)
            link_values = frame_links[row, column]
            for l in range(len(link_values)):
                if row[l] < n_objects_1 and column[l] < n_objects_2:
                    centroid_t_x, centroid_t_y = properties[i][row[l]]['centroid']
                    centroid_t1_x, centroid_t1_y = properties[i+1][column[l]]['centroid']
                    plt.plot([centroid_t_y + i * (width) , centroid_t1_y + (i + 1) * width],
                            [centroid_t_x , centroid_t1_x], colors[link_values[l]])
                

        plt.title(f"Timepoints : {self.time_points+1} (starting at t=0)")
        plt.show()

In [5]:
class trackerNet(nn.Module):
    
    def __init__(self, input_node_size= 6, hidden_size=64, edge_classes=3):
        super().__init__()
        self.hidden_size = hidden_size
        self.edge_classes = edge_classes
        self.input_node_size = input_node_size
        self.edge_mlp = nn.Sequential(
                    nn.Linear(hidden_size, hidden_size),
                    nn.ReLU(inplace=True),
                    nn.Linear(hidden_size, edge_classes),
        )
        
        self.inital_node_transform = nn.Sequential(
                    #nn.BatchNorm1d(input_node_size),
                    nn.Linear(input_node_size, hidden_size),
                    nn.ReLU(inplace=True),
        )
            
    def forward(self, one_step_data):
        x, edge_index, edge_attr = one_step_data['x'], one_step_data['edge_index'], one_step_data['edge_attr']
        n_objects_1, n_objects_2 = one_step_data['node_t'], one_step_data['node_t1']
        
        # initial node transform
        x = self.inital_node_transform(x)
        
        # update the edges accordingly and generate the affinity matrices
        src, dst = edge_index
        #print("src:", src, "dst:", dst)
        #print("x_src shape:", x[src].shape)
        #print("x_dst shape:", x[dst].shape)
        diff = x[src] - x[dst]
        #print("Diff shape:", diff.shape)
        edge_attr_mlp = self.edge_mlp(diff)
        affinity_matrix_scores = edge_attr_mlp.view(n_objects_1, n_objects_2, self.edge_classes)
        #print("Affinity matrix size: ", affinity_matrix_scores.shape)
                
        return x, affinity_matrix_scores #(n_objects_1, n_objects_2), edge_index

#### Loss function

```
d = channelStackGraph(time_points=100)
d.plot_images()
timepoint = 9
net = trackerNet()
x, affinity_matrix_scores = net(d[timepoint])
t = TrackerLoss()
t(affinity_matrix_scores, x, d[timepoint]['links'])
print(d[timepoint]['links'])
```

torch.argmax(torch.softmax(affinity_matrix_scores, dim=2), dim=-1)

In [6]:
images_dirpath = '../../data/tracking_data/images'
links_dirpath = '../../data/tracking_data/links'

In [7]:
i = 350
data = channelStackGraph(time_points=None, images_dirpath=images_dirpath, 
                            links_dirpath=links_dirpath, datapoint_number=i)
data.plot_images()

### Train loops for the final model.

In [8]:
nEpochs = 15

images_dirpath = '../../data/tracking_data/images'
links_dirpath = '../../data/tracking_data/links'

net = trackerNet()
criterion = TrackerLoss(weight=0.1)

optimizer = torch.optim.SGD(net.parameters(), lr=0.001)

stack_numbers = list(range(1000))

#data = channelStackGraph(time_points=time_points)
#data.plot_images()
random.shuffle(stack_numbers)
train_stack_numbers = stack_numbers[:800]
validation_stack_numbers = stack_numbers[800:]

for i in range(nEpochs):
    
    # shuffle to pick graphs in random order
    random.shuffle(train_stack_numbers)
    epoch_loss = 0.0
    for l in train_stack_numbers:
        avg_loss = 0.0
        data = channelStackGraph(time_points=None, images_dirpath=images_dirpath, links_dirpath=links_dirpath,
                                 datapoint_number=l)
        time_points = data.time_points
        for t in range(time_points):
            optimizer.zero_grad()

            x, affinity_matrix_scores = net(data[t])
            loss = criterion(affinity_matrix_scores, x, data[t]['links'])
            loss.backward()
            avg_loss = avg_loss + loss.item()
            optimizer.step()
        epoch_loss += avg_loss / time_points
        #print(f"Seen data_stack {l}")
    net.eval()
    epoch_val_loss = 0.0
    for vl in validation_stack_numbers:
        avg_val_loss = 0.0
        data = channelStackGraph(time_points=None, images_dirpath= images_dirpath,
                                links_dirpath=links_dirpath, datapoint_number= vl)
        time_points = data.time_points
        for t in range(time_points):
            x, affinity_matrix_scores = net(data[t])
            loss = criterion(affinity_matrix_scores, x, data[t]['links'])
            avg_val_loss = avg_val_loss + loss.item()
            
        epoch_val_loss += avg_val_loss / time_points
        
    print(f"Epoch : {i} -- Train Loss: {epoch_loss / len(train_stack_numbers)} -- Val loss: {epoch_val_loss/ len(validation_stack_numbers)}")
    net.train()

Epoch : 0 -- Train Loss: 0.09374988769517714 -- Val loss: 0.18421345861115834
Epoch : 1 -- Train Loss: 0.0519958519029129 -- Val loss: 0.10635627234626584
Epoch : 2 -- Train Loss: 0.12556063278620613 -- Val loss: 0.08314883443468425
Epoch : 3 -- Train Loss: 0.047009166325504725 -- Val loss: 0.08146179043623636
Epoch : 4 -- Train Loss: 0.11782356821846672 -- Val loss: 0.07552939836215372
Epoch : 5 -- Train Loss: 0.04507315826831821 -- Val loss: 0.08200289809316891
Epoch : 6 -- Train Loss: 0.15122830306750848 -- Val loss: 0.2327696075037123
Epoch : 7 -- Train Loss: 0.11121811810428646 -- Val loss: 0.10105753671944843
Epoch : 8 -- Train Loss: 0.04634379602791509 -- Val loss: 0.09633520317727226
Epoch : 9 -- Train Loss: 0.10709799053683262 -- Val loss: 0.10979798237128995
Epoch : 10 -- Train Loss: 0.04718833884786267 -- Val loss: 0.0927583662180805
Epoch : 11 -- Train Loss: 0.046254167704110324 -- Val loss: 0.0935666065640428
Epoch : 12 -- Train Loss: 0.04604282746367319 -- Val loss: 0.090

In [9]:
save_path = Path('../../saved_models/track_model.pth')
torch.save(net.state_dict(), save_path)

In [14]:
data.plot_images()

In [10]:
net_path = '../../saved_models/tracker_model.pth'
net = trackerNet()
net.load_state_dict(torch.load(net_path))
net.eval()

trackerNet(
  (edge_mlp): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=3, bias=True)
  )
  (inital_node_transform): Sequential(
    (0): Linear(in_features=6, out_features=64, bias=True)
    (1): ReLU(inplace=True)
  )
)

In [11]:
%timeit channelStackGraph(time_points=30)

523 ms ± 181 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
time_points=100
data = channelStackGraph(time_points=time_points, datapoint_number=None)
data.plot_images()

#### Test printouts

In [13]:
def run_net():
    affinity_accumulated = []
    true_accumulated = []

    with torch.no_grad():
        for i in range(time_points):
            x, affinity_matrix_scores = net(data[i])
            affinity_accumulated.append(affinity_matrix_scores)
            true_accumulated.append(data[i]['links'])
            #print("Links: ", data[i]['links'])
            #print("Pred:", torch.argmax(torch.softmax(affinity_matrix_scores, dim=2), dim=-1))
            #print("*********")

In [14]:
%timeit run_net()

12.8 ms ± 189 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
affinity_accumulated = []
true_accumulated = []

with torch.no_grad():
    for i in range(time_points):
        x, affinity_matrix_scores = net(data[i])
        affinity_accumulated.append(affinity_matrix_scores)
        true_accumulated.append(data[i]['links'])
        print("Links: ", data[i]['links'])
        print("Pred_prob: ", torch.max(torch.softmax(affinity_matrix_scores, dim=2), dim=-1)[0])
        print("Pred:", torch.argmax(torch.softmax(affinity_matrix_scores, dim=2), dim=-1))
        print("*********")

Links:  tensor([[1]])
Pred_prob:  tensor([[0.9983]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9996]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9983]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9983]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9983]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9996]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9998]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9983]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9996]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9996]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9996]])
Pred: tensor([[1]])
*********
Links:  tensor([[1]])
Pred_prob:  tensor([[0.9996]])
Pred: tensor([[1]])
*********
Link

In [16]:
len(affinity_accumulated), len(true_accumulated)

(100, 100)

In [17]:
aff = torch.softmax(affinity_accumulated[30], dim=2)

In [18]:
aff

tensor([[[8.4666e-04, 9.9831e-01, 8.4518e-04]]])

In [19]:
aff.shape

torch.Size([1, 1, 3])

In [20]:
torch.max(aff, dim=-1)[0]

tensor([[0.9983]])

In [21]:
torch.max(aff, dim=-1)[1]

tensor([[1]])

In [22]:
data.plot_images()

### Metrics


In [23]:
def metrics_set_of_links(affinity_matrices, true_links):
    assert len(affinity_matrices) == len(true_links), "True and predicted don't match in shapes"
    y_pred = []
    y_true = []
    for i in range(len(true_links)):
        true_links_one_step = true_links[i]
        # accumulate the indices of true links
        true_links_indices = torch.nonzero(true_links_one_step, as_tuple=True)
        links = true_links_one_step[true_links_indices]
        #print(links.tolist())
        y_true.extend(links.tolist())
        pred_links_one_step = torch.argmax(torch.softmax(affinity_matrices[i], dim=2), dim=-1)
        # accumulate the indices of the predicted links 
        pred_links = pred_links_one_step[true_links_indices]
        #print(pred_links.tolist())
        #print("********")
        y_pred.extend(pred_links.tolist())
        # calculate the confusion matrix
        
    return y_true, y_pred

In [24]:
y_true, y_pred = metrics_set_of_links(affinity_accumulated, true_accumulated)

In [25]:
np.unique(y_pred), np.unique(y_true)

(array([1, 2]), array([1, 2]))

In [26]:
from sklearn.metrics import confusion_matrix, plot_confusion_matrix, ConfusionMatrixDisplay

In [27]:
confusion_matrix(y_true, y_pred)

array([[158,   0],
       [  0,   4]])

In [28]:
def metrics_all_possible_combinations(affinity_matrices, true_links):
    
    assert len(affinity_matrices) == len(true_links), "True and predicted don't match in shapes"
    y_pred = []
    y_true = []
    for i in range(len(true_links)):
        # flatten everything and append in the same way
        y_true.extend(true_links[i].flatten().tolist())
        pred_links_one_step = torch.argmax(torch.softmax(affinity_matrices[i], dim=2), dim=-1)
        #print(pred_links_one_step.shape, true_links[i].shape)
        y_pred.extend(pred_links_one_step.flatten().tolist())
        
    return y_true, y_pred

In [29]:
y_true_all, y_pred_all = metrics_all_possible_combinations(affinity_accumulated, true_accumulated)

In [30]:
np.unique(y_pred_all), np.unique(y_true_all)

(array([0, 1, 2]), array([0, 1, 2]))

In [31]:
confusion_matrix(y_true_all, y_pred_all)

array([[ 95,  52,   0],
       [  0, 158,   0],
       [  0,   0,   4]])

In [32]:
cm = confusion_matrix(y_true_all, y_pred_all)

In [33]:
disp = ConfusionMatrixDisplay(cm)

In [34]:
import seaborn as sns

In [35]:

ax= plt.subplot()
sns.heatmap(cm, annot=True, fmt='g', ax=ax, cmap='Blues');  #annot=True to annotate cells, ftm='g' to disable scientific notation

# labels, title and ticks
ax.set_xlabel('Predicted labels');ax.set_ylabel('True labels'); 
ax.set_title('Confusion Matrix'); 
ax.xaxis.set_ticklabels(['No-link', 'Movement', 'Division']);
ax.yaxis.set_ticklabels(['No-link', 'Movement', 'Division']);
