In [4]:
import os
import natsort
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.utils.data as data
from torchvision import transforms
from torch.utils.data import Dataset

from squeakout import SqueakOut_autoencoder as SqueakOut

%load_ext autoreload
%autoreload 2

In [5]:
class CustomDataSet(Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsort.natsorted(all_imgs)

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc).convert("L").resize((512, 512), Image.ANTIALIAS)
        tensor_image = self.transform(image)
        return tensor_image
    
transToTensor = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])

In [6]:
model = SqueakOut()
ckpt_path = "./squeakout_weights.ckpt"
checkpoint = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))


num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_trainable_params}")


# move to GPU!
if torch.cuda.is_available(): model.to("cuda")

model size: 18.029MB
Number of trainable parameters: 4683600


In [10]:
# path where source spectrogram directories are located
src_data = "./dataset/test/"

# path where segmentation masks will be saved
# (will follow same structure as source data)
save_root = "./outputs/segmentation/"

# path to save a montage showing the original spectrogram,
# the segmentation mask, and an overlay of the spectrogram
# with mask for visualization purposes
montage_root = "./outputs/montages/"

In [None]:
# for each recording directory:
#     open all spectrograms as a dataset
#     pass the dataset to the neural network and save segmentation masks

for enum, data_dir in enumerate([src_data]):
    print(f"running {data_dir}")
    out_dir_group = os.path.dirname(data_dir).split('/')[-1]
    
    # create directories to save masks and montage
    segm_out_dir = os.path.join(save_root, out_dir_group)
    if not os.path.exists(segm_out_dir):
        os.makedirs(segm_out_dir)
    mont_out_dir = os.path.join(montage_root, out_dir_group)
    if not os.path.exists(mont_out_dir):
        os.makedirs(mont_out_dir)

    # create spectrogram dataset
    my_dataset = CustomDataSet(data_dir, transform=transToTensor)
    print(f"number of spectrograms {my_dataset.__len__()}")
    train_loader = data.DataLoader(my_dataset, batch_size=8, shuffle=False, num_workers=4, drop_last=False)
    segmentations = []
    spectrograms = []
    
    # iterate over batches and get masks
    for idx, img in enumerate(train_loader):
        out = model(img.to("cuda"))
        segmentations.append(out.detach())
        spectrograms.append(img)


    segm = np.asarray([x.cpu().numpy() for x in segmentations])
    spec = np.asarray([x.cpu().numpy() for x in spectrograms])

    __, __, spec_w, spec_h = spec[0].shape
    a=0
    for imid, im in enumerate(segm):
        for nimid, nim in enumerate(im):
            spc = (spec[imid][nimid][0] * 255).astype(np.int)
            mask = (torch.sigmoid(torch.tensor(nim[0])).detach().numpy() > 0.51) * 255
            mask_img = Image.fromarray(mask.astype("uint8"))
            mask_output_path = os.path.join(segm_out_dir, my_dataset.total_imgs[a])
            mask_img.convert("L").save(mask_output_path)

            img = np.hstack([spc, mask, spc* 0.5 + 0.5 * mask])
            img = Image.fromarray(img)
            img_output_path = os.path.join(mont_out_dir, my_dataset.total_imgs[a][:-4] + "_montage.png")
            img.convert("L").save(img_output_path)
            a+=1

    print(f"{a} masks saved\n\n")

running ./images/
number of spectrograms 849


  del sys.path[0]
  del sys.path[0]
  del sys.path[0]
  del sys.path[0]
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
