In [1]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from skimage import io
from multiprocessing import Pool
import matplotlib.pyplot as plt
import os

In [2]:
FOV_files = [str(f) for f in Path('/scr/mdoron/allen/FOV/fov_path///').glob('*.tiff') if os.path.getsize(f) > 0 if 'ome' not in str(f)]


In [3]:
len(FOV_files)

4194

In [4]:
img = io.imread(FOV_files[2]).astype(float)

In [5]:
# for i in range(4):
#     plt.figure(figsize=(20,10))
#     plt.imshow(img[[32 - 14, 32, 32 + 14],:,:,i].transpose(1,2,0), cmap='Greys')

In [8]:
from fnet.cli.predict import parse_model, load_model


In [None]:

model_def = parse_model(path_model_dir)
model = load_model(model_def["path"], no_optim=True)
model.to_gpu(args.gpu_ids)


In [None]:

checkpoint = torch.load(f"cell_epoch_30_GEN.pth") 

ignore_index = 0 
gpuid=0
n_classes= 5
in_channels= 3
padding= True
depth= 6
wf= 5 
up_mode= 'upconv' 
batch_norm = False 
batch_size=1
patch_size=256
edge_weight = 1.1 
phases = ["train","val"] 
validation_phases= ["val"] 

# Specify if we should use a GPU (cuda) or only the CPU
if(torch.cuda.is_available()):
    print(torch.cuda.get_device_properties(gpuid))
    torch.cuda.set_device(gpuid)
    device = torch.device(f'cuda:{gpuid}')
else:
    device = torch.device(f'cpu')

# Define the network
Gen = UNet(n_classes=n_classes, in_channels=in_channels, padding=padding,depth=depth,wf=wf, up_mode=up_mode, batch_norm=batch_norm).to(device)
print(f"total params: \t{sum([np.prod(p.size()) for p in Gen.parameters()])}")
Gen.load_state_dict(checkpoint['model_dict'])
Gen.eval()


In [None]:
subset_img = img[[32 - 14, 32, 32 + 14],:600,:600,0].transpose(1,2,0).astype(float)
# subset_img = cv2.resize(subset_img, dsize=(1200,1200), interpolation=cv2.INTER_CUBIC)

# Define empty arrays
checkfull = {}
a = {}
b2 = {}
gt2 = {}
gt3 = {}
ionorm = {}
blank_bf = np.zeros((3,1024, 1024), dtype=np.float32)          
blank_fl = np.zeros((5, 1024, 1024))#, dtype=np.uint32)

X1 = cv2.resize(subset_img, dsize=(998,998), interpolation=cv2.INTER_CUBIC)
X1 = np.swapaxes(X1,0,2)
X1 = np.swapaxes(X1,1,2)
X1 = np.expand_dims(X1, axis = 0)

subset_img -= subset_img.min()
subset_img /= subset_img.max()
subset_img *= 255
subset_img = subset_img.astype(int)

for channel in range(3):
    X1a = X1[:,channel,:,:]
    mean, std = X1a.mean(), X1a.std()
    b2[channel] = (X1a-mean)/std
    b2[channel] = np.expand_dims(b2[channel],axis = 0)
    b29 = b2[channel]
    gtim69 = Image.fromarray(b29[0,0,:,:])
    if channel>0:
        ionorm = np.concatenate((ionorm,b2[channel]),axis = 0)
    else:
        ionorm = b2[0]
X1 = np.swapaxes(ionorm,0,1)
blank_bf[:,0:998,0:998] = X1[0,:,:,:]
X = blank_bf # X is the 3 channel 998x998 normalised brightfield input

# These two loops execute the stitching algorithm (each pixel is the median of four overlapping 256x256 tiles)
countP = 0
for x in range(7):
    for y in range(7):
        x_in = X[:,x*128:(256+x*128), y*128:(256+y*128)]
        x_in = np.expand_dims(x_in,axis = 0)
        x_in = torch.from_numpy(x_in)
        prediction1 = Gen(x_in.to(device))
        checkfull = prediction1[0,:,:,:]
        zy = checkfull.detach().cpu().numpy()
        blank_fl[:,x*128:(256+x*128), y*128:(256+y*128)] = zy
        gt3[countP] = zy
        countP += 1

for channel in range(5): 
    countP = 0
    for P in range(49):
        a = gt3[P]
        gt2[P] = a[channel,:,:]
        if P<7:
            pass
        else:
            if P % 7 == 0:
                countP += 1
            else:
                ca = gt2[(P-8)] 
                cb = gt2[(P-7)]
                cc = gt2[(P-1)]
                cd = gt2[(P)]
                blank_fl[channel,countP*128:(countP+1)*128,
                              (P - countP*7)*128:(P+1-countP*7)*128] = np.median([ca[128:,128:],cb[128:,0:128],
                              cc[0:128,128:],cd[0:128,0:128]],axis=0)


    # Save the stitched predicted image                          
    im = blank_fl[channel,0:998,0:998]
    fig, axes = plt.subplots(1,2, figsize=(10,5))
    axes[0].imshow(subset_img[:,:,1], cmap='Greys')        
    axes[0].axis('off')
    axes[0].set_title('Brightfield')
    axes[1].imshow(im, cmap='Greys_r')    
    axes[1].axis('off')
    axes[1].set_title('ER (predicted)')    
    plt.tight_layout()

In [None]:
len(FOV_files)

In [None]:
img = io.imread(FOV_files[93]).astype(float)
fig, axes = plt.subplots(2,2, figsize=(10 * 2,7 * 2))
axes[0][0].imshow(img[32,:,:,0], cmap='Greys_r')
axes[0][0].axis('off')
axes[0][0].set_title('brightfield')
axes[1][0].imshow(img[32,:,:,1], cmap='Greys_r')
axes[1][0].axis('off')
axes[1][0].set_title('Structure')
axes[0][1].imshow(img[32,:,:,2], cmap='Greys_r')
axes[0][1].axis('off')
axes[0][1].set_title('Cell membrane')
axes[1][1].imshow(img[32,:,:,3], cmap='Greys_r')
axes[1][1].axis('off')
axes[1][1].set_title('DNA')
plt.tight_layout()