In [46]:
import numpy as np
import pandas as pd
import pickle
import os
import cv2
import imageio
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
import torch
from torch.utils.data import TensorDataset
import cmath
from run_VAE import sites_NOVEMBER, RAW_NOVEMBER, INTERMEDIATE_NOVEMBER
from HiddenStateExtractor.cv2_feature import get_density, get_angle_apr, get_aspect_ratio_no_rotation
from HiddenStateExtractor.vq_vae import VQ_VAE, rescale
from SingleCellPatch.extract_patches import select_window, generate_mask
plt.switch_backend('module://ipykernel.pylab.backend_inline')

In [10]:
DATA_PATH = '/mnt/comp_micro/Projects/learningCellState/microglia/data_processed'
sites = sum([['D%d-Site_%d' % (i, j) for j in range(9)] for i in range(3, 6)], [])
window_size = 256

In [29]:
traj_dats = {}
for site in sites:
    print(site)
    image_stack = np.load(os.path.join(DATA_PATH, '%s.npy' % site))
    segmentation_stack = np.load(os.path.join(DATA_PATH, '%s_NNProbabilities.npy' % site))
    cell_pixel_assignments = pickle.load(open(os.path.join(DATA_PATH, 'D-supps', site, 'cell_pixel_assignments.pkl'), 'rb'))
    cell_positions = pickle.load(open(os.path.join(DATA_PATH, 'D-supps', site, 'cell_positions.pkl'), 'rb'))
    nonmg_traj = pickle.load(open(os.path.join(DATA_PATH, 'D-supps', site, 'non_mg_traj.pkl'), 'rb'))
    
    selected_patches = {t_point: [] for t_point in range(image_stack.shape[0])}
    for i, (traj, traj_position) in enumerate(zip(*nonmg_traj)):
        assert traj.keys() == traj_position.keys()
        for t_point in traj:
            selected_patches[t_point].append((traj[t_point], traj_position[t_point]))
    
    
    selected_patch_dats = {t_point: {} for t_point in range(image_stack.shape[0])}
    for t_point in sorted(selected_patches.keys()):
        print(t_point)
        if len(selected_patches[t_point]) == 0:
            continue
        raw_image = image_stack[t_point]
        cell_segmentation = segmentation_stack[t_point]

        positions, positions_labels = cell_pixel_assignments[t_point]
        mg_cells, non_mg_cells, other_cells = cell_positions[t_point]

        # Define fillings for the masked pixels in this slice
        background_pool = raw_image[np.where(cell_segmentation[:, :, 0] > 0.9)]
        background_pool = np.median(background_pool, 0)
        background_filling = np.ones((window_size, window_size, 1)) * background_pool.reshape((1, 1, -1))

        for cell_id, cell_position in selected_patches[t_point]:
            window = [(cell_position[0]-window_size//2, cell_position[0]+window_size//2),
                      (cell_position[1]-window_size//2, cell_position[1]+window_size//2)]
            window_segmentation = select_window(cell_segmentation, window, padding=-1)
            remove_mask, tm, tm2 = generate_mask(positions, 
                                                 positions_labels, 
                                                 cell_id, 
                                                 window, 
                                                 window_segmentation)

            output_mat = select_window(raw_image, window, padding=0)
            masked_output_mat = output_mat * (1 - remove_mask) + background_filling * remove_mask
            selected_patch_dats[t_point][cell_id] = masked_output_mat
            
    for i, (traj, traj_position) in enumerate(zip(*nonmg_traj)):
        name = site + '/' + str(i) + '_nonmg'
        print(name)
        traj_patches = {}
        for t_point in traj:
            traj_patches[t_point] = selected_patch_dats[t_point][traj[t_point]]
        traj_dats[name] = traj_patches
        
    with open('JUNE_nonmg_trajs.pkl', 'wb') as f:
        pickle.dump(traj_dats, f)

D3-Site_0
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
D3-Site_0/0_nonmg
D3-Site_0/1_nonmg
D3-Site_0/2_nonmg
D3-Site_0/3_nonmg
D3-Site_0/4_nonmg
D3-Site_0/5_nonmg
D3-Site_0/6_nonmg
D3-Site_0/7_nonmg
D3-Site_0/8_nonmg
D3-Site_0/9_nonmg
D3-Site_1
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
D3-Site_1/0_nonmg
D3-Site_1/1_nonmg
D3-Site_1/2_nonmg
D3-Site_1/3_nonmg
D3-Site_1/4_nonmg
D3-Site_1/5_nonmg
D3-Site_1/6_nonmg
D3-Site_2
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
D3-Site_2/0_nonmg
D3-Site_2/1_nonmg
D3-Site_2/2_nonmg
D3-Site_3
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
D3-Site_3/0_nonm

In [37]:
traj_motions = {}
for site in sites: 
    nonmg_traj = pickle.load(open(os.path.join(DATA_PATH, 'D-supps', site, 'non_mg_traj.pkl'), 'rb'))
    for i, (traj, traj_position) in enumerate(zip(*nonmg_traj)):
        name = site + '/' + str(i) + '_nonmg'
        traj_motions[name] = traj_position

In [48]:
tensors = {}
for t, traj_dat in traj_dats.items():
    for t_point in traj_dat:
        _d = traj_dat[t_point]
        stacks = []
        for c in range(2):
            c_slice = cv2.resize(np.array(_d[:, :, c]).astype(float), (128, 128))
            stacks.append(c_slice/65535.)
        tensors[(t, t_point)] = torch.from_numpy(np.stack(stacks, 0)).float()
fs = sorted(tensors.keys())
dataset = TensorDataset(torch.stack([tensors[f_n] for f_n in fs], 0))

phase_slice = dataset.tensors[0][:, 0]
print(phase_slice.mean())
print(phase_slice.std())
phase_slice = ((phase_slice - phase_slice.mean()) / phase_slice.std()) * 0.0257 + 0.4980
retard_slice = dataset.tensors[0][:, 1]
print(retard_slice.mean())
print(retard_slice.std())
retard_slice = retard_slice / retard_slice.mean() * 0.0285
adjusted_dataset = TensorDataset(torch.stack([phase_slice, retard_slice], 1))

adjusted_dataset = rescale(adjusted_dataset)

In [None]:
model = VQ_VAE(alpha=0.0005, gpu=False)
# model = model.cuda()
model.load_state_dict(torch.load('HiddenStateExtractor/save_0005_bkp4.pt', map_location={'cuda:0': 'cpu'}))

z_bs = {}
for i in range(len(dataset)):
    sample = dataset[i:(i+1)][0].cuda()
    z_b = model.enc(sample)
    f_n = fs[i]
    z_bs[f_n] = z_b.cpu().data.numpy()

dats = np.stack([z_bs[f] for f in fs], 0).reshape((len(dataset), -1))

pca = pickle.load(open('HiddenStateExtractor/pca_save.pkl', 'rb'))
dats_ = pca.transform(dats)