In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt5

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, CheckButtons, RadioButtons
from matplotlib.animation import FuncAnimation
from skimage import io
import pickle
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import random

In [3]:
states_to_classes = {
    'all_growing': 0,
    'partly_growing': 1,
    'stopped_growing': 2,
    'stopped_fading': 3,
    'stopped_vanishing': 4,
    'channel_empty': 5
}
class_to_states = {
    0: 'all_growing',
    1: 'partly_growing',
    2: 'stopped_growing',
    3: 'stopped_fading',
    4: 'stopped_vanishing',
    5: 'channel_empty'
}

### Data creator using matplotlib widgets

Mark the state transitions


In [4]:
class StateTransition(object):
    
    def __init__(self, phaseDir, fileformat='.tiff',
                 labels_file='transitions.pickle', frame_rate=1):
        
        self.phaseDir = phaseDir
        self.fileformat = fileformat
        self.labelsFile = labels_file
        self.frameRate = frame_rate
        self.indices = [int(filename.split('.')[0].split('/')[-1]) for filename in 
                       glob.glob(self.phaseDir + "*" + self.fileformat)]
        self.indices.sort()
        
        self.fig, self.ax = plt.subplots(1, 1, num=str(self.phaseDir), figsize=(12, 10))
        plt.subplots_adjust(left=0.35, bottom=0.25)
        self.axcolor = 'lightgoldenrodyellow'
        
        self.currentFrame = 0
        self.imagesPlot = self.ax.imshow(self.__getitem__(self.currentFrame), cmap='gray')
        self.ax.set_title(str(self.currentFrame) + " _ " + str(self.currentFrame + self.frameRate))
        
          
        self.previousax =  plt.axes([0.5, 0.025, 0.1, 0.03])
        self.previousButton = Button(self.previousax, 'Previous', color=self.axcolor, hovercolor='0.975')
        self.previousButton.on_clicked(self.previousDatapoint)
        
        self.nextax =  plt.axes([0.65, 0.025, 0.1, 0.03])
        self.nextButton = Button(self.nextax, 'Next', color=self.axcolor, hovercolor='0.975')
        self.nextButton.on_clicked(self.nextDatapoint)
        
        self.saveax = plt.axes([0.8, 0.025, 0.1, 0.03])
        self.saveButton = Button(self.saveax, 'Save', color=self.axcolor, hovercolor='0.875')
        self.saveButton.on_clicked(self.save)
        
        
        self.rax = plt.axes([0.05, 0.5, 0.20, 0.25], facecolor=self.axcolor)
        self.radioToState = {
            '0 : All Growing': 0,
            '1 : Partly Growing': 1,
            '2 : Stopped Growing': 2,
            '3 : Stopped Fading' : 3,
            '4 : Stopped Vanishing': 4,
            '5 : Channel Empty' : 5
        }
        self.radio = RadioButtons(self.rax, ('0 : All Growing', '1 : Partly Growing',
                                            '2 : Stopped Growing', '3 : Stopped Fading', 
                                            '4 : Stopped Vanishing', '5 : Channel Empty'))
        #self.radio.on_clicked(self.changeState)
        
        
        # the keys will be frame1_frame2 number and value is the state
        self.states = {
            
        }
        
        
        plt.pause(0.01)

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):

        if idx + self.frameRate < self.__len__() and idx >= 0:
            phaseFilename1 = self.phaseDir + str(self.indices[idx]) + self.fileformat
            phaseFilename2 = self.phaseDir + str(self.indices[idx+self.frameRate]) + self.fileformat
            img1 = io.imread(phaseFilename1)
            img2 = io.imread(phaseFilename2)
            
            full_img = np.concatenate((img1, img2), axis = 1)
            return full_img
        else:
            return None
    
    
    def nextDatapoint(self, buttonPress):
        key = str(self.currentFrame) + "_" + str(self.currentFrame + self.frameRate)
        #print(self.radio.value_selected)
        self.states[key] = self.radioToState[self.radio.value_selected]
        
        print(self.states)
        
        # get next Image if you are in range
        self.currentFrame += 1
        nextImage = self.__getitem__(self.currentFrame)
        #print(nextImage.shape)
        if nextImage is not None:
            self.imagesPlot.set_data(nextImage)
            self.ax.set_title(str(self.currentFrame) + " _ " + str(self.currentFrame + self.frameRate))
            self.fig.canvas.draw()
        else:
            self.currentFrame -= 1
            
    def previousDatapoint(self, buttonPress):
        key = str(self.currentFrame) + "_" + str(self.currentFrame + self.frameRate)
        #print(self.radio.value_selected)
        self.states[key] = self.radioToState[self.radio.value_selected]
        
        print(self.states)
        
        # get next Image if you are in range
        self.currentFrame -= 1
        previousImage = self.__getitem__(self.currentFrame)
        #print(nextImage.shape)
        if previousImage is not None:
            self.imagesPlot.set_data(previousImage)
            self.ax.set_title(str(self.currentFrame) + " _ " + str(self.currentFrame + self.frameRate))
            self.fig.canvas.draw()
        else:
            self.currentFrame += 1
    
    def save(self, save):
        print(self.states)
        saveFilename = self.phaseDir + self.labelsFile
        with open(saveFilename, 'wb') as f:
            pickle.dump(self.states, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        print(f"File save: {saveFilename}")

### Generate training data one stack at a time

In [5]:
#phaseDir = '/home/pk/Documents/trainingData/deadalive1/40/'
#markStates = StateTransition(phaseDir, frame_rate=1)

#### Dataset bundling

In [6]:
class statesDataset(object):
    
    def __init__(self, phaseDirectoriesList, fileformat='.tiff', 
                 transforms=None, class_to_states = {
                        0: 'all_growing', 1: 'partly_growing',
                        2: 'stopped_growing',3: 'stopped_fading',
                        4: 'stopped_vanishing',5: 'channel_empty'}):
        self.phaseDirectoriesList = phaseDirectoriesList
        self.fileformat = fileformat
        self.transforms = transforms
        self.classToStates = class_to_states
        self.nClasses = len(class_to_states)
        
        self.data = {} # each class has it's value as key and the values is a list of 2 filename tuples
        
        self.all_data = []
        
        for directory in self.phaseDirectoriesList:
            statesFilename = directory + 'transitions.pickle'
            with open(statesFilename, 'rb') as file:
                states = pickle.load(file)
            for key, value in states.items():
                file1 = directory + key.split('_')[0] + self.fileformat
                file2 = directory + key.split('_')[1] + self.fileformat
                if value in self.data:
                    self.data[value].append((file1, file2))
                else:
                    self.data[value] = [(file1, file2)]
                
                self.all_data.append((file1, file2, value))
        
        classStats = self.classStatistics()
        self.weights = np.ones(shape=(self.nClasses,))
        for key in classStats:
            self.weights[key] /= classStats[key]
                    
    
    def __len__(self):
        #length = 0
        #for key in self.data:
        #    length += len(self.data[key])
        #return length
        return len(self.all_data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        datapoint = self.all_data[idx]
        
        img1 = io.imread(datapoint[0])
        img2 = io.imread(datapoint[1])
        label = datapoint[2]
        
        sample = {
            'frame1': img1,
            'frame2': img2,
            'label': label
        }
        
        if self.transforms is not None:
            sample = self.transforms(sample)
        
        return sample
    
    def classStatistics(self):
        statistics = {}
        for key in self.data:
            statistics[key] = len(self.data[key])
            
        return statistics
    
    def calculateWeights(self):
        pass
    
    def plotDatapoint(self, idx):
        
        datapoint = self.__getitem__(idx)
        
        pltImage = np.concatenate((datapoint['frame1'], datapoint['frame2']), axis = 1)
        plt.figure()
        plt.imshow(pltImage, cmap='gray')
        plt.title(f"Class: {datapoint['label']} -- {self.classToStates[datapoint['label']]}")
        plt.show()

In [7]:
class tensorizeSample(object):
    
    def __call__(self, sample):
        
        sample['frame1'] = torch.from_numpy(sample['frame1']).unsqueeze(0)
        sample['frame2'] = torch.from_numpy(sample['frame2']).unsqueeze(0)
        
        return sample

phaseDirectoriesList = ['/home/pk/Documents/trainingData/deadalive1/' + str(i) + '/' for i in range(0, 41)]

dataset = statesDataset(phaseDirectoriesList, transforms = tensorizeSample())

len(dataset)

dataset.clasStatistics()

dataset[1000]

dataset[0]['frame1'].shape

### Net 

In [8]:
class conv_relu_norm(nn.Module):
    def __init__(self, input_channels, output_channels, kernel=3, stride=1, padding=1):
        super(conv_relu_norm, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channels, output_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.conv(x)
        return x

class stateNet(nn.Module):
    
    def __init__(self, nClasses=6):
        super(stateNet, self).__init__()
        self.conv1 = conv_relu_norm(1, 64)
        self.conv2 = conv_relu_norm(64, 128)
        self.conv3 = conv_relu_norm(128, 256)
        
        self.conv4 = conv_relu_norm(256, 256)
        
        self.fc6 = nn.Linear(76800, 1024)
        self.imageFeatures = nn.Sequential(
                nn.Linear(1024, 1024),
                nn.ReLU(inplace=True)
        )
        
        self.outputLinear = nn.Sequential(
                nn.Linear(2048, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 256),
                nn.ReLU(inplace=True),
                nn.Linear(256, nClasses)
        )
        
        
    def forward_one(self, img):
        # get to conv feature head for each image
        
        #print(f"Image shape: {img.shape}")
        batch_size = img.shape[0]
        
        conv1 = self.conv1(img)
        #print(f"Conv1 shape: {conv1.shape}")
        pool1 = F.relu(F.max_pool2d(conv1, (2, 2)))
        #print(f"Pool1 shape: {pool1.shape}")
        
        conv2 = self.conv2(pool1)
        #print(f"Conv2 shape: {conv2.shape}")
        pool2 = F.relu(F.max_pool2d(conv2, (2, 2)))
        #print(f"Pool2 shape: {pool2.shape}")
        
        conv3 = self.conv3(pool2)
        #print(f"Conv3 shape: {conv3.shape}")
        pool3 = F.relu(F.max_pool2d(conv3, (2, 3)))
        #print(f"Pool3 shape: {pool3.shape}")
        
        conv4 = self.conv4(pool3)
        #print(f"Conv4 shape: {conv4.shape}")
        conv4_reshaped = conv4.view(batch_size, -1)
        #print(f"Conv4 reshaped: {conv4_reshaped.shape}")
        
        fc6 = F.relu(self.fc6(conv4_reshaped))
        #print(f"FC6 shape: {fc6.shape}")
        
        out = self.imageFeatures(fc6)
        #print(f"Output shape: {out.shape}")
        
        return out
    
    def forward(self, img1, img2):
        # pass the sample image through the net seperately and then
        
        imgFeatures1 = self.forward_one(img1)
        imgFeatures2 = self.forward_one(img2)
        
        # stack the image features
        stackedFeatures = torch.cat((imgFeatures1, imgFeatures2), 1)
        #print(stackedFeatures.shape)
        
        
        netOutput = self.outputLinear(stackedFeatures)
        #print(netOutput.shape)
        
        return netOutput

In [9]:
phaseDirectoriesList = ['/home/pk/Documents/trainingData/deadalive1/' + str(i) + '/' for i in range(0, 41)]
dataset = statesDataset(phaseDirectoriesList, transforms = tensorizeSample())
statedataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=6)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [10]:
dataset.classStatistics()

{0: 355, 2: 162, 3: 919, 4: 53, 5: 71, 1: 38}

In [11]:
dataset.weights

array([0.0028169 , 0.02631579, 0.00617284, 0.00108814, 0.01886792,
       0.01408451])

databatch = next(iter(stateDataLoader))

In [12]:
net = stateNet(nClasses=6)
net.to(device)
criterion = nn.CrossEntropyLoss(weight=torch.tensor(dataset.weights,dtype=torch.float))
criterion.to(device)
optimizer = optim.Adam(net.parameters(), lr = 0.00005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma = 0.5)

In [13]:
nEpochs = 10
for epoch in range(nEpochs):
    epoch_loss = []
    for i_batch, data_batch in enumerate(statedataloader, 0):
        frame1_batch, frame2_batch, labels = data_batch['frame1'].to(device), data_batch['frame2'].to(device), data_batch['label'].to(device)

        optimizer.zero_grad()
        output_scores = net(frame1_batch, frame2_batch)
        loss = criterion(output_scores, labels)

        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    scheduler.step()
    print(f"Epoch {epoch+1}: ..  ..Avg.Loss: {np.mean(epoch_loss)}")

Epoch 1: ..  ..Avg.Loss: 1.5330058133602142
Epoch 2: ..  ..Avg.Loss: 1.094512906074524
Epoch 3: ..  ..Avg.Loss: 0.9336384582519531
Epoch 4: ..  ..Avg.Loss: 0.7711231178045272
Epoch 5: ..  ..Avg.Loss: 0.6599619603157043
Epoch 6: ..  ..Avg.Loss: 0.6090902316570282
Epoch 7: ..  ..Avg.Loss: 0.5522158116102218
Epoch 8: ..  ..Avg.Loss: 0.3693456320464611
Epoch 9: ..  ..Avg.Loss: 0.38225493013858797
Epoch 10: ..  ..Avg.Loss: 0.3333915063738823


In [17]:
output_scores

tensor([[-0.9125, -3.5484,  0.4680,  6.9903,  4.7184, -5.3253],
        [ 3.4226, -0.5304,  2.9040,  2.2176,  0.2189, -5.4145],
        [-0.0680, -4.9235, -0.8241,  5.5935,  3.8342, -2.7889],
        [ 5.5589,  1.6421,  3.7758, -0.3839, -0.9987, -5.2551],
        [ 0.7648, -2.8116,  3.7682,  7.5868,  1.8685, -7.4572],
        [ 7.8441,  0.8090,  3.5499,  0.3334, -3.4195, -5.1678],
        [-1.5301, -6.5270,  0.2636,  7.1894,  5.8312, -5.0786],
        [ 3.5874, -2.8173,  3.2082,  2.5729,  0.8387, -5.0557],
        [ 7.7557,  1.3497,  2.2764,  0.1000, -2.5069, -4.8078],
        [-1.7252, -5.2836,  1.1182,  7.9258,  4.2648, -5.5359],
        [ 4.0144,  0.6745,  4.1240,  0.7771,  0.1702, -5.4624],
        [-1.8071, -6.2825, -1.1059,  7.4153,  5.0598, -3.5935],
        [-0.7318, -5.7279, -0.6559,  7.1132,  4.2798, -3.3341],
        [ 8.2844,  3.0437,  1.0200,  1.2183, -2.9872, -5.0463],
        [ 8.0044,  0.6238,  0.6490,  1.8685, -2.4512, -5.0594],
        [-0.6129, -5.2056,  1.7339, 10.2

In [18]:
squish = nn.Softmax(dim = 1)

In [19]:
squish(output_scores.cpu())

tensor([[3.3455e-04, 2.3974e-05, 1.3305e-03, 9.0500e-01, 9.3311e-02, 4.0554e-06],
        [5.1151e-01, 9.8193e-03, 3.0454e-01, 1.5329e-01, 2.0774e-02, 7.4289e-05],
        [2.9529e-03, 2.2991e-05, 1.3864e-03, 8.4924e-01, 1.4620e-01, 1.9435e-04],
        [8.3887e-01, 1.6698e-02, 1.4103e-01, 2.2017e-03, 1.1906e-03, 1.6874e-05],
        [1.0615e-03, 2.9698e-05, 2.1393e-02, 9.7431e-01, 3.2008e-03, 2.8521e-07],
        [9.8513e-01, 8.6733e-04, 1.3445e-02, 5.3907e-04, 1.2641e-05, 2.2003e-06],
        [1.2984e-04, 8.7752e-07, 7.8055e-04, 7.9474e-01, 2.0435e-01, 3.7351e-06],
        [4.7329e-01, 7.8271e-04, 3.2394e-01, 1.7161e-01, 3.0295e-02, 8.3463e-05],
        [9.9370e-01, 1.6411e-03, 4.1458e-03, 4.7036e-04, 3.4693e-05, 3.4751e-06],
        [6.2672e-05, 1.7852e-06, 1.0764e-03, 9.7382e-01, 2.5034e-02, 1.3871e-06],
        [4.5208e-01, 1.6022e-02, 5.0444e-01, 1.7752e-02, 9.6761e-03, 3.4632e-05],
        [9.0221e-05, 1.0272e-06, 1.8189e-04, 9.1311e-01, 8.6606e-02, 1.5117e-05],
        [3.6962e

In [14]:
savedModel = {
    'model_state_dict': net.state_dict(),
    'labels': class_to_states
}
torch.save(savedModel, '/home/pk/Documents/models/multistate.pth')

In [22]:
from sklearn.metrics import confusion_matrix

In [27]:
actual_labels = []
predicted_labels = []
squish = nn.Softmax(dim = 1)
with torch.no_grad():
    for i_batch, data_batch in enumerate(statedataloader, 0):
        frame1_batch, frame2_batch, labels = data_batch['frame1'].to(device), data_batch['frame2'].to(device), data_batch['label'].to(device)
        
        output_scores = net(frame1_batch, frame2_batch)
        output_scoresCPU = output_scores.cpu()
        output_labels = squish(output_scoresCPU).numpy().argmax(axis = 1)
        labels_cpu = labels.cpu().numpy()
        
        actual_labels.extend(list(labels_cpu))
        predicted_labels.extend(list(output_labels))


In [28]:
len(actual_labels)

1598

In [29]:
len(predicted_labels)

1598

In [31]:
cm = confusion_matrix(actual_labels, predicted_labels, normalize=None)

In [34]:
fig, ax = plt.subplots()
ax.matshow(cm, cmap=plt.cm.Blues)
ax.set_title('Classification errors')
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
#ax.set_xticklabels([''] + labels)
#ax.set_yticklabels([''] + labels)
for (i, j), z in np.ndenumerate(cm):
    ax.text(j, i, z, ha='center', va='center',
            bbox=dict(boxstyle='round', facecolor='white', edgecolor='0.3'))

## Test data

In [25]:
class testData(object):
    
    def __init__(self, phaseDir, fileformat='.tiff'):
        
        self.phaseDir = phaseDir
        self.fileformat = fileformat
        self.indices = [int(filename.split('.')[0].split('/')[-1]) for filename in 
                       glob.glob(self.phaseDir + "*" + self.fileformat)]
        self.indices.sort()
                
    def __len__(self):
        return len(self.indices) - 1
    
    
    def __getitem__(self, idx):
        phaseFilename1 = self.phaseDir + str(self.indices[idx]) + self.fileformat
        phaseFilename2 = self.phaseDir + str(self.indices[idx+1]) + self.fileformat
        img1 = io.imread(phaseFilename1)
        img2 = io.imread(phaseFilename2)
        print(phaseFilename1)
        print(phaseFilename2)
        
        return {
            'frame1': torch.from_numpy(img1).unsqueeze(0).unsqueeze(0),
            'frame2': torch.from_numpy(img2).unsqueeze(0).unsqueeze(0)
        }

In [26]:
phaseTestDir = '/home/pk/Documents/trainingData/deadalive1/81/'
testdata = testData(phaseTestDir)

In [33]:
net = stateNet()
net.eval()

stateNet(
  (conv1): conv_relu_norm(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU(inplace=True)
    )
  )
  (conv2): conv_relu_norm(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU(inplace=True)
    )
  )
  (conv3): conv_relu_norm(
    (conv): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
   

In [20]:
out = net(testdata[0]['frame1'], testdata[0]['frame2'])

NameError: name 'testdata' is not defined

In [19]:
out

NameError: name 'out' is not defined

In [36]:
squish = nn.Softmax(dim = 1)
squish(out)

tensor([[0.1753, 0.1596, 0.1644, 0.1739, 0.1548, 0.1719]],
       grad_fn=<SoftmaxBackward>)

In [52]:
class_to_states = {
    0: 'all_growing',
    1: 'partly_growing',
    2: 'stopped_growing',
    3: 'stopped_fading',
    4: 'stopped_vanishing',
    5: 'channel_empty'
}

In [76]:
for key in class_to_states:
    print(class_to_states[key])

all_growing
partly_growing
stopped_growing
stopped_fading
stopped_vanishing
channel_empty


In [57]:
0 in class_to_states

True

In [54]:
len(class_to_states)

6

In [55]:
'0_1'.split("_")

['0', '1']

### Train loop

### Test plots