In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt5

In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
from narsil.deadalive.datasets import channelStackTrain
from narsil.deadalive.modelDev import trainDeadAliveNet
from narsil.deadalive.network import CaffeLSTMCell

In [4]:
phaseDirectoriesList = ['/home/pk/Documents/trainingData/deadalive/0/','/home/pk/Documents/trainingData/deadalive/1/' ]

In [5]:
phaseDirectoriesList

['/home/pk/Documents/trainingData/deadalive/0/',
 '/home/pk/Documents/trainingData/deadalive/1/']

In [6]:
dataset = channelStackTrain(phaseDirectoriesList, numUnrolls=2, fileformat='.tif')

In [7]:
len(dataset)

78

In [8]:
dataset[0]['imageSequence'].shape

(1, 1024, 40, 2)

In [9]:
class deadAliveNetBase(nn.Module):

    def __init__(self, device, args=None):
        super(deadAliveNetBase, self).__init__()
        self.device = device
        self.args = args
        self.learning_rate =  None
        self.optimizer = None
        self.outputs = None
        self.loss_function = nn.BCELoss()

    def loss(self, outputs, labels):
        return self.loss_function(outputs, labels)

    def setup_optimizer(self, learning_rate):
        self.learning_rate = learning_rate
        self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=0.0005)

    def update_learning_rate(self, lr_new):
        if self.learning_rate != lr_new:
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = lr_new
            self.learning_rate = lr_new

    def step(self, inputs, labels):
        self.optimizer.zero_grad()
        self.outputs = self(inputs)
        loss = self.loss(self.outputs, labels)
        loss.backward()
        self.optimizer.step()
        return loss.data.cpu().numpy()[0]
       

In [10]:
img_tensor = torch.zeros(size=(2, 1, 800, 36))

In [11]:
img_tensor.shape

torch.Size([2, 1, 800, 36])

In [18]:
class deadAliveNet(deadAliveNetBase):
    
    def __init__(self, device, lstm_size=1024, args=None):
        super(deadAliveNet, self).__init__(device, args)
        
        self.device = device
        self.lstm_size = lstm_size
        self.lstm_state = None
        
        self.conv = nn.ModuleList([
            nn.Conv2d(1, 8, kernel_size=(3, 3), stride=1, padding =1),
            nn.Conv2d(8, 16, kernel_size=(3, 3), stride=1, padding=1),
            nn.Conv2d(16, 32, kernel_size=(3, 3), stride=1, padding=1),
            nn.Conv2d(32, 32, kernel_size=(3, 3), stride=1, padding=1),          
        ])
        
        self.lrn = nn.ModuleList([
            nn.LocalResponseNorm(size=4, alpha=0.0001, beta=0.75),
            nn.LocalResponseNorm(size=4, alpha=0.0001, beta=0.75)
        ])
        
        self.conv_skip = nn.ModuleList([
            nn.Conv2d(8, 2, 1),
            nn.Conv2d(16, 4, 1),
            nn.Conv2d(32, 8, 1)
        ])
        
        self.prelu_skip = nn.ModuleList([
            nn.PReLU(2),
            nn.PReLU(4),
            nn.PReLU(8)
        ])
        
        self.fc6 = nn.Linear(48800, 1024)
        self.lstm1 = CaffeLSTMCell(1024, self.lstm_size)
        self.lstm2 = CaffeLSTMCell(1024 + self.lstm_size, self.lstm_size)
        
        self.fc_out = nn.Linear(self.lstm_size, 6)
        
    
    def forward(self, input, lstm_state=None):
        batch_size = input.shape[0]
        
        
        conv1 = self.conv[0](input)
        print(f"Conv1 shape: {conv1.shape}")
        pool1 = F.relu(F.max_pool2d(conv1, (2, 2)))
        print(f"Pool1 shape: {pool1.shape}")
        lrn1 = self.lrn[0](pool1)
        print(f"Lrn1 shape: {lrn1.shape}")
        
        # get the pool features into a vector for the final lstm at this spatial 
        # scale
        
        conv1_skip = self.prelu_skip[0](self.conv_skip[0](lrn1))
        print(f"Conv1_skip shape: {conv1_skip.shape}")
        # flatten to pool later
        conv1_skip_flatten = conv1_skip.view(batch_size, -1)
        print(f"Conv1_skip_flatten shape: {conv1_skip_flatten.shape}")
        
        
        conv2 = self.conv[1](lrn1)
        print(f"Conv2 shape: {conv2.shape}")
        pool2 = F.relu(F.max_pool2d(conv2, (2, 2)))
        print(f"Pool2 shape: {pool2.shape}")
        lrn2 = self.lrn[1](pool2)
        print(f"Lrn2 shape: {lrn2.shape}")
        
        
        conv2_skip = self.prelu_skip[1](self.conv_skip[1](lrn2))
        print(f"Conv2_skip shape: {conv2_skip.shape}")
        # flatten to pool later
        conv2_skip_flatten = conv2_skip.view(batch_size, -1)
        print(f"Conv2_skip_flatten shape: {conv2_skip_flatten.shape}")
        
        conv3 = F.relu(self.conv[2](lrn2))
        print(f"Conv3 shape: {conv3.shape}")
        conv4 = F.relu(self.conv[3](conv3))
        print(f"Conv4 shape: {conv4.shape}")
        
        conv4_skip = self.prelu_skip[2](self.conv_skip[2](conv4))
        print(f"Conv4_skip shape: {conv4_skip.shape}")
        conv4_skip_flatten = conv4_skip.view(batch_size, -1)
        print(f"Conv4_skip_flatten shape: {conv4_skip_flatten.shape}")
        
        pool4 = F.relu(F.max_pool2d(conv4, (2, 2)))
        print(f"Pool4 shape: {pool4.shape}")
        
        pool4_flat = pool4.view(batch_size, -1)
        print(f"Pool4_flat shape: {pool4_flat.shape}")
        
        skip_concat = torch.cat([conv1_skip_flatten, conv2_skip_flatten, conv4_skip_flatten, pool4_flat], 1)
        print(f"Skip concat shape: {skip_concat.shape}")
        
        fc6 = F.relu(self.fc6(skip_concat))
        print(f"FC6 shape: {fc6.shape}")
        
        
        if lstm_state is None:
            outputs1, state1 = self.lstm1(fc6)
            outputs2, state2 = self.lstm2(torch.cat((fc6, outputs1), 1))
        else:
            outputs1, state1, outputs2, state2 = lstm_state
            outputs1, state1 = self.lstm1(fc6, (outputs1, state1))
            outputs2, state2 = self.lstm2(torch.cat((fc6, outputs1), 1), (outputs2, state2))

        self.lstm_state  = (outputs1, state1, outputs2, state2)
        
        fc_out = self.fc_out(outputs2)
        
        print(f"FC_out shape: {fc_out.shape}")
        return torch.sigmoid(fc_out)
        

In [19]:
net = deadAliveNet(device="cpu")

In [21]:
net(img_tensor)

Conv1 shape: torch.Size([2, 8, 800, 36])
Pool1 shape: torch.Size([2, 8, 400, 18])
Lrn1 shape: torch.Size([2, 8, 400, 18])
Conv1_skip shape: torch.Size([2, 2, 400, 18])
Conv1_skip_flatten shape: torch.Size([2, 14400])
Conv2 shape: torch.Size([2, 16, 400, 18])
Pool2 shape: torch.Size([2, 16, 200, 9])
Lrn2 shape: torch.Size([2, 16, 200, 9])
Conv2_skip shape: torch.Size([2, 4, 200, 9])
Conv2_skip_flatten shape: torch.Size([2, 7200])
Conv3 shape: torch.Size([2, 32, 200, 9])
Conv4 shape: torch.Size([2, 32, 200, 9])
Conv4_skip shape: torch.Size([2, 8, 200, 9])
Conv4_skip_flatten shape: torch.Size([2, 14400])
Pool4 shape: torch.Size([2, 32, 100, 4])
Pool4_flat shape: torch.Size([2, 12800])
Skip concat shape: torch.Size([2, 48800])
FC6 shape: torch.Size([2, 1024])
FC_out shape: torch.Size([2, 6])


tensor([[0.4985, 0.5016, 0.4990, 0.4997, 0.5022, 0.5068],
        [0.4985, 0.5016, 0.4990, 0.4997, 0.5022, 0.5068]],
       grad_fn=<SigmoidBackward>)

In [22]:
from torchinfo import summary

In [23]:
summary(net, input_size=(2, 1, 800, 36))

Conv1 shape: torch.Size([2, 8, 800, 36])
Pool1 shape: torch.Size([2, 8, 400, 18])
Lrn1 shape: torch.Size([2, 8, 400, 18])
Conv1_skip shape: torch.Size([2, 2, 400, 18])
Conv1_skip_flatten shape: torch.Size([2, 14400])
Conv2 shape: torch.Size([2, 16, 400, 18])
Pool2 shape: torch.Size([2, 16, 200, 9])
Lrn2 shape: torch.Size([2, 16, 200, 9])
Conv2_skip shape: torch.Size([2, 4, 200, 9])
Conv2_skip_flatten shape: torch.Size([2, 7200])
Conv3 shape: torch.Size([2, 32, 200, 9])
Conv4 shape: torch.Size([2, 32, 200, 9])
Conv4_skip shape: torch.Size([2, 8, 200, 9])
Conv4_skip_flatten shape: torch.Size([2, 14400])
Pool4 shape: torch.Size([2, 32, 100, 4])
Pool4_flat shape: torch.Size([2, 12800])
Skip concat shape: torch.Size([2, 48800])
FC6 shape: torch.Size([2, 1024])
FC_out shape: torch.Size([2, 6])


Layer (type:depth-idx)                   Output Shape              Param #
deadAliveNet                             --                        --
├─ModuleList: 1-1                        --                        --
├─ModuleList: 1-2                        --                        --
├─ModuleList: 1-3                        --                        --
├─ModuleList: 1-4                        --                        --
├─ModuleList: 1-1                        --                        --
│    └─Conv2d: 2-1                       [2, 8, 800, 36]           80
├─ModuleList: 1-2                        --                        --
│    └─LocalResponseNorm: 2-2            [2, 8, 400, 18]           --
├─ModuleList: 1-3                        --                        --
│    └─Conv2d: 2-3                       [2, 2, 400, 18]           18
├─ModuleList: 1-4                        --                        --
│    └─PReLU: 2-4                        [2, 2, 400, 18]           2
├─ModuleList: 1-

In [87]:
400 *36

14400

### Plotting dead-alive probabilites

In [29]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

fig, ax = plt.subplots()
xdata, ydata = [], []
ln, = plt.plot([], [], 'ro')

def init():
    ax.set_xlim(0, 2*np.pi)
    ax.set_ylim(-1, 1)
    return ln,

def update(frame):
    xdata.append(frame)
    ydata.append(np.sin(frame))
    ln.set_data(xdata, ydata)
    return ln,

ani = FuncAnimation(fig, update, frames=np.linspace(0, 2*np.pi, 128),
                    init_func=init, blit=True)
plt.show()


In [8]:
modelParameters = {
    'device': "cuda:1"
}
optimizationParameters = {
    'learning_rate': 0.5e-5,
    'nEpochs': 100
}

In [9]:
net = trainDeadAliveNet(phaseDirectoriesList, modelParameters, optimizationParameters, fileformat='.tif')

In [13]:
net.train()

Epoch 0 -- started
Epoch average loss: 0.15325021594762803
Epoch 1 -- started
Epoch average loss: 0.1096721351146698
Epoch 2 -- started
Epoch average loss: 0.11148410886526108
Epoch 3 -- started
Epoch average loss: 0.10040972828865051
Epoch 4 -- started
Epoch average loss: 0.09695472568273544
Epoch 5 -- started
Epoch average loss: 0.09909523427486419
Epoch 6 -- started
Epoch average loss: 0.09453219920396805
Epoch 7 -- started
Epoch average loss: 0.09195844233036041
Epoch 8 -- started
Epoch average loss: 0.09463360160589218
Epoch 9 -- started
Epoch average loss: 0.09367366284132003
Epoch 10 -- started
Epoch average loss: 0.0925093412399292
Epoch 11 -- started
Epoch average loss: 0.09434753209352494
Epoch 12 -- started
Epoch average loss: 0.09306800067424774
Epoch 13 -- started
Epoch average loss: 0.09562535881996155
Epoch 14 -- started
Epoch average loss: 0.10037241280078887
Epoch 15 -- started
Epoch average loss: 0.10584187507629395
Epoch 16 -- started
Epoch average loss: 0.0962564885