In [None]:
# conda install -c pytorch pytorch-cpu torchvision

In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as utils

## Set up sequences of simple cell movements

In [3]:
def create_basic_cell_shift(row):
    '''Create sequence of arrays depicting a cell moving at constant speed
       across a domain, frame by frame'''
    input_arrays = []
    arr = np.zeros((100, 100))
    for i, step in enumerate([10, 15, 20]):
        arr2 = np.copy(arr)
        arr2[row:row+10, step:step+10] = 1
        #plt.pcolormesh(arr2)
        #plt.show()
        #plt.close()
        input_arrays.append(arr2)
        
    input_arrays = np.array(input_arrays)

    return input_arrays

In [4]:
def create_basic_cell_shift_v(column):
    '''Create sequence of arrays depicting a cell moving at constant speed
       across a domain, frame by frame'''
    input_arrays = []
    arr = np.zeros((100, 100))
    for i, step in enumerate([10, 15, 20]):
        arr2 = np.copy(arr)
        arr2[step:step+10, column:column+10] = 1
        #plt.pcolormesh(arr2)
        #plt.show()
        #plt.close()
        input_arrays.append(arr2)
        
    input_arrays = np.array(input_arrays)

    return input_arrays

In [5]:
dataset = []
for row in range(0, 50, 5): #10):
    dataset.append(create_basic_cell_shift(row))
for column in range(0, 50, 5):
    dataset.append(create_basic_cell_shift_v(column))
dataset = np.array(dataset)

Create torch dataset and dataloader

In [6]:
# Convert to torch tensors
tensor = torch.stack([torch.Tensor(i) for i in dataset])
#tensor = tensor.unsqueeze(0) # to add a fake batch dimension as torch.nn only supports inputs that are a mini-batch of samples, not a single sample
print(tensor.shape)

torch.Size([20, 3, 100, 100])


In [7]:
train_loader = utils.DataLoader(tensor, batch_size=1)

## Create validation data

In [8]:
dataset3 = []
for column in range(75, 80):
    dataset3.append(create_basic_cell_shift_v(column))
for row in range(55, 60):
    dataset3.append(create_basic_cell_shift(row))

dataset3 = np.array(dataset3) 
tensor3 = torch.stack([torch.Tensor(i) for i in dataset3])
print(tensor3.shape)

val_loader = utils.DataLoader(tensor3, batch_size=1)

torch.Size([10, 3, 100, 100])


## Create CNN

In [9]:
from torch.autograd import Variable
import torch.nn.functional as F

class CNN(torch.nn.Module):
    
    def __init__(self):
        self.as_super = super(CNN, self)
        self.as_super.__init__()

        self.conv_1 = torch.nn.Conv2d(2, 1, kernel_size=13, stride=1, padding=6)
        self.pool_1 = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1, return_indices=True)
        
    def forward(self, x):
        orig = x
                
        x = self.conv_1(x)
        x, inds = self.pool_1(x)
        
        return(x)

In [10]:
import torch.optim as optim

def createLossAndOptimizer(net, learning_rate=0.01):
    
    #Loss function
    loss = torch.nn.MSELoss()
    
    #Optimizer
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    
    return(loss, optimizer)

# Train model

In [11]:
import time

def train_net(net, batch_size, n_epochs, learning_rate):
    
    #Print the hyperparameters of the training:
    print("===== HYPERPARAMETERS =====")
    print("batch_size=", batch_size)
    print("epochs=", n_epochs)
    print("learning_rate=", learning_rate)
    print("=" * 30)
    
    #Get training data
    n_batches = len(train_loader)
    
    #Create the loss and optimizer functions
    loss, optimizer = createLossAndOptimizer(net, learning_rate)
    
    #Time for printing at end how long training takes
    training_start_time = time.time()
    
    #Loop for n_epochs
    for epoch in range(n_epochs):
        
        running_loss = 0.0
        print_every = n_batches // 10
        start_time = time.time()
        total_train_loss = 0
        
        for i, data in enumerate(train_loader, 0):
            
            #Get inputs
            inputs, labels = data[:,:2], data[:,2]
            
            #Wrap them in a Variable object
            inputs, labels = Variable(inputs), Variable(labels)
            
            #Set gradients to zero (as the backward function accumulates gradients so want a fresh one for each mini-batch)
            optimizer.zero_grad()
            
            #Forward pass, backward pass, optimize
            outputs = net(inputs)
            loss_size = loss(outputs[0], labels)
            loss_size.backward()
            optimizer.step()
            
            #Print statistics
            running_loss += loss_size.data.item() #[0]
            total_train_loss += loss_size.data.item() #[0]
            
            #Print every 10th batch of an epoch
            if (i + 1) % (print_every + 1) == 0:
                print("Epoch {}, {:d}% \t train_loss: {:.2f} took: {:.2f}s".format(
                        epoch+1, int(100 * (i+1) / n_batches), running_loss / print_every, 
                        time.time() - start_time))
                #Reset running loss and time
                running_loss = 0.0
                start_time = time.time()
            
        #At the end of the epoch, do a pass on the validation set
        total_val_loss = 0
        for data in val_loader:
            
            #data = data.type('torch.FloatTensor')
            inputs, labels = data[:,:2], data[:,2]
            #Wrap tensors in Variables
            inputs, labels = Variable(inputs), Variable(labels)
            
            #Forward pass
            val_outputs = net(inputs)
            val_loss_size = loss(val_outputs[0], labels)
            total_val_loss += val_loss_size.data.item() #[0]
            
        print("Validation loss = {:.2f}".format(total_val_loss / len(val_loader)))
        
    print("Training finished, took {:.2f}s".format(time.time() - training_start_time))

In [None]:
cnn = CNN()
train_net(cnn, batch_size=2, n_epochs=10, learning_rate=0.001)

===== HYPERPARAMETERS =====
batch_size= 2
epochs= 10
learning_rate= 0.001
Epoch 1, 15% 	 train_loss: 0.02 took: 0.06s
Epoch 1, 30% 	 train_loss: 0.01 took: 0.10s
Epoch 1, 45% 	 train_loss: 0.01 took: 0.10s
Epoch 1, 60% 	 train_loss: 0.01 took: 0.19s
Epoch 1, 75% 	 train_loss: 0.01 took: 0.21s
Epoch 1, 90% 	 train_loss: 0.01 took: 0.10s
Validation loss = 0.01
Epoch 2, 15% 	 train_loss: 0.01 took: 0.20s
Epoch 2, 30% 	 train_loss: 0.01 took: 0.11s
Epoch 2, 45% 	 train_loss: 0.01 took: 0.10s
Epoch 2, 60% 	 train_loss: 0.01 took: 0.09s
Epoch 2, 75% 	 train_loss: 0.00 took: 0.09s
Epoch 2, 90% 	 train_loss: 0.00 took: 0.10s
Validation loss = 0.00
Epoch 3, 15% 	 train_loss: 0.01 took: 0.02s


In [None]:
def show_outputs(net, loader):
    for i, data in enumerate(loader):
        data = data.type('torch.FloatTensor')
        inputs, labels = data[:,:2], data[:,2]
        #Wrap tensors in Variables
        inputs, labels = Variable(inputs), Variable(labels)
        
        #Forward pass
        val_outputs = net(inputs)
        
        fig = plt.figure(figsize=(12,8))

        for i in range(2):
            ax = fig.add_subplot(2, 3, i+1)
            cf = plt.contourf(inputs[0,i], cmap=plt.cm.Greys)
            ax.set_xticks(np.arange(0, 100, 10))
            ax.set_yticks(np.arange(0, 100, 10))
            plt.grid()
            plt.setp(ax.xaxis.get_ticklabels(), visible=False)
            plt.setp(ax.yaxis.get_ticklabels(), visible=False)
            if i == 0:
                plt.title('inputs')
            
        ax = fig.add_subplot(2, 3, 3)
        cf = plt.contourf(labels[0], cmap=plt.cm.Greys)
        ax.set_xticks(np.arange(0, 100, 10))
        ax.set_yticks(np.arange(0, 100, 10))
        plt.grid()
        plt.setp(ax.xaxis.get_ticklabels(), visible=False)
        plt.setp(ax.yaxis.get_ticklabels(), visible=False)
        plt.title('truth')
        ax = fig.add_subplot(2, 3, 6)
        cf = plt.contourf(val_outputs[0, 0].detach().numpy(), cmap=plt.cm.Greys)
        ax.set_xticks(np.arange(0, 100, 10))
        ax.set_yticks(np.arange(0, 100, 10))
        plt.grid()
        plt.setp(ax.xaxis.get_ticklabels(), visible=False)
        plt.setp(ax.yaxis.get_ticklabels(), visible=False)
        plt.title('model')

        plt.tight_layout()

In [None]:
show_outputs(cnn, val_loader)