In [67]:
import numpy as np
import pandas as pd
import torch
import random
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from tqdm import tqdm
from torchvision.models import resnet50, ResNet50_Weights, vgg16,VGG16_Weights
import pickle 
import os
import cv2
import json
from tqdm import tqdm
from geomloss import SamplesLoss

### Find the list of frames

In [2]:
data_path = 'data/'
files = os.listdir(data_path+'droplets/')
files = [f.split(".")[0] for f in files]
files = list(set(files))


In [3]:
def load_images(file):
    return np.array(cv2.imreadmulti(file, flags=cv2.IMREAD_GRAYSCALE)[1], dtype=object)
def load_json(file):
    with open(file) as f:
        return json.load(f)

In [33]:
class CellsDataset(Dataset):
    def __init__(self, files, data_path,train = True , transform=None, seed = 42, test_size = 0.1):
        self.files = files
        self.data_path = data_path
        self.transform = transform
        self.seed = seed
        self.test_size = test_size
        self.images = []
        temp = [load_json(data_path+'generated/'+ f+'.json') for f in files]
        self.data = []
        for t, file in zip(temp, files):
            images = load_images(data_path+'droplets/' + file +'.tif')
            for i in range(len(t['valid_bb'])):
                if t['valid_bb'][i] == 0 or len(t['cell'][i]) > 200:
                    continue
                self.images.append(images[i])
                t['cell'][i] = t['cell'][i] + [[-1,-1]] * (200 - len(t['cell'][i]))
                self.data.append({'frame': t['file_name'], 'bb': t['bb'][i], 'cell': t['cell'][i]})
        random.seed(self.seed)
        test_id = random.sample(range(len(self.images)), int(len(self.images)*self.test_size))


        if  train:
            self.images = [self.images[i] for i in range(len(self.images)) if i not in test_id]
            self.data = [self.data[i] for i in range(len(self.data)) if i not in test_id]
        else:
            self.images = [self.images[i] for i in range(len(self.images)) if i in test_id]
            self.data = [self.data[i] for i in range(len(self.data)) if i in test_id]
        

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        data = self.data[idx]
        image = np.array(image).astype(np.float32)
        if self.transform:
            image = self.transform(image)
    
        bb = data['bb']
        cell = data['cell']
        bb = torch.from_numpy(np.array(bb)).float()
        cell = torch.from_numpy(np.array(cell)).float()
        
        return {'frame': data['frame'], 'image': image, 'bb': bb, 'cell': cell}


In [34]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.ToPILImage(),
    transforms.Resize((448,448)),
    transforms.ToTensor()
])

In [35]:
cell = CellsDataset(files, data_path, transform=transform)

In [7]:
len(cell)

1840

In [36]:
# create a dataloader
cell_loader = DataLoader(cell, batch_size=8, shuffle=True, num_workers=0)

In [37]:
# print the first batch to see what we have
for batch in cell_loader:
    print(batch['image'].shape)
    print(batch['bb'].shape)
    print(batch['cell'].shape)
    break

torch.Size([8, 1, 448, 448])
torch.Size([8, 4])
torch.Size([8, 200, 2])


In [76]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
        self.conv2 = nn.Conv2d(32, 32, 5, padding=2)
        self.activation = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(32, 64, 5, padding=2)
        self.conv4 = nn.Conv2d(64, 64, 5, padding=2)
        self.fc1 = nn.Linear(64 * 112 * 112 + 4, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 200 * 2)

    def forward(self, image, bb):
        bb = bb
        x = image
        x = self.pool(self.activation(self.conv2(self.activation(self.conv1(x)))))
        x = self.pool(self.activation(self.conv4(self.activation(self.conv3(x)))))
        x = x.view(-1, 64 * 112 * 112).squeeze(0)
        x = torch.cat((x, bb), dim=1)
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        return x.view(-1, 200, 2)
    


In [77]:
net = CNN()

In [78]:
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

net.to(device)

cpu


CNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (activation): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv4): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (fc1): Linear(in_features=802820, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=1024, bias=True)
  (fc3): Linear(in_features=1024, out_features=400, bias=True)
)

In [82]:
class CustomLoss(nn.Module):
    def __init__(self, alpha=0.1):
        super().__init__()
        self.alpha = alpha

    def forward(self, pred, target):
        nb_pred = torch.sum((pred[:, :, 0] >= 0) & (pred[:, :, 1] >= 0))
        nb_target = torch.sum((target[:, :, 0] >= 0) & (target[:, :, 1] >= 0))
        nb_loss = torch.abs(nb_pred - nb_target)

        pred = pred[(pred[:, :, 0] >= 0) & (pred[:, :, 1] >= 0)]
        target = target[(target[:, :, 0] >= 0) & (target[:, :, 1] >= 0)]
        dist_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)(pred, target)
        return (1 - self.alpha) * dist_loss + nb_loss * self.alpha


In [84]:
lr = 0.0001
criterion = CustomLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

# train the network


In [90]:
trainset = CellsDataset(files, data_path, transform=transform)
trainloader = DataLoader(trainset, batch_size=2, shuffle=True, num_workers=0)

testset = CellsDataset(files, data_path, train=False, transform=transform)
testloader = DataLoader(testset, batch_size=2, shuffle=True, num_workers=0)



In [91]:
nb_epoch = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
losses = []
for epoch in range(nb_epoch):
    running_loss = 0.0
    for i, data in enumerate(tqdm(trainloader)):
        # get the inputs; data is a list of [inputs, labels]
        image = data['image'].to(device)
        bb = data['bb'].to(device)
        labels = data['cell'].to(device)
        bb = data['bb'].to(device)
        # zero the parameter gradients
        print("Before zero_grad")
        optimizer.zero_grad()

        outputs = net(image, bb)
        print("After output")
        loss = criterion(outputs, labels)
        print("After loss")
        loss.backward()
        print("After backward")
        optimizer.step()

        running_loss += loss.item()
        if i % 10 == 9:    # print every 100 mini-batches
            losses.append(running_loss / 10)
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0


  0%|          | 0/1035 [00:00<?, ?it/s]

Before zero_grad
After output
After loss
After backward


  0%|          | 1/1035 [01:42<29:25:21, 102.44s/it]

Before zero_grad
After output
After loss
After backward


  0%|          | 1/1035 [03:24<58:47:53, 204.71s/it]


KeyboardInterrupt: 