In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt5

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Button, CheckButtons, RadioButtons
import pickle
from skimage import io
import glob
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import random

### Dataset creation

In [3]:
class markStateTransition(object):
    
    def __init__(self, phaseDir=None, fileformat='.tiff', saveFilename='states.pickle'):
        self.phaseDir = phaseDir
        self.fileformat = fileformat
        self.saveFilename = saveFilename
        self.indices = [int(filename.split('.')[0].split('/')[-1]) for filename in 
                       glob.glob(self.phaseDir + "*" + self.fileformat)]
        self.indices.sort()
        self.states = {}  # it is a dictionary
        # with keys as directory name, file index 1, file index 2 and states
        self.states['dirName'] = phaseDir
        self.states['data'] = [] # use keys frame1, frame2 and keys for each of the state
        
        
        self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, num=str(self.phaseDir))
        plt.subplots_adjust(left=0.25, bottom=0.25)
        self.axcolor = 'lightgoldenrodyellow'
        
        #
        self.frame1 = 0
        self.pltFrame1 = self.ax1.imshow(self.__getitem__(self.frame1), cmap='gray')
        self.ax1.set_title(str(self.frame1) + self.fileformat)
        self.frame2 = 1
        self.pltFrame2 = self.ax2.imshow(self.__getitem__(self.frame2), cmap='gray')
        self.ax2.set_title(str(self.frame2) + self.fileformat)
        
        # buttons
        # next and save
        self.previousax = plt.axes([0.5, 0.025, 0.1, 0.03])
        self.nextax = plt.axes([0.65, 0.025, 0.1, 0.03])
        self.saveax = plt.axes([0.8, 0.025, 0.1, 0.03])
        self.nextButton = Button(self.nextax, 'Next', color=self.axcolor, hovercolor='0.975')
        self.saveButton = Button(self.saveax, 'Save', color=self.axcolor, hovercolor='0.975')
        self.previousButton = Button(self.previousax, 'Previous', color=self.axcolor, hovercolor='0.975')
        self.nextButton.on_clicked(self.nextImage)
        self.previousButton.on_clicked(self.previousImage)
        self.saveButton.on_clicked(self.save)
        
        # Radio buttons for each category transition
        
        self.rax = plt.axes([0.05, 0.7, 0.15, 0.15], facecolor=self.axcolor)
        self.radio = RadioButtons(self.rax, ('moved', 'not moved'))
        self.moved = True
        self.radio.on_clicked(self.changeMovedState)
        
        
        # more radiobuttons for other states
        
        
        self.reachedEnd = False
        
        
        plt.pause(0.01)
        
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        if idx < self.__len__():
            phaseFilename = self.phaseDir + str(self.indices[idx]) + self.fileformat
            #print(phaseFilename)
            return io.imread(phaseFilename)
        else:
            return None
    
    
    def changeMovedState(self, label):
        moved_dict = {'moved': True, 'not moved': False}
        self.moved = moved_dict[label]
    
    def nextImage(self, buttonPress):
        # store the current state in the data
        if self.frame2 <= len(self.indices) - 1 and self.reachedEnd == False:
            self.states['data'].append({
                'frame1': self.phaseDir + str(self.frame1) + self.fileformat,
                'frame2': self.phaseDir + str(self.frame2) + self.fileformat,
                'moved': self.moved
                })
            print(self.states['data'][-1])
            
            if self.frame2 == len(self.indices) - 1:
                self.reachedEnd = True
                print('Last frame reached')
                return
            
            
            if self.frame2 < len(self.indices) - 1:
                self.frame1 += 1
                self.frame2 += 1
                self.pltFrame1.set_data(self.__getitem__(self.frame1))
                self.ax1.set_title(str(self.frame1) + self.fileformat)
                self.pltFrame2.set_data(self.__getitem__(self.frame2))
                self.ax2.set_title(str(self.frame2) + self.fileformat)
                self.fig.canvas.draw()
            
            
    def previousImage(self, buttonPress):
        
        if self.frame1 == 0:
            print('First frame reached')
            return
        else:
            if len(self.states['data']) != 0:
                del self.states['data'][-1]
            
            self.reachedEnd = False
                
            self.frame1 -= 1
            self.frame2 -= 1
            self.pltFrame1.set_data(self.__getitem__(self.frame1))
            self.ax1.set_title(str(self.frame1) + self.fileformat)
            self.pltFrame2.set_data(self.__getitem__(self.frame2))
            self.ax2.set_title(str(self.frame2) + self.fileformat)
            self.fig.canvas.draw()
    
    def save(self, save):
        print(self.states)
        saveFilename = self.phaseDir + self.saveFilename
        with open(saveFilename, 'wb') as f:
            pickle.dump(self.states, f, protocol=pickle.HIGHEST_PROTOCOL)
        print(f"File saved {saveFilename}")

In [5]:
phaseDir = '/home/pk/Documents/trainingData/deadalive1/19/'
genData = markStateTransition(phaseDir)

{'frame1': '/home/pk/Documents/trainingData/deadalive1/18/0.tiff', 'frame2': '/home/pk/Documents/trainingData/deadalive1/18/1.tiff', 'moved': True}
{'frame1': '/home/pk/Documents/trainingData/deadalive1/18/1.tiff', 'frame2': '/home/pk/Documents/trainingData/deadalive1/18/2.tiff', 'moved': True}
{'frame1': '/home/pk/Documents/trainingData/deadalive1/18/2.tiff', 'frame2': '/home/pk/Documents/trainingData/deadalive1/18/3.tiff', 'moved': True}
{'frame1': '/home/pk/Documents/trainingData/deadalive1/18/3.tiff', 'frame2': '/home/pk/Documents/trainingData/deadalive1/18/4.tiff', 'moved': True}
{'frame1': '/home/pk/Documents/trainingData/deadalive1/18/4.tiff', 'frame2': '/home/pk/Documents/trainingData/deadalive1/18/5.tiff', 'moved': True}
{'frame1': '/home/pk/Documents/trainingData/deadalive1/18/5.tiff', 'frame2': '/home/pk/Documents/trainingData/deadalive1/18/6.tiff', 'moved': True}
{'frame1': '/home/pk/Documents/trainingData/deadalive1/18/6.tiff', 'frame2': '/home/pk/Documents/trainingData/de

### Data loader and parameters

In [4]:
class cellsMovingDataset(object):
    
    def __init__(self, phaseDirectoriesList, fileformat='.tiff',transforms=None):
        
        self.phaseDirectoriesList = phaseDirectoriesList
        self.data = []
        self.transforms = transforms
        
        for directory in self.phaseDirectoriesList:
            # read the states file
            statesFilename = directory + 'states.pickle'
            with open(statesFilename, 'rb') as file:
                states = pickle.load(file)
            data = states['data']
            
            self.data.extend(data)
    
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        
        datapoint = self.data[idx]
        
        frame1 = io.imread(datapoint['frame1'])
        frame2 = io.imread(datapoint['frame2'])
        cellsmoving = datapoint['moved']
        
        sample ={
            'frame1': frame1,
            'frame2': frame2,
            'cellsmoving': cellsmoving,
            'frame1file': datapoint['frame1'],
            'frame2file': datapoint['frame2']
        }
    
        
        if self.transforms != None:
            sample = self.transforms(sample)
        
        return sample
    
    def statistics(self):
        stats = {'moved': 0, 'notmoved': 0}
        for datapoint in self.data:
            if datapoint['moved']:
                stats['moved'] += 1
            else:
                stats['notmoved'] += 1
        
        return stats
    
    def plotDatapoint(self, idx):
        
        datapoint = self.__getitem__(idx)       
        fig, ax = plt.subplots(1, 2)
        plt.title(datapoint['frame1file'].split('/')[0])
        ax[0].imshow(datapoint['frame1'], cmap='gray')
        ax[1].imshow(datapoint['frame2'], cmap='gray')
        
        ax[0].set_title(f"Moved : {datapoint['cellsmoving']}")
        ax[1].set_title(f"Moved: {datapoint['cellsmoving']}")
        plt.show(block=False)

### Data augmentations

In [5]:
class randomTranslations(object):

    def __init__(self, translate=(0.0, 0.10)):
        self.translate = translate
    
    def __call__(self, sample):
        #img_shape = sample['frame1'].shape[-2:]
        #img_size = [img_shape[1], img_shape[0]]
        affine_params = transforms.RandomAffine.get_params(degrees= [0,0], 
                                                           translate=self.translate,
                                                           scale_ranges=None,
                                                           shears=None,
                                                           img_size=[36, 800])
        #print(affine_params)
        frame1Tensor = torch.from_numpy(sample['frame1']).unsqueeze(0)
        frame2Tensor = torch.from_numpy(sample['frame2']).unsqueeze(0)
        
        
        frame1Translated = transforms.functional.affine(frame1Tensor, *affine_params)
        frame2Translated = transforms.functional.affine(frame2Tensor, *affine_params)
        
        return {
            'frame1': frame1Translated,
            'frame2': frame2Translated,
            'cellsmoving': torch.tensor(sample['cellsmoving'],dtype=torch.float).unsqueeze(0),
            'frame1file': sample['frame1file'],
            'frame2file': sample['frame2file']
        }

In [6]:
class randomVerticalFlips(object):
    
    def __init__(self, p=0.5):
        self.probability = p
        
    def __call__(self, sample):
        if random.random() < self.probability:
            # then flip both images
            sample['frame1'] = transforms.functional.vflip(sample['frame1'])
            sample['frame2'] = transforms.functional.vflip(sample['frame2'])
            
        return sample

In [6]:
phaseDirectoriesList = ['/home/pk/Documents/trainingData/deadalive1/0/',
                        '/home/pk/Documents/trainingData/deadalive1/1/',
                       '/home/pk/Documents/trainingData/deadalive1/2/',
                       '/home/pk/Documents/trainingData/deadalive1/3/',
                       '/home/pk/Documents/trainingData/deadalive1/4/']

In [12]:
data = cellsMovingDataset(phaseDirectoriesList, transforms = randomTranslations())

In [13]:
#data.plotDatapoint(21)

In [14]:
data.statistics()

{'moved': 51, 'notmoved': 160}

datapoint = data[0]
#t_sample = f(t_sample)
fig, ax = plt.subplots(1, 4)
plt.title(datapoint['frame1file'].split('/')[0])
ax[0].imshow(t_sample['frame1'].numpy().squeeze(0), cmap='gray')
ax[1].imshow(datapoint['frame1'], cmap='gray')
ax[2].imshow(t_sample['frame2'].numpy().squeeze(0), cmap='gray')
ax[3].imshow(datapoint['frame2'], cmap='gray')

ax[0].set_title(f"Moved : {t_sample['cellsmoving']}")
ax[1].set_title(f"Moved: {datapoint['cellsmoving']}")
ax[2].set_title(f"Moved: {t_sample['cellsmoving']}")
ax[3].set_title(f"Moved: {datapoint['cellsmoving']}")

In [15]:
data[0]

{'frame1': tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [-0.4336, -0.4748, -0.4768,  ..., -0.4591, -0.4061, -0.4375],
          [-0.4571, -0.4512, -0.4395,  ..., -0.4709, -0.4061, -0.4512],
          [-0.4316, -0.4257, -0.5042,  ..., -0.3923, -0.3923, -0.4355]]]),
 'frame2': tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [-0.4629, -0.4688, -0.4156,  ..., -0.4452, -0.4688, -0.4057],
          [-0.4570, -0.4550, -0.4116,  ..., -0.4491, -0.4116, -0.4373],
          [-0.3919, -0.3683, -0.4432,  ..., -0.4333, -0.4491, -0.4688]]]),
 'cellsmoving': tensor([1.]),
 'frame1file': '/home/pk/Documents/trainingData/de

### Net architecture

In [7]:
class conv_relu_norm(nn.Module):
    def __init__(self, input_channels, output_channels, kernel=3, stride=1, padding=1):
        super(conv_relu_norm, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channels, output_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.conv(x)
        return x

class movingNet(nn.Module):
    
    def __init__(self):
        super(movingNet, self).__init__()
        self.conv1 = conv_relu_norm(1, 64)
        self.conv2 = conv_relu_norm(64, 128)
        self.conv3 = conv_relu_norm(128, 256)
        
        self.conv4 = conv_relu_norm(256, 256)
        
        self.fc6 = nn.Linear(76800, 1024)
        self.imageFeatures = nn.Sequential(
                nn.Linear(1024, 512),
                nn.ReLU(inplace=True)
        )
        
        self.outputLinear = nn.Sequential(
                nn.Linear(1024, 512),
                nn.ReLU(inplace=True),
                nn.Linear(512, 256),
                nn.ReLU(inplace=True),
                nn.Linear(256, 1)
        )
        
        
    def forward_one(self, img):
        # get to conv feature head for each image
        
        #print(f"Image shape: {img.shape}")
        batch_size = img.shape[0]
        
        conv1 = self.conv1(img)
        #print(f"Conv1 shape: {conv1.shape}")
        pool1 = F.relu(F.max_pool2d(conv1, (2, 2)))
        #print(f"Pool1 shape: {pool1.shape}")
        
        conv2 = self.conv2(pool1)
        #print(f"Conv2 shape: {conv2.shape}")
        pool2 = F.relu(F.max_pool2d(conv2, (2, 2)))
        #print(f"Pool2 shape: {pool2.shape}")
        
        conv3 = self.conv3(pool2)
        #print(f"Conv3 shape: {conv3.shape}")
        pool3 = F.relu(F.max_pool2d(conv3, (2, 3)))
        #print(f"Pool3 shape: {pool3.shape}")
        
        conv4 = self.conv4(pool3)
        #print(f"Conv4 shape: {conv4.shape}")
        conv4_reshaped = conv4.view(batch_size, -1)
        #print(f"Conv4 reshaped: {conv4_reshaped.shape}")
        
        fc6 = F.relu(self.fc6(conv4_reshaped))
        #print(f"FC6 shape: {fc6.shape}")
        
        out = self.imageFeatures(fc6)
        #print(f"Output shape: {out.shape}")
        
        return out
    
    def forward(self, img1, img2):
        # pass the sample image through the net seperately and then
        
        imgFeatures1 = self.forward_one(img1)
        imgFeatures2 = self.forward_one(img2)
        
        # stack the image features
        stackedFeatures = torch.cat((imgFeatures1, imgFeatures2), 1)
        #print(stackedFeatures.shape)
        
        
        netOutput = self.outputLinear(stackedFeatures)
        #print(netOutput.shape)
        
        return torch.sigmoid(netOutput)
        

In [9]:
net = movingNet()

In [22]:
movingDataLoader = DataLoader(data, batch_size=16)

In [23]:
sample = next(iter(movingDataLoader))

In [24]:
sample['frame1'].shape

torch.Size([16, 1, 800, 36])

In [27]:
net(sample['frame1'], sample['frame2'])

tensor([[0.4907],
        [0.4922],
        [0.4925],
        [0.4918],
        [0.4910],
        [0.4913],
        [0.4912],
        [0.4912],
        [0.4925],
        [0.4914],
        [0.4905],
        [0.4925],
        [0.4919],
        [0.4916],
        [0.4907],
        [0.4923]], grad_fn=<SigmoidBackward>)

In [28]:
sample['cellsmoving']

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]])

In [8]:
import math


In [9]:
math.log(0.5)

-0.6931471805599453

In [8]:
def weighted_binary_cross_entropy(output, target, weights=None):
        
    if weights is not None:
        assert len(weights) == 2
        
        loss = weights[1] * (target * torch.log(output)) + \
               weights[0] * ((1 - target) * torch.log(1 - output))
    else:
        loss = target * torch.log(output) + (1 - target) * torch.log(1 - output)
        
    #print(loss)
    return torch.neg(torch.mean(loss))

### Train and validation loop

In [9]:
batch_size = 48
nEpochs = 50
phaseDirectoriesList = ['/home/pk/Documents/trainingData/deadalive1/' + str(i) + '/' for i in range(0, 19)]
data = cellsMovingDataset(phaseDirectoriesList, transforms = randomTranslations())
movingDataLoader = DataLoader(data, batch_size=batch_size, num_workers=6, shuffle=True)
net = movingNet()
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
net.to(device)

# loss function
criterion = nn.BCELoss()

# optimizer
optimizer = optim.Adam(net.parameters(), lr=5e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

for epoch in range(nEpochs):
    
    epochLoss = []
    
    for i_batch, data in enumerate(movingDataLoader, 0):
        optimizer.zero_grad()
        img1, img2 = data['frame1'].to(device), data['frame2'].to(device)
        output = net(img1, img2)
        target = data['cellsmoving'].to(device)
        #print(output)
        #print(target)
        loss = weighted_binary_cross_entropy(output, target,weights=[771/588, 771/193])
        #print(loss)
        epochLoss.append(loss.item())
        loss.backward()
        optimizer.step()
    scheduler.step()
        
    print(f"Epoch: {epoch + 1} avg loss: {np.mean(epochLoss)}")

Epoch: 1 avg loss: 1.2029531142290901
Epoch: 2 avg loss: 0.7996154427528381
Epoch: 3 avg loss: 0.6629438365206999
Epoch: 4 avg loss: 0.6002627646221834
Epoch: 5 avg loss: 0.5528471592594596
Epoch: 6 avg loss: 0.468210328151198
Epoch: 7 avg loss: 0.40848353329826803
Epoch: 8 avg loss: 0.49995168517617616
Epoch: 9 avg loss: 0.3724890903514974
Epoch: 10 avg loss: 0.36458728330976825
Epoch: 11 avg loss: 0.3115284881171058
Epoch: 12 avg loss: 0.30505984846283407
Epoch: 13 avg loss: 0.2989852726459503
Epoch: 14 avg loss: 0.2760763069724335
Epoch: 15 avg loss: 0.30264050601159825
Epoch: 16 avg loss: 0.29055747915716734
Epoch: 17 avg loss: 0.2633880683604409
Epoch: 18 avg loss: 0.22958366923472462
Epoch: 19 avg loss: 0.25274987168171825
Epoch: 20 avg loss: 0.20838351854506662
Epoch: 21 avg loss: 0.20352498989771395
Epoch: 22 avg loss: 0.22168895865187926
Epoch: 23 avg loss: 0.19294397773988106
Epoch: 24 avg loss: 0.21211347707054196
Epoch: 25 avg loss: 0.17526479722822413
Epoch: 26 avg loss: 0

In [10]:
phaseDirectoriesList = ['/home/pk/Documents/trainingData/deadalive1/' + str(i) + '/' for i in range(0, 19)]
data = cellsMovingDataset(phaseDirectoriesList, transforms = randomTranslations())

In [10]:
data.statistics()

{'moved': 193, 'notmoved': 588}

In [11]:
588 + 183

771

In [11]:
savedModel = {
    'model_state_dict': net.state_dict()
}
savedPath = '/home/pk/Documents/models/moving.pth'
torch.save(savedModel, savedPath)

In [12]:
import numpy as np
import glob
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from skimage import io

In [13]:
def testNet(dataDir, modelPath, fileformat='.tiff', plot=False):
        
    # net intializee and run
    net = movingNet()
    saved_net_parameters = torch.load(modelPath)
    net.load_state_dict(saved_net_parameters['model_state_dict'])
    net.eval()
    
    filenames = [int(filename.split('.')[0].split('/')[-1]) 
             for filename in glob.glob(dataDir + "*"+ fileformat)]
    filenames.sort()
    sortedFilenames = [dataDir + str(filenumber) + fileformat for filenumber in filenames]
    print(sortedFilenames)
    
    imgTransforms = transforms.Compose([
        transforms.ToTensor()
    ])
    
    predicted_states =[]
    with torch.no_grad():
        for i in range(len(sortedFilenames) - 1):
            img1 = io.imread(sortedFilenames[i])
            img2 = io.imread(sortedFilenames[i+1])
            
            img1 = imgTransforms(img1).unsqueeze(0)
            img2 = imgTransforms(img2).unsqueeze(0)
            output = net(img1, img2) > 0.5
            print(output.item())
            predicted_states.append(output.item())
            
    
    
    if plot == True:
        states = np.array(predicted_states).T
        fig, ax = plt.subplots(nrows=1, ncols=2)
        img = io.imread(sortedFilenames[0])

        tdata, movingdata, = [], []
        movingplot, = ax[1].plot([], [], 'g,:', label='Moving')
        #partialdeadplot, = ax[1].plot([], [], 'r,-.', label='Partial Dead')
        #alldeadplot, = ax[1].plot([], [], 'k--', label='All Dead')
        #nocellsplot, = ax[1].plot([], [], 'b*', label='No cells')
        #cellsvanishplot, = ax[1].plot([], [], 'mo', label='Cells vanish')


        imgplot = ax[0].imshow(img, cmap='gray')


        def init():
            img = io.imread(sortedFilenames[0])
            imgplot.set_array(img)
            ax[1].set_xlim([0, len(sortedFilenames)])
            ax[1].set_ylim([-0.5, 2])
            return imgplot, 

        def update(frame):

            img = io.imread(sortedFilenames[frame])
            imgplot.set_array(img)
            tdata.append(frame)
            movingdata.append(int(states[0][frame]))
            movingplot.set_data(tdata, movingdata) 

            partialdeaddata.append(int(states[2][frame]))
            partialdeadplot.set_data(tdata, partialdeaddata)

            alldeaddata.append(int(states[3][frame]))
            alldeadplot.set_data(tdata, alldeaddata)

            nocellsdata.append(int(states[4][frame]))
            nocellsplot.set_data(tdata, nocellsdata)

            cellsvanishdata.append(int(states[5][frame]))
            cellsvanishplot.set_data(tdata, cellsvanishdata)


            return [imgplot, movingplot, partialdeadplot,]

        ani = FuncAnimation(fig, update, frames=range(0, len(sortedFilenames)),
                            init_func=init, blit=False, repeat=False, interval=1000)
        plt.legend(loc='upper right')
        plt.show()
        
    return None, predicted_states

In [23]:
phaseDir = '/home/pk/Documents/trainingData/deadalive1/28/'
modelPath = '/home/pk/Documents/models/moving.pth'
ani, predicted_states = testNet(phaseDir, modelPath, plot=False)

['/home/pk/Documents/trainingData/deadalive1/28/0.tiff', '/home/pk/Documents/trainingData/deadalive1/28/1.tiff', '/home/pk/Documents/trainingData/deadalive1/28/2.tiff', '/home/pk/Documents/trainingData/deadalive1/28/3.tiff', '/home/pk/Documents/trainingData/deadalive1/28/4.tiff', '/home/pk/Documents/trainingData/deadalive1/28/5.tiff', '/home/pk/Documents/trainingData/deadalive1/28/6.tiff', '/home/pk/Documents/trainingData/deadalive1/28/7.tiff', '/home/pk/Documents/trainingData/deadalive1/28/8.tiff', '/home/pk/Documents/trainingData/deadalive1/28/9.tiff', '/home/pk/Documents/trainingData/deadalive1/28/10.tiff', '/home/pk/Documents/trainingData/deadalive1/28/11.tiff', '/home/pk/Documents/trainingData/deadalive1/28/12.tiff', '/home/pk/Documents/trainingData/deadalive1/28/13.tiff', '/home/pk/Documents/trainingData/deadalive1/28/14.tiff', '/home/pk/Documents/trainingData/deadalive1/28/15.tiff', '/home/pk/Documents/trainingData/deadalive1/28/16.tiff', '/home/pk/Documents/trainingData/deadali

In [35]:
output > 0.5

tensor([[ True],
        [ True],
        [ True],
        [False],
        [ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False]], device='cuda:1')

In [34]:
target

tensor([[1.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]], device='cuda:1')

### Test loop