In [1]:
import numpy as np
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [2]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(51, 32, kernel_size=(3, 3), stride=1, padding=1)
        self.act1 = nn.ReLU()
        self.drop = nn.Dropout(0.1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=1, padding=1)
        self.act2 = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=(2, 2))

        self.fc1 = nn.Linear(64 * 32 * 32, 256)
        self.act3 = nn.ReLU()
        self.fc2 = nn.Linear(256, 18)
        # self.softmax = nn.Softmax

    def forward(self, x):
        x = self.act1(self.conv1(x))
        x = self.drop(x)
        x = self.act2(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(-1, 64 * 32 * 32)
        x = self.act3(self.fc1(x))
        # x = self.softmax(self.fc2(x))
        x = self.fc2(x)
        return x

In [3]:
class CellData(Dataset):
    def __init__(self, cells, cell_inds):
        self.cells = cells
        self.cell_inds = cell_inds
    
    def __len__(self):
        return len(self.cell_inds)

    def __getitem__(self, index):
        x = self.cells[index]
        ind = self.cell_inds[index]
        return ind, x

In [4]:
def normalize(X):
    for i in tqdm(range(X.shape[-1])):
        X[..., i] = exposure.rescale_intensity(X[..., i], out_range=(-1, 1))
        X[..., i] = exposure.equalize_adapthist(X[..., i])
    return X

In [5]:
def segment_cells(X, y):
    X = normalize(X)
    cell_dict = {}
    for i in tqdm(np.unique(y)):
        img_x, img_y = np.where(cell_mask==i)
        
        if len(img_x)==0 or len(img_y)==0:
            continue
        cell = X[img_x, img_y]
        img_x1, img_y1, img_x2, img_y2 = img_x.min(), img_y.min(), img_x.max(), img_y.max()
        cell_mask = np.zeros((img_x2-img_x1, img_y2-img_y1))
        cell_mask = (cell_mask[img_x1:img_x2, img_y1:img_y2]==i)
        
        cell_image = image[img_x1:img_x2, img_y1:img_y2]
        cell_image = cell_image * np.repeat(np.expand_dims(cell_mask, axis=2), 51, axis=2)
        
        cell_image = cv2.resize(cell_image, (cell_size, cell_size))
        cell_dict[i] = cell_image
    return cell_dict

In [7]:
def run_inference(X, y):
    cell_dict = segment_cells(X, y)
    cell_dataset = CellData(list(cell_dict.values()), list(cell_dict.keys()))
    cell_loader = DataLoader(cell_dataset, batch_size=32, shuffle=True)
    
    pred_dict_
    model = CNN()
    model.double()
    model.load_state_dict(torch.load("model_weights.pt"))
    model.eval()
    
    all_preds = []
    all_inds = []
    
    with torch.no_grad():
        for indices, inputs in tqdm(cell_loader):
            inputs = inputs.permute(0, 3, 1, 2)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            all_inds.extend(indices.numpy())
            all_preds.extend(predicted.numpy())

    pred_dict = dict(zip(all_inds, all_preds))
    return pred_dict
    