In [1]:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# %matplotlib inline

import torch
from tqdm import tqdm
from IPython.display import clear_output
from sklearn.metrics import f1_score

sys.path.append('../..')

from seismicpro.batchflow import Pipeline, Dataset, B, V, D, C
from seismicpro.batchflow.models.torch import ResNet18, VGG7
from seismicpro.src import TraceIndex, SeismicDataset, FieldIndex, KNNIndex
from seismicpro.src.seismic_batch import (SeismicBatch,
                                            seismic_plot)
from seismicpro.batchflow.models.torch.layers import ConvBlock
%env CUDA_VISIBLE_DEVICES=2

env: CUDA_VISIBLE_DEVICES=2


In [2]:
from seismicpro.batchflow import action, inbatch_parallel

class InverseBatch_2d(SeismicBatch):
    @action
    @inbatch_parallel(init='_init_component')
    def inv_traces(self, index, src, dst, p=.5):
        pos = self.get_pos(None, src, index)
        traces = getattr(self, src)[pos]
        size = traces.shape[0]
        mask = np.random.choice([1, -1], size=size, p=(1-p, p))
        getattr(self, dst[0])[pos] = traces * mask.reshape(-1, 1)
        getattr(self, dst[1])[pos] = 1 - np.clip(mask, 0, 1)

    @action
    def update_batch(self, src, from_cont):
        batch = getattr(self, src[0])
        labels = getattr(self, src[1])
        if from_cont[0] is None:
            return self
        new_data = from_cont[0][0]
        new_labels = from_cont[0][1]
        batch = np.vstack((batch, new_data))
        labels = np.vstack((labels, new_labels))
        setattr(self, src[0], batch)
        setattr(self, src[1], labels)
        return self

    @action
    def HNS(self, src, labels, preds, metric, to, n_worse=50):
        order = metric(labels, preds)
        sigm = torch.nn.Sigmoid()
        preds = sigm(torch.Tensor(preds))
        to[0] = [getattr(self, src)[order[:n_worse]], labels[order[:n_worse]], preds[order[:n_worse]]]
        return self

def draw_res(loss, scores, labels, title=' '):
    _, ax = plt.subplots(1, 2, figsize=(20, 9))
    ax[0].plot(loss[-300:], label='Loss')
    for i, score in enumerate(scores):
        ax[1].plot(score, label=labels[i])
    ax[0].set_title(title)
    ax[0].legend()
    ax[1].legend()
    ax[0].grid()
    ax[1].grid()
    plt.show()
    
def create_test_ppl(train_ppl, data, mode='w'):
    test_ppl = (Pipeline().load(components='raw', fmt='segy')
              .standardize(src='raw', dst='raw')
              .inv_traces(src='raw', dst=['raw', 'labels'], p=0.0026)
              .import_model('model', train_ppl, 'model')
              .preprocess_component(src='raw', dst='raw')
              .preprocess_answers(src='labels', dst='labels')
              .init_variable('pred', init_on_each_run=list())
              .init_variable('labels', init_on_each_run=list())
              .update_variable('labels', B('labels'), mode=mode) 
              .predict_model('model', B('raw'), fetches='predictions',
                             save_to=V('pred', mode=mode))) << data
    return test_ppl

def get_results(ppl):
    sigm = torch.nn.Sigmoid()
    pred = sigm(torch.Tensor(ppl.v('pred')))
    preds = np.array(np.array(pred) > .5, dtype=int).ravel()
    labels = np.array(ppl.v('labels')).ravel()
    return preds, labels

def create_test_ppl(train_ppl, data, mode='w'):
    test_ppl = (Pipeline().load(components='raw', fmt='segy')
              .standardize(src='raw', dst='raw')
              .inv_traces(src='raw', dst=['raw', 'labels'], p=0.0026)
              .import_model('model', train_ppl, 'model')
              .apply_transform_all(src='raw', dst='raw', func=lambda x: np.expand_dims(np.stack(x), axis=1).astype(np.float32))
              .apply_transform_all(src='labels', dst='labels', func=lambda x: np.stack(x).astype(np.float32))
              .init_variable('pred', init_on_each_run=list())
              .init_variable('labels', init_on_each_run=list())
              .update_variable('labels', B('labels'), mode=mode) 
              .predict_model('model', B('raw'), fetches='predictions',
                             save_to=V('pred', mode=mode))) << data
    return test_ppl

In [3]:
N_NEIGH = 2
pal_path = '/data/FB/dataset_1/Pal_Flatiron_1k.sgy'
wz_path = '/data/FB/dataset_2/WZ_Flatiron_1k.sgy'
vor_path = '/data/FB/dataset_6/3_FBP_input_ffid_raw-500_off-800.sgy'

pal_index = KNNIndex(name='raw', path=pal_path, extra_headers=['offset'], n_neighbors=N_NEIGH)
pal_index = pal_index.create_subset(pal_index.indices[:100000])
pal_index.split()
pal_data_tr = Dataset(pal_index.train, InverseBatch_2d)
pal_data_te = Dataset(pal_index.test, InverseBatch_2d)

wz_index = KNNIndex(name='raw', path=wz_path, extra_headers=['offset'], n_neighbors=N_NEIGH)
wz_index = wz_index.create_subset(wz_index.indices[:100000])
wz_index.split()
wz_data_tr = Dataset(wz_index.train, InverseBatch_2d)
wz_data_te = Dataset(wz_index.test, InverseBatch_2d)

vor_index = KNNIndex(name='raw', path=vor_path, extra_headers=['offset'], n_neighbors=N_NEIGH)
vor_index = vor_index.create_subset(vor_index.indices[:100000])
vor_index.split()
vor_data_tr = Dataset(vor_index.train, InverseBatch_2d)
vor_data_te = Dataset(vor_index.test, InverseBatch_2d)

In [4]:
dsts_tr = np.array([pal_data_tr, wz_data_tr, vor_data_tr])
dsts_te = np.array([pal_data_te, wz_data_te, vor_data_te])

In [5]:
from seismicpro.batchflow import L, I
prep_ppl = (Pipeline()
          .load(components='raw', fmt='segy', tslice=np.arange(751))
          .standardize(src='raw', dst='raw')
          .init_variable('prob', init_on_each_run=float)
          .init_variable('n_iter', init_on_each_run=0)
          .update_variable('prob', L(lambda x, m: m/(2*m + x**(1.4)))(V('n_iter'), C('n_iter')), mode='w')
          .inv_traces(src='raw', dst=['raw', 'labels'], p=V('prob'))
          .update_variable('n_iter', V('n_iter')+1)
          .apply_transform_all(src='raw', dst='raw', func=lambda x: np.expand_dims(np.stack(x), axis=1).astype(np.float32))
          .apply_transform_all(src='labels', dst='labels', func=lambda x: np.stack(x).astype(np.float32))
)

In [7]:
SIZE = 751
inputs_config = {
    'raw': {'shape': (1, N_NEIGH, SIZE)}, 
    'masks': {'shape': (N_NEIGH, )}
    }

config = {
    'loss': 'bce',
    'inputs': inputs_config,
    'initial_block/inputs': 'raw',
    'optimizer': 'Adam',
    'head' : dict(layout='Vf', utils=N_NEIGH),
    'n_iters': D('size')/B('size'),
    'decay': dict(name='exp', gamma=0.99),
    'device': 'gpu:0',
}

In [8]:
def metric(labels, preds):
    sigm = torch.nn.Sigmoid()
    pred = sigm(torch.Tensor(preds))
    preds = np.array(np.array(pred) > .5, dtype=int).ravel()
    labels = np.array(labels).ravel()
    argsort = np.argsort(preds != labels)[::-1] / N_NEIGH
    return argsort.astype(int)

In [9]:
load_config = {'load' : {'path' : './decrease_prob'},
                                    'build': False,
                                    'device': 'gpu:0'}

batches = np.array([None])
train_ppl = prep_ppl + (Pipeline()
             .init_model('dynamic', ResNet18, 'model', config)
             .update_batch(['raw', 'labels'], from_cont=batches)
             .init_variable('loss', init_on_each_run=list)
             .init_variable('pred', init_on_each_run=list)
             .init_variable('labels', init_on_each_run=list)
             .update_variable('labels', B('labels'), mode='w')
             .train_model('model', B('raw'), B('labels'),  
                          fetches=['loss', 'predictions'], save_to=[V('loss', mode='a'), V('pred', mode='w')])
             .HNS('raw', V('labels'), V('pred'), n_worse=50, metric=metric, to=batches)
)

model_init


In [10]:
from itertools import combinations

In [11]:
comb = [list(combinations([1, 2, 3], i)) for i in range(1, 4)]

In [12]:
new_comb = []
for i in comb:
    for j in i:
        new_comb.append(np.array(j)-1)

### Training on all datasets

In [None]:
N_EPOCHS = 151
N_ITER = 200
B_SIZE = 400
ppl_config = dict(n_iter=N_ITER)
global_scores = []

for ixs in tqdm(new_comb):
    train_numb = dsts_tr[ixs]
    if not isinstance(train_numb, np.ndarray):
        train_numb = np.array([train_numb])
    ppl_train = (train_ppl << ppl_config)
    ppl_train = ppl_train.run_later(B_SIZE, shuffle=True, drop_last=True)
    ppl_train._init_all_variables()#reset('vars')
    ppl_val = [create_test_ppl(ppl_train, data, mode='a') for data in dsts_te]
    scores = [[] for _ in range(len(ppl_val))]

    for i in tqdm(range(N_EPOCHS)):
        for iters in range(N_ITER):
            for dst in train_numb:
                batch = dst.next_batch(B_SIZE, shuffle=True, drop_last=True)
                ppl_train.execute_for(batch)
        if i % 1 == 0:
            print(i)
            ppl_train.save_model_now('model', './models/{}/model_{}_ep_{}'.format(ixs, ixs, i))
        for i, ppl_v in enumerate(ppl_val):
            ppl_v.run(B_SIZE, shuffle=True, drop_last=True, n_iters=1, reset=['iter', 'vars'])##change!
            preds, labels = get_results(ppl_v)
            scores[i].append(f1_score(labels, preds))
        clear_output()
        draw_res(ppl_train.v('loss'), scores, ['pal', 'wz', 'vor'])
        ppl_train.save_model_now('model', './models/{}/model_{}_ep_{}'.format(ixs, ixs, i))
    global_scores.append(scores)

  0%|          | 0/7 [00:00<?, ?it/s]
  0%|          | 0/151 [00:00<?, ?it/s][A