In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import draw_bounding_boxes
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.ops import RoIPool
from torch.optim.lr_scheduler import StepLR
from PIL import Image
import pandas as pd
from torch.utils.data import random_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Set PYTORCH_CUDA_ALLOC_CONF environment variable
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


In [None]:
class FetusDetector(nn.Module):
    def __init__(self):
        super(FetusDetector, self).__init__()
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        layers = list(resnet.children())[:8]
        self.features1 = nn.Sequential(*layers[:6])
        self.features2 = nn.Sequential(*layers[6:])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(2084, 81),
            nn.Sigmoid()
        )
        self.bb = nn.Linear(2084,36)

    def forward(self, image, box):
        x = self.features1(image)
        x = self.features2(x)
        x = F.relu(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        box = box.view(box.size(0), -1)
        x = torch.cat((x, box), dim=1)
        cls_logits = self.classifier(x)
        bbox_logits = self.bb(x)
        return cls_logits, bbox_logits


In [None]:
class FetusDataset(Dataset):
    def __init__(self, data_dir, labels_file, transform=None):
        self.data_dir = data_dir
        self.labels_df = pd.read_excel(labels_file)
        self.transform = transform
        self.label_map = {
            'thalami': 0,
            'nasal bone': 1,
            'palate': 2,
            'nasal skin': 3,
            'nasal tip': 4,
            'midbrain': 5,
            'NT': 6,
            'IT': 7,
            'CM': 8
        }

        self.transform_PIL = transforms.Compose([
            transforms.ToPILImage()
        ])

        self.transform_tensor = transforms.Compose([
            transforms.PILToTensor()
        ])

        # Filter out rows with non-existing image files
        self.labels_df = self.labels_df[self.labels_df.apply(lambda x: os.path.exists(os.path.join(self.data_dir, x["fname"])), axis=1)]

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.data_dir, self.labels_df.iloc[idx, 0])).convert('RGB')
        image_T = self.transform_tensor(image)

        if self.transform:
            image = self.transform(image)

        n_h, n_w = image.shape[1:]

        og_h, og_w = image_T.shape[1:]

        rows = self.labels_df[self.labels_df.iloc[:, 0] == self.labels_df.iloc[idx, 0]]
        bboxes = []
        labels = []
        for _, row in rows.iterrows():
            h_min, w_min, h_max, w_max = row[2:6].values.astype(float)
            w_min *= (n_w / og_w)
            h_min *= (n_h / og_h)
            w_max *= (n_w / og_w)
            h_max *= (n_h / og_h)
            bboxes.append([w_min, h_min, w_max, h_max])
            labels.append(self.label_map.get(row[1]))

        bboxes = torch.tensor(bboxes, dtype=torch.float32)
        labels = torch.tensor(labels)

        # print(bboxes)
        # print(labels)

        # # Convert image to torch.uint8
        # image_s = image.type(torch.uint8)

        # # Assuming draw_bounding_boxes is a function provided by torch
        # image_box_scaled = draw_bounding_boxes(image_s, boxes=bboxes, labels=[str(x) for x in labels], font=rf"\Ariel\arial.ttf", font_size=30, width=3, colors="red")

        # # Convert the resulting image to PIL
        # img = transforms.ToPILImage()(image_box_scaled)

        # # Show the image
        # plt.imshow(img)
        # plt.axis('off')
        # plt.show()

        return image, bboxes, labels


In [None]:
# Define paths to your data and labels file
data_dir = 'Dataset for Fetus Framework\Training\Standard'
labels_file = 'ObjectDetection.xlsx'

# Define transform for data augmentation and normalization
transform = transforms.Compose([
    transforms.Resize((470, 650)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the dataset and split into training and validation sets
dataset = FetusDataset(data_dir, labels_file, transform=transform)

# Check if any samples are left in the dataset
if not dataset:
    raise ValueError("No valid samples found in the dataset")


In [None]:
def custom_collate(batch):
    images = []
    boxes = []
    labels = []
    max_num_boxes = max(len(data[1]) for data in batch)

    for data in batch:
        image, bboxes, lbls = data
        images.append(image)
        num_boxes = len(bboxes)
        if num_boxes < max_num_boxes:
            pad_boxes = torch.zeros(max_num_boxes - num_boxes, 4)
            pad_labels = torch.zeros(max_num_boxes - num_boxes)
            bboxes = torch.cat((bboxes, pad_boxes), dim=0)
            lbls = torch.cat((lbls, pad_labels), dim=0)
        boxes.append(bboxes)
        labels.append(lbls)

    images = torch.stack(images, dim=0)
    boxes = torch.stack(boxes, dim=0)
    labels = torch.stack(labels, dim=0)

    return images, boxes, labels


In [None]:
train_dataset, ground_dataset = random_split(dataset, [0.8, 0.2])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)

val_loader = DataLoader(ground_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = FetusDetector().to(device)

model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
def smooth_l1_loss(prediction, target, beta=1.0):

    prediction = prediction.view(target.shape)
    diff = torch.abs(prediction - target)

    smooth_l1_loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)

    loss_per_box = smooth_l1_loss.mean(dim=-1)
    loss_per_sample = loss_per_box.mean(dim=-1)

    loss = loss_per_sample.mean()

    return loss


In [None]:
num_epochs = 40
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_accuracy = 0.0

    for batch_idx, (image, boxes, labels) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')):
        optimizer.zero_grad()

        image, boxes, labels = image.to(device), boxes.to(device), labels.to(device)

        # print(boxes.shape)

        # # Calculate the number of valid ROIs for each batch
        # num_valid_rois = (boxes != 0).sum(dim=1).to(device)

        # # Create batch indices
        # batch_indices = torch.arange(boxes.size(0)).view(-1, 1, 1).expand(-1, 10, -1).to(device)

        # # Mask out the padded regions based on the number of valid ROIs
        # padded_mask = (batch_indices < num_valid_rois.unsqueeze(1)).float().to(device)

        # # Expand batch indices to match the size of the boxes tensor
        # batch_indices_expanded = batch_indices.unsqueeze(-1).expand(-1, -1, -1, 4).to(device)

        # # Concatenate batch indices with the boxes tensor and apply the padded mask
        # boxes_with_batch_idx = torch.cat((batch_indices_expanded, boxes.unsqueeze(1)), dim=-1) * padded_mask.unsqueeze(-1).to(device)

        # # Reshape the tensor to (32 * 10, 5) and remove padding
        # boxes_with_batch_idx = boxes_with_batch_idx.view(-1, 5)

        # boxes_with_batch_idx = boxes_with_batch_idx.to(device)

        # print(boxes_with_batch_idx.shape)

        # Forward pass
        cls_logits, bbox_logits = model(image, boxes)
        # Reshape the output of the classifier to [batch_size * num_boxes, num_classes]
        cls_logits_flat = cls_logits.view(-1, 9)

        # Reshape the target labels to [batch_size * num_boxes]
        # Assuming each bounding box has a single associated class label
        labels_flat = labels.view(-1).long()

        # Calculate classification loss using CrossEntropyLoss
        loss_cls = F.cross_entropy(cls_logits_flat, labels_flat)

        # Calculate bounding box regression loss using smooth L1 loss (or any appropriate regression loss)
        # Adjust this part based on your specific loss function and requirements
        loss_bbox = smooth_l1_loss(bbox_logits, boxes)

        # Total loss
        loss = loss_cls + loss_bbox

        # Backward pass
        loss.backward()

        # Update parameters
        optimizer.step()

        # Compute accuracy
        _, predicted = torch.max(cls_logits, 1)
        predicted_expanded = predicted.unsqueeze(1).expand_as(labels)
        correct = (predicted_expanded == labels).sum().item()
        accuracy = correct / (labels.size(0) * labels.size(1))

        # Update running loss
        running_loss += loss.item()
        running_accuracy += accuracy

    # Calculate average loss per epoch
    avg_loss = running_loss / len(train_loader)
    avg_acc = running_accuracy / len(train_loader)

    print(f"Epoch {epoch+1}, Avg. Loss: {avg_loss}, Avg. accuracy: {avg_acc}")

    torch.save(model.state_dict(), 'lmao4.pt')

    scheduler.step()
