In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt5

In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

In [3]:
from narsil.deadalive.datasets import channelStackTrain
from narsil.deadalive.modelDev import trainDeadAliveNet
from narsil.deadalive.network import CaffeLSTMCell, deadAliveNet80036

In [4]:
phaseDirectoriesList = ['/home/pk/Documents/trainingData/deadalive1/0/',
                        '/home/pk/Documents/trainingData/deadalive1/1/',
                       '/home/pk/Documents/trainingData/deadalive1/2/',
                       '/home/pk/Documents/trainingData/deadalive1/3/',
                       '/home/pk/Documents/trainingData/deadalive1/4/',
                       '/home/pk/Documents/trainingData/deadalive1/5/',
                       '/home/pk/Documents/trainingData/deadalive1/6/',
                       '/home/pk/Documents/trainingData/deadalive1/7/']

In [5]:
phaseDirectoriesList

['/home/pk/Documents/trainingData/deadalive1/0/',
 '/home/pk/Documents/trainingData/deadalive1/1/',
 '/home/pk/Documents/trainingData/deadalive1/2/',
 '/home/pk/Documents/trainingData/deadalive1/3/',
 '/home/pk/Documents/trainingData/deadalive1/4/',
 '/home/pk/Documents/trainingData/deadalive1/5/',
 '/home/pk/Documents/trainingData/deadalive1/6/',
 '/home/pk/Documents/trainingData/deadalive1/7/']

In [6]:
dataset = channelStackTrain(phaseDirectoriesList, numUnrolls=2, fileformat='.tiff')

In [7]:
len(dataset)

312

In [8]:
dataset[0]['imageSequence'].shape

(1, 800, 36, 2)

In [9]:
modelParameters = {
    'device': "cuda:1"
}
optimizationParameters = {
    'learning_rate': 1.0e-5,
    'nEpochs': 100
}

In [10]:
net = trainDeadAliveNet(phaseDirectoriesList, modelParameters, optimizationParameters, fileformat='.tiff')

In [11]:
net.train()

Epoch 0 -- started
Epoch average loss: 0.6727258533239364
Epoch 1 -- started
Epoch average loss: 0.5738601653199447
Epoch 2 -- started
Epoch average loss: 0.46030636210190623
Epoch 3 -- started
Epoch average loss: 0.436689682696995
Epoch 4 -- started
Epoch average loss: 0.43194451928138733
Epoch 5 -- started
Epoch average loss: 0.43085976964549016
Epoch 6 -- started
Epoch average loss: 0.4274795384783494
Epoch 7 -- started
Epoch average loss: 0.4270545997117695
Epoch 8 -- started
Epoch average loss: 0.4268042288328472
Epoch 9 -- started
Epoch average loss: 0.4251725391337746
Epoch 10 -- started
Epoch average loss: 0.42586643131155716
Epoch 11 -- started
Epoch average loss: 0.41638360917568207
Epoch 12 -- started
Epoch average loss: 0.4152754495541255
Epoch 13 -- started
Epoch average loss: 0.4193878422180812
Epoch 14 -- started
Epoch average loss: 0.41360117495059967
Epoch 15 -- started
Epoch average loss: 0.41484078930483925
Epoch 16 -- started
Epoch average loss: 0.4161769830518299
E

In [19]:
modelPath = '/home/pk/Documents/models/newdeadalive.pth'
net.save(modelPath)

### Plot and test net performance on test data

In [10]:
import numpy as np
import glob
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from skimage import io

phaseDir = '/home/pk/Documents/trainingData/deadalive1/7/'
fileformat = '.tiff'
filenames = [int(filename.split('.')[0].split('/')[-1]) 
             for filename in glob.glob(phaseDir + "*"+ fileformat)]
filenames.sort()
sortedFilenames = [phaseDir + str(filenumber) + fileformat for filenumber in filenames]

states = np.load(phaseDir + 'states.npy')

In [54]:
fig, ax = plt.subplots(nrows=1, ncols=2)
img = io.imread(sortedFilenames[0])

tdata, movingdata, partialdeaddata, alldeaddata, nocellsdata, cellsvanishdata = [], [], [], [], [], []
movingplot, = ax[1].plot([], [], 'g,:', label='Moving')
partialdeadplot, = ax[1].plot([], [], 'r,-.', label='Partial Dead')
alldeadplot, = ax[1].plot([], [], 'k--', label='All Dead')
nocellsplot, = ax[1].plot([], [], 'b*', label='No cells')
cellsvanishplot, = ax[1].plot([], [], 'mo', label='Cells vanish')


imgplot = ax[0].imshow(img, cmap='gray')


def init():
    img = io.imread(sortedFilenames[0])
    imgplot.set_array(img)
    ax[1].set_xlim([0, 40])
    ax[1].set_ylim([-0.5, 2])
    return imgplot, 
    
def update(frame):
    
    img = io.imread(sortedFilenames[frame])
    imgplot.set_array(img)
    tdata.append(frame)
    movingdata.append(int(states[0][frame]))
    movingplot.set_data(tdata, movingdata) 
    
    partialdeaddata.append(int(states[2][frame]))
    partialdeadplot.set_data(tdata, partialdeaddata)
    
    alldeaddata.append(int(states[3][frame]))
    alldeadplot.set_data(tdata, alldeaddata)
    
    nocellsdata.append(int(states[4][frame]))
    nocellsplot.set_data(tdata, nocellsdata)
    
    cellsvanishdata.append(int(states[5][frame]))
    cellsvanishplot.set_data(tdata, cellsvanishdata)
    
    
    return [imgplot, movingplot, partialdeadplot,]

ani = FuncAnimation(fig, update, frames=range(0, len(sortedFilenames)),
                    init_func=init, blit=False, repeat=False, interval=500)
plt.legend(loc='upper right')
plt.show()

In [16]:
def runNet(dataDir, modelPath, fileformat='.tiff', plot=False):
    filenames = [int(filename.split('.')[0].split('/')[-1]) 
             for filename in glob.glob(phaseDir + "*"+ fileformat)]
    filenames.sort()
    sortedFilenames = [phaseDir + str(filenumber) + fileformat for filenumber in filenames]
    
    # net intializee and run
    net = deadAliveNet80036(device="cpu")
    saved_net_parameters = torch.load(modelPath)
    net.load_state_dict(saved_net_parameters['model_state_dict'])
    net.eval()
    
    imgTransforms = transforms.Compose([
        transforms.ToTensor()
    ])
    
    predicted_states =[]
    with torch.no_grad():
        lstm_state = None
        for i in range(len(sortedFilenames)):
            image = io.imread(sortedFilenames[i])
            imageToNet = imgTransforms(image).unsqueeze(0)
            probabilities = net(imageToNet, lstm_state) > 0.5
            
            lstm_state = net.lstm_state
            print(list(probabilities.numpy()[0]))
            predicted_states.append(list(probabilities.numpy()[0]))
    
    
    if plot == True:
        states = np.array(predicted_states).T
        fig, ax = plt.subplots(nrows=1, ncols=2)
        img = io.imread(sortedFilenames[0])

        tdata, movingdata, partialdeaddata, alldeaddata, nocellsdata, cellsvanishdata = [], [], [], [], [], []
        movingplot, = ax[1].plot([], [], 'g,:', label='Moving')
        partialdeadplot, = ax[1].plot([], [], 'r,-.', label='Partial Dead')
        alldeadplot, = ax[1].plot([], [], 'k--', label='All Dead')
        nocellsplot, = ax[1].plot([], [], 'b*', label='No cells')
        cellsvanishplot, = ax[1].plot([], [], 'mo', label='Cells vanish')


        imgplot = ax[0].imshow(img, cmap='gray')


        def init():
            img = io.imread(sortedFilenames[0])
            imgplot.set_array(img)
            ax[1].set_xlim([0, len(sortedFilenames)])
            ax[1].set_ylim([-0.5, 2])
            return imgplot, 

        def update(frame):

            img = io.imread(sortedFilenames[frame])
            imgplot.set_array(img)
            tdata.append(frame)
            movingdata.append(int(states[0][frame]))
            movingplot.set_data(tdata, movingdata) 

            partialdeaddata.append(int(states[2][frame]))
            partialdeadplot.set_data(tdata, partialdeaddata)

            alldeaddata.append(int(states[3][frame]))
            alldeadplot.set_data(tdata, alldeaddata)

            nocellsdata.append(int(states[4][frame]))
            nocellsplot.set_data(tdata, nocellsdata)

            cellsvanishdata.append(int(states[5][frame]))
            cellsvanishplot.set_data(tdata, cellsvanishdata)


            return [imgplot, movingplot, partialdeadplot,]

        ani = FuncAnimation(fig, update, frames=range(0, len(sortedFilenames)),
                            init_func=init, blit=False, repeat=False, interval=1000)
        plt.legend(loc='upper right')
        plt.show()
        
    return ani, predicted_states

In [30]:
phaseDir = '/home/pk/Documents/trainingData/deadalive1/0/'
modelPath = '/home/pk/Documents/models/newdeadalive.pth'
ani, predicted_states = runNet(phaseDir, modelPath,plot=True)

[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[True, False, False, False, False, False]
[False, True, False, False, False,

In [29]:
np.array(predicted_states).T

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False],
       [False, False, False, False, False, False, False, False, False,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True],
       [False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False],
       [False, False, False, False, F

## Create training data

In [12]:
from narsil.deadalive.utils import createTrainingData

In [27]:
phaseDir = '/home/pk/Documents/realtimeData/analysisData/3/oneMMChannelPhase/7/'

In [28]:
createTrainingData(segmentedDir=phaseDir, phaseDir=phaseDir)

<narsil.deadalive.utils.createTrainingData at 0x7f949b7fa2e0>

[[ True  True  True  True  True  True  True  True  True  True False False
  False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False]
 [False False False False False False False False False False  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True]
 [False False False False False False False False False False False False
  False  True  True  True  True  True  True  True  True  True  True False
  False False False False False False False False False False False False
  False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False False False False  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True]
 [Fa

### Plotting dead-alive probabilites

In [29]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

fig, ax = plt.subplots()
xdata, ydata = [], []
ln, = plt.plot([], [], 'ro')

def init():
    ax.set_xlim(0, 2*np.pi)
    ax.set_ylim(-1, 1)
    return ln,

def update(frame):
    xdata.append(frame)
    ydata.append(np.sin(frame))
    ln.set_data(xdata, ydata)
    return ln,

ani = FuncAnimation(fig, update, frames=np.linspace(0, 2*np.pi, 128),
                    init_func=init, blit=True)
plt.show()


In [8]:
modelParameters = {
    'device': "cuda:1"
}
optimizationParameters = {
    'learning_rate': 0.5e-5,
    'nEpochs': 100
}

In [9]:
net = trainDeadAliveNet(phaseDirectoriesList, modelParameters, optimizationParameters, fileformat='.tif')

In [13]:
net.train()

Epoch 0 -- started
Epoch average loss: 0.15325021594762803
Epoch 1 -- started
Epoch average loss: 0.1096721351146698
Epoch 2 -- started
Epoch average loss: 0.11148410886526108
Epoch 3 -- started
Epoch average loss: 0.10040972828865051
Epoch 4 -- started
Epoch average loss: 0.09695472568273544
Epoch 5 -- started
Epoch average loss: 0.09909523427486419
Epoch 6 -- started
Epoch average loss: 0.09453219920396805
Epoch 7 -- started
Epoch average loss: 0.09195844233036041
Epoch 8 -- started
Epoch average loss: 0.09463360160589218
Epoch 9 -- started
Epoch average loss: 0.09367366284132003
Epoch 10 -- started
Epoch average loss: 0.0925093412399292
Epoch 11 -- started
Epoch average loss: 0.09434753209352494
Epoch 12 -- started
Epoch average loss: 0.09306800067424774
Epoch 13 -- started
Epoch average loss: 0.09562535881996155
Epoch 14 -- started
Epoch average loss: 0.10037241280078887
Epoch 15 -- started
Epoch average loss: 0.10584187507629395
Epoch 16 -- started
Epoch average loss: 0.0962564885