# Structure Classification

Annotations for nine key structures (thalami, midbrain, palate, 4th ventricle, cisterna magna, nuchal translucency (NT), nasal tip, nasal skin, and nasal bone) could be found in ObjectDetection.xlsx file.The column names in ObjectDetection.xlsx file are the image name, structure name and xy min max coordinates of the structure in the image.


In [None]:
# !pip install plotlib opencv-python openpyxl scikit-learn

In [None]:
# import os 
# os.chdir ("./Jaik's Model")


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.optim.lr_scheduler import StepLR
from torchvision.models import resnet34, ResNet34_Weights
from torchvision.ops import box_iou
from torchvision.transforms import transforms, v2
from torchvision.io import read_image, ImageReadMode
from torchvision.utils import draw_bounding_boxes
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch.cuda as cuda
import cv2
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
import random
import math
from datetime import datetime
from collections import Counter
from PIL import Image
from pathlib import Path
from matplotlib.patches import Rectangle
from torchvision import models

## Dataset Preparation
Please ensure that you download the zip file folders from the [google drive](https://drive.google.com/file/d/1-ppPA9UHw9ZTBxyGmbWEyCgRNKTECC_6/view?usp=drive_link) and add it to the root folder of this directory before running the below code cell."
   

In [None]:
data_dir=Path("Fetus_Health_Dataset/allStandard")
labels_file=Path("ObjectDetection.xlsx")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def read_dataset(data_dir, labels_file):
    labels_df = pd.read_excel(labels_file)

    #Map the different labels to ints in order to train
    label_map = {
            'thalami': 0,
            'nasal bone': 1,
            'palate': 2,
            'nasal skin': 3,
            'nasal tip': 4,
            'midbrain': 5,
            'NT': 6,
            'IT': 7,
            'CM': 8 }

    #Loop through the excel sheet to make sure each image exist (not all is used)
    mask = labels_df.apply(lambda x: os.path.exists(os.path.join(data_dir, x["fname"])), axis=1)
    labels_df = labels_df[mask]

    #Create image paths in order to read into model
    image_paths = [Path(data_dir) / fname for fname in labels_df["fname"]]

    #Add all the classes
    labels = [label_map[name] for name in labels_df["structure"]]
    og_shapes = []

    #Progress bar
    for image_path in tqdm(image_paths, desc="Loading Images", total=len(image_paths)):
        img = cv2.imread(str(image_path))
        height, width, _ = img.shape
        og_shapes.append((width, height))

    #Read bounding boxes
    bboxes = labels_df.iloc[:, 2:6].values.astype(float)

    #Format data into dataframe
    data = {
        'filepath': image_paths,
        'width': [shape[0] for shape in og_shapes],
        'height': [shape[1] for shape in og_shapes],
        'class': labels,
        'xmin': bboxes[:, 1],
        'ymin': bboxes[:, 0],
        'xmax': bboxes[:, 3],
        'ymax': bboxes[:, 2]
    }

    df_data = pd.DataFrame(data)

    return df_data


In [None]:
df_train = read_dataset(data_dir, labels_file)

In [None]:
def read_image(path):
    return cv2.imread(str(path))


In [None]:
def create_mask(bb, x):
    #Creates a mask for the bounding box of same shape as image
    rows,cols,*_ = x.shape
    Y = np.zeros((rows, cols))
    bb = bb.astype(np.int32)
    Y[bb[0]:bb[2], bb[1]:bb[3]] = 1.
    return Y

def mask_to_bb(Y):
    #Convert mask Y to a bounding box 
    cols, rows = np.nonzero(Y)
    if len(cols)==0:
        return np.zeros(4, dtype=np.float32)
    top_row = np.min(rows)
    left_col = np.min(cols)
    bottom_row = np.max(rows)
    right_col = np.max(cols)
    return np.array([left_col, top_row, right_col, bottom_row], dtype=np.float32)

def create_bb_array(x):
    #Generates bounding box array from df_train row
    return np.array([x[5],x[4],x[7],x[6]])

In [None]:
def resize_image_bb(read_path,write_path,bb,sz):
    #Resize an image and its bounding box
    #Write image to new path (leave original dataset intact)
    im = read_image(read_path)
    new_path = write_path/read_path.parts[-1]
    if new_path.exists():
        im_resized = read_image(new_path)
    else:
        im_resized = cv2.resize(im, (int(1.49*sz), sz))
        cv2.imwrite(new_path, im_resized)
    
    Y_resized = cv2.resize(create_mask(bb, im), (int(1.49*sz), sz))
    return new_path, mask_to_bb(Y_resized)

In [None]:
new_paths = []
new_bbs = []
train_path_resized = Path('image_resized')

#Progress Bar
for index, row in tqdm(df_train.iterrows(), desc="Resizing Images",total=df_train.shape[0]):
    new_path,new_bb = resize_image_bb(row['filepath'], train_path_resized, create_bb_array(row.values),300)
    new_paths.append(new_path)
    new_bbs.append(new_bb)

#Update image path and boxes
df_train['new_path'] = new_paths
df_train['new_bb'] = new_bbs

In [None]:
def crop(im, r, c, target_r, target_c):
    return im[r:r+target_r, c:c+target_c]

# Random crop to the original size
def random_crop(x, r_pix=8):
    r, c,*_ = x.shape
    c_pix = round(r_pix*c/r)
    rand_r = random.uniform(0, 1)
    rand_c = random.uniform(0, 1)
    start_r = np.floor(2*rand_r*r_pix).astype(int)
    start_c = np.floor(2*rand_c*c_pix).astype(int)
    return crop(x, start_r, start_c, r-2*r_pix, c-2*c_pix)

def center_crop(x, r_pix=8):
    r, c,*_ = x.shape
    c_pix = round(r_pix*c/r)
    return crop(x, r_pix, c_pix, r-2*r_pix, c-2*c_pix)

In [None]:
def rotate_cv(im, deg, y=False, mode=cv2.BORDER_REFLECT, interpolation=cv2.INTER_AREA):
    #Rotates an image by deg degrees"""
    r,c,*_ = im.shape
    M = cv2.getRotationMatrix2D((c/2,r/2),deg,1)
    if y:
        return cv2.warpAffine(im, M,(c,r), borderMode=cv2.BORDER_CONSTANT)
    return cv2.warpAffine(im,M,(c,r), borderMode=mode, flags=cv2.WARP_FILL_OUTLIERS+interpolation)

def random_cropXY(x, Y, r_pix=8):
    #Returns a random crop
    r, c,*_ = x.shape
    c_pix = round(r_pix*c/r)
    rand_r = random.uniform(0, 1)
    rand_c = random.uniform(0, 1)
    start_r = np.floor(2*rand_r*r_pix).astype(int)
    start_c = np.floor(2*rand_c*c_pix).astype(int)
    xx = crop(x, start_r, start_c, r-2*r_pix, c-2*c_pix)
    YY = crop(Y, start_r, start_c, r-2*r_pix, c-2*c_pix)
    return xx, YY

def transformsXY(path, bb, transforms):
    #Read image, convert formate, and create a mask
    x = cv2.imread(str(path)).astype(np.float32)
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)/255
    Y = create_mask(bb, x)

    #Apply trasformations such as flips and rotates
    if transforms:
        rdeg = (np.random.random()-.50)*20
        x = rotate_cv(x, rdeg)
        Y = rotate_cv(Y, rdeg, y=True)
        if np.random.random() > 0.5:
            x = np.fliplr(x).copy()
            Y = np.fliplr(Y).copy()
        x, Y = random_cropXY(x, Y)
    else:
        x, Y = center_crop(x), center_crop(Y)
    return x, mask_to_bb(Y)

In [None]:
def create_corner_rect(bb, color='red'):
    bb = np.array(bb, dtype=np.float32)
    return plt.Rectangle((bb[1], bb[0]), bb[3]-bb[1], bb[2]-bb[0], color=color,
                         fill=False, lw=3)

def show_corner_bb(im, bb):
    plt.imshow(im)
    plt.gca().add_patch(create_corner_rect(bb))

In [None]:
df_train = df_train.reset_index()

In [None]:
#Create training, validation, and test set
X_train, X_val_temp, y_train, y_val_temp = train_test_split(df_train[['new_path', 'new_bb']], df_train['class'], test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_val_temp, y_val_temp, test_size=0.3, random_state=42)

In [None]:
def normalize(im):
    #Normalise images
    imagenet_stats = np.array([[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]])
    return (im - imagenet_stats[0])/imagenet_stats[1]

def unnormalize(im):
    #Unnormalise images for printing
    imagenet_stats = np.array([[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]])
    return im * imagenet_stats[1, np.newaxis, np.newaxis] + imagenet_stats[0, np.newaxis, np.newaxis]

In [None]:
#Pass image data set to edit images
class FetusDataset(Dataset):
    def __init__(self, paths, bb, y, transforms=False):
        self.transforms = transforms
        self.paths = paths.values
        self.bb = bb.values
        self.y = y.values
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        y_class = self.y[idx]
        x, y_bb = transformsXY(path, self.bb[idx], self.transforms)
        x = normalize(x)
        x = np.rollaxis(x, 2)
        return x, y_class, y_bb

In [None]:
train_ds = FetusDataset(X_train['new_path'],X_train['new_bb'] ,y_train, transforms=True)
valid_ds = FetusDataset(X_val['new_path'],X_val['new_bb'],y_val)
test_ds = FetusDataset(X_test['new_path'],X_test['new_bb'],y_test)

In [None]:
batch_size = 32

#Create dataloaders for training set
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size=batch_size, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=batch_size, pin_memory=True)

In [None]:
#Model to train bounding box
class BB_model(nn.Module):
    def __init__(self):
        super(BB_model, self).__init__()
        resnet = models.resnet34(pretrained=True)
        
        #Retrain last 8 layers
        layers = list(resnet.children())[:8]
        self.features1 = nn.Sequential(*layers[:6])
        self.features2 = nn.Sequential(*layers[6:])
        # self.classifier = nn.Sequential(nn.BatchNorm1d(512), nn.Linear(512, 4))
        
        #Train
        self.bb_norm = nn.BatchNorm1d(512)
        self.bb = nn.ModuleList([nn.Linear(512, 4) for i in range(0,9)])

    def forward(self, x, y_class=None):
        x = self.features1(x)
        x = self.features2(x)
        x = F.leaky_relu(x)
        x = nn.AdaptiveAvgPool2d((1,1))(x)
        x = x.view(x.shape[0], -1)
        x = self.bb_norm(x)
        if y_class is None: ## Use None to get all 9 boxes
            return torch.cat([self.bb[i](x) for i in range(9)], dim=-1)
        out = torch.empty(x.shape[0], 4, device=x.device)
        for cls in range(9):
            idx = (y_class == cls).nonzero()
            out[idx, ...] = self.bb[cls](x[idx, ...])
        return out
        # return torch.stack([self.bb[cls](x[i:i+1, ...]) for i, cls in enumerate(y_class)]).squeeze(1)

In [None]:
def update_optimizer(optimizer, lr):
    for i, param_group in enumerate(optimizer.param_groups):
        param_group["lr"] = lr

In [None]:
# def train_epocs(model, optimizer, train_dl, val_dl, epochs=10,C=1000):
#     # with torch.no_grad():
#     #     val_loss, val_oui, val_dist = val_metrics(model, valid_dl)
#     for i in range(epochs):
#         model.train()
#         total = 0
#         sum_loss = 0
#         for x, y_class, y_bb in tqdm(train_dl, desc=f"Epoch {i+1}/{epochs}"):
#             batch = y_class.shape[0]
#             x = x.to(device).float()
#             #y_class = y_class.to(device)
#             y_bb = y_bb.to(device).float()
#             boxes = model(x, y_class)
#             # loss_class = F.cross_entropy(out_class, y_class, reduction="sum")
#             # print(boxes.shape, y_bb.shape)
#             # pred_y_bb = torch.gather(boxes, -1, y_class[:, None, None].repeat(1, 4, 1))
#             loss_bb = F.l1_loss(boxes, y_bb, reduction="none").sum(1)
#             loss_bb = loss_bb.sum()
#             loss = loss_bb
#             sum_loss +=  loss.detach().cpu().item()
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
#             total += batch
#         train_loss = sum_loss/total
#         with torch.no_grad():
#             val_loss, val_oui, val_dist = val_metrics(model, valid_dl)
#         torch.save(model.state_dict(), 'model_v5.pth')
#         print("train_loss %.3f val_loss %.3f val_iou %.3f val_dist %.3f" % (train_loss, val_loss, val_oui, val_dist))
#     return sum_loss/total

In [None]:
import matplotlib.pyplot as plt

def train_epocs(model, optimizer, train_dl, val_dl, epochs=10, C=1000):
    train_losses = []
    val_distces = []
    val_ious = []
    val_losses = []

    for i in range(epochs):
        model.train()
        total = 0
        sum_loss = 0

        # Training loop
        for x, y_class, y_bb in tqdm(train_dl, desc=f"Epoch {i+1}/{epochs}"):
            batch = y_class.shape[0]
            x = x.to(device).float()
            y_bb = y_bb.to(device).float()
            boxes = model(x, y_class)
            loss_bb = F.l1_loss(boxes, y_bb, reduction="none").sum(1)
            loss_bb = loss_bb.sum()
            loss = loss_bb
            sum_loss += loss.detach().cpu().item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total += batch

        # Calculate training loss
        train_loss = sum_loss / total
        train_losses.append(train_loss)

        # Calculate validation loss
        with torch.no_grad():
            val_loss, val_oui, val_dist = val_metrics(model, valid_dl)
            val_losses.append(val_loss)
            val_ious.append(val_oui)
            val_distces.append(val_dist)
        
        # Save model after each epoch
        torch.save(model.state_dict(), 'model_v5.pth')

        # Print and log the losses
        print("Epoch %d/%d: train_loss %.3f val_loss %.3f val_iou %.3f val_dist %.3f" % (
            i+1, epochs, train_loss, val_loss, val_oui, val_dist))
        
    # Plotting validation and test losses
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')
    plt.legend()
    
    # Plotting distance and IoU
    plt.subplot(1, 2, 2)
    plt.plot([val_dist.cpu() for val_dist in val_distces], label='Validation Distance')  # Moving tensors to CPU
    plt.plot([val_iou.cpu() for val_iou in val_ious], label='Validation Intersection')   # Moving tensors to CPU
    plt.xlabel('Epoch')
    plt.ylabel('Metric Value')
    plt.title('Validation Distance and Intersection over Union (IoU)')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

    return train_losses, val_losses


In [None]:
# def val_metrics(model, valid_dl, C=1000):
#     model.eval()
#     total = 0
#     sum_loss = 0
#     correct = 0
#     for x, y_class, y_bb in valid_dl:
#         batch = y_class.shape[0]
#         x = x.to(device).float()
#         y_class = y_class.to(device)
#         y_bb = y_bb.to(device).float()
#         out_class, out_bb = model(x)
#         loss_class = F.cross_entropy(out_class, y_class, reduction="sum")
#         loss_bb = F.l1_loss(out_bb, y_bb, reduction="none").sum(1)
#         loss_bb = loss_bb.sum()
#         loss = loss_class + loss_bb/C
#         _, pred = torch.max(out_class, 1)
#         correct += pred.eq(y_class).sum().item()
#         sum_loss += loss.item()
#         total += batch
#     return sum_loss/total, correct/total

In [None]:
import torch.nn.functional as F
from torchvision.ops import box_iou, distance_box_iou

def center_xy(bboxes):
    #Compute the center (x, y) coordinates of bounding boxes
    return (bboxes[:, :2] + bboxes[:, 2:]) / 2

def bbox_center_distance(bbox1, bbox2):
    #Calculate the distance between the centers of two bounding boxes
    center1 = center_xy(bbox1)
    center2 = center_xy(bbox2)
    return torch.norm(center1 - center2, dim=1)

def val_metrics(model, valid_dl):
    model.eval()
    total = 0
    sum_loss = 0
    iou_total = 0
    dist_total = 0
    for x, y_class, y_bb in valid_dl:
        batch = y_class.shape[0]
        x = x.to(device).float()
        # y_class = y_class.to(device)
        y_bb = y_bb.to(device).float()
        bboxes = model(x, y_class)
        # loss_class = F.cross_entropy(out_class, y_class, reduction="sum")
        # pred_y_bb = torch.gather(boxes, -1, y_class[:, None, None].repeat(1, 4, 1))
        loss_bb = F.l1_loss(bboxes, y_bb, reduction="sum")
        loss = loss_bb
        
        iou = (box_iou(bboxes, y_bb) * torch.eye(batch, device=bboxes.device)).sum()
        iou_total += iou

        dist = bbox_center_distance(bboxes, y_bb).sum()
        dist_total += dist

        sum_loss += loss.item()
        total += batch
    return sum_loss/total, iou_total/total, dist/total

In [None]:
model = BB_model().to(device)
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(parameters, lr=0.001)

In [None]:
train_epocs(model, optimizer, train_dl, valid_dl, epochs=10)

For the function below, some outputs with bbox annotations are saved as EXAMPLE_x_test_model.png files in the Example_Output folder.

In [None]:
def test_model(model, test_dl, device):
    model.eval()
    with torch.no_grad():
        for x, y_class, y_bb in test_dl:
            x = x.to(device).float()
            y_class = y_class.to(device)
            y_bb = y_bb.cpu().float()
            bboxes = model(x)
            bboxes = bboxes.reshape(x.shape[0], 9, 4)
            for i in range(x.shape[0]):
                img = x[i].permute(1, 2, 0).cpu().numpy()  
                img = unnormalize(img)
                y_bbox = y_bb[i, ...].numpy()
                # print(bboxes[y_class[i]], y_class[i])
                bbox = bboxes[i, y_class[i], ...].cpu().numpy()
    
                # Plot image
                plt.imshow(img)
                ax = plt.gca()

                # True bounding box
                true_xmin, true_ymin, true_xmax, true_ymax = y_bbox
                true_width = true_xmax - true_xmin
                true_height = true_ymax - true_ymin
                true_rect = plt.Rectangle((true_xmin, true_ymin), true_width, true_height,
                                          linewidth=2, edgecolor='g', facecolor='none',)
                ax.add_patch(true_rect)

                # Predicted bounding box
                pred_xmin, pred_ymin, pred_xmax, pred_ymax = bbox
                pred_width = pred_xmax - pred_xmin
                pred_height = pred_ymax - pred_ymin
                pred_rect = plt.Rectangle((pred_xmin, pred_ymin), pred_width, pred_height,
                                          linewidth=2, edgecolor='r', facecolor='none')
                ax.add_patch(pred_rect)
                plt.title(str(y_class[i].item()))

                plt.show()

In [None]:
test_model(model, train_dl, device)