In [25]:
import importlib
import numpy as np
import h5py
import scipy.ndimage as sim
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import albumentations as A
import scipy.interpolate as sintp
import scipy.spatial as spat
import skimage.morphology as skmorph
num_neurons=84
shape=(256,160)

In [26]:
frdir="../../../Simple-Annotation-GUI/data/data_temp/zmdir/frames/"

In [27]:
h5=h5py.File("../../../Simple-Annotation-GUI/data/184.h5","r+")
pointdat=np.array(h5["pointdat"])
h5.close()

In [28]:
%matplotlib notebook
fig=plt.figure(figsize=(6,4))
plt.imshow(np.max(np.load(frdir+"frame_1484.npy")[0],axis=2).T)
plt.scatter(*pointdat[1484][:,:2].T,c="r",s=1)

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7f1a8ffde5e0>

In [29]:
xlocs=np.zeros(num_neurons+1)
for i,pt in enumerate(pointdat[1484]):
    if i==0:
        continue
    xlocs[i]=pt[0]
xlocs=(xlocs-np.nanmin(xlocs))/(np.nanmax(xlocs)-np.nanmin(xlocs))

In [45]:
def get_highlight(pts):
    existing=np.isnan(pts[:,0])!=1
    if np.sum(existing)<5:
        return np.full((2,shape[0],shape[1]),0)
    grid=np.array(np.meshgrid(np.arange(shape[0]),np.arange(shape[1]),indexing="ij"))
    res=sintp.griddata(pts[existing,:2], xlocs[existing],grid.reshape(2,-1).T,method="linear")
    res2=sintp.griddata(pts[existing,:2], xlocs[existing],grid.reshape(2,-1).T,method="nearest")
    res=np.where(np.isnan(res),res2,res)
    t=spat.cKDTree(pts[existing,:2])
    res=res.reshape(shape[0],shape[1])
    res=sim.gaussian_filter(res,1)
    binmask=(t.query(grid.reshape(2,-1).T,k=1)[0]<8).reshape(shape[0],shape[1])
    binmask=sim.binary_fill_holes(binmask)
    binmask=sim.binary_dilation(binmask,structure=np.ones((5,5)))
    res[binmask!=1]=np.nan
    resr=res
    resg=1-res
    resr=np.nan_to_num(resr)
    resg=np.nan_to_num(resg)
    return np.concatenate((resr[None,:],resg[None,:]),axis=0) 

In [46]:
plt.figure()
ii=np.random.randint(0,3002)
res=get_highlight(pointdat[ii])
plt.imshow(res[0].T)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f1a901df580>

In [100]:
import UNet2d
dnet=UNet2d.Net(n_channels=1,num_classes=2)

In [101]:
sum([p.numel() for p in dnet.parameters()])

6819682

In [138]:
class DynDataset(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()
        self.indlist=np.arange(3002)
        self.num_frames_tot=len(self.indlist)
        inf={"high":"image"}
        self.trf=A.Compose([
            A.ShiftScaleRotate(shift_limit=0.15, scale_limit=0.15, rotate_limit=180, interpolation=1, border_mode=0, value=0, p=1),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.7),
            A.OpticalDistortion(border_mode=0,p=0.7),
            A.MotionBlur(blur_limit=10,p=0.7),
            A.IAAPerspective(scale=(0.05, 0.1), keep_size=True, p=0.5),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, interpolation=1, border_mode=0, value=0, mask_value=0,p=0.5),
            A.GridDistortion(num_steps=5, distort_limit=0.3, interpolation=1, border_mode=0, value=0, mask_value=0, p=0.5)
            ],
            additional_targets=inf
            )
    def get_trf(self,fr,high):
        feed={}
        feed["high"]=high.transpose(1,2,0)
        res=self.trf(image=fr,**feed)
        return res["image"],res["high"].transpose(2,0,1)
        
    def __getitem__(self,i):
        assert 0<=i<self.num_frames_tot
        ii=self.indlist[i]
        fr=(np.max(np.load(frdir+"frame_"+str(ii)+".npy")[0],axis=2)/255).astype(np.float32)
        high=get_highlight(pointdat[ii])
        fr,high=self.get_trf(fr,high.astype(np.float32))
        return [torch.Tensor(fr).unsqueeze(0),torch.Tensor(high)]

    def __len__(self):
        return self.num_frames_tot

allset=DynDataset()
totnum=len(allset)

In [139]:
batch_size=32

In [140]:
traindataloader=torch.utils.data.DataLoader(allset,batch_size=batch_size,shuffle=True,num_workers=20)
num_trains=len(traindataloader)
digits=len(str(num_trains))


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print("Cuda successful")
dnet.to(device=device)
""

Cuda successful


''

In [141]:
torch.cuda.empty_cache()

In [145]:
%matplotlib notebook
losses=[]
inds=[]
fig=plt.figure(figsize=(8,6))
fraxes=[]
predaxes=[]
highaxes=[]
frims=[]
predims=[]
highims=[]
ax=fig.add_subplot(2,3,1)
ax.axis("off")
ax.set_title("Frame proj",fontsize=8)
im=ax.imshow(np.zeros(shape).T,vmin=0,vmax=1)
fraxes.append(ax)
frims.append(im)

ax=fig.add_subplot(2,3,2)
ax.axis("off")
ax.set_title("Pred proj",fontsize=8)
im=ax.imshow(np.zeros(shape).T)
predaxes.append(ax)
predims.append(im)

ax=fig.add_subplot(2,3,3)
ax.axis("off")
ax.set_title("high proj",fontsize=8)
im=ax.imshow(np.zeros(shape).T)
highaxes.append(ax)
highims.append(im)
    
lossax=fig.add_subplot(2,3,4)
lossax.set_yscale("log")
lossax.set_title("Loss",fontsize=8)
lossplots=[]
lossplots.append(lossax.plot([],[],label="Loss")[0])
lossplots.append(lossax.plot([],[],label="Val Loss")[0])

iouax=fig.add_subplot(2,3,5)

lmapax=fig.add_subplot(2,3,6)
lmapim=lmapax.imshow(np.zeros(shape).T,vmin=0,vmax=1)
lmapax.axis("off")
lmapax.set_title("Loss Map",fontsize=8)
def make3dTransp(im):
    return np.moveaxis(np.concatenate([im[[0]]/2,im[[1]],im[[0]]/2],axis=0),0,2).swapaxes(0,1)
def update_plot():
    last=300
    frnp=fr[0,0].cpu().detach().numpy()
    prednp=pred[0].cpu().detach().numpy()
    highnp=high[0].cpu().detach().numpy()
    frims[0].set_array(frnp.T)
    predims[0].set_array(make3dTransp(prednp))
    highims[0].set_array(make3dTransp(highnp))
    
    lmapim.set_array(lossmap[0].cpu().detach().numpy().T)
    
    indsnp=np.array(inds)[-last:]
    ts=np.arange(len(losses[-last:]))+1
    plotdat=np.array([ts,np.array(losses)[-last:]])
    lossplots[0].set_data(plotdat)
    lossplots[1].set_data(plotdat[:,indsnp])

    lossax.set_ylim(np.nanmin(losses[-last:])*0.9,np.nanmax(losses[-last:])*1.1)
    lossax.set_xlim(0,len(losses[-last:])+1)

    fig.canvas.draw()

<IPython.core.display.Javascript object>

In [146]:
optimizer=torch.optim.Adam(dnet.parameters(),lr=0.003)

In [147]:
for epoch in range(500):
    print("Epoch:",epoch)
    dnet.train()
    for i,(fr,high) in enumerate(traindataloader):
        print("\r"+str(i)+"/"+str(len(traindataloader)),end="")
        fr = fr.to(device=device, dtype=torch.float32)
        high= high.to(device=device, dtype=torch.float32)
        pred=torch.sigmoid(dnet(fr))
        lossmap=torch.mean(torch.nn.MSELoss(reduction="none")(pred,high),dim=1)
        losses_batch=torch.mean(lossmap,dim=(1,2))
        loss=torch.mean(losses_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        inds.append(0)
        update_plot()
    torch.save(dnet.state_dict(),"male_highlighter.pth")

Epoch: 0
93/94Epoch: 1
93/94Epoch: 2
93/94Epoch: 3
93/94Epoch: 4
93/94Epoch: 5
93/94Epoch: 6
93/94Epoch: 7
93/94Epoch: 8
93/94Epoch: 9
93/94Epoch: 10
93/94Epoch: 11
93/94Epoch: 12
93/94Epoch: 13
93/94Epoch: 14
93/94Epoch: 15
93/94Epoch: 16
93/94Epoch: 17
93/94Epoch: 18
93/94Epoch: 19
93/94Epoch: 20
93/94Epoch: 21
93/94Epoch: 22
93/94Epoch: 23
93/94Epoch: 24
93/94Epoch: 25
93/94Epoch: 26
93/94Epoch: 27
93/94Epoch: 28
93/94Epoch: 29
93/94Epoch: 30
93/94Epoch: 31
93/94Epoch: 32
93/94Epoch: 33
93/94Epoch: 34
93/94Epoch: 35
93/94Epoch: 36
93/94Epoch: 37
93/94Epoch: 38
93/94Epoch: 39
93/94Epoch: 40
93/94Epoch: 41
93/94Epoch: 42
93/94Epoch: 43
93/94Epoch: 44
93/94Epoch: 45
93/94Epoch: 46
93/94Epoch: 47
93/94Epoch: 48
93/94Epoch: 49
93/94Epoch: 50
93/94Epoch: 51
93/94Epoch: 52
93/94Epoch: 53
93/94Epoch: 54
93/94Epoch: 55
93/94Epoch: 56
93/94Epoch: 57
93/94Epoch: 58
93/94Epoch: 59
93/94Epoch: 60
93/94Epoch: 61
93/94Epoch: 62
93/94Epoch: 63
93/94Epoch: 64
93/94Epoch: 65
93/94Epoch: 66
93/94Epoch

KeyboardInterrupt: 

In [134]:
distmat=np.load("./distmat_zmtest.npy")

In [152]:
while len(allset)!=1941:
    torch.cuda.empty_cache()
    print("training with",len(allset),"masks for ",max((10*200)//len(allset),2),"epoch")
    for epoch in range(max((10*200)//len(allset),2)):
        print("Epoch:",epoch)
        dnet.train()
        print()
        for i,(fr,mask,trinds) in enumerate(traindataloader):
            print("\r"+str(i)+"/"+str(len(traindataloader)),end="")
            fr = fr.to(device=device, dtype=torch.float32)
            mask= mask.to(device=device, dtype=torch.float32)
            preds=dnet(fr)
            lossmap=criterion(preds,mask)#channel 0
            losses_batch=torch.mean(lossmap,dim=(1,2))
            loss=torch.mean(losses_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            inds.append(0)
            update_plot()
        print()
    print("train done")
    props={}
    done={}
    for i in allset.indlist:
        done[i]=True
    for i in allset.indlist:
        adlist=np.argsort(distmat[i])
        c=0
        for el in adlist:
            if el not in done.keys():
                if el in props.keys():
                    if distmat[i,el]<props[el]:
                        props[el]=distmat[i,el]
                        c+=1
                else:
                    props[el]=distmat[i,el]
                    c+=1
            if c==5:
                break
    keys=[]
    vals=[]
    for key,val in props.items():
        keys.append(key)
        vals.append(val)
    keys=np.array(keys)[np.argsort(vals)]
    print("adding",len(keys)//2,"new masks:",keys[:len(keys)//2])
    for i in keys[:len(keys)//2]:
        fr=np.max(np.load("./data/frames/frame_"+str(i)+".npy")[0][4:-4,7:-7],axis=2)/255
        with torch.no_grad():
            fr = torch.tensor([[fr]]).to(device=device, dtype=torch.float32)
            preds=dnet(fr,verbose=verbose).cpu().detach().numpy()
        np.save("./data/predhighs/predhigh_"+str(i)+".npy",preds[0])
    allset=DynDataset("./data/")
    totnum=len(allset)
    traindataloader=torch.utils.data.DataLoader(allset,batch_size=batch_size,shuffle=True,num_workers=20)

training with 1859 masks for  2 epoch
Epoch: 0

464/465
Epoch: 1

464/465
train done
adding 41 new masks: [ 252  248  250    6   11    1  470    2    4 1742  249   12  214  237
  276   24  273 1744   23  350   22 1743  477 1741  227   13  207  351
  213  184 1740   19  275  278  209  208  211  274   14    5  185]
training with 1900 masks for  2 epoch
Epoch: 0

474/475
Epoch: 1

474/475
train done
adding 20 new masks: [ 18  20 235  17 210 212  16  21 236  15 186 228 277 234 358 357 355 359
 229 188]
training with 1920 masks for  2 epoch
Epoch: 0

479/480
Epoch: 1

479/480
train done
adding 10 new masks: [356 230 232 187 233 231 485 479 484 478]
training with 1930 masks for  2 epoch
Epoch: 0

482/483
Epoch: 1

482/483
train done
adding 5 new masks: [ 480  483  481  482 1939]
training with 1935 masks for  2 epoch
Epoch: 0

483/484
Epoch: 1



KeyboardInterrupt: 

In [153]:
done={}
for i in allset.indlist:
    done[i]=True
len(done)

1935

In [154]:
for i in range(1941):
    if i in done:
        continue
    print(i)
    fr=np.max(np.load("./data/frames/frame_"+str(i)+".npy")[0][4:-4,7:-7],axis=2)/255
    with torch.no_grad():
        fr = torch.tensor([[fr]]).to(device=device, dtype=torch.float32)
        preds=dnet(fr,verbose=verbose).cpu().detach().numpy()
    np.save("./data/predhighs/predhigh_"+str(i)+".npy",preds[0])

1934
1935
1936
1937
1938
1940


In [155]:
torch.save(dnet.state_dict(),"net_"+DenseNetName+"_recent_"+runname+".pth")

In [123]:
len(keys[:len(keys)//2])

21

In [101]:
distmat=np.load("./distmat_zmtest.npy")

In [104]:
for i in np.argsort(distmat[1540])[1:10]:
    fr=np.max(np.load("./data/frames/frame_"+str(i)+".npy")[0][4:-4,7:-7],axis=2)/255
    with torch.no_grad():
        fr = torch.tensor([[fr]]).to(device=device, dtype=torch.float32)
        preds=dnet(fr,verbose=verbose).cpu().detach().numpy()
    np.save("./data/predhighs/predhigh_"+str(i)+".npy",preds[0])
    

In [85]:
fr=np.max(np.load("./data/frames/frame_1539.npy")[0][4:-4,7:-7],axis=2)/255
fr = torch.tensor([[fr]]).to(device=device, dtype=torch.float32)
preds=dnet(fr,verbose=verbose)
update_plot()

In [70]:
plt.figure()
plt.plot(losses)
plt.yscale("log")

<IPython.core.display.Javascript object>