In [1]:
import src.dataloader
import src.loss
import src.transforms as t
import src.functional
from src.models.HCNet import HCNet
from src.models.RDCNet import RDCNet

import torch
import torch.nn
from torch.utils.data import DataLoader

import time
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriter
from scipy.ndimage.morphology import binary_fill_holes, binary_erosion, binary_dilation
import skimage.io as io
from importlib import reload  

In [2]:
reload(src.transforms)

<module 'src.transforms' from '/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/src/transforms.py'>

In [3]:
model = torch.jit.script(HCNet(in_channels=3, out_channels=4, complexity=10)).cuda()
model.train()
model.load_state_dict(torch.load('train_Dec13_1.hcnet'))
print('')




In [4]:
print('Loading Train...')
transforms = torchvision.transforms.Compose([
    t.nul_crop(),
    t.random_crop(shape=(256, 256, 16)),
    t.elastic_deformation(grid_shape=(3, 3, 2), scale=1.5),
    t.to_cuda(),    
    t.adjust_centroids(), 
    t.random_h_flip(),
    t.random_v_flip(),
    t.random_affine(shear=(-15, 15)),
    t.adjust_brightness(range_brightness = (-0.2, 0.2)),
    #t.adjust_gamma(),
    t.adjust_centroids(),
])
data = src.dataloader.dataset('/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/data/train', transforms=transforms)
dl = DataLoader(data, batch_size=1, shuffle=False, num_workers=0)
print('Done')



Loading Train...
Done


In [5]:
len(data)

15

In [None]:
print('Loading Val...')
transforms = torchvision.transforms.Compose([t.nul_crop(),
                                             t.random_crop(shape=(256, 256, 25)),
                                             t.adjust_centroids()])
val = src.dataloader.dataset('/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/data/validate', transforms=transforms)
val = DataLoader(val, batch_size=1, shuffle=False, num_workers=0)
print('Done')

In [None]:
lr = 2e-3
gamma = 0.993
wd = .01
#sigma= 0.1 
#sigma = lambda e: 0.0231 + 0.075 * 15/(15+e)
iterations=5

In [None]:
try:
    if e > 0:
        pass
except:
    epochs = 500
    e = -1

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = wd)
loss_fun = src.loss.tversky_loss()

In [None]:
optimizer.weight_decay = wd
optimizer.lr = lr
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False)

In [None]:
try: 
    writer
    save_file = 'Dec14_1.hcnet'
except:
    writer = SummaryWriter()
    best_train_loss = 2
    best_val_loss = 2


In [None]:
model.train()
print('')

In [None]:
while True:
    e += 1
    epoch_loss = []
    try:
        for data_dict in dl:
            image = data_dict['image']
            image = (image - 0.5) / 0.5
            mask = data_dict['masks'] > 0.5
            centroids = data_dict['centroids']

            if centroids.shape[1] == 0:
                continue

            optimizer.zero_grad()

            out = model(image.cuda(),iterations)
            prob_map = out[:, -1, ...]
            sigma = torch.tensor([0.02]).cuda()  #torch.sigmoid(out[:,-3::,...])

            out = src.functional.vector_to_embedding(out[:, 0:3:1, ...])
            out = src.functional.embedding_to_probability(out, centroids.cuda(), sigma)

            loss = loss_fun(out, mask.cuda(), alpha=0.5, beta=0.5) \
                   + loss_fun(out.sum(1).unsqueeze(1), mask.cuda().sum(1).unsqueeze(1))


            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.detach().cpu().item())

        writer.add_scalar('Loss/train', torch.mean(torch.tensor(epoch_loss)).item(), e)
        writer.add_scalar('Hyperparam/lr',scheduler.get_lr()[0],e)
        writer.add_scalar('Hyperparam/weight_decay',wd,e)
        writer.add_scalar('Hyperparam/iter',iterations,e)


        scheduler.step()
        
    except:
        continue
    
    with torch.no_grad():
        val_loss = []
        model.eval()
        for data_dict in val:
            image = data_dict['image']
            image = (image - 0.5) / 0.5
            mask = data_dict['masks'] > 0.5
            centroids = data_dict['centroids']

            out = model(image.cuda(),iterations)
            sigma = torch.tensor([0.02]).cuda()  #torch.sigmoid(out[:,-3::,...])
            out = src.functional.vector_to_embedding(out[:, 0:3:1, ...])
            out = src.functional.embedding_to_probability(out, centroids.cuda(), sigma)
            
            del sigma
            
            loss = loss_fun(out, mask.cuda())
            val_loss.append(loss.item())
        model.train()
        val_loss = torch.tensor(val_loss).mean()
        
    writer.add_scalar('Loss/validate', val_loss.item(), e)

    if torch.mean(torch.tensor(epoch_loss)).item() < best_train_loss:
        torch.save(model.state_dict(), 'train_' + save_file)
        best_train_loss = torch.mean(torch.tensor(epoch_loss)).item()
    
    elif torch.tensor(val_loss).mean() < best_val_loss:
        torch.save(model.state_dict(), 'val_' + save_file)
        best_val_loss = torch.mean(torch.tensor(val_loss)).item()

    


In [None]:
del  image, mask, out, sigma

In [None]:
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size(), obj.device)
    except:
        pass

In [None]:
centroids.numel()

In [None]:
#####plt.imshow(out.detach().cpu().numpy()[0,[3],:,:,6].transpose((1,2,0)))
render, values = out.max(1)

plt.imshow(values[0,:,:,15].detach().cpu().numpy())
plt.show()

values.max()

In [None]:
save_name = 'Dec_13_1.hcnet'
torch.save(model.state_dict(), save_name)

In [None]:
scheduler.get_lr()

In [None]:
model.train()
with torch.no_grad():
    for data_dict in val:
        image = data_dict['image']
        image = (image - 0.5) / 0.5
        mask = data_dict['masks'] > 0.5
        centroids = data_dict['centroids']

        out = model(image.cuda(), 5)

        sigma = torch.tensor([0.0261]).cuda()
        
        out = src.functional.vector_to_embedding(out[:, 0:3:1, ...])
        out = src.functional.embedding_to_probability(out, centroids.cuda(), sigma)
        loss = loss_fun(out, mask.cuda())
        
        print(loss)
        break
model.train()
print(' ')
test = out.clone()
for i in range(3):
    test[0,i,...][out[0,-1,...] < 0.5] = 0
test[test > 0.5] = 1
test[test < 0.5] = 0
print(loss_fun(test, mask.cuda()))
value, ind = out.max(1)
ind[value<0.5]=0
#for i in range(render.shape[0]):
#    render[i,:,:,:] = binary_dilation(binary_fill_holes(render[i,:,:,:]))
ind = ind.cpu().detach().squeeze(0).float().numpy().transpose((2,1,0))
io.imsave('bigtest.tif', ind)
ind.min()

In [None]:
value, ind = out.max(1)
out = out.cpu()
out[out.cpu()<0.5]=0

print(ind.shape, mask.shape)
loss_fun(mask.cpu(), out.cpu().gt(0).unsqueeze(0)[...,0:-1:1])

In [None]:
embed.shape
plt.figure(figsize=(10,10))
plt.plot(embed[0,0,:,:,6].detach().cpu().numpy(), embed[0,1,:,:,6].detach().cpu().numpy(),'k.',alpha=0.01)
plt.xlim([0,1])
plt.ylim([0,1])
plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(ind[0,:,:,10])
plt.show()

In [None]:
render = out.argmax(1)

In [None]:
value, ind = out.max(1)

In [None]:
ind[value<0.5]=0

In [None]:
ind.cpu().detach().squeeze(0).numpy().min()

In [None]:
del ind, value

In [None]:
plt.plot(embed[0,0,:,:,6].detach().cpu().numpy(), embed[0,1,:,:,6].detach().cpu().numpy(),'k.',alpha=0.002)
plt.plot(cent[0,:,0].cpu(), cent[0,:,1].cpu(), 'ro')
plt.xlim([0,1])
plt.ylim([0,1])
plt.show()



In [None]:

model.eval()
with torch.no_grad():
    for data_dict in dl:
        image = data_dict['image']
        image = (image - 0.5) / 0.5

        out = model(image.cuda(), 5)
        sigma = torch.tensor([0.0261]) #out[:, -3::, ...]
#         for i in range(3):
#             out[:,i,...][out[:,-1,...] < 0.25] = -10
        out = out[:, 0:3:1, ...]
        embed = src.functional.vector_to_embedding(out)

            
        cent = src.functional.estimate_centroids(embed, 0.01, 100)  # 0.0081, 160 
        out = src.functional.embedding_to_probability(embed.cpu(), cent.cpu(), sigma)
    
    
print(cent.shape[1], ' cells predicted')
print(out.shape)

value, ind = out.max(1)

ind[value<0.25]=0
correction_matrix = binary_dilation(binary_erosion(ind))
ind[np.logical_not(binary_dilation)] = 0
io.imsave('naieve_test.tif', ind.cpu().detach().squeeze(0).float().numpy().transpose((2,1,0)))
ind.min()
ind.shape

x = embed.detach().cpu().numpy()[0,0,...].flatten()
y = embed.detach().cpu().numpy()[0,1,...].flatten()
plt.figure(figsize=(10,10))
plt.hist2d(x,y,bins=256, range=((0,.5), (0,0.5)))
plt.plot(cent[0,:,0].div(512).detach().cpu().numpy(), cent[0,:,1].div(512).detach().cpu().numpy(), 'ro')
plt.plot(data_dict['centroids'][0,:,0].cpu()/512, data_dict['centroids'][0,:,1].cpu()/512, 'bo')
plt.show()
cent.shape

In [None]:
x = embed.detach().cpu().numpy()[0,0,...].flatten()
y = embed.detach().cpu().numpy()[0,1,...].flatten()
sig_x = torch.sigmoid(sigma).detach().cpu()[0,0,...].flatten()
sig_y = torch.sigmoid(sigma).detach().cpu()[0,1,...].flatten()
ind = torch.logical_and(sig_x>0.005, sig_y>0.005)

plt.figure(figsize=(10,10))
plt.hist2d(x[ind], y[ind], bins=512, range=((0,0.5),(0,0.5)))

plt.plot(cent[0,:,0].div(512).detach().cpu().numpy(), cent[0,:,1].div(512).detach().cpu().numpy(), 'ro')
plt.plot(data_dict['centroids'][0,:,0].cpu()/512, data_dict['centroids'][0,:,1].cpu()/512, 'go')
plt.show()

In [None]:
model(image.cuda(), 5)[0,0:3:1,...].max()

In [None]:
cluster = torch.linspace(20, 100, 10)
factor = torch.linspace(0.01,0.03,10)
for c in cluster:
    for f in factor:
        cent = src.functional.estimate_centroids(embed, f.item(), c.item())
        print(c, f, cent.shape[1], data_dict['centroids'].shape[1])

In [None]:
%timeit
sigma_list = torch.linspace(0.0, 0.05, 40)
losses = []
loss_fun = src.loss.jaccard_loss()

model.eval()
with torch.no_grad():
    for data_dict in val:

        image = data_dict['image']

        image = (image - 0.5) / 0.5
        mask = data_dict['masks'] > 0.5
        centroids = data_dict['centroids']

        out = model(image.cuda(), 5)

        sigma = torch.sigmoid(out[:, -1, ...])
        print(image.shape, out.shape, sigma.shape)
        
        out = src.functional.vector_to_embedding(out[:, 0:3:1, ...])
        embed = out.cpu()
        out = src.functional.embedding_to_probability(out, centroids.cuda(), sigma)
        loss = loss_fun(out, mask.cuda())
        losses.append(loss.item())
        print(s, loss.item())
        break

model.train()

# plt.plot(sigma_list, losses)
# plt.ylabel('Loss')
# plt.xlabel('Sigma')
# plt.show()
# model.train()
# print(' ')

In [None]:
del out, loss, embed, sigma

In [None]:
colors = torch.zeros(image.shape)
for z in range(image.shape[4]):
    print(z)
    for x in range(image.shape[2]):
        for y in range(image.shape[3]):

            if ind[0,x,y,z] == 0:
                continue
            torch.manual_seed(ind[0,x,y,z])
            colors[0,:,x,y,z] = torch.rand(3)

In [None]:
re = (image[:,[2,1,0],...] + (colors * 0.8))[0,...].numpy().transpose((3,1,2,0))

In [None]:
colors.shape

In [None]:
io.imsave('re.tif',re)

In [None]:
for _ in range(10):
    for data_dict in dl:
        mask = data_dict['masks'].squeeze().sum(0).sum(-1)
    plt.imshow(mask.cpu().numpy())
    plt.show()
    print(torch.any(torch.nonzero(mask) == 100))