In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt5

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, CheckButtons, RadioButtons
from matplotlib.animation import FuncAnimation
from skimage import io
import pickle
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import random

In [3]:
states_to_classes = {
    'all_growing': 0,
    'partly_growing': 1,
    'stopped_growing': 2,
    'stopped_fading': 3,
    'stopped_vanishing': 4,
    'channel_empty': 5
}
class_to_states = {
    0: 'all_growing',
    1: 'partly_growing',
    2: 'stopped_growing',
    3: 'stopped_fading',
    4: 'stopped_vanishing',
    5: 'channel_empty'
}

### Data creator using matplotlib widgets

Mark the state transitions


In [4]:
class StateTransition(object):
    
    def __init__(self, phaseDir, fileformat='.tiff',
                 labels_file='transitions.pickle', frame_rate=1):
        
        self.phaseDir = phaseDir
        self.fileformat = fileformat
        self.labelsFile = labels_file
        self.frameRate = frame_rate
        self.indices = [int(filename.split('.')[0].split('/')[-1]) for filename in 
                       glob.glob(self.phaseDir + "*" + self.fileformat)]
        self.indices.sort()
        
        self.fig, self.ax = plt.subplots(1, 1, num=str(self.phaseDir), figsize=(12, 10))
        plt.subplots_adjust(left=0.35, bottom=0.25)
        self.axcolor = 'lightgoldenrodyellow'
        
        self.currentFrame = 0
        self.imagesPlot = self.ax.imshow(self.__getitem__(self.currentFrame), cmap='gray')
        self.ax.set_title(str(self.currentFrame) + " _ " + str(self.currentFrame + self.frameRate))
        
          
        self.previousax =  plt.axes([0.5, 0.025, 0.1, 0.03])
        self.previousButton = Button(self.previousax, 'Previous', color=self.axcolor, hovercolor='0.975')
        self.previousButton.on_clicked(self.previousDatapoint)
        
        self.nextax =  plt.axes([0.65, 0.025, 0.1, 0.03])
        self.nextButton = Button(self.nextax, 'Next', color=self.axcolor, hovercolor='0.975')
        self.nextButton.on_clicked(self.nextDatapoint)
        
        self.saveax = plt.axes([0.8, 0.025, 0.1, 0.03])
        self.saveButton = Button(self.saveax, 'Save', color=self.axcolor, hovercolor='0.875')
        self.saveButton.on_clicked(self.save)
        
        
        self.rax = plt.axes([0.05, 0.5, 0.20, 0.25], facecolor=self.axcolor)
        self.radioToState = {
            '0 : All Growing': 0,
            '1 : Partly Growing': 1,
            '2 : Stopped Growing': 2,
            '3 : Stopped Fading' : 3,
            '4 : Stopped Vanishing': 4,
            '5 : Channel Empty' : 5
        }
        self.radio = RadioButtons(self.rax, ('0 : All Growing', '1 : Partly Growing',
                                            '2 : Stopped Growing', '3 : Stopped Fading', 
                                            '4 : Stopped Vanishing', '5 : Channel Empty'))
        #self.radio.on_clicked(self.changeState)
        
        
        # the keys will be frame1_frame2 number and value is the state
        self.states = {
            
        }
        
        
        plt.pause(0.01)

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):

        if idx + self.frameRate < self.__len__() and idx >= 0:
            phaseFilename1 = self.phaseDir + str(self.indices[idx]) + self.fileformat
            phaseFilename2 = self.phaseDir + str(self.indices[idx+self.frameRate]) + self.fileformat
            img1 = io.imread(phaseFilename1)
            img2 = io.imread(phaseFilename2)
            
            full_img = np.concatenate((img1, img2), axis = 1)
            return full_img
        else:
            return None
    
    
    def nextDatapoint(self, buttonPress):
        key = str(self.currentFrame) + "_" + str(self.currentFrame + self.frameRate)
        #print(self.radio.value_selected)
        self.states[key] = self.radioToState[self.radio.value_selected]
        
        print(self.states)
        
        # get next Image if you are in range
        self.currentFrame += 1
        nextImage = self.__getitem__(self.currentFrame)
        #print(nextImage.shape)
        if nextImage is not None:
            self.imagesPlot.set_data(nextImage)
            self.ax.set_title(str(self.currentFrame) + " _ " + str(self.currentFrame + self.frameRate))
            self.fig.canvas.draw()
        else:
            self.currentFrame -= 1
            
    def previousDatapoint(self, buttonPress):
        key = str(self.currentFrame) + "_" + str(self.currentFrame + self.frameRate)
        #print(self.radio.value_selected)
        self.states[key] = self.radioToState[self.radio.value_selected]
        
        print(self.states)
        
        # get next Image if you are in range
        self.currentFrame -= 1
        previousImage = self.__getitem__(self.currentFrame)
        #print(nextImage.shape)
        if previousImage is not None:
            self.imagesPlot.set_data(previousImage)
            self.ax.set_title(str(self.currentFrame) + " _ " + str(self.currentFrame + self.frameRate))
            self.fig.canvas.draw()
        else:
            self.currentFrame += 1
    
    def save(self, save):
        print(self.states)
        saveFilename = self.phaseDir + self.labelsFile
        with open(saveFilename, 'wb') as f:
            pickle.dump(self.states, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        print(f"File save: {saveFilename}")

### Generate training data one stack at a time

In [5]:
#phaseDir = '/home/pk/Documents/trainingData/deadalive1/40/'
#markStates = StateTransition(phaseDir, frame_rate=1)

#### Dataset bundling

In [6]:
class statesDataset(object):
    
    def __init__(self, phaseDirectoriesList, fileformat='.tiff', 
                 transforms=None, class_to_states = {
                        0: 'all_growing', 1: 'partly_growing',
                        2: 'stopped_growing',3: 'stopped_fading',
                        4: 'stopped_vanishing',5: 'channel_empty'}):
        self.phaseDirectoriesList = phaseDirectoriesList
        self.fileformat = fileformat
        self.transforms = transforms
        self.classToStates = class_to_states
        self.nClasses = len(class_to_states)
        
        self.data = {} # each class has it's value as key and the values is a list of 2 filename tuples
        
        self.all_data = []
        
        for directory in self.phaseDirectoriesList:
            statesFilename = directory + 'transitions.pickle'
            with open(statesFilename, 'rb') as file:
                states = pickle.load(file)
            for key, value in states.items():
                file1 = directory + key.split('_')[0] + self.fileformat
                file2 = directory + key.split('_')[1] + self.fileformat
                if value in self.data:
                    self.data[value].append((file1, file2))
                else:
                    self.data[value] = [(file1, file2)]
                
                self.all_data.append((file1, file2, value))
        
        classStats = self.classStatistics()
        self.weights = np.ones(shape=(self.nClasses,))
        for key in classStats:
            self.weights[key] /= classStats[key]
                    
    
    def __len__(self):
        #length = 0
        #for key in self.data:
        #    length += len(self.data[key])
        #return length
        return len(self.all_data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        datapoint = self.all_data[idx]
        
        img1 = io.imread(datapoint[0])
        img2 = io.imread(datapoint[1])
        label = datapoint[2]
        
        sample = {
            'frame1': img1,
            'frame2': img2,
            'label': label
        }
        
        if self.transforms is not None:
            sample = self.transforms(sample)
        
        return sample
    
    def classStatistics(self):
        statistics = {}
        for key in self.data:
            statistics[key] = len(self.data[key])
            
        return statistics
    
    def calculateWeights(self):
        pass
    
    def plotDatapoint(self, idx):
        
        datapoint = self.__getitem__(idx)
        
        pltImage = np.concatenate((datapoint['frame1'], datapoint['frame2']), axis = 1)
        plt.figure()
        plt.imshow(pltImage, cmap='gray')
        plt.title(f"Class: {datapoint['label']} -- {self.classToStates[datapoint['label']]}")
        plt.show()

In [7]:
class randomTranslations(object):

    def __init__(self, translate=(0.0, 0.10)):
        self.translate = translate
    
    def __call__(self, sample):
        #img_shape = sample['frame1'].shape[-2:]
        #img_size = [img_shape[1], img_shape[0]]
        affine_params = transforms.RandomAffine.get_params(degrees= [0,0], 
                                                           translate=self.translate,
                                                           scale_ranges=None,
                                                           shears=None,
                                                           img_size=[36, 800])
        #print(affine_params)
        frame1Tensor = torch.from_numpy(sample['frame1']).unsqueeze(0)
        frame2Tensor = torch.from_numpy(sample['frame2']).unsqueeze(0)
        
        
        frame1Translated = transforms.functional.affine(frame1Tensor, *affine_params)
        frame2Translated = transforms.functional.affine(frame2Tensor, *affine_params)
        
        sample['frame1'] = frame1Translated
        sample['frame2'] = frame2Translated
        
        return sample

In [8]:
class tensorizeSample(object):
    
    def __call__(self, sample):
        
        sample['frame1'] = torch.from_numpy(sample['frame1']).unsqueeze(0)
        sample['frame2'] = torch.from_numpy(sample['frame2']).unsqueeze(0)
        
        return sample

phaseDirectoriesList = ['/home/pk/Documents/trainingData/deadalive1/' + str(i) + '/' for i in range(0, 41)]

dataset = statesDataset(phaseDirectoriesList, transforms = tensorizeSample())

len(dataset)

dataset.clasStatistics()

dataset[1000]

dataset[0]['frame1'].shape

### Net 

In [9]:
class conv_relu_norm(nn.Module):
    def __init__(self, input_channels, output_channels, kernel=3, stride=1, padding=1):
        super(conv_relu_norm, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channels, output_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.conv(x)
        return x

class stateNet(nn.Module):
    
    def __init__(self, nClasses=6):
        super(stateNet, self).__init__()
        self.conv1 = conv_relu_norm(1, 64)
        self.conv2 = conv_relu_norm(64, 128)
        self.conv3 = conv_relu_norm(128, 256)
        
        self.conv4 = conv_relu_norm(256, 256)
        
        self.fc6 = nn.Linear(76800, 1024)
        self.imageFeatures = nn.Sequential(
                nn.Linear(1024, 1024),
                nn.ReLU(inplace=True)
        )
        
        self.outputLinear = nn.Sequential(
                nn.Linear(2048, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 256),
                nn.ReLU(inplace=True),
                nn.Linear(256, nClasses)
        )
        
        
    def forward_one(self, img):
        # get to conv feature head for each image
        
        #print(f"Image shape: {img.shape}")
        batch_size = img.shape[0]
        
        conv1 = self.conv1(img)
        #print(f"Conv1 shape: {conv1.shape}")
        pool1 = F.relu(F.max_pool2d(conv1, (2, 2)))
        #print(f"Pool1 shape: {pool1.shape}")
        
        conv2 = self.conv2(pool1)
        #print(f"Conv2 shape: {conv2.shape}")
        pool2 = F.relu(F.max_pool2d(conv2, (2, 2)))
        #print(f"Pool2 shape: {pool2.shape}")
        
        conv3 = self.conv3(pool2)
        #print(f"Conv3 shape: {conv3.shape}")
        pool3 = F.relu(F.max_pool2d(conv3, (2, 3)))
        #print(f"Pool3 shape: {pool3.shape}")
        
        conv4 = self.conv4(pool3)
        #print(f"Conv4 shape: {conv4.shape}")
        conv4_reshaped = conv4.view(batch_size, -1)
        #print(f"Conv4 reshaped: {conv4_reshaped.shape}")
        
        fc6 = F.relu(self.fc6(conv4_reshaped))
        #print(f"FC6 shape: {fc6.shape}")
        
        out = self.imageFeatures(fc6)
        #print(f"Output shape: {out.shape}")
        
        return out
    
    def forward(self, img1, img2):
        # pass the sample image through the net seperately and then
        
        imgFeatures1 = self.forward_one(img1)
        imgFeatures2 = self.forward_one(img2)
        
        # stack the image features
        stackedFeatures = torch.cat((imgFeatures1, imgFeatures2), 1)
        #print(stackedFeatures.shape)
        
        
        netOutput = self.outputLinear(stackedFeatures)
        #print(netOutput.shape)
        
        return netOutput

In [10]:
phaseDirectoriesList = ['/home/pk/Documents/trainingData/deadalive1/' + str(i) + '/' for i in range(0, 41)]
dataset = statesDataset(phaseDirectoriesList, transforms = randomTranslations())
statedataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=6)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [11]:
dataset.classStatistics()

{0: 355, 2: 162, 3: 919, 4: 53, 5: 71, 1: 38}

In [12]:
dataset.weights

array([0.0028169 , 0.02631579, 0.00617284, 0.00108814, 0.01886792,
       0.01408451])

databatch = next(iter(stateDataLoader))

In [13]:
net = stateNet(nClasses=6)
net.to(device)
criterion = nn.CrossEntropyLoss(weight=torch.tensor(dataset.weights,dtype=torch.float))
criterion.to(device)
optimizer = optim.Adam(net.parameters(), lr = 0.00005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma = 0.5)

In [14]:
nEpochs = 50
for epoch in range(nEpochs):
    epoch_loss = []
    for i_batch, data_batch in enumerate(statedataloader, 0):
        frame1_batch, frame2_batch, labels = data_batch['frame1'].to(device), data_batch['frame2'].to(device), data_batch['label'].to(device)

        optimizer.zero_grad()
        output_scores = net(frame1_batch, frame2_batch)
        loss = criterion(output_scores, labels)

        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    scheduler.step()
    print(f"Epoch {epoch+1}: ..  ..Avg.Loss: {np.mean(epoch_loss)}")

Epoch 1: ..  ..Avg.Loss: 1.6113714838027955
Epoch 2: ..  ..Avg.Loss: 1.3203878426551818
Epoch 3: ..  ..Avg.Loss: 1.2294101679325105
Epoch 4: ..  ..Avg.Loss: 1.0341612803936004
Epoch 5: ..  ..Avg.Loss: 0.8394487237930298
Epoch 6: ..  ..Avg.Loss: 0.7646921801567078
Epoch 7: ..  ..Avg.Loss: 0.7429525434970856
Epoch 8: ..  ..Avg.Loss: 0.6822455018758774
Epoch 9: ..  ..Avg.Loss: 0.6473642086982727
Epoch 10: ..  ..Avg.Loss: 0.6441334372758866
Epoch 11: ..  ..Avg.Loss: 0.5443035015463829
Epoch 12: ..  ..Avg.Loss: 0.564657284617424
Epoch 13: ..  ..Avg.Loss: 0.5176343956589698
Epoch 14: ..  ..Avg.Loss: 0.5106999018788337
Epoch 15: ..  ..Avg.Loss: 0.4393028458952904
Epoch 16: ..  ..Avg.Loss: 0.40086173951625825
Epoch 17: ..  ..Avg.Loss: 0.375201066583395
Epoch 18: ..  ..Avg.Loss: 0.4025968986749649
Epoch 19: ..  ..Avg.Loss: 0.36361352652311324
Epoch 20: ..  ..Avg.Loss: 0.5315257807075977
Epoch 21: ..  ..Avg.Loss: 0.3563631580770016
Epoch 22: ..  ..Avg.Loss: 0.27450292721390723
Epoch 23: ..  ..Av

In [16]:
output_scores

tensor([[-1.9984e+01, -3.4243e+00,  7.8116e+00,  8.1063e+00, -3.3214e+00,
          2.5880e+01],
        [ 1.6100e+00, -6.4979e-01,  8.3550e+00,  1.1074e+01,  3.9870e+00,
         -1.1433e+01],
        [ 1.9542e+01, -2.3853e+00,  4.4177e+00,  4.7521e+00, -1.0958e+01,
         -1.4001e+01],
        [-9.7581e+00, -2.1334e+00, -4.4150e+00,  4.1795e+00,  4.0654e+00,
          1.8854e+01],
        [-3.6247e-01, -5.1557e+00,  4.5424e+00,  1.2419e+01,  6.0522e+00,
         -7.1764e+00],
        [-1.9390e+00, -4.5795e+00,  3.1794e+00,  1.3521e+01,  8.6621e+00,
         -6.0092e+00],
        [-8.9870e+00, -1.9495e+00, -4.7082e+00,  3.7203e+00,  2.3232e+00,
          1.8959e+01],
        [-2.9335e+00, -6.2057e+00,  3.0536e+00,  1.5225e+01,  1.0815e+01,
         -6.6388e+00],
        [ 1.3132e+00, -4.0695e+00,  5.4477e+00,  1.0198e+01,  3.2176e+00,
         -7.2507e+00],
        [ 5.4874e+00, -5.1591e+00,  9.3623e+00,  1.3886e+01,  8.2699e-01,
         -1.2124e+01],
        [ 7.9923e+00,  7.2648e

In [17]:
squish = nn.Softmax(dim = 1)

In [18]:
squish(output_scores.cpu())

tensor([[1.2067e-20, 1.8763e-13, 1.4223e-08, 1.9097e-08, 2.0796e-13, 1.0000e+00],
        [7.2729e-05, 7.5913e-06, 6.1808e-02, 9.3733e-01, 7.8352e-04, 1.5750e-10],
        [1.0000e+00, 3.0005e-10, 2.7021e-07, 3.7750e-07, 5.6768e-14, 2.7066e-15],
        [3.7499e-13, 7.6800e-10, 7.8427e-11, 4.2369e-07, 3.7798e-07, 1.0000e+00],
        [2.8059e-06, 2.3249e-08, 3.7865e-04, 9.9790e-01, 1.7136e-03, 3.0818e-09],
        [1.9160e-07, 1.3667e-08, 3.2010e-05, 9.9227e-01, 7.6981e-03, 3.2714e-09],
        [7.3003e-13, 8.3114e-10, 5.2671e-11, 2.4102e-07, 5.9607e-08, 1.0000e+00],
        [1.2841e-08, 4.8696e-10, 5.1137e-06, 9.8799e-01, 1.2007e-02, 3.1579e-10],
        [1.3709e-04, 6.2997e-07, 8.5621e-03, 9.9038e-01, 9.2057e-04, 2.6166e-08],
        [2.2278e-04, 5.2984e-09, 1.0733e-02, 9.8904e-01, 2.1081e-06, 5.0036e-12],
        [2.3622e-01, 1.1412e-01, 6.4676e-01, 2.8958e-03, 2.8114e-07, 3.2550e-11],
        [4.9249e-05, 1.2584e-06, 8.8542e-08, 9.9984e-01, 1.0782e-04, 1.8436e-08],
        [9.9997e

In [19]:
savedModel = {
    'model_state_dict': net.state_dict(),
    'labels': class_to_states
}
torch.save(savedModel, '/home/pk/Documents/models/multistate.pth')

In [20]:
from sklearn.metrics import confusion_matrix

In [21]:
actual_labels = []
predicted_labels = []
squish = nn.Softmax(dim = 1)
with torch.no_grad():
    for i_batch, data_batch in enumerate(statedataloader, 0):
        frame1_batch, frame2_batch, labels = data_batch['frame1'].to(device), data_batch['frame2'].to(device), data_batch['label'].to(device)
        
        output_scores = net(frame1_batch, frame2_batch)
        output_scoresCPU = output_scores.cpu()
        output_labels = squish(output_scoresCPU).numpy().argmax(axis = 1)
        labels_cpu = labels.cpu().numpy()
        
        actual_labels.extend(list(labels_cpu))
        predicted_labels.extend(list(output_labels))


In [22]:
len(actual_labels)

1598

In [23]:
len(predicted_labels)

1598

In [24]:
cm = confusion_matrix(actual_labels, predicted_labels, normalize=None)

In [25]:
fig, ax = plt.subplots()
ax.matshow(cm, cmap=plt.cm.Blues)
ax.set_title('Classification errors')
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
#ax.set_xticklabels([''] + labels)
#ax.set_yticklabels([''] + labels)
for (i, j), z in np.ndenumerate(cm):
    ax.text(j, i, z, ha='center', va='center',
            bbox=dict(boxstyle='round', facecolor='white', edgecolor='0.3'))

## Test data

In [26]:
class testData(object):
    
    def __init__(self, phaseDir, fileformat='.tiff'):
        
        self.phaseDir = phaseDir
        self.fileformat = fileformat
        self.indices = [int(filename.split('.')[0].split('/')[-1]) for filename in 
                       glob.glob(self.phaseDir + "*" + self.fileformat)]
        self.indices.sort()
                
    def __len__(self):
        return len(self.indices) - 1
    
    
    def __getitem__(self, idx):
        phaseFilename1 = self.phaseDir + str(self.indices[idx]) + self.fileformat
        phaseFilename2 = self.phaseDir + str(self.indices[idx+1]) + self.fileformat
        img1 = io.imread(phaseFilename1)
        img2 = io.imread(phaseFilename2)
        #print(phaseFilename1)
        #print(phaseFilename2)
        
        return {
            'frame1': torch.from_numpy(img1).unsqueeze(0).unsqueeze(0),
            'frame2': torch.from_numpy(img2).unsqueeze(0).unsqueeze(0)
        }

In [27]:
phaseTestDir = '/home/pk/Documents/trainingData/deadalive1/89/'
testdata = testData(phaseTestDir)

In [28]:
net = stateNet()
savedNet = torch.load('/home/pk/Documents/models/multistate.pth')
net.load_state_dict(savedNet['model_state_dict'])
net.eval()

stateNet(
  (conv1): conv_relu_norm(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU(inplace=True)
    )
  )
  (conv2): conv_relu_norm(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU(inplace=True)
    )
  )
  (conv3): conv_relu_norm(
    (conv): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
   

In [29]:
squish = nn.Softmax(dim = 1)
np.set_printoptions(precision=2, suppress=True)
for i in range(39):
    out = net(testdata[i]['frame1'], testdata[i]['frame2'])
    print(f"{squish(out).detach().numpy().argmax(axis =1)} -- {squish(out).detach().numpy()}")

[0] -- [[0.99 0.   0.   0.   0.   0.  ]]
[0] -- [[1. 0. 0. 0. 0. 0.]]
[0] -- [[1. 0. 0. 0. 0. 0.]]
[0] -- [[1. 0. 0. 0. 0. 0.]]
[0] -- [[1. 0. 0. 0. 0. 0.]]
[0] -- [[0.97 0.03 0.   0.   0.   0.  ]]
[0] -- [[0.75 0.24 0.   0.   0.   0.  ]]
[1] -- [[0. 1. 0. 0. 0. 0.]]
[1] -- [[0.01 0.99 0.   0.   0.   0.  ]]
[0] -- [[0.93 0.06 0.   0.   0.   0.  ]]
[1] -- [[0.04 0.96 0.01 0.   0.   0.  ]]
[1] -- [[0.1 0.9 0.  0.  0.  0. ]]
[1] -- [[0.09 0.9  0.01 0.   0.   0.  ]]
[1] -- [[0.02 0.98 0.01 0.   0.   0.  ]]
[1] -- [[0.25 0.73 0.01 0.   0.   0.  ]]
[1] -- [[0.05 0.94 0.01 0.   0.   0.  ]]
[0] -- [[0.51 0.44 0.04 0.   0.   0.  ]]
[1] -- [[0.23 0.72 0.05 0.   0.   0.  ]]
[1] -- [[0.   0.99 0.01 0.   0.   0.  ]]
[1] -- [[0. 1. 0. 0. 0. 0.]]
[0] -- [[0.84 0.09 0.07 0.   0.   0.  ]]
[2] -- [[0.1  0.31 0.59 0.01 0.   0.  ]]
[0] -- [[0.88 0.07 0.05 0.   0.   0.  ]]
[1] -- [[0.   0.99 0.01 0.   0.   0.  ]]
[1] -- [[0.15 0.77 0.08 0.   0.   0.  ]]
[1] -- [[0.09 0.86 0.05 0.   0.   0.  ]]
[1] -- [[0.2

In [30]:
out

tensor([[ 3.8738,  7.3161,  5.8150, -1.6393, -5.4459, -8.2013]],
       grad_fn=<AddmmBackward>)

In [49]:
squish = nn.Softmax(dim = 1)
squish(out)

tensor([[3.7514e-02, 1.0395e-02, 9.5118e-01, 8.9237e-04, 1.5333e-05, 2.5778e-08]],
       grad_fn=<SoftmaxBackward>)

In [None]:
class testPlot(object):
    
    def __init__(self):
        pass
    
    def __getitem__(self):
        pass
    
    def getStateChange(self):
        pass
    
    def __len__(self):
        pass

In [None]:
def testNet(dataDir, modelPath, fileformat='.tiff', plot=False):
        
    # net intializee and run
    net = ()
    saved_net_parameters = torch.load(modelPath)
    net.load_state_dict(saved_net_parameters['model_state_dict'])
    net.eval()
    
    filenames = [int(filename.split('.')[0].split('/')[-1]) 
             for filename in glob.glob(dataDir + "*"+ fileformat)]
    filenames.sort()
    sortedFilenames = [dataDir + str(filenumber) + fileformat for filenumber in filenames]
    print(sortedFilenames)
    
    imgTransforms = transforms.Compose([
        transforms.ToTensor()
    ])
    
    predicted_states =[]
    with torch.no_grad():
        for i in range(len(sortedFilenames) - 1):
            img1 = io.imread(sortedFilenames[i])
            img2 = io.imread(sortedFilenames[i+1])
            
            img1 = imgTransforms(img1).unsqueeze(0)
            img2 = imgTransforms(img2).unsqueeze(0)
            output = net(img1, img2) > 0.5
            print(output.item())
            predicted_states.append(output.item())
            
    
    
    if plot == True:
        states = np.array(predicted_states).T
        fig, ax = plt.subplots(nrows=1, ncols=2)
        img = io.imread(sortedFilenames[0])

        tdata, movingdata, = [], []
        movingplot, = ax[1].plot([], [], 'g,:', label='Moving')
        #partialdeadplot, = ax[1].plot([], [], 'r,-.', label='Partial Dead')
        #alldeadplot, = ax[1].plot([], [], 'k--', label='All Dead')
        #nocellsplot, = ax[1].plot([], [], 'b*', label='No cells')
        #cellsvanishplot, = ax[1].plot([], [], 'mo', label='Cells vanish')


        imgplot = ax[0].imshow(img, cmap='gray')


        def init():
            img = io.imread(sortedFilenames[0])
            imgplot.set_array(img)
            ax[1].set_xlim([0, len(sortedFilenames)])
            ax[1].set_ylim([-0.5, 2])
            return imgplot, 

        def update(frame):

            img = io.imread(sortedFilenames[frame])
            imgplot.set_array(img)
            tdata.append(frame)
            movingdata.append(int(states[0][frame]))
            movingplot.set_data(tdata, movingdata) 

            partialdeaddata.append(int(states[2][frame]))
            partialdeadplot.set_data(tdata, partialdeaddata)

            alldeaddata.append(int(states[3][frame]))
            alldeadplot.set_data(tdata, alldeaddata)

            nocellsdata.append(int(states[4][frame]))
            nocellsplot.set_data(tdata, nocellsdata)

            cellsvanishdata.append(int(states[5][frame]))
            cellsvanishplot.set_data(tdata, cellsvanishdata)


            return [imgplot, movingplot, partialdeadplot,]

        ani = FuncAnimation(fig, update, frames=range(0, len(sortedFilenames)),
                            init_func=init, blit=False, repeat=False, interval=1000)
        plt.legend(loc='upper right')
        plt.show()
        
    return None, predicted_states

In [52]:
class_to_states = {
    0: 'all_growing',
    1: 'partly_growing',
    2: 'stopped_growing',
    3: 'stopped_fading',
    4: 'stopped_vanishing',
    5: 'channel_empty'
}

In [76]:
for key in class_to_states:
    print(class_to_states[key])

all_growing
partly_growing
stopped_growing
stopped_fading
stopped_vanishing
channel_empty


In [57]:
0 in class_to_states

True

In [54]:
len(class_to_states)

6

In [55]:
'0_1'.split("_")

['0', '1']

### Train loop

### Test plots