# Overview (Schizo Edition)

Find the head circumference given ultrasound image.

> How do we do this? 

In the `data/hc18` directory, we have the dataset with images, annotated images, pixel size (to show how many mm each pixel occupies) and the actual head circumference. 

> How do we find the head circumference?

We draw the segment mask on the ultrasound image using some deep learning model, then we estimate an ellipse on the mask using opencv, and finally we calculate the head circumference by finding the circumference of the estimated ellipse and the given pixel size. 

> How do we draw a segment mask on the ultrasound image?

We will use PyTorch, use the data in the `training_set` as the input, where the original ultrasound image is the input, and the annotated binary mask is the output of the model. 

> Cool, then how do we train the model in question using PyTorch?  

reference: https://pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/

We will use U-Net architecture.

1. Prepare the data!
   * We don't need to do much, since the images here are already single channel (gray-scale)

# Dataset Class

In [1]:
from my_class import SegmentationDataset

# U-Net model

## Config

In [2]:
import torch 
import os

DATASET_PATH = os.path.join("../data", "hc18") 

IMAGE_DATASET_PATH = os.path.join(DATASET_PATH, "training_set")
# No need for mask path, we will filter from the same dir 
TEST_SPLIT = 0.2

DEVICE = "mps" if torch.mps.is_available() else "cpu" # because I'm on Mac or shitty laptop

# No pin memory. Didn't work. 
# ref: https://github.com/pytorch/pytorch/issues/86060

NUM_CHANNELS = 1 
NUM_CLASSES = 1 
NUM_LEVELS = 3 

INIT_LR = 0.001 
NUM_EPOCHS = 40 
BATCH_SIZE = 8 

INPUT_IMAGE_WIDTH = 800 
INPUT_IMAGE_HEIGHT = 540 

THRESHOLD = 0.5 

BASE_OUTPUT = "output" 

MODEL_PATH = os.path.join(BASE_OUTPUT, "unet_hc.pth") 
PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"])
TEST_PATHS = os.path.sep.join([BASE_OUTPUT, "test_paths.txt"])

## The model

### imports

In [3]:
from torch.nn import ConvTranspose2d 
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F 

### Block module

In [4]:
class Block(Module): 
    def __init__(self, inChannels, outChannels): 
        super().__init__()
        self.conv1 = Conv2d(inChannels, outChannels, 3)
        self.relu = ReLU() 
        self.conv2 = Conv2d(outChannels, outChannels, 3)

    def forward(self, x): 
        # Ugliest syntax award. 
        # Would look better with piping (functional)
        return self.conv2(self.relu(self.conv1(x))) 

### Encoder class

In [5]:
class Encoder(Module):
    def __init__(self, channels=(1, 16, 32, 64)):
        super().__init__()
        self.encBlocks = ModuleList(
            [Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)]
        )
        self.pool = MaxPool2d(2)

    def forward(self, x): 
        blockOutputs = []

        for block in self.encBlocks:
            x = block(x)
            blockOutputs.append(x)
            x = self.pool(x)

        return blockOutputs

### Decoder class

In [6]:
class Decoder(Module):
    def __init__(self, channels=(64, 32, 16)):
        super().__init__()

        self.channels = channels
        self.upconvs = ModuleList(
            [ConvTranspose2d(channels[i], channels[i+1], 2, 2) for i in range(len(channels) - 1)]
        )
        self.dec_blocks = ModuleList(
            [Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)]
        )

    def forward(self, x, encFeatures):
        for i in range(len(self.channels) - 1):
            x = self.upconvs[i](x) 

            encFeat = self.crop(encFeatures[i], x)
            x = torch.cat([x, encFeat], dim=1)
            x = self.dec_blocks[i](x) 

        return x 

    def crop(self, encFeatures, x):
        (_, _, H, W) = x.shape
        encFeatures = CenterCrop([H, W])(encFeatures)

        return encFeatures

### UNet class

In [7]:
class UNet(Module): 
    def __init__(
        self, 
        encChannels=(1, 16, 32, 64), 
        decChannels=(64, 32, 16), 
        nbClasses=1, 
        retainDim=True,
        outSize=(INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH)
    ):
        super().__init__()
        self.encoder = Encoder(encChannels)
        self.decoder = Decoder(decChannels) 

        self.head = Conv2d(decChannels[-1], nbClasses, 1)
        self.retainDim = retainDim
        self.outSize = outSize

    def forward(self, x):
        encFeatures = self.encoder(x) 
        decFeatures = self.decoder(encFeatures[::-1][0], encFeatures[::-1][1:])
        mapper = self.head(decFeatures)

        if self.retainDim: 
            mapper = F.interpolate(mapper, self.outSize)

        return mapper 

## Train

In [22]:
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam 
from torch.utils.data import DataLoader
from torchvision import transforms 
from imutils import paths 
from tqdm import tqdm
import matplotlib.pyplot as plt 
import time 
from sklearn.model_selection import train_test_split


allPaths = sorted(list(paths.list_images(IMAGE_DATASET_PATH)))
maskPaths = [file for file in allPaths if "Annotation" in file]
imagePaths = list(set(allPaths) - set(maskPaths))

# print(imagePaths)
# print(maskPaths)

In [23]:
(trainImages, testImages, trainMasks, testMasks) = train_test_split(imagePaths, maskPaths, test_size=TEST_SPLIT, random_state=67) # Six seveeeeeeeeeeen

print("[INFO] saving testing image paths...")
f = open(TEST_PATHS, "w")
f.write("\n".join(testImages))
f.close()

[INFO] saving testing image paths...


In [26]:
transforms = transforms.Compose([
    transforms.ToPILImage(), 
    transforms.Resize((INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH)), 
    transforms.Grayscale(),
    transforms.ToTensor()
])

trainDataset = SegmentationDataset(
    imagePaths=trainImages,
    maskPaths=trainMasks,
    transforms=transforms
)
testDataset = SegmentationDataset(
    imagePaths=testImages, 
    maskPaths=testMasks,
    transforms=transforms
)

print(f"[INFO] found {len(trainDataset)} examples in the training set...")
print(f"[INFO] found {len(testDataset)} examples in the test set...")

trainLoader = DataLoader(
    trainDataset, 
    shuffle=True, 
    batch_size=BATCH_SIZE, 
    num_workers=os.cpu_count()
)
testLoader = DataLoader(
    testDataset, 
    shuffle=False, 
    batch_size=BATCH_SIZE, 
    num_workers=os.cpu_count()
)



[INFO] found 799 examples in the training set...
[INFO] found 200 examples in the test set...


### Init UNet

In [27]:
unet = UNet().to(DEVICE)

lossFunc = BCEWithLogitsLoss()
opt = Adam(unet.parameters(), lr=INIT_LR)

trainSteps = len(trainDataset) // BATCH_SIZE
testSteps = len(testDataset) // BATCH_SIZE

H = {"train_loss": [], "test_loss": []}

### Training loop

In [None]:
print("[INFO] training the network...") 
startTime = time.time() 
for e in tqdm(range(NUM_EPOCHS)):
    unet.train() 

    totalTrainLoss = 0 
    totalTestLoss = 0 

    for (i, (x, y)) in enumerate(trainLoader):
        (x, y) = (x.to(DEVICE), y. to(DEVICE))

        pred = unet(x)
        loss = lossFunc(pred, y) 

        opt.zero_grad()
        loss.backward() 
        opt.step() 

        totalTrainLoss += loss 

    with torch.no_grad(): 
        unet.eval() 

        for (x, y) in testLoader: 
            (x, y) = (x.to(DEVICE), y. to(DEVICE)) 

            pred = unet(x) 
            totalTestLoss += lossFunc(pred, y) 

    avgTrainLoss = totalTrainLoss / trainSteps
    avgTestLoss = totalTestLoss / testSteps

    H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
    H["test_loss"].append(avgTestLoss.cpu().detach().numpy())

    print("[INFO] EPOCH: {}/{}".format(e+1, NUM_EPOCHS))
    print("Train loss: {:.6f}, Test loss: {:.4f}".format(avgTrainLoss, avgTestLoss))

endTime = time.time() 
print("[INFO] total time taken to train the model: {:.2f}s".format(endTime - startTime))

[INFO] training the network...


  0%|                                                                                                                                                | 0/40 [00:00<?, ?it/s]