In [None]:
import os
import sys
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td

import math

from skimage import io

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
os.getcwd()

In [None]:
# define model class
sqrt2 = 1.414
numViews = 9

class resblock(nn.Module):
    def __init__(self,channels = 48):
        super(resblock,self).__init__()
        self.channels = channels
        
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1,bias=False)
        torch.nn.init.kaiming_normal_(self.conv1.weight,nonlinearity='relu')
        self.bn1 = nn.BatchNorm2d(channels)
        
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1,bias=False)
        torch.nn.init.kaiming_normal_(self.conv2.weight,nonlinearity='relu')
        self.bn2 = nn.BatchNorm2d(channels)
        
    def forward(self,x1):
        x1 = self.conv1(x1)
        x1 = F.relu(self.bn1(x1),inplace=True)
        x1 = self.conv2(x1)
        return (self.bn2(x1))

class cm2netblock(nn.Module):
    def __init__(self,inchannels, numblocks, outchannels = 48):
        super(cm2netblock,self).__init__()
        self.inchannels = inchannels
        self.outchannels = outchannels
        self.numblocks = numblocks
        
        self.conv1 = nn.Conv2d(inchannels,outchannels,kernel_size=3,padding=1)
        torch.nn.init.kaiming_normal_(self.conv1.weight,nonlinearity='relu')
        self.resblocks = nn.ModuleList([resblock() for i in range(numblocks)])
        self.conv2 = nn.Conv2d(outchannels,outchannels,kernel_size=3,padding=1)
        torch.nn.init.kaiming_normal_(self.conv2.weight,nonlinearity='relu')
            
    def forward(self,x):
        x0 = (self.conv1(x)) 
        x1 = torch.clone(x0)
        for _, modulee in enumerate(self.resblocks):
            x1 = (modulee(x1) + x1)/sqrt2 # short residual connection 
        x1 = (x1 + x0)/sqrt2 # long residual connection
        return self.conv2(x1)
            
class cm2net(nn.Module):
    def __init__(self,numBlocks, stackchannels = numViews, rfvchannels = 24, outchannels = 24):
        super(cm2net,self).__init__()
        
        self.stackpath = cm2netblock(stackchannels, numblocks = numBlocks)
        self.rfvpath = cm2netblock(rfvchannels, numblocks = numBlocks)
        self.endconv = nn.Conv2d(outchannels*2,outchannels, kernel_size=3, padding=1)
        torch.nn.init.kaiming_normal_(self.endconv.weight,nonlinearity='relu')
        
    def forward(self,stack,rfv):
        return self.endconv((self.stackpath(stack)+self.rfvpath(rfv))/sqrt2) # branch fusion

In [None]:
# set up dataloader
class Dataset(td.Dataset):
    
    def __init__(self, folder):
        super(Dataset, self).__init__()
        self.directory = folder # requires a forward slash at the end
        

    def __len__(self):
        DIR = self.directory+"rfvbg" # rfbg refers to the refocused volume with background
        return (len([name for name in os.listdir(DIR) if os.path.isfile(os.path.join(DIR, name))]))

    def __getitem__(self, index):
        stack = io.imread(self.directory+"stackbg/sim_meas_"+str(index)+".tif")
        stack = (stack - stack.min()) / (stack.max() - stack.min()).astype(np.int16) 
        stack = torch.from_numpy(stack) 
        
        rfv = io.imread(self.directory+"rfvbg/sim_meas_"+str(index)+".tif")
        rfv = (rfv - rfv.min()) / (rfv.max() - rfv.min()).astype(np.int16) 
        rfv = torch.from_numpy(rfv) 
        
        gt = io.imread(self.directory+"gt/sim_gt_vol_"+str(index)+".tif")
        gt = (gt - gt.min()) / (gt.max() - gt.min()).astype(np.int16) 
        gt = torch.from_numpy(gt)
        
        return stack, rfv, gt

class my_subset(Dataset):

    def __init__(self, dataset,isVal):
        self.dataset = dataset
        self.isVal = isVal
    def __getitem__(self, idx):
        p = 224
        
        # a and b parameters for poisson gaussian noise calibrated from experiment
#         amin = 1.49e-4 - 5.7092e-5
#         amax = 1.49e-4 + 5.7092e-5
#         bmin = 5.41e-6 - 2.7754e-6
#         bmax = 5.41e-6 + 2.7754e-6
#         aa = torch.rand(1)*(amax-amin)+amin 
#         bb = torch.rand(1)*(bmax-bmin)+bmin 
        
        aa = torch.randn(1)*5.7092e-5 + 1.49e-4
        bb = torch.randn(1)*2.7754e-6 + 5.41e-6
        # if during validation, feed in the entire 512x512 data
        if self.isVal: 
            stack, rfv, gt = self.dataset.__getitem__(idx)
            stack += torch.sqrt(aa*stack+bb)*torch.randn(stack.shape) 
            rfv += torch.sqrt(aa*rfv+bb)*torch.randn(rfv.shape)/3 
            return stack, rfv, gt
        #if during training, get random 224x224 patches
        else:
            stack, rfv, gt = self.dataset.__getitem__(idx)
            dim = stack.shape
            a = torch.randint(0,dim[1]-p,(1,))
            b = torch.randint(0,dim[2]-p,(1,))

            stack = stack[:,a:a+p,b:b+p]
            stack += torch.sqrt(aa*stack+bb)*torch.randn(stack.shape) 
            rfv = rfv[:,a:a+p,b:b+p]
            rfv += torch.sqrt(aa*rfv+bb)*torch.randn(rfv.shape)/3
            return stack , rfv, gt[:,a:a+p,b:b+p]
    def __len__(self):
        return self.dataset.__len__()

d_traindata = './training data/dataset11/'

completedataset = Dataset(d_traindata)
train_size = int(0.8 * len(completedataset))
test_size = int(0.2*len(completedataset))
train_dataset, val_dataset = torch.utils.data.random_split(completedataset, [train_size, test_size])
val_dataset = my_subset(val_dataset,True)
train_dataset = my_subset(train_dataset,False)

In [None]:
# train
layers = 20
net = cm2net(numBlocks = layers).to(device)


#if torch.cuda.device_count() > 1:
#    
#    print("Let's use", torch.cuda.device_count(), "GPUs!")
#    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
#    net = nn.DataParallel(net)

batchNumber = 12
notes = 'sbrnet_'
numEpoch = 10000
lr = 1e-3
bestValLoss = math.inf
def count_parameters(model): 
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
numparam = count_parameters(net)
notessave = 'cosine annealing scheduler start at 1e-3 and relu act function'

trainloader = td.DataLoader(train_dataset, batch_size=batchNumber, shuffle=True, pin_memory=True)
valloader = td.DataLoader(val_dataset, batch_size=batchNumber, shuffle=True, pin_memory=True)

modelname = 'cm2net_'+notes+'layers_'+str(layers)
#d_trainedmodels = 'H:/jeffrey/scattering/models/test/'
d_trainedmodels = './models/'

# feat_list = ['r11','r21','r31','r41','r51']
# feat_weight = torch.tensor([1/32, 1/16, 1/8, 1/4, 1])
# VGG19 = VGG19Transfer().to(device)
losshistory = []
trainhistory = []
valhistory = []
optimizer = torch.optim.Adam(net.parameters(),lr=lr,betas=(0.9, 0.999))
# optimizer = torch.optim.SGD(netrfv.parameters(),lr=lr)

#### uncomment if continuing training ####
#path = d_trainedmodels+modelname
#checkpoint2 = torch.load(path)
#net.load_state_dict(checkpoint2['net_state_dict'])
#optimizer.load_state_dict(checkpoint2['optimizer_state_dict'])
#trainhistory = checkpoint2['trainloss']
#valhistory = checkpoint2['valloss']

torch.backends.cudnn.benchmark = True

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,30)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
scaler = torch.cuda.amp.GradScaler()
# criterionl1 = nn.L1Loss()
starto = time.time()

startepoch = 0
for epoch in range(numEpoch):
    s = time.time()
    trainloss = 0
    valloss = 0
    numtrainloader = len(trainloader)
    numvalloader = len(valloader)
    net.train()
    for stack, rfv, gt in trainloader:
        
        optimizer.zero_grad()
        
        stack, rfv, gt = stack.to(device), rfv.to(device), gt.to(device)
        with torch.cuda.amp.autocast(enabled=True):
            stack = stack.half()
            rfv = rfv.half()
            gt = gt.half()
            
            out = net(stack,rfv)

            loss = F.binary_cross_entropy_with_logits(out, gt)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if epoch == 0: print(loss)
        if not torch.isnan(loss):
            lo = loss.detach().cpu()
            losshistory.append(lo)
            trainloss = trainloss + lo
        else:
            numtrainloader = numtrainloader - 1
    
    if numtrainloader == 0:
        trainloss = float("NaN")
        
    else:
        trainloss = trainloss/numtrainloader
        trainhistory.append(trainloss)
        
    net.eval()
    with torch.no_grad():
        for stack, rfv, gt in valloader:
            
            stack, rfv, gt = stack.to(device), rfv.to(device), gt.to(device)
            with torch.cuda.amp.autocast(enabled=True):
                stack = stack.half()
                rfv = rfv.half()
                gt = gt.half()
                out = net(stack,rfv)

                loss = F.binary_cross_entropy_with_logits(out, gt)
                
            if not torch.isnan(loss):
                valloss = valloss + loss.detach().cpu()
            else:
                 numvalloader = numvalloader - 1   
                    
    t = time.time() - s
            
    if numvalloader == 0:
        valloss = float("NaN")
        continue
    else:
        valloss = valloss / numvalloader
        valhistory.append(valloss.detach().cpu())
        
    if valloss < bestValLoss:

        tt = time.time() - starto
        bestValLoss = valloss
        torch.save({
                'epoch': epoch,
                'net_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'valloss': valhistory,
                'trainloss': trainhistory,
                'lr': lr,
                'notes': notessave,
                'dir': d_traindata,
                'time': tt,
                'batchsize': batchNumber,
                'num_network_param': numparam,
                }, d_trainedmodels+modelname) 
    scheduler.step()