In [78]:
import numpy as np
import pandas as pd
import torch
import random
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from tqdm import tqdm
from torchvision.models import resnet50, ResNet50_Weights, vgg16,VGG16_Weights
import pickle 
import os
import cv2
import json
from tqdm import tqdm

### Find the list of frames

In [79]:
data_path = 'data/'
files = os.listdir(data_path+'droplets/')
files = [f.split(".")[0] for f in files]
files = list(set(files))


In [80]:
def load_images(file):
    return np.array(cv2.imreadmulti(file, flags=cv2.IMREAD_GRAYSCALE)[1], dtype=object)
def load_json(file):
    with open(file) as f:
        return json.load(f)

In [110]:
class CellsDataset(Dataset):
    def __init__(self, files, data_path, transform=None, seed = 42, test_size = 0.2):
        self.files = files
        self.data_path = data_path
        self.transform = transform
        self.seed = seed
        self.test_size = test_size
        self.images = []
        temp = [load_json(data_path+'generated/'+ f+'.json') for f in files]
        self.data = []
        for t, file in zip(temp, files):
            images = load_images(data_path+'droplets/' + file +'.tif')
            for i in range(len(t['valid_bb'])):
                if t['valid_bb'][i] == 0 or len(t['cell'][i]) > 200:
                    continue
                self.images.append(images[i])
                t['cell'][i] = t['cell'][i] + [[-1,-1]] * (200 - len(t['cell'][i]))
                self.data.append({'frame': t['file_name'], 'bb': t['bb'][i], 'cell': t['cell'][i]})

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        data = self.data[idx]
        if self.transform:
            image = self.transform(image)
        
        return {'frame': data['frame'], 'image': image, 'bb': data['bb'], 'cell': data['cell']}

In [111]:
cell = CellsDataset(files, data_path)

200
