### Dependencies

In [17]:
import torch
import json
import glob
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from torch.utils.data import DataLoader
from torchsummary import summary
from common.classes import (
    SeismicDataset
)
from common.utils import (
    get_sliding_wnd_params,
    get_sliding_wnd_patches,
    window_2d,
    recover_img_from_patches,
)

### Dataset related parameters

In [19]:
stage = 'test'
dataset_root = 'D:/220617_seismic'
dataset_proc = f'{dataset_root}/thebe_processed'

### Patch related parameters

In [20]:
with open(f'{dataset_proc}/patch_params.json', 'r+') as f:
    patch_params = json.load(f)

patch_sz = patch_params['patch_sz']
step = patch_params['step']
overlap_sz = patch_sz - step
print(f'Patch size: {patch_sz}')
print(f'Step: {step}')

Patch size: 96
Step: 48


### Load an inference model

In [21]:
alg = "unet"
merge_method = "smooth"  # smooth average crop
if alg == "unet":
    from model_zoo.UNET import Unet
    model = Unet()
    print("use model Unet")
    filename = "unet_96_48_900200_seed"
    model_path = './checkpoints/noaugmodelsseed/{}.model'.format(
        filename)
    save_path = './overlap_clean_predictions/{}_{}'.format(
        filename, merge_method)
elif alg == "deeplab":
    from model_zoo.DEEPLAB.deeplab import DeepLab
    model = DeepLab(backbone='mobilenet', num_classes=1, output_stride=16)
    print("use model DeepLab")
    filename = "mobilenet_96_48_900200_seed"
    model_path = './checkpoints/noaugmodelsseed/{}.model'.format(
        filename)
    save_path = './overlap_clean_predictions/{}_{}'.format(
        filename, merge_method)
elif alg == "hed":
    from model_zoo.HED import HED
    model = HED()
    print("use model HED")
    filename = "hed_96_48_900200_seed3"
    model_path = './checkpoints/noaugmodelsseed/{}.model'.format(
        filename)
    save_path = './overlap_clean_predictions/{}_{}'.format(
        filename, merge_method)
elif alg == "rcf":
    from model_zoo.RCF import RCF
    model = RCF()
    print("use model RCF")
    filename = "rcf_96_48_900200_seed"
    model_path = './checkpoints/noaugmodelsseed/{}.model'.format(
        filename)
    save_path = './overlap_clean_predictions/{}_{}'.format(
        filename, merge_method)
else:
    print("please enter a valid model")

print(model_path)
print(save_path)

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    devcie = torch.device('cpu')

model.load_state_dict(torch.load(model_path))
model.to(device)

summary(model, (1, patch_sz, patch_sz))

use model Unet
./checkpoints/noaugmodelsseed/unet_96_48_900200_seed.model
./overlap_clean_predictions/unet_96_48_900200_seed_smooth
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 96, 96]             320
       BatchNorm2d-2           [-1, 32, 96, 96]              64
              ReLU-3           [-1, 32, 96, 96]               0
         Dropout2d-4           [-1, 32, 96, 96]               0
            Conv2d-5           [-1, 32, 96, 96]           9,248
       BatchNorm2d-6           [-1, 32, 96, 96]              64
              ReLU-7           [-1, 32, 96, 96]               0
         Dropout2d-8           [-1, 32, 96, 96]               0
       double_conv-9           [-1, 32, 96, 96]               0
        MaxPool2d-10           [-1, 32, 48, 48]               0
           Conv2d-11           [-1, 64, 48, 48]          18,496
      BatchNorm2d-12           [-1,

### Processing test dataset

In [22]:
seismic_list = glob.glob(f'{dataset_proc}/{stage}/seismic/*.npy')
fault_list = glob.glob(f'{dataset_proc}/{stage}/annotation/*.npy')

print(f'Testing for {len(fault_list)} samples')

for seismic_path, fault_path in zip(seismic_list, fault_list):
    seismic = np.load(seismic_path)
    fault = np.load(fault_path)

    Z, XL = fault.shape
    padding, cnt = get_sliding_wnd_params((XL, Z), patch_sz, step)
    num_patches = cnt[0] * cnt[1]

    patches = get_sliding_wnd_patches(seismic, padding, patch_sz, step)
    patches = patches.astype(np.float32)
    patches = np.expand_dims(patches, 1)
    
    batch_size = 64
    dataset = SeismicDataset(patches)
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False
    )

    preds = []
    recovered_imgs = []
    for imgs in data_loader:
        outputs = model(imgs.cuda())
        preds.extend(outputs)
        if len(preds) >= num_patches:
            stacked = torch.stack(preds).numpy()[:num_patches]
            preds = preds[num_patches:]

            if merge_method == "smooth":
                weights = window_2d(wnd_sz=patch_sz, power=2)
                stacked = np.moveaxis(stacked, -3, -1)
                stacked = np.array([patch * weights for patch in stacked])
                stacked = stacked.reshape((cnt[1], cnt[0], patch_sz, patch_sz, 1))
                recovered = recover_img_from_patches(stacked, (XL, Z, 1), padding, overlap_sz)
                recovered_imgs.extend(recovered)
            else:
                print('invalid merge method')

Testing for 100 samples


OutOfMemoryError: CUDA out of memory. Tried to allocate 50.00 MiB (GPU 0; 8.00 GiB total capacity; 7.09 GiB already allocated; 0 bytes free; 7.35 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF