In [1]:
IMAGE_SIZE = 512

In [2]:
import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from tqdm import tqdm

from PIL import Image, ImageOps

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T

import torchvision.transforms.functional as TF
import torchvision.models as models

import albumentations as A

In [None]:
DATA_DIR = "/Volumes/SSD970/"
IMAGES_DIR = os.path.join(DATA_DIR, "xray_images")

In [None]:
bbox_df = pd.read_csv(os.path.join(DATA_DIR, 'segmentation_sagittal_bbox.csv')).set_index('UID')
bbox_df.head()

In [None]:
class SagittalBoundaryDataset(Dataset):
    def __init__(self, df, image_dir, transform=None):
        super().__init__()

        self.df = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        s = self.df.iloc[idx]
        UID = s.name
        img = Image.open(os.path.join(self.image_dir, UID, f"{int(s.sagittal_index)}.jpeg"))

        label = s[['xmin','ymin','xmax','ymax']]

        if self.transform:
            img, label = self.transform(img,  label)

        return img, label

dataset = SagittalBoundaryDataset(bbox_df, IMAGES_DIR)
img, label = dataset[1]
print(label)
plt.imshow(img)

In [None]:
class DataTransform(nn.Module):
    def __init__(self, image_size):
        super().__init__()
        self.transform = A.Compose([
            A.ShiftScaleRotate(p=0.5),
            A.RandomBrightnessContrast(p=0.3),
        ],
            bbox_params=A.BboxParams(format='pascal_voc'),
        )

    def forward(self, x, label):
        x = TF.center_crop(x, max(x.width, x.height))
        transformed = self.transform(image=x, bboxes=label)

        return transformed["image"], transformed["bboxes"]
transform = DataTransform(IMAGE_SIZE)