In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt5

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Button, CheckButtons, RadioButtons
import pickle
from skimage import io
import glob
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

### Dataset creation

In [3]:
class markStateTransition(object):
    
    def __init__(self, phaseDir=None, fileformat='.tiff', saveFilename='states.pickle'):
        self.phaseDir = phaseDir
        self.fileformat = fileformat
        self.saveFilename = saveFilename
        self.indices = [int(filename.split('.')[0].split('/')[-1]) for filename in 
                       glob.glob(self.phaseDir + "*" + self.fileformat)]
        self.indices.sort()
        self.states = {}  # it is a dictionary
        # with keys as directory name, file index 1, file index 2 and states
        self.states['dirName'] = phaseDir
        self.states['data'] = [] # use keys frame1, frame2 and keys for each of the state
        
        
        self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, num=str(self.phaseDir))
        plt.subplots_adjust(left=0.25, bottom=0.25)
        self.axcolor = 'lightgoldenrodyellow'
        
        #
        self.frame1 = 0
        self.pltFrame1 = self.ax1.imshow(self.__getitem__(self.frame1), cmap='gray')
        self.ax1.set_title(str(self.frame1) + self.fileformat)
        self.frame2 = 1
        self.pltFrame2 = self.ax2.imshow(self.__getitem__(self.frame2), cmap='gray')
        self.ax2.set_title(str(self.frame2) + self.fileformat)
        
        # buttons
        # next and save
        self.previousax = plt.axes([0.5, 0.025, 0.1, 0.03])
        self.nextax = plt.axes([0.65, 0.025, 0.1, 0.03])
        self.saveax = plt.axes([0.8, 0.025, 0.1, 0.03])
        self.nextButton = Button(self.nextax, 'Next', color=self.axcolor, hovercolor='0.975')
        self.saveButton = Button(self.saveax, 'Save', color=self.axcolor, hovercolor='0.975')
        self.previousButton = Button(self.previousax, 'Previous', color=self.axcolor, hovercolor='0.975')
        self.nextButton.on_clicked(self.nextImage)
        self.previousButton.on_clicked(self.previousImage)
        self.saveButton.on_clicked(self.save)
        
        # Radio buttons for each category transition
        
        self.rax = plt.axes([0.05, 0.7, 0.15, 0.15], facecolor=self.axcolor)
        self.radio = RadioButtons(self.rax, ('moved', 'not moved'))
        self.moved = True
        self.radio.on_clicked(self.changeMovedState)
        
        
        # more radiobuttons for other states
        
        
        plt.pause(0.01)
        
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        if idx < self.__len__():
            phaseFilename = self.phaseDir + str(self.indices[idx]) + self.fileformat
            #print(phaseFilename)
            return io.imread(phaseFilename)
        else:
            return None
    
    
    def changeMovedState(self, label):
        moved_dict = {'moved': True, 'not moved': False}
        self.moved = moved_dict[label]
    
    def nextImage(self, buttonPress):
        
        # store the current state in the data
        self.states['data'].append({
            'frame1': self.phaseDir + str(self.frame1) + self.fileformat,
            'frame2': self.phaseDir + str(self.frame2) + self.fileformat,
            'moved': self.moved
            })
        print(self.states['data'][-1])
        # update figure
        if self.frame2 == self.__len__() - 1:
            print('Last frame reached')
            return
        else:
            self.frame1 += 1
            self.frame2 += 1
            self.pltFrame1.set_data(self.__getitem__(self.frame1))
            self.ax1.set_title(str(self.frame1) + self.fileformat)
            self.pltFrame2.set_data(self.__getitem__(self.frame2))
            self.ax2.set_title(str(self.frame2) + self.fileformat)
            self.fig.canvas.draw()
            
    def previousImage(self, buttonPress):
        
        if self.frame1 == 0:
            print('First frame reached')
            return
        else:
            if len(self.states['data']) != 0:
                del self.states['data'][-1]
                
            self.frame1 -= 1
            self.frame2 -= 1
            self.pltFrame1.set_data(self.__getitem__(self.frame1))
            self.ax1.set_title(str(self.frame1) + self.fileformat)
            self.pltFrame2.set_data(self.__getitem__(self.frame2))
            self.ax2.set_title(str(self.frame2) + self.fileformat)
            self.fig.canvas.draw()
    
    def save(self, save):
        print(self.states)
        saveFilename = self.phaseDir + self.saveFilename
        with open(saveFilename, 'wb') as f:
            pickle.dump(self.states, f, protocol=pickle.HIGHEST_PROTOCOL)
        print(f"File saved {saveFilename}")

In [4]:
phaseDir = '/home/pk/Documents/trainingData/deadalive1/4/'
genData = markStateTransition(phaseDir)

### Data loader and parameters

In [4]:
class cellsMovingDataset(object):
    
    def __init__(self, phaseDirectoriesList, fileformat='.tiff'):
        
        self.phaseDirectoriesList = phaseDirectoriesList
        self.data = []
        
        for directory in self.phaseDirectoriesList:
            # read the states file
            statesFilename = directory + 'states.pickle'
            with open(statesFilename, 'rb') as file:
                states = pickle.load(file)
            data = states['data']
            
            self.data.extend(data)
    
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        
        datapoint = self.data[idx]
        
        frame1 = io.imread(datapoint['frame1'])
        frame2 = io.imread(datapoint['frame2'])
        cellsmoving = datapoint['moved']
        
        return {
            'frame1': frame1,
            'frame2': frame2,
            'cellsmoving': cellsmoving,
            'frame1file': datapoint['frame1'],
            'frame2file': datapoint['frame2']
        }
    
    def statistics(self):
        stats = {'moved': 0, 'notmoved': 0}
        for datapoint in self.data:
            if datapoint['moved']:
                stats['moved'] += 1
            else:
                stats['notmoved'] += 1
        
        return stats
    
    def plotDatapoint(self, idx):
        
        datapoint = self.__getitem__(idx)       
        fig, ax = plt.subplots(1, 2)
        plt.title(datapoint['frame1file'].split('/')[0])
        ax[0].imshow(datapoint['frame1'], cmap='gray')
        ax[1].imshow(datapoint['frame2'], cmap='gray')
        
        ax[0].set_title(f"Moved : {datapoint['cellsmoving']}")
        ax[1].set_title(f"Moved: {datapoint['cellsmoving']}")
        plt.show(block=False)

In [5]:
phaseDirectoriesList = ['/home/pk/Documents/trainingData/deadalive1/0/',
                        '/home/pk/Documents/trainingData/deadalive1/1/',
                       '/home/pk/Documents/trainingData/deadalive1/2/',
                       '/home/pk/Documents/trainingData/deadalive1/3/',
                       '/home/pk/Documents/trainingData/deadalive1/4/']

In [6]:
data = cellsMovingDataset(phaseDirectoriesList)

In [7]:
len(data)

211

In [8]:
data[0]

{'frame1': array([[ 1.6432905 ,  1.5902892 ,  1.6138453 , ...,  1.6177714 ,
          1.6079563 ,  1.519621  ],
        [ 1.8199612 ,  1.7178848 ,  1.7061068 , ...,  1.5549551 ,
          1.6550685 ,  1.6354384 ],
        [ 1.8651104 ,  1.7846271 ,  1.7453669 , ...,  1.7375149 ,
          1.8179982 ,  1.8454803 ],
        ...,
        [-0.4806841 , -0.48853615, -0.44142395, ..., -0.455165  ,
         -0.48264712, -0.41590485],
        [-0.455165  , -0.4767581 , -0.453202  , ..., -0.48264712,
         -0.42179388, -0.43357193],
        [-0.455165  , -0.4276829 , -0.4021638 , ..., -0.46694306,
         -0.43946093, -0.4276829 ]], dtype=float32),
 'frame2': array([[ 1.7492168 ,  1.5520586 ,  1.6624671 , ...,  1.7373873 ,
          1.6545808 ,  1.6742966 ],
        [ 1.6407797 ,  1.6762682 ,  1.6151491 , ...,  1.4988258 ,
          1.6013482 ,  1.6565524 ],
        [ 1.776819  ,  1.7728758 ,  1.7649895 , ...,  1.621064  ,
          1.6407797 ,  1.6565524 ],
        ...,
        [-0.4313536

In [9]:
data.plotDatapoint(21)

In [10]:
data.statistics()

{'moved': 51, 'notmoved': 160}

### Data augmentations

In [11]:
class randomTranslations(object):

    def __init__(self, translate=(0.0, 0.25)):
        self.translate = translate
    
    def __call__(self, sample):
        affine_params = transforms.RandomAffine.get_params(degrees= [0,0], 
                                                           translate=self.translate,
                                                           scale_ranges=None,
                                                           shears=None,
                                                           img_size=[36, 800])
        print(affine_params)
        frame1Tensor = torch.from_numpy(sample['frame1']).unsqueeze(0)
        frame2Tensor = torch.from_numpy(sample['frame2']).unsqueeze(0)
        
        
        frame1Translated = transforms.functional.affine(frame1Tensor, *affine_params)
        frame2Translated = transforms.functional.affine(frame2Tensor, *affine_params)
        
        return {
            'frame1': frame1Translated,
            'frame2': frame2Translated,
            'cellsmoving': sample['cellsmoving'],
            'frame1file': sample['frame1file'],
            'frame2file': sample['frame2file']
        }

In [17]:
class randomVerticalFlips(object):
    
    def __call__(self, sample):
        
        return sample

In [12]:
data[0]

{'frame1': array([[ 1.6432905 ,  1.5902892 ,  1.6138453 , ...,  1.6177714 ,
          1.6079563 ,  1.519621  ],
        [ 1.8199612 ,  1.7178848 ,  1.7061068 , ...,  1.5549551 ,
          1.6550685 ,  1.6354384 ],
        [ 1.8651104 ,  1.7846271 ,  1.7453669 , ...,  1.7375149 ,
          1.8179982 ,  1.8454803 ],
        ...,
        [-0.4806841 , -0.48853615, -0.44142395, ..., -0.455165  ,
         -0.48264712, -0.41590485],
        [-0.455165  , -0.4767581 , -0.453202  , ..., -0.48264712,
         -0.42179388, -0.43357193],
        [-0.455165  , -0.4276829 , -0.4021638 , ..., -0.46694306,
         -0.43946093, -0.4276829 ]], dtype=float32),
 'frame2': array([[ 1.7492168 ,  1.5520586 ,  1.6624671 , ...,  1.7373873 ,
          1.6545808 ,  1.6742966 ],
        [ 1.6407797 ,  1.6762682 ,  1.6151491 , ...,  1.4988258 ,
          1.6013482 ,  1.6565524 ],
        [ 1.776819  ,  1.7728758 ,  1.7649895 , ...,  1.621064  ,
          1.6407797 ,  1.6565524 ],
        ...,
        [-0.4313536

In [13]:
t = randomTranslations()

In [14]:
t_sample = t(data[0])

(0.0, (0, 45), 1.0, (0.0, 0.0))


In [15]:
t_sample

{'frame1': tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [-0.4473, -0.4591, -0.4571,  ..., -0.4689, -0.3982, -0.5121],
          [-0.4198, -0.4807, -0.4552,  ..., -0.4532, -0.4375, -0.4277],
          [-0.4355, -0.4002, -0.4453,  ..., -0.4532, -0.4316, -0.4591]]]),
 'frame2': tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [-0.4530, -0.4629, -0.4925,  ..., -0.4038, -0.4570, -0.4471],
          [-0.4984, -0.4452, -0.4314,  ..., -0.4471, -0.4432, -0.4590],
          [-0.4570, -0.4668, -0.4550,  ..., -0.4846, -0.4590, -0.4806]]]),
 'cellsmoving': True,
 'frame1file': '/home/pk/Documents/trainingData/deadalive1

In [16]:
datapoint = data[0]
t = randomTranslations()
t_sample = t(datapoint)
fig, ax = plt.subplots(1, 4)
plt.title(datapoint['frame1file'].split('/')[0])
ax[0].imshow(t_sample['frame1'].numpy().squeeze(0), cmap='gray')
ax[1].imshow(datapoint['frame1'], cmap='gray')
ax[2].imshow(t_sample['frame2'].numpy().squeeze(0), cmap='gray')
ax[3].imshow(datapoint['frame2'], cmap='gray')

ax[0].set_title(f"Moved : {t_sample['cellsmoving']}")
ax[1].set_title(f"Moved: {datapoint['cellsmoving']}")
ax[2].set_title(f"Moved: {t_sample['cellsmoving']}")
ax[3].set_title(f"Moved: {datapoint['cellsmoving']}")

(0.0, (0, 94), 1.0, (0.0, 0.0))


Text(0.5, 1.0, 'Moved: True')

### Net architecture

In [18]:
class conv_relu_norm(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(conv_relu_norm, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, 3, stride=1, padding=1),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channels, output_channels, 3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.conv(x)

class movingNet(nn.Module):
    
    def __init__(self):
        
        self.conv1 = conv_relu_norm(1, 64)
             
        
    
    def forward_one(self, img):
        # get to conv feature head for each image
        
        return img
    
    def forward(self, sample):
        # pass the sample image through the net seperately and then
        
        imgFeatures1 = self.forward_one(sample['frame1'])
        imgFeatures2 = self.forward_one(sample['frame2'])
        
        # stack the image features
        
        return sample
        

### Train and validation loop

### Test loop