In [None]:
%load_ext autoreload
%autoreload 2
from labeling import dataload, show3D
from CO2_identify import *
from mynetwork import CO2mask
from torch.utils.data import DataLoader
from scipy.ndimage import gaussian_filter
from torchvision.transforms.functional import resize

import torchvision
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import json
from torchstat import stat

In [None]:
fn = '../define_path.txt'
with open(fn) as f:
    lines = f.readlines()
for idx, line in enumerate(lines):
    if idx == 1:
        dir_co2 = line.split('=')[1][:-1]
    if idx == 3:
        dir_grid = line.split('=')[1][:-1]
    if idx == 13:
        root = line.split('=')[1][:-1]

# NN for CO2 mask identification

## Load datasets

In [None]:
# dataset information file names
pmf = 'pm_info.json'
pdf = 'patch_info'
pdfap = '.csv'

In [None]:
# load training dataset
train = dataset_patch(root,pmf,f'{pdf}_train{pdfap}')
Ntrain = len(train)
print(f'Training dataset size: {Ntrain}')

In [None]:
# load validating dataset
valid = dataset_patch(root,pmf,f'{pdf}_valid{pdfap}')
Nvalid = len(valid)
print(f'Validating dataset size: {Nvalid}')

In [None]:
# network path
path_net = f'../resources/NNpred2D/co2_identify.pt'
path_bestnet = f'../resources/NNpred2D/co2_identify_best.pt'

## Display sampled patches in training and validating dataset

In [None]:
# define the sampled patches in training dataset for display
ndis_tr = 5
train_id_list = np.random.choice(len(train),size=ndis_tr,replace=False)
#train_id_list = np.linspace(0,Ntrain,ndis_tr+2,dtype=np.int16)[1:-1]
pst = patch_show(train,train_id_list)
print(f'train_id_list: {train_id_list}')
# define the sampled patches in valid dataset for display
ndis_va = 3
valid_id_list = np.random.choice(len(valid),size=ndis_va,replace=False)
#valid_id_list = np.linspace(0,Nvalid,ndis_va+2,dtype=np.int16)[1:-1]
psv = patch_show(valid,valid_id_list)
print(f'valid_id_list: {valid_id_list}')
# data patch resize shape
rs = valid.nsz
# sampling rate of epoch for display
epoch_itv = 20
# sampling rate of batch number for display
batch_itv = 10

In [None]:
pst.view2d()

In [None]:
psv.view2d()

## Training and validating

### Define training parameters

In [None]:
params = OrderedDict(
    lr = [.0002]
    ,batch_size = [30]
    ,shuffle = [True]
    ,epoch_num = [200]
    ,adadelta_num = [0]
)
cuda_gpu = True
gpus = [0]

In [None]:
# initilize run
M = RunManager(cuda_gpu)

In [None]:
# define sampler for loading valid_set
#valid_sampler = SubsetSampler(valid_id_list)
cpu_device = torch.device('cpu')

### Start training and validation

In [None]:
ts = time.time()

In [None]:
for run in RunBuilder.get_runs(params):
    # initialize network
    nw = 2
    network = CO2mask()
    
    # inherit from previous train network
    #network.load_state_dict(torch.load(path_bestnet,map_location=cpu_device))
    #network = network.train()

    if cuda_gpu:
        network = torch.nn.DataParallel(network, device_ids=gpus).cuda()
        nw = 0
    # train_set loader
    loader_train = DataLoader(
        train
        ,batch_size = run.batch_size
        ,shuffle = run.shuffle
        ,num_workers = nw
        ,drop_last = False)
    # valid_set loader (load the entire dataset as a single batch)
    loader_valid = DataLoader(
         valid
        ,batch_size = Nvalid
        ,shuffle = False
        ,num_workers = nw
        ,drop_last = False)
    
    # define the initial optimizer as adadelta
    optimizer = optim.Adadelta(network.parameters())

    # initialize training and validation loss lists
    TrLoss_list = []
    VaLoss_list = []
    Bloss = float('inf') # initial smallest loss
    Bloss_epNo = 0 # intial smallest loss corresponding epoch No.
    
    # begin this training run
    M.begin_run(run, network, loader_train)
    for epoch in range(run.epoch_num):
        print(f'Epoch No.: {epoch}')
        # initialize runner
        M.begin_epoch()
        batch_id=0
        Loss = 0
        
        # adjust learning rate
        #adjust_learning_rate(optimizer, epoch, run.lr)
        
        # initialize sampled prediction arrays
        Trpred = np.zeros((ndis_tr,rs[0],rs[1]))
        Vapred = np.zeros((ndis_va,rs[0],rs[1]))
        
        if epoch == run.adadelta_num:
            # redefine the optimizer as adam 
            optimizer = optim.Adam(network.parameters(),run.lr)

        # loop through different batches in training dataset
        Loss = 0 # initial loss
        Np = 0 # number of patches
        C = 0 # batch No.
        for batch in loader_train:
            if C%batch_itv == 0:
                print(f'Batch No. {C}--------------------')
            C += 1
            R0t, Mask, idx = batch
            # find the indices of sampled training patches in current batch for later display
            bs = len(idx)
            Np += bs
            Idx = idx.tolist()
            cp = findtrace(train_id_list,Idx)
            # copy cpu data on GPU
            if cuda_gpu:
                R0t = R0t.cuda()
                Mask = Mask.cuda()
            # forward modeling
            pMask = network(R0t)
            
            # record the sampled training patches for later display
            for c,p in cp:
                Trpred[c] = pMask[p][0].cpu().detach().numpy()
            # validating loss
            loss = F.binary_cross_entropy(pMask, Mask)
            #loss = BBCE(pMask, Mask)
            # backward for gradient
            optimizer.zero_grad()
            loss.backward()
            # update NN
            optimizer.step()
            # track the loss
            Loss = track_loss_out(Loss,loss,bs)

        # record the mean loss for the entire training dataset
        TrLoss_list.append(Loss/Np)
        print(f'Mean training loss for epoch No. {epoch}: {Loss/Np}')
        # display the sampled patch fitting in training dataset
        if (epoch%epoch_itv == 0) or (epoch == run.epoch_num-1):
            print(f'Training patch samples display at epoch No. {epoch}')
            pst.view2d(Trpred)
        
        # save the current-epoch training network
        torch.save(network.module.state_dict(),path_net)
        
        r'''
        (for cpu validating)
        # Load saved validating network on cpu
        networkvalid = CO2mask()
        networkvalid.load_state_dict(torch.load(path_net,map_location=cpu_device))
        '''
        
        networkvalid = network.eval()
        print(f'Start validating for epoch No. {epoch}')
        # loop through different batches in valid dataset
        Loss = 0 # initial loss
        Np = 0 # number of patches
        for batch in loader_valid:
            R0t, Mask, idx = batch
            # find the indices of sampled validating patches in current batch for later display
            bs = len(idx)
            Np += bs
            Idx = idx.tolist()
            cp = findtrace(valid_id_list,Idx)
            # copy cpu data on GPU
            if cuda_gpu:
                R0t = R0t.cuda()
                Mask = Mask.cuda()
            # forward modeling
            with torch.no_grad():
                pMask = networkvalid(R0t)
            # record the sampled validating patches for later display
            for c,p in cp:
                Vapred[c] = pMask[p][0].detach().cpu().numpy()   
            # valid loss
            loss = F.binary_cross_entropy(pMask, Mask)
            #loss = BBCE(pMask, Mask)
            # track the loss
            Loss = track_loss_out(Loss,loss,bs)

        # record the mean loss for the entire validating dataset
        VaLoss_list.append(Loss/Np)
        print(f'Mean validating loss for epoch No. {epoch}: {Loss/Np}')
        if epoch>0:
            if VaLoss_list[-1] < Bloss:
                # save the currently best network (providing smallest validating loss)
                torch.save(network.module.state_dict(),path_bestnet)
                Bloss = VaLoss_list[-1]
                Bloss_epNo = epoch
                
        # display the validating result
        if (epoch%epoch_itv == 0) or (epoch == run.epoch_num-1):    
            print(f'Validating patch samples display at epoch No. {epoch}')
            psv.view2d(Vapred)

In [None]:
te = time.time()

In [None]:
print(f'Training time: {te-ts} s')

In [None]:
# plot the training and validating loss
epoch = np.arange(run.epoch_num)
fig,ax = plt.subplots(1,1)
ax.plot(epoch,TrLoss_list,label='Training')
ax.plot(epoch,VaLoss_list,label='Validating')
ax.plot(Bloss_epNo,Bloss,'ro',label='Best Validating result')
ax.legend()
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_ylim(0,0.5)

In [None]:
np.array(TrLoss_list,dtype=np.float32).tofile(f'../resources/NNpred2D/train_loss.dat')

In [None]:
np.array(VaLoss_list,dtype=np.float32).tofile(f'../resources/NNpred2D/valid_loss.dat')

## Perform testing

In [None]:
fn = '../define_path.txt'
with open(fn) as f:
    lines = f.readlines()
for idx, line in enumerate(lines):
    if idx == 1:
        dir_co2 = line.split('=')[1][:-1]
    if idx == 15:
        testpath = line.split('=')[1][:-1]

### define the network and basic information for testing

In [None]:
# testing datasets
yearlist = ['1999_b01_t01','2001_b01_t01',
            '2004_b01_t07','2006_b01_t07',
            '2008_b01_t08',
            '2010_b01_t10','2010_b01_t11','2010_b10_t10','2010_b10_t11']

In [None]:
# reference dataset grid path
xydfn = f'{dir_co2}/10p10/2010 processing/data/10p10nea.sgy'
# load the reference dataset head
Dr = dataload(fn=xydfn)
DD = (Dr.nx,Dr.ny,Dr.nt)

In [None]:
# network path
path_bestnet = f'../resources/NNpred2D/co2_identify.pt'
gpus = [0]
cuda_gpu = True

In [None]:
# load the network
networktest = CO2mask()
networktest.load_state_dict(torch.load(path_bestnet,map_location=torch.device('cpu')))
networktest = networktest.eval()
if cuda_gpu:
    networktest = torch.nn.DataParallel(networktest, device_ids=gpus).cuda()
    nw = 0
networktest = networktest.eval()

In [None]:
# fixed dataset information file names
pmf = 'pm_info.json'
pdf = 'patch_info.csv'
# batch_size of testing dataset
bs = 1000
# sampling number for patch display
ndis_ts = 3

In [None]:
# reference mask dataset for training and validating
mkfn = f'../resources/label/masks.dat'
# readin CO2 mask
masks = np.fromfile(f'{mkfn}',dtype=np.float32)
masks = np.reshape(masks,DD)
# find the slice indices for display
MI = np.argmax(np.sum(masks,axis=(1,2)))
MX = np.argmax(np.sum(masks,axis=(0,2)))
MT = np.argmax(np.sum(masks,axis=(0,1)))
MIXT = (MI,MX,MT)

### start to test dataset for all years

In [None]:
for year in yearlist:
    print(f'Testing year: {year}')
    # load test dataset
    root_test = f'{testpath}/{year}/test'
    if year[:4] == '2010':
        maskyear = True
    else:
        maskyear = False
    test = dataset_patch(root_test,pmf,pdf,mask=maskyear)
    Ntest = len(test)
    print(f'Testing dataset size: {Ntest}')
    # define the sampled patches in test dataset for display
    #test_id_list = np.random.choice(len(test),size=ndis_ts,replace=False)
    test_id_list = np.linspace(0,Ntest,ndis_ts+2,dtype=np.int16)[1:-1]
    pss = patch_show(test,test_id_list)
    print(f'test_id_list for {year}: {test_id_list}')
    
    # display the test patches
    #pss.view2d()
    
    # patch size
    rs = test.nsz
    # loop through different batches in testing dataset
    loader_test = DataLoader(
         test
        ,batch_size = bs
        ,drop_last = False)
    # allocate memory for testing batches
    Tepred = np.zeros((ndis_ts,rs[0],rs[1]))
    teMasks = torch.zeros((Ntest,1,rs[0],rs[1]),dtype=torch.float32)
    Np = 0 # current accumulative number of patches
    for batch in loader_test:
        if test.mask:
            R0t, _, idx = batch
        else:
            R0t, idx = batch
        # copy cpu data on GPU
        if cuda_gpu:
            R0t = R0t.cuda()
        # forward modeling
        bs = len(idx)
        Np += bs
        # find the indices of sampled testing patches in current batch for later display
        Idx = idx.tolist()
        cp = findtrace(test_id_list,Idx)
        # forward modeling
        with torch.no_grad():
            pMask = networktest(R0t)
        # record the sampled testing patches for later display
        for c,p in cp:
            Tepred[c] = pMask[p][0].cpu().detach().numpy()
        # save pMask for final combination
        teMasks[Np-bs:Np] = pMask.detach()

    # display the sampled patch fitting in validating dataset
    pss.view2d(Tepred)
    # combine pMask
    teMasks = teMasks.squeeze()
    pMask_cb = patch_combine_2D(teMasks,test,ixswitch=8070)
    # save pMask
    pMask_cb.tofile(f'{root_test}/tsMask.dat')
    teMasks.numpy().tofile(f'{root_test}/ts_patchMask.dat')
    if year[:4] == '2010':
        # calculate BCE loss for 2010 data
        pMask_cb[pMask_cb>1.0] = 1.0
        tmp = F.binary_cross_entropy(torch.tensor(pMask_cb),torch.tensor(masks))
        np.array(tmp).tofile(f'{root_test}/BCE_loss.dat')
        print(f'The prediction BCE loss for {year} is {tmp}!')
    # display the combined prediction in 3D
    fig = plt.figure(figsize=(9,7))
    ax = fig.add_subplot(1,1,1,projection='3d')
    _ = show3D(pMask_cb,ax=ax,xyzi=(test.DD[0]//2,test.DD[1]//2,test.DD[2]//2),
               clim=[0,1],rcstride=(5,5),tl=f'Mask_pred')
    plt.show()
    # display the combined prediction in slices
    print(f'Horizontal slice Artifact above 600 ms for {year}:')
    plt.imshow(np.max(pMask_cb[:,:,:300],axis=2),vmin=0,vmax=1,aspect=1,cmap='gray')
    plt.show()
    print(f'Horizontal slice Artifact below 1200 ms for {year}:')
    plt.imshow(np.max(pMask_cb[:,:,600:],axis=2),vmin=0,vmax=1,aspect=1,cmap='gray')
    plt.show()
    print(f'Inline assemble No. 134 for {year}:')
    plt.imshow(pMask_cb[134,:,:].T,vmin=0,vmax=1,aspect=0.3,cmap='gray')
    plt.show()
    if year[:4] == '2010':
        print(f'Reference Inline assemble No. 134 for {year}:')
        plt.imshow(masks[134,:,:].T,vmin=0,vmax=1,aspect=0.3,cmap='gray')
        plt.show()