# Import all necessary modules

In [1]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from cubdl.PixelGrid import make_pixel_grid
from cubdl.das_torch import DAS_PW
import h5py
from cubdl.PlaneWaveData import PlaneWaveData
from glob import glob
from scipy.signal import hilbert, convolve
from efficientnet_lite import efficientnet_lite_params, build_efficientnet_lite
import torch.nn as nn

# Dataset Definition

In [19]:
def get_name_and_acq(filename):
    file_text = os.path.splitext(filename)
    name = file_text[0]
    number = int(name[-3:])
    origin = name[0:3]
    return origin, number

def read_us_images(directory):
    files = os.listdir(directory)
    return files

class TSHData(PlaneWaveData):
    """ Load data from Tsinghua University. """

    def __init__(self, database_path, acq):
        # Make sure the selected dataset is valid
        moniker = "TSH{:03d}".format(acq) + "*.hdf5"
        fname = [
            y for x in os.walk(database_path) for y in glob(os.path.join(x[0], moniker))
        ]
        assert fname, "File not found."

        # Load dataset
        f = h5py.File(fname[0], "r")

        # Get data
        self.angles = np.array(f["angles"])
        self.idata = np.array(f["channel_data"], dtype="float32")
        self.idata = np.reshape(self.idata, (128, len(self.angles), -1))
        self.idata = np.transpose(self.idata, (1, 0, 2))
        self.qdata = np.imag(hilbert(self.idata, axis=-1))
        self.fc = np.array(f["modulation_frequency"]).item()
        self.fs = np.array(f["sampling_frequency"]).item()
        self.c = 1540  # np.array(f["sound_speed"]).item()
        self.time_zero = np.zeros((len(self.angles),), dtype="float32")
        self.fdemod = 0

        # Make the element positions based on L11-4v geometry
        pitch = 0.3e-3
        nelems = self.idata.shape[1]
        xpos = np.arange(nelems) * pitch
        xpos -= np.mean(xpos)
        self.ele_pos = np.stack([xpos, 0 * xpos, 0 * xpos], axis=1)

        # For this dataset, time zero is the center point
        for i, a in enumerate(self.angles):
            self.time_zero[i] = self.ele_pos[-1, 0] * np.abs(np.sin(a)) / self.c

        # Validate that all information is properly included
        super().validate()

def USData_1angle(root_dir,file):
    print("1angle : " + str(file))
    origin, acq = get_name_and_acq(file)
    # print(os.path.abspath(file))
    full_path = os.path.abspath(file)
    if origin == 'TSH':
        P = TSHData(root_dir, acq)
        xlims = [P.ele_pos[0, 0], P.ele_pos[-1, 0]]
        zlims = [10e-3, 45e-3]

    wvln = P.c / P.fc
    dx = wvln / 2.5
    dz = dx  # Use square pixels
    grid = make_pixel_grid(xlims, zlims, dx, dz)
    fnum = 1

    # make data from 1 angle
    x = (P.idata, P.qdata)
    idx = len(P.angles) // 2  # Choose center angle
    das1 = DAS_PW(P, grid, idx, rxfnum=fnum)
    idas1, qdas1 = das1(x)
    idas1, qdas1 = idas1.detach().cpu().numpy(), qdas1.detach().cpu().numpy()

    us_1angle = np.stack((idas1, qdas1), axis=0)
    us_1angle = torch.from_numpy(us_1angle)

    return us_1angle

def USData_Nangles(root_dir,file):
    print("Nangle : " + str(file))
    origin, acq = get_name_and_acq(file)
    # print(os.path.abspath(file))
    full_path = os.path.abspath(file)
    if origin == 'TSH':
        P = TSHData(root_dir, acq)
        xlims = [P.ele_pos[0, 0], P.ele_pos[-1, 0]]
        zlims = [10e-3, 45e-3]

    wvln = P.c / P.fc
    dx = wvln / 2.5
    dz = dx  # Use square pixels
    grid = make_pixel_grid(xlims, zlims, dx, dz)
    fnum = 1

    # make data from 1 angle
    x = (P.idata, P.qdata)
    dasN = DAS_PW(P, grid, rxfnum=fnum)
    idasN, qdasN = dasN(x)
    idasN, qdasN = idasN.detach().cpu().numpy(), qdasN.detach().cpu().numpy()
    
    us_Nangles = np.stack((idasN, qdasN), axis=0)
    us_Nangles = torch.from_numpy(us_Nangles)

    return us_Nangles

In [20]:
class USDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the US images from TSH.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.us_images = pd.DataFrame(read_us_images(root_dir))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.us_images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        us_1angle = USData_1angle(self.root_dir,self.us_images.iloc[idx, 0])
        us_Nangle = USData_Nangles(self.root_dir,self.us_images.iloc[idx, 0])

        sample = {'us_1angle': us_1angle, 'us_Nangle': us_Nangle}

        if self.transform:
            sample = self.transform(sample)

        return sample

# Model Definition

In [26]:
model_name = "efficientnet_lite0"
num_outputs = 275544
model = build_efficientnet_lite(model_name, num_outputs)
device = torch.device('cpu')
model = model.to(device)

# Loss and Optimizer Definition

In [27]:
import torch.optim as optim

criterion = nn.L1Loss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Training

## Dataset

In [28]:
# batch_size = 1
# US_dataset = USDataset(root_dir = r'D:\OneDrive\Documents\Maestria_Biomedica\Imagenes_Medicas\model_cubdl\TSH')
# trainloader = torch.utils.data.DataLoader(US_dataset, batch_size=batch_size,shuffle=True, num_workers=0)

batch_size = 1
US_dataset_train = USDataset(root_dir = r'D:\OneDrive\Documents\Maestria_Biomedica\Imagenes_Medicas\model_cubdl\dataset_split\train')
trainloader = torch.utils.data.DataLoader(US_dataset_train, batch_size=batch_size,shuffle=True, num_workers=0)
US_dataset_val = USDataset(root_dir = r'D:\OneDrive\Documents\Maestria_Biomedica\Imagenes_Medicas\model_cubdl\dataset_split\val')
valloader = torch.utils.data.DataLoader(US_dataset_val, batch_size=batch_size,shuffle=True, num_workers=0)


In [29]:
import wandb
wandb.login()



True

In [30]:
wandb.init(project="efficientNet_cubdl")
for epoch in range(2):  # loop over the dataset multiple times
    ckpt = 0
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        das1 = data['us_1angle']
        dasN = data['us_Nangle']

        das1 = das1.to(device)
        dasN = dasN.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        # print(das1.shape)
        outputs = model(das1)
        outputs = torch.reshape(outputs, (2,356,387))
        loss = criterion(outputs, dasN)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 10 == 9:    # print every 10 mini-batches
            wandb.log({"Train_Loss": loss})
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0
        
            # PATH = "model_epoch" + str(epoch) + "_ckpt" + str(ckpt) + ".pt"
            # ckpt = ckpt + 1

            # torch.save({
            #             'epoch': epoch,
            #             'model_state_dict': model.state_dict(),
            #             'optimizer_state_dict': optimizer.state_dict(),
            #             }, PATH)
    
    for i, data in enumerate(valloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        das1 = data['us_1angle']
        dasN = data['us_Nangle']
        # forward + backward + optimize
        # print(das1.shape)
        outputs = model(das1)
        outputs = torch.reshape(outputs, (2,356,387))
        loss = criterion(outputs, dasN)

        # print statistics
        wandb.log({"Val_Loss": loss})
        print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, loss))
    
    

print('Finished Training')

1angle : TSH390.hdf5
Nangle : TSH390.hdf5
1angle : TSH300.hdf5
Nangle : TSH300.hdf5
1angle : TSH268.hdf5
Nangle : TSH268.hdf5
1angle : TSH331.hdf5
Nangle : TSH331.hdf5


KeyboardInterrupt: 