In [1]:
import src.dataloader
import src.loss
import src.transforms as t
import src.functional
from src.models.HCNet import HCNet

import torch
import torch.nn
from torch.utils.data import DataLoader

import time
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriter
from scipy.ndimage.morphology import binary_fill_holes, binary_erosion, binary_dilation
import skimage.io as io

In [2]:
model = torch.jit.script(HCNet(in_channels=3, out_channels=4, complexity=30)).cuda()
model.train()
#model.load_state_dict(torch.load('/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/Dec1_2020_3.hcnet'))
print('')




In [3]:
print('Loading Train...')
transforms = torchvision.transforms.Compose([
    t.nul_crop(),
    t.random_crop(shape=(256, 256, 23)),
    t.elastic_deformation(grid_shape=(3, 3, 2), scale=1.5),
    t.to_cuda(),
    t.random_h_flip(),
    t.random_v_flip(),
    t.random_affine(shear=(-15, 15)),
    t.adjust_brightness(range_brightness = (-0.2, 0.2)),
    t.adjust_gamma(),
    t.adjust_centroids(),
])
data = src.dataloader.dataset('/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/data/train', transforms=transforms)
dl = DataLoader(data, batch_size=1, shuffle=False, num_workers=0)
print('Done')



Loading Train...


KeyboardInterrupt: 

In [None]:
print('Loading Val...')
transforms = torchvision.transforms.Compose([t.adjust_centroids()])
val = src.dataloader.dataset('/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/data/validate', transforms=transforms)
val = DataLoader(val, batch_size=1, shuffle=False, num_workers=0)
print('Done')

In [None]:
lr = 1e-2
gamma = 0.99
wd = 0
#sigma= 0.1 
#sigma = lambda e: 0.0231 + 0.075 * 15/(15+e)
iterations=5

In [None]:
epochs = 500
e = -1

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = wd)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False)
loss_fun = src.loss.jaccard_loss()

In [None]:
optimizer.weight_decay = wd
optimizer.lr = lr

In [None]:
writer = SummaryWriter()

In [None]:
model.train()
print('')

In [None]:
while True:
    e += 1
    epoch_loss = []
    for data_dict in dl:
        image = data_dict['image']
        image = (image - 0.5) / 0.5
        mask = data_dict['masks'] > 0.5
        centroids = data_dict['centroids']
        
        optimizer.zero_grad()

        out = model(image.cuda(),iterations)
        sigma = torch.sigmoid(out[:,-1,...])
        out = src.functional.vector_to_embedding(out[:, 0:3:1, ...])
        out = src.functional.embedding_to_probability(out, centroids.cuda(), sigma)

                                                                                                # This is jank
        loss = loss_fun(out, mask.cuda())

        loss.backward()
        optimizer.step()
        
        
        
        epoch_loss.append(loss.detach().cpu().item())
        
    writer.add_scalar('Loss/train', torch.mean(torch.tensor(epoch_loss)).item(), e)
    writer.add_scalar('Hyperparam/lr',scheduler.get_lr()[0],e)
    writer.add_scalar('Hyperparam/weight_decay',wd,e)
    writer.add_scalar('Hyperparam/iter',iterations,e)
    

    scheduler.step()
    
    with torch.no_grad():
        val_loss = []
        for data_dict in val:
            image = data_dict['image']
            image = (image - 0.5) / 0.5
            mask = data_dict['masks'] > 0.5
            centroids = data_dict['centroids']

            out = model(image.cuda(),iterations)
            sigma = torch.sigmoid(out[:,-1,...])
            out = src.functional.vector_to_embedding(out[:, 0:3:1, ...])
            out = src.functional.embedding_to_probability(out, centroids.cuda(), sigma)
            
            loss = loss_fun(out, mask.cuda())
            val_loss.append(loss.item())
        val_loss = torch.tensor(val_loss).mean()
    writer.add_scalar('Loss/validate', val_loss.item(), e)
    
    del out, loss, image, mask, val_loss
    


In [None]:
#####plt.imshow(out.detach().cpu().numpy()[0,[3],:,:,6].transpose((1,2,0)))
render, values = out.max(1)

plt.imshow(values[0,:,:,15].detach().cpu().numpy())
plt.show()

values.max()

In [None]:
save_name = 'Dec3_big_sigma_1.hcnet'
torch.save(model.state_dict(), save_name)

In [None]:
scheduler.get_lr()

In [None]:
model.eval()
with torch.no_grad():
    for data_dict in val:
        image = data_dict['image']
        image = (image - 0.5) / 0.5
        mask = data_dict['masks'] > 0.5
        centroids = data_dict['centroids']

        out = model(image.cuda(), 5)

        sigma = torch.sigmoid(out[:,-1,...])
        
        out = src.functional.vector_to_embedding(out[:, 0:3:1, ...])
        out = src.functional.embedding_to_probability(out, centroids.cuda(), sigma)
        loss = loss_fun(out, mask.cuda())
        
        print(loss)
        break
model.train()
print(' ')
value, ind = out.max(1)
ind[value<0.5]=0
#for i in range(render.shape[0]):
#    render[i,:,:,:] = binary_dilation(binary_fill_holes(render[i,:,:,:]))
ind = ind.cpu().detach().squeeze(0).float().numpy().transpose((2,1,0))
io.imsave('bigtest.tif', ind)
ind.min()

In [None]:
value, ind = out.max(1)
out = out.cpu()
out[out.cpu()<0.5]=0

print(ind.shape, mask.shape)
loss_fun(mask.cpu(), out.cpu().gt(0).unsqueeze(0)[...,0:-1:1])

In [None]:
embed.shape
plt.figure(figsize=(10,10))
plt.plot(embed[0,0,:,:,6].detach().cpu().numpy(), embed[0,1,:,:,6].detach().cpu().numpy(),'k.',alpha=0.01)
plt.xlim([0,1])
plt.ylim([0,1])
plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(ind[0,:,:,10])
plt.show()

In [None]:
render = out.argmax(1)

In [None]:
value, ind = out.max(1)

In [None]:
ind[value<0.5]=0

In [None]:
ind.cpu().detach().squeeze(0).numpy().min()

In [None]:
del ind, value

In [None]:
plt.plot(embed[0,0,:,:,6].detach().cpu().numpy(), embed[0,1,:,:,6].detach().cpu().numpy(),'k.',alpha=0.002)
plt.plot(cent[0,:,0].cpu(), cent[0,:,1].cpu(), 'ro')
plt.xlim([0,1])
plt.ylim([0,1])
plt.show()



In [None]:
%%time
model.eval()
with torch.no_grad():
    for data_dict in val:
        image = data_dict['image']
        image = (image - 0.5) / 0.5

        out = model(image.cuda(), 5)
        embed = src.functional.vector_to_embedding(out)
        cent = src.functional.estimate_centroids(embed, 0.01, 100)
        cent = cent.unsqueeze(0)
        cent[:,:,0] *= image.shape[2]
        cent[:,:,1] *= image.shape[3]
        cent[:,:,2] *= 40
        out = src.functional.embedding_to_probability(embed.cpu(), cent.cpu(), torch.tensor([0.0231]))
        break
        
value, ind = out.max(1)
ind[value<0.5]=0
correction_matrix = binary_dilation(binary_erosion(ind))
ind[np.logical_not(binary_dilation)] = 0
io.imsave('naieve_test.tif', ind.cpu().detach().squeeze(0).float().numpy().transpose((2,1,0)))
ind.min()
ind.shape

In [None]:
ind.shape

In [None]:
%timeit
sigma_list = torch.linspace(0.0, 0.05, 40)
losses = []
loss_fun = src.loss.jaccard_loss()

model.eval()
for i, s in enumerate(sigma_list):
    with torch.no_grad():
        for data_dict in val:
            image = data_dict['image']
            image = (image - 0.5) / 0.5
            mask = data_dict['masks'] > 0.5
            centroids = data_dict['centroids']

            out = model(image.cuda(), 5)
            out = src.functional.vector_to_embedding(out)
            embed = out.cpu()
            out = src.functional.embedding_to_probability(out, centroids.cuda(), s)
            loss = loss_fun(out, mask.cuda())
            losses.append(loss.item())
            print(s, loss.item())
            break
        
model.train()

# plt.plot(sigma_list, losses)
# plt.ylabel('Loss')
# plt.xlabel('Sigma')
# plt.show()
# model.train()
# print(' ')

In [None]:
colors = torch.zeros(image.shape)
for z in range(image.shape[4]):
    print(z)
    for x in range(image.shape[2]):
        for y in range(image.shape[3]):

            if ind[0,x,y,z] == 0:
                continue
            torch.manual_seed(ind[0,x,y,z])
            colors[0,:,x,y,z] = torch.rand(3)

In [None]:
re = (image[:,[2,1,0],...] + (colors * 0.8))[0,...].numpy().transpose((3,1,2,0))

In [None]:
colors.shape

In [None]:
io.imsave('re.tif',re)