In [42]:

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_fireworks import ChatFireworks
from langchain_openai import ChatOpenAI

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an deep learning model paper writer assistant tasked with writing excellent and comprehensive and detailed description to describe a deep learning project."
            " Generate the best paper possible given user's code, which is a deep learning project containing the code to train the model, the code of model architecture and the code for data preparation"
            "To generate the best paper, you should consider the following points: parameter use, model architecture such as what does each layer do and what parameters are used, data preparation such as input data size or any preprocessing, training such as loss funcion or normalization, and evaluation such as what metrics being used, inference such as what post processing is done"
            "You must make it clear what the code does, how it works, and why it is important. The summary should be detailed and comprehensive, but also concise and easy to understand. The summary should be written in a professional and engaging tone, and should be free of grammatical errors. Especially pay attention to parameter use in the model"
            " If the user provides critique, respond with a revised version of your previous attempts."
            f"The overall structure can be referred to {demo}. You should not use any facts such as data from the demo that are not provided by the user in the code,"
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)
llm = ChatOpenAI(
    model="gpt-4o",
)
generate = prompt | llm

In [43]:
summary = ""
request = HumanMessage(
    content=f"Write a 3-page paper on the code provided, which is a segmentation model used to detect centerlines of dendrites under microscope\n {code}",
)
for chunk in generate.stream({"messages": [request]}):
    print(chunk.content, end="")
    summary += chunk.content

# Centerline Detection of Neuronal Dendrites: Deep Learning Model Workflow

## 1. Introduction

The objective of this project is to develop a robust deep learning model to detect the centerlines of neuronal dendrite membranes from grayscale images. The main goal is to accurately identify the 'backbone' structure within the images, ensuring continuity even in the presence of noise. This task is categorized under binary segmentation, focusing on isolating central lines within membrane structures. The provided dataset consists of 16-bit grayscale images of size 65x65 pixels. The membrane structures have intensity values ranging from approximately 150 to 1000, while the background intensity values range from 110 to 150. Accurate segmentation is challenging due to noise, which sometimes has pixel values similar to those of the membranes.

## 2. Data Preparation

### 2.1 Data Collection
The dataset comprises 28,634 samples stored as 65x65 16-bit PNG images. Each image is grayscale, with corr

RemoteProtocolError: peer closed connection without sending complete message body (incomplete chunked read)

In [35]:
reflection_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a teacher grading a paper about a deep learning project. Generate critique and recommendations for the user's submission."
            f"The overall paper structure can be referred to {demo}."
            " Provide detailed recommendations, including requests for length, degree of details, clarity and format, etc."
            
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)
reflect = reflection_prompt | llm

In [36]:
reflection = ""
for chunk in reflect.stream({"messages": [request, HumanMessage(content=summary)]}):
    print(chunk.content, end="")
    reflection += chunk.content

Title: Deep Learning Model for Centerline Detection of Neuronal Dendrites

Authors: [Your Name(s)]

Abstract:
This paper presents a deep learning model for the detection of centerlines in neuronal dendrite images. The model employs a customized Gabor U-Net architecture, designed to capture intricate structural details in high-depth grayscale images. The model is trained and evaluated on a dataset consisting of 28,634 16-bit grayscale images, with corresponding segmentation masks. We explore various data augmentation techniques, model architectures, and loss functions to improve the model's performance. The best-performing model achieves a high F1 score of 0.9532, demonstrating its effectiveness in accurately identifying and segmenting the central lines of membrane structures.

1. Introduction
The automatic detection of centerlines in neuronal dendrite images is a crucial task in understanding the morphology and connectivity of neuronal networks. This paper introduces a deep learning mo

In [37]:
updated_paper = ""
for chunk in generate.stream(
    {"messages": [request, AIMessage(content=summary), HumanMessage(content=reflection)]}
):
    print(chunk.content, end="")
    updated_paper += chunk.content

Title: Deep Learning Model for Centerline Detection of Neuronal Dendrites

Authors: [Your Name(s)]

Abstract:
This paper presents a deep learning model for the detection of centerlines in neuronal dendrite images. The model employs a customized Gabor U-Net architecture, designed to capture intricate structural details in high-depth grayscale images. The model is trained and evaluated on a dataset consisting of 28,634 16-bit grayscale images, with corresponding segmentation masks. We explore various data augmentation techniques, model architectures, and loss functions to improve the model's performance. The best-performing model achieves a high F1 score of 0.9532, demonstrating its effectiveness in accurately identifying and segmenting the central lines of membrane structures.

1. Introduction
The automatic detection of centerlines in neuronal dendrite images is a crucial task in understanding the morphology and connectivity of neuronal networks. This paper introduces a deep learning mo

In [25]:
print(f"Write an summary on the code provided, which is a segmentation model used to detect centerlines of dendrites under microscope\n {code}")

Write an summary on the code provided, which is a segmentation model used to detect centerlines of dendrites under microscope
 

[Python code for training the model]

import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from data import CenterlineDataset
from gabor_unet_nfc import GaborUNet
import argparse
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torchvision import transforms
import matplotlib.pyplot as plt
from torch import nn
import os
import datetime
import uuid
import math
from scipy.ndimage import distance_transform_edt
import numpy as np
from unet_model import UNet
import torch
from torchvision.transforms import functional as TF




# class WeightedBCELoss(nn.Module):
#     '''Less weight is given to pixels in the dilated area but not in the original target.'''
#     def __init__(self, custom_weight=0.5, loss_function='bce'):
#         super(WeightedBCELoss, self).__init__()
#         self.custom_weight = custom

In [41]:
demo = """
Centerline Detection of Neuronal Dendrites: Deep Learning Model Workflow
Chainathan Santhanam Sudhakar 24 May 2024
1 Introduction
The objective of this project is to develop a model for centerline detection of Neuronal Dendrite mem- branes. The goal is to predict the ’backbone’ of the structure with outputs that are as continuous as possible. This task falls under the category of binary segmentation, focusing on identifying the central lines within membrane structures in grayscale images.
The provided dataset consists of 16-bit grayscale images of size 65x65 pixels. The membrane struc- tures have intensity values ranging from approximately 150 to 1000, whereas the background intensity ranges from approximately 110 to 150. However, it is noted that in some areas, the noise pixel values are similar to those of the membranes, which presents an additional challenge for accurate segmentation.
The model needs to handle these high-depth images and accurately segment the central lines of the membranes despite the presence of noise.
2 Data Preparation 2.1 Data Collection
The data used for this task was provided by evaluators, consisting of 28,634 samples stored as 65x65 16-bit PNG images. Each image is grayscale, with corresponding segmentation masks also stored as 65x65 PNGs. This ensures a direct mapping between images and their labels.
2.2 Data Preprocessing
Effective data preprocessing is critical for preparing the data for model training. The following steps were taken:
1. Intensity Range Adjustment:
• From the data analysis, the intensity range for the background was adjusted from [110, 150] to [100, 150], and for the membrane from [150, 1000] to [150, 1010]. This adjustment ensures a higher percentage of samples fall within these ranges, thereby improving the robustness of the model.
2. Handling Empty Samples:
• It was observed that 1,000 samples (3.5% of the total dataset) contained no centerline informa- tion. These empty samples were reduced to 100, as retaining a small number of empty samples (0.35% of the total dataset) helps the model learn to identify the absence of centerlines, which is a valid scenario. However, having too many could bias the model towards predicting no centerline, which would be detrimental to its performance on samples with actual centerlines.
3. Normalization and Clipping:
• Images were clipped to the specified ranges and normalized to the [0, 1] range using MINMAX Normalization. This step is crucial for ensuring that the pixel values are within a consistent range for model training, enhancing convergence during the training process.
4. Data Augmentation:
1
• Applied transformations to increase the diversity of the training data and improve the model’s ability to generalize to new, unseen data:
– Elastic Deformation: These deformations simulate realistic distortions that could occur in the membrane images, helping the model become robust to slight variations and dis- tortions. Elastic deformation involves randomly distorting the image and mask together, maintaining their alignment.
– Horizontal and Vertical Flips: These augmentations enhance the model’s ability to rec- ognize membranes irrespective of their orientation.
Additionally, an attempt was made to use a Patch and Reconstruct approach:
• Patch Construction: This approach involved randomly selecting four images, constructing larger patches, and using these to train the model. The images were then either downscaled to 65x65 or directly fed into the model as 110x110 images. However, the results were suboptimal. Downscaling led to significant information loss due to the small resolution, while the larger patch size might have led to difficulty in learning due to the increase in input size, potentially requiring modifications in the model architecture and hyperparameters to handle effectively.
5. Dataset and DataLoader Construction:
• Custom PyTorch Dataset classes (CenterlineDataset and PatchConstructDataset) were im- plemented to handle the preprocessing steps and facilitate loading of data during training. The DataLoader objects were created for batching and shuffling the data during training and evaluation phases.
2.3 Data Splitting
The dataset was split into training, validation, and test sets with a ratio of 70% training, 15% validation, and 15% testing. This ensures that the model is trained on a substantial portion of the data while reserving enough data for unbiased validation and testing.
3 Model Development 3.1 Model Selection
For the task of centerline detection of membranes, the U-Net architecture was selected due to its ef- fectiveness in biomedical image segmentation tasks. U-Net is particularly well-suited for this task as it captures both the spatial context and fine-grained details, which are essential for accurate segmentation.
Additionally, several variations of U-Net were experimented with to enhance performance:
1. Standard U-Net: A custom baseline model for our task based on Original U-Net architecture.
2. Mini U-Net: A smaller version to test the trade-off between model complexity and performance.
3. Extended U-Net: An extended version with additional layers to improve noise handling and segmentation accuracy.
4. U-Net with Self-Attention (U-Net SA): Incorporates self-attention mechanisms to better capture long-range dependencies and improve segmentation quality.
5. Mini U-Net with Self-Attention: A smaller version of U-Net SA.
3.2 Model Architecture
1. Standard U-Net:
• Size: 1931201
• Encoding Path: Consists of three encoding blocks, each followed by a max-pooling layer. Each block includes two convolutional layers with ReLU activations and batch normalization, designed to progressively capture more complex features.
2

• Bottleneck: A convolutional block at the lowest resolution to capture the most abstract fea- tures of the input image.
• Decoding Path: Symmetrically structured like the encoding path, but with transposed con- volutions (upconvolutions) for upsampling. Each decoding block combines features from the corresponding encoding block through concatenation, followed by convolutional layers.
• Output Layer: A series of convolutional layers with ReLU activations and batch normalization, ending with a sigmoid activation to produce the final segmentation map.
2. Mini U-Net:
• Size: 471361
• Encoding Path: Similar to the standard U-Net but with only two encoding blocks, reducing the model’s complexity and computational requirements.
• Bottleneck and Decoding Path: Follows the same principles as the standard U-Net but with fewer layers.
• Output Layer: Matches the structure of the standard U-Net’s output layer. 3. Extended U-Net:
• Size: 1933577
• Encoding and Decoding Paths: Same structure as the standard U-Net but with additional
layers in the output segment to enhance feature extraction and noise reduction.
• Output Layer: Contains more convolutional layers with batch normalization and ReLU acti- vations, aiming to refine the segmentation output and mitigate the influence of noise.
4. U-Net with Self-Attention (U-Net SA):
• Size: 1958364
• Encoding Path: Same structure as the standart U-Net but includes self-attention modules in the skip connection block to capture global contextual information.
• Decoding Path: Similar to the standard U-Net but incorporates the enhanced feature maps from the self-attention modules.
• Output Layer: Structured like the standard U-Net, ending with a sigmoid activation for the final output.
5. U-Net Mini with Self-Attention (U-Net Mini SA):
• Size: 477883
• Encoding Path: Similar to the standard U-Net with Self Attention but with only two encoding blocks, reducing the model’s complexity and computational requirements.
3.3 Training
3.3.1 Training Setup 1. Hyperparameters:
• Learning Rate: Initially set to 1e-3 and then adjusted towards 1e-4 based on the convergence behavior.
• Batch Size: Depending on the available computational resources (the models were trained between Kaggle and local GPU based on availability), a batch size of 128(64 on local GPU) was chosen to balance between training speed and stability.
• Number of Epochs: Set to 50 epochs, or until convergence. 2. Loss Function:
• Binary Cross-Entropy Loss: Given the binary nature of the segmentation task, BCE is used as it measures the difference between the predicted probability distribution and the actual distribution. It is particularly suitable for pixel-wise classification tasks like segmentation.
3

• Combined: Dice Loss + BCE Loss; Additionally, Dice Loss is used to handle class imbalance and ensure that the overlap between the predicted and actual segments is maximized. The combined loss was be used to leverage the strengths of both.
3. Optimization Algorithm:
• Adam Optimizer: Selected for its efficiency and adaptive learning rate capabilities. Adam combines the advantages of AdaGrad and RMSProp, making it well-suited for dealing with sparse gradients and noisy data.
4. Training Loop:
• Forward Pass: Compute the predicted segmentation mask for the input batch.
• Loss Calculation: Calculate the BCE or Combined Loss for the batch.
• Backward Pass: Perform backpropagation to compute the gradients.
• Weight Update: Update the model weights using the Adam optimizer.
• Monitoring: Track training and validation loss and metrics at the end of each epoch to monitor progress.
3.4 Testing
3.4.1 Model Evaluation 1. Performance Metrics:
• Intersection over Union (IoU): Measures the overlap between the predicted segmentation and the ground truth, divided by the union of both. It is a critical metric for segmentation tasks, especially for evaluating the accuracy of the predicted regions.
• Dice Coefficient: Similar to IoU, it measures the overlap between the predicted and actual segments but is more sensitive to small segmentations, making it useful for medical image segmentation where precision is crucial.
• Precision and Recall: Precision measures the ratio of true positive predictions to all positive predictions, while recall measures the ratio of true positive predictions to all actual positives. These metrics are important for understanding the balance between false positives and false negatives.
• F1-Score: The harmonic mean of precision and recall, providing a single metric that balances both aspects. (Primary Focus)
2. Evaluation Procedure:
• Loading the Test Data: Use the test set created during the data preparation phase. Ensure
that the data is not used during training to maintain an unbiased evaluation.
• Model Inference: Run the trained model on the test set to obtain predicted segmentation masks.
• Metric Calculation: Compute the performance metrics for the predicted masks against the ground truth masks.
3. Visualization:
• Segmentation Results: Visualize a few test images with their predicted and ground truth
segmentation masks to qualitatively assess the model’s performance.
• Metric Visualization: Plot metrics for the test set to provide a clear understanding of the model’s performance.
4

4 Inference
4.1 Loading the Trained Model
1. Model Loading:
• Load the trained model weights from the saved checkpoint to ensure that the model can be used for inference on new data. This involves initializing the model architecture and loading the state dictionary with the trained weights and moving the model to respective device.
4.2 Running Inference on New Data
1. Data Preparation:
• Loading Images: Load the new 16-bit grayscale images.
• Clipping Intensity Values: Clip the intensity values of the images to the range [100, 1010].
• Normalization: Normalize the images to the [0, 1] range.
• Adding Channel Dimension: Add a channel dimension to the images to match the input requirements of the model.
2. Inference Process:
• Pass the new data through the loaded model to obtain the predicted segmentation masks.
• Ensure the model is in evaluation mode to disable batch normalization layers from updating during inference.
3. Post-processing:
• Apply post-processing steps to refine the predicted masks. This includes thresholding the output probabilities to obtain binary masks and applying morphological operations like skele- tonization depending on the task requirement.
4.3 Visualization of Inferred Segmentation Masks
1. Displaying Results:
• Visualize the original images alongside the predicted segmentation masks to qualitatively assess the model’s performance on new data. This helps in understanding how well the model generalizes to unseen samples.
2. Saving Results:
• Save the predicted masks for further analysis or for use in downstream applications. This can
be done by converting the predicted tensors to images and saving them in the desired format.
5

5 Results
The models were evaluated based on several metrics, with a particular focus its balance between precision and recall. Below are the results of the different configurations, listed in order of their F1 scores:
on the F1 score due to model architectures and
 Model
Normal U-Net Mini U-Net Mini SA U-Net Extended U-Net SA U-Net
Epoch Loss
42 BCE
43 BCE
19 BCE
55 Combined 40 BCE
F1 Score
0.9532 0.9409 0.9334 0.9425 0.9249
Dice IoU Precision
0.9532 0.9106 0.9687 0.9409 0.8884 0.9644 0.9334 0.8754 0.9552 0.9425 0.8914 0.9235 0.9249 0.8606 0.9487
Recall
0.9381 0.9185 0.9128 0.9624 0.9025
      6
•
•
•
•
•
Table 1: Best Performance metrics of different U-Net models.
Detailed Results and Analysis
Normal U-Net with BCE Loss: Achieved the highest F1 score, indicating the best overall performance in balancing precision and recall. This model is particularly effective in accurately identifying both true positive and false negative rates.
Mini U-Net with BCE Loss: While being a smaller model, it still provided a high F1 score, demonstrating that reducing the model complexity did not significantly compromise its ability to accurately segment the images.
Mini Self-Attention U-Net with BCE Loss: The inclusion of self-attention mechanisms in the mini U-Net improved its ability to understand spatial dependencies, leading to a robust seg- mentation performance with a high F1 score with least amount of training.
Extended U-Net with Combined Loss: This model had a high recall, making it very effective in identifying true positive segments. The use of Dice Loss helped in balancing class imbalances but resulted in slightly lower precision.
Self-Attention U-Net with BCE Loss: Despite incorporating self-attention, this model had a lower recall compared to the other models, which impacted its overall F1 score. However, it still showed good precision.
The analysis of these results indicates that the Normal U-Net with BCE Loss is the best performing model for this task, achieving the highest F1 score and thus the best balance between precision and recall. Other models also performed well, each with its strengths, suggesting various trade-offs between complexity, precision, and recall.
"""

In [38]:
code = """

[Python code for training the model]

import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from data import CenterlineDataset
from gabor_unet_nfc import GaborUNet
import argparse
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torchvision import transforms
import matplotlib.pyplot as plt
from torch import nn
import os
import datetime
import uuid
import math
from scipy.ndimage import distance_transform_edt
import numpy as np
from unet_model import UNet
import torch
from torchvision.transforms import functional as TF


class To16BitTensor:
    def __call__(self, image):
        image_tensor = TF.to_tensor(image)
        image_tensor = image_tensor / 65535.0
        return image_tensor

class Normalize16BitRange(transforms.ToTensor):
    def __call__(self, pic):
        '''
            Input: PIL Image or numpy array
            Output: Tensor Normalized to [0, 1] based on the actual data range
        '''
        img = torch.from_numpy(np.array(pic, dtype=np.float32, copy=True))
        
        # mimic ToTensor having shape [C, H, W]
        if img.dim() == 2:
            img = img.unsqueeze(0)

        # Normalization based on the actual data range
        # min_val = torch.min(img)
        # max_val = torch.max(img)
        # min_val = 108
        max_val = 2000
        # pixels greater than 2826 fixed at 2826
        img = torch.where(img > max_val, torch.tensor(max_val), img)
        # pixels less than 90 fixed at 90
        # img = torch.where(img < min_val, torch.tensor(min_val), img)
        img = img / max_val  # Normalize to [0, 1]

        return img


       
class WeightedBCELoss(nn.Module):
    def __init__(self, weight=None):
        super().__init__()
        self.custom_weight = weight
    def forward(self, prediction, label):
        
        # Apply the general weighted binary cross entropy loss function, assigning higher penality to false negatives.
        # The purpose here is to handle imbalanced datasets, where the number of negative pixels is much higher than the number of positive pixels. 
        
        weight = torch.ones_like(label)
        # Handle class imbalance
        # GET THE WEIGHT
        if self.custom_weight is not None:
            weight[(label == 1)] = self.custom_weight
        else:        
            # calculate the number of all pixles in the label 
            total_pixels = label.numel()
            # calculate the number of positive pixels in the label
            positive_pixels = label.sum().item()
            # calculate the number of negative pixels in the label
            negative_pixels = total_pixels - positive_pixels
            # calculate the weights for the positive pixels and negative pixels
            weight_positive_prime = max(math.log(2 * total_pixels / (positive_pixels + 1e-6)), 1.0)
            weight_negative_prime = max(math.log(2 * total_pixels / (negative_pixels+ 1e-6)), 1.0)
            weight_positive = weight_positive_prime / (weight_positive_prime + weight_negative_prime)
            weight_negative = weight_negative_prime / (weight_positive_prime + weight_negative_prime)
            # assign the weights to the pixels
            weight[label == 1] = weight_positive
            weight[label == 0] = weight_negative
        
            # Assign significance per pixel based on their location
            # find the distance from each pixel to the nearest centerline pixel
            # label_np = label.numpy()
            # binary_mask = (label_np == 0).astype(np.int) # 1 for background, 0 for centerline
            # distance_map_np = distance_transform_edt(binary_mask)
            # distance_map = torch.from_numpy(distance_map_np)
            # GPU compatible version
            label_np = label.cpu().numpy()
            binary_mask = (label_np == 0).astype(int) # 1 for background, 0 for centerline
            distance_map_np = distance_transform_edt(binary_mask)
            distance_map = torch.from_numpy(distance_map_np).to(label.device)
            weight = weight * (1 - 0.01) ** distance_map
        
        # print("The data type of prediction is: ", prediction.dtype)
        # print("The data type of label is: ", label.dtype)

        loss = F.binary_cross_entropy(prediction, label, weight=weight) 
        return loss


def save_checkpoint(state, filename='checkpoint.pth.tar'):
    # state here is a dictionary containing the model's state_dict, the optimizer's state_dict, and the epoch number
    torch.save(state, filename)
    print(f"Checkpoint saved to {filename}")


def load_checkpoint(checkpoint_dir, model, optimizer, device):
    latest_checkpoint = None
    max_epoch = -1
    for file in os.listdir(checkpoint_dir):
        if file.startswith('checkpoint') and file.endswith('.pth.tar'):
            epoch_num = int(file.split('_')[-1].split('.')[0])
            if epoch_num > max_epoch:
                max_epoch = epoch_num
                latest_checkpoint = file
    
    if latest_checkpoint is not None:
        checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print(f"Checkpoint loaded. Starting from epoch {start_epoch}")
    else:
        print("No checkpoint found. Starting from scratch.")
        start_epoch = 1
    
    return model, optimizer, start_epoch


def train(args, model, device, train_loader, optimizer, epoch, custom_loss=None, log_name='Analysis/logs/log.txt', checkpoint_dir = 'checkpoints'):
    log_messages = []
    model.train()
    train_loss = 0
    for batch_idx, data_tuple in enumerate(train_loader):
        # Unpack the data_tuple and check for None values
        if any(x is None for x in data_tuple):
            print(f"Skipping batch {batch_idx} due to None values")
            log_messages.append(f"Skipping batch {batch_idx} due to None values")
            continue  # Skip this batch
        
        data, target = (x.to(device) for x in data_tuple)
        optimizer.zero_grad()
        output = model(data)
        # print(target.shape)
        # print(output.shape)
        # print(target.dtype)
        # print(output.dtype)
        # print("The max value of the target is: ", torch.max(target).item())
        # print("The min value of the target is: ", torch.min(target).item())
        # print("The max value of the output is: ", torch.max(output).item())
        # print("The min value of the output is: ", torch.min(output).item())
        if not custom_loss:
            loss = F.binary_cross_entropy(output, target)
        else:
            # loss = custom_loss(output, target, dilated_target)
            loss = custom_loss(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if batch_idx % args.log_interval == 0:
            message = f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss (Batch): {loss.item():.6f}'
            print(message)
            # add to log 
            log_messages.append(message)

    with open(log_name, 'a') as log_file:
        for message in log_messages:
            log_file.write(f"{'*'*20} Training Epoch: {epoch} {'*'*20}\n")
            log_file.write(message + '\n')

    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth.tar")
    save_checkpoint({
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, filename=checkpoint_path)

    average_loss = train_loss / len(train_loader)
    return average_loss


def test(args, model, device, test_loader, epoch, custom_loss=None, log_name='Analysis/logs/log.txt', loss_function='bce'):
    log_messages = []
    model.eval()
    test_loss = 0
    total_correct_pixels = 0
    total_centerline_pixels = 0
    test_f1_score = 0
    with torch.no_grad():
        for batch_idx, data_tuple in enumerate(test_loader):
            if any(x is None for x in data_tuple):
                print(f"Skipping batch {batch_idx} due to None values")
                log_messages.append(f"Skipping batch {batch_idx} due to None values")
                continue 
            
            data, target = (x.to(device) for x in data_tuple)
            output = model(data)
            if not custom_loss:
                loss = F.binary_cross_entropy(output, target, reduction='sum').item()
            else:
                # loss = custom_loss(output, target, dilated_target).item()
                loss = custom_loss(output, target).item()
            test_loss += loss
            # pred = output > 0.5
            # correct_pixels = pred.eq(target.view_as(pred)).sum().item()
            # total_correct_pixels += correct_pixels
            # total_pixels += target.numel()
            # accuracy = 100. * correct_pixels / target.numel()

            # accuracy = the number of corrected predicted centerline pixels divided by the number of centerline pixels in the target
            pred_binary = (output > 0.5).type(torch.bool)
            target_binary = target.type(torch.bool)

            # centerline_pixels = target_binary.sum().item()
            # total_centerline_pixels += centerline_pixels
            # correct_pixels = (pred_binary & target_binary).sum().item()
            # total_correct_pixels += correct_pixels

            # if centerline_pixels > 0:
            #     accuracy = 100. * correct_pixels / centerline_pixels
            # else:
            #     accuracy = 100

            TP = (pred_binary & target_binary).sum().item()
            FP = (pred_binary & ~target_binary).sum().item()
            FN = (~pred_binary & target_binary).sum().item()

            # calculate accuracy, precision, and recall
            if (TP + FN) > 0:
                recall = 100. * TP / (TP + FN)
            else:
                recall = 0
            
            if (TP + FP) > 0:
                precision = 100. * TP / (TP + FP)
            else:
                precision = 0
            
            # Calculate F1 Score
            if (precision + recall) > 0:
                f1_score = 2 * (precision * recall) / (precision + recall)
            else:
                f1_score = 0
                
            test_f1_score += f1_score
            total_centerline_pixels += target_binary.sum().item()
            total_correct_pixels += TP

            if batch_idx % args.log_interval == 0:
                message = f'Test Epoch: {epoch} [{batch_idx * len(data)}/{len(test_loader.dataset)} ({100. * batch_idx / len(test_loader):.0f}%)]\tLoss (Batch): {loss:.6f} \t Accuracy (Batch): {recall:.2f}% \t F1 Score (Batch): {f1_score:.2f}%'
                print(message)
                # add to log
                log_messages.append(message)

    with open(log_name, 'a') as log_file:
        for message in log_messages:
            log_file.write(f"\n{'*'*20} Testing Epoch: {epoch} {'*'*20}\n")
            log_file.write(message + '\n')     

    average_loss = test_loss / len(test_loader)  # Calculate average loss per batch
    average_accuracy = 100. * total_correct_pixels / total_centerline_pixels  # Calculate recall (accuracy)
    average_f1_score = test_f1_score / len(test_loader)  # Calculate average F1 score per batch
    return average_loss, average_accuracy, average_f1_score


def main():
    
    # COMMAND LINE ARGUMENTS
    parser = argparse.ArgumentParser(description='PyTorch GaborUNet Training')
    parser.add_argument('--batch-size', type=int, default=16, metavar='N',
                        help='input batch size for training (default: 16)')
    parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
                        help='input batch size for testing (default: 100)')
    parser.add_argument('--epochs', type=int, default=50, metavar='N',
                        help='number of epochs to train (default: 50)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='learning rate step gamma (default: 0.7)')
    parser.add_argument('--step-size', type=int, default=1, metavar='N',)
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    # parser.add_argument('--save-dilated', action='store_true', default=False, 
    #                     help='save dilated labels for debugging')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--custom-weight', action='store_true', default=False,
                        help='less penality for pixels in dilated area but not in original target')
    parser.add_argument('--weight', type=float, default=8, help='Custom weight for the BC or NLL loss for pixels in dilated area but not in the original target')
    parser.add_argument('--weight-decay', type=float, default=1e-5, help='Weight decay for optimizer')
    # parser.add_argument('--add-poor-quality-training', action='store_true', default=False, help='Add poor quality training data to the dataset')
    parser.add_argument('--turn-off-analysis', action='store_true', default=False, help='Turn off analysis of the dataset (saving the loss image)')
    parser.add_argument('--gabor-kernel-size', type=int, default=19, help='Size of the Gabor kernel')
    parser.add_argument('--message', type=str, default='', help='Additional message to add to the log')
    # parser.add_argument('--dilation-iterations', type=int, default=1, help='Number of dilation iterations for the dilated label')
    # parser.add_argument('--loss-function' , type=str, default='bce', help='Loss function to use (bce or nll)')
    parser.add_argument('--train-from-checkpoint', type=str, default=None, help='Input the logid of the the checkpoint to resume training')
    args = parser.parse_args()

    log_messages = []
    # generate a unique id 
    log_id = uuid.uuid4() if args.train_from_checkpoint is None else args.train_from_checkpoint
    log_name = f'Analysis/logs/log-{datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}-{log_id}.txt'

    # add to log
    log_messages.append(f'Log ID: {log_id}; Message: {args.message}')
    log_messages.append(f'Starting time: {datetime.datetime.now()}')
    log_messages.append(f"{'*'*20}")
    log_messages.append(f"Batch size: {args.batch_size}, help='input batch size for training (default: 16)")
    log_messages.append(f"Number of epochs: {args.epochs}, help='number of epochs to train (default: 50)")
    log_messages.append(f"Learning rate: {args.lr}, help='learning rate (default: 0.01)")
    log_messages.append(f"Learning rate step gamma: {args.gamma}, help='learning rate step gamma (default: 0.7)")
    log_messages.append(f"Learning rate step size: {args.step_size}, help='learning rate step size")
    # log_messages.append(f"Save dilated labels: {args.save_dilated}, help='save dilated labels for debugging")
    log_messages.append(f"Log interval: {args.log_interval}, help='how many batches to wait before logging training status")
    log_messages.append(f"Custom weight: {args.custom_weight}, help='less penality for pixels in dilated area but not in original target")
    log_messages.append(f"weight: {args.weight}, help='Custom weight for the BCE OR NLL loss for pixels in dilated area but not in the original target")
    log_messages.append(f"Weight decay: {args.weight_decay}, help='Weight decay for optimizer")
    # log_messages.append(f"Add poor quality training: {args.add_poor_quality_training}, help='Add poor quality training data to the dataset")
    log_messages.append(f"Turn off analysis: {args.turn_off_analysis}, help='Turn off analysis of the dataset (saving the loss image)")
    log_messages.append(f"Gabor kernel size: {args.gabor_kernel_size}, help='Size of the Gabor kernel")
    # log_messages.append(f"Dilation iterations: {args.dilation_iterations}, help='Number of dilation iterations for the dilated label")
    # log_messages.append(f"Loss function: {args.loss_function}, help='Loss function to use (bce or nll)")
    log_messages.append(f"Train from checkpoint: {args.train_from_checkpoint}, help='Input the logid of the the checkpoint to resume training")
    log_messages.append(f"{'*'*20}")


    if args.custom_weight:
        print(f"Using custom loss function with weight: {args.weight}")
        log_messages.append(f"Using custom loss function with weight: {args.weight}")



    transform = transforms.Compose([
        # transforms.RandomHorizontalFlip(),  
        # transforms.RandomRotation(15),  
        # transforms.ToTensor(),  
        Normalize16BitRange(),
        # Normalize16Bit(),
        # transforms.ToTensor()
    ])

    # hand-labelled simple_structure_dataset_dilated
    data_dir_simple_structure_dilated = 'simple_structure_dataset_dilated'
    img_dir_simple_structure_dilated = f"{data_dir_simple_structure_dilated}/samples"
    label_dir_simple_structure_dilated = f"{data_dir_simple_structure_dilated}/labels"
    simple_structure_dilated_dataset = CenterlineDataset(img_dir_simple_structure_dilated, label_dir_simple_structure_dilated, transform, augmentation=False)
    flipped_simple_structure_dilated_dataset = CenterlineDataset(img_dir_simple_structure_dilated, label_dir_simple_structure_dilated, transform, augmentation="flip")


    # hand-labelled simple_structure_dataset_dilated_masked
    data_dir_simple_structure_dilated_masked = 'simple_structure_dataset_dilated_masked'
    img_dir_simple_structure_dilated_masked = f"{data_dir_simple_structure_dilated_masked}/samples"
    label_dir_simple_structure_dilated_masked = f"{data_dir_simple_structure_dilated_masked}/labels"
    simple_structure_dilated_masked_dataset = CenterlineDataset(img_dir_simple_structure_dilated_masked, label_dir_simple_structure_dilated_masked, transform, augmentation=False)
    flipped_simple_structure_dilated_masked_dataset = CenterlineDataset(img_dir_simple_structure_dilated_masked, label_dir_simple_structure_dilated_masked, transform, augmentation="flip")

    # hand-labelled branch_point_dataset_dilated
    data_dir_branch_point_dilated = 'branch_point_dataset_dilated'
    img_dir_branch_point_dilated = f"{data_dir_branch_point_dilated}/samples"
    label_dir_branch_point_dilated = f"{data_dir_branch_point_dilated}/labels"
    branch_point_dilated_dataset = CenterlineDataset(img_dir_branch_point_dilated, label_dir_branch_point_dilated, transform, augmentation=False)
    flipped_branch_point_dilated_dataset = CenterlineDataset(img_dir_branch_point_dilated, label_dir_branch_point_dilated, transform, augmentation="flip")

    # hand-labelled branch_point_dataset_dilated_masked
    data_dir_branch_point_dilated_masked = 'branch_point_dataset_dilated_masked'
    img_dir_branch_point_dilated_masked = f"{data_dir_branch_point_dilated_masked}/samples"
    label_dir_branch_point_dilated_masked = f"{data_dir_branch_point_dilated_masked}/labels"
    branch_point_dilated_masked_dataset = CenterlineDataset(img_dir_branch_point_dilated_masked, label_dir_branch_point_dilated_masked, transform, augmentation=False)
    flipped_branch_point_dilated_masked_dataset = CenterlineDataset(img_dir_branch_point_dilated_masked, label_dir_branch_point_dilated_masked, transform, augmentation="flip")

    # hand labelled noisy_patch_dataset
    data_dir_noisy_patch = 'noisy_patch_dataset'
    img_dir_noisy_patch = f"{data_dir_noisy_patch}/samples"
    label_dir_noisy_patch = f"{data_dir_noisy_patch}/labels"
    noisy_patch_dataset = CenterlineDataset(img_dir_noisy_patch, label_dir_noisy_patch, transform, augmentation=False)


    # combined_dataset = ConcatDataset([original_dataset, flipped_dataset, masked_dataset, masked_dataset_flipped, bg_dataset, branch_dataset, branch_dataset_flipped])
    combined_dataset = ConcatDataset([simple_structure_dilated_dataset, 
                                      flipped_simple_structure_dilated_dataset, 
                                      simple_structure_dilated_masked_dataset, 
                                      flipped_simple_structure_dilated_masked_dataset, 
                                      branch_point_dilated_dataset, 
                                      flipped_branch_point_dilated_dataset, 
                                      branch_point_dilated_masked_dataset, 
                                      flipped_branch_point_dilated_masked_dataset, 
                                      noisy_patch_dataset])

    # Determine lengths for train and test sets
    total_size = len(combined_dataset)
    train_size = int(0.98 * total_size)
    test_size = total_size - train_size

    # Split the dataset
    torch.manual_seed(42)
    train_dataset, test_dataset = random_split(combined_dataset, [train_size, test_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

    print(f"Dataset Preparation Complete! Train loader size: {len(train_loader)}, Test loader size: {len(test_loader)}")
    log_messages.append(f"Dataset Preparation Complete! Train loader size: {len(train_loader)}, Test loader size: {len(test_loader)}")


    # TRAINING
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    # custom_loss = WeightedBCELoss(args.custom_weight, args.loss_function) if args.custom_loss else None
    # custom_weight = args.custom_weight if args.custom_weight else None
    # custom_loss = WeightedBCELoss(custom_weight=custom_weight) if args.custom_weight else WeightedBCELoss()
    custom_loss = WeightedBCELoss(weight=args.weight) if args.custom_weight else WeightedBCELoss()

    model = GaborUNet(kernel_size=args.gabor_kernel_size, in_channels=1, out_channels=1, num_orientations=8, num_scales=5).to(device)
    # model = UNet(in_channels=1, out_channels=1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    if args.train_from_checkpoint:
        print(f"Resuming training from checkpoint {args.train_from_checkpoint}")
        log_messages.append(f"Resuming training from checkpoint {args.train_from_checkpoint}")
    else:
        print("Training Started!")
        log_messages.append("Training Started!")

    checkpoint_dir = f'checkpoints-{log_id}'
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Load model from checkpoint if available
    model, optimizer, start_epoch = load_checkpoint(checkpoint_dir, model, optimizer, device)

    train_losses = []
    test_losses = []
    test_accuracies = []
    f1_scores = []

    # dealing with checkpoint, append logs to the log file if loading from checkpoint
    log_file_mode = 'a' if os.path.exists(log_name) and args.train_from_checkpoint else 'w'
    with open(log_name, log_file_mode) as log_file:
        for message in log_messages:
            log_file.write(message + '\n')

    log_messages = []

    for epoch in range(start_epoch, args.epochs + 1):
        train_loss = train(args, model, device, train_loader, optimizer, epoch, custom_loss=custom_loss, log_name=log_name, checkpoint_dir=checkpoint_dir)
        test_loss, test_accuracy, test_f1_score = test(args, model, device, test_loader, epoch, custom_loss=custom_loss, log_name=log_name)
        
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        f1_scores.append(test_f1_score)
        
        scheduler.step()

    # torch.save(model.state_dict(), f'gabor_unet_model_state_dict-{log_id}.pth')
    # print(f"Model's state_dict saved to gabor_unet_model_state_dict-{log_id}.pth")
    # log_messages.append(f"Model's state_dict saved to gabor_unet_model_state_dict-{log_id}.pth")

    # torch.save(model, f'gabor_unet_model_complete-{log_id}.pth')
    # print(f"Entire model saved to gabor_unet_model_complete-{log_id}.pth")
    # log_messages.append(f"Entire model saved to gabor_unet_model_complete-{log_id}.pth")

    torch.save(model.state_dict(), f'gabor_unet_model_state_dict-{log_id}.pth')
    print(f"Model's state_dict saved to gabor_unet_model_state_dict-{log_id}.pth")
    log_messages.append(f"Model's state_dict saved to gabor_unet_model_state_dict-{log_id}.pth")

    torch.save(model, f'gabor_unet_model_complete-{log_id}.pth')
    print(f"Entire model saved to gabor_unet_model_complete-{log_id}.pth")
    log_messages.append(f"Entire model saved to gabor_unet_model_complete-{log_id}.pth")

    
    # save the image into the folder 'Analysis'
    if not args.turn_off_analysis:
        if not os.path.exists("Analysis"):
            os.makedirs("Analysis")
        import matplotlib.pyplot as plt

        plt.figure(figsize=(10, 5))
        plt.subplot(1, 3, 1)
        plt.plot(range(1, args.epochs + 1), train_losses, label='Train Loss')
        plt.plot(range(1, args.epochs + 1), test_losses, label='Test Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Loss Curve')
        plt.legend()

        plt.subplot(1, 3, 2)
        plt.plot(range(1, args.epochs + 1), test_accuracies, color='red', label='Test Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.title('Accuracy Curve')
        plt.legend()

        plt.subplot(1, 3, 3)
        plt.plot(range(1, args.epochs + 1), f1_scores, color='green', label='F1 Score')
        plt.xlabel('Epoch')
        plt.ylabel('F1 Score')
        plt.title('F1 Score Curve')
        plt.legend()

        # print hyperparameters in the image 
        plt.text(0.5, 0.5, 
         ('Batch size: {}\n'
          'Epochs: {}\n'
          'Learning rate: {}\n'
          'Gamma: {}\n'
          'Step size: {}\n'
          'Custom weight: {}\n'
          'Weight: {}\n'
          'Weight decay: {}\n'
          'Gabor kernel size: {}').format(args.batch_size, args.epochs, args.lr, args.gamma, 
                                          args.step_size, args.custom_weight, args.weight, 
                                          args.weight_decay, 
                                          args.gabor_kernel_size), 
         horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)

        


        plt.tight_layout()
        t = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        plt.savefig(f'Analysis/loss&anloss_accuracy_curve-{t}-{log_id}.png')
        print(f"Loss and accuracy curve saved to Analysis/loss_accuracy_curve-{t}-{log_id}.png")
        log_messages.append(f"Loss and accuracy curve saved to Analysis/loss&accuracy/loss_accuracy_curve-{t}-{log_id}.png")
    

    print("Training Complete!")
    log_messages.append("Training Complete!")
    log_messages.append(f'Ending time: {datetime.datetime.now()}')

    with open(log_name, 'a') as log_file:
        for message in log_messages:
            log_file.write(message + '\n')


if __name__ == '__main__':
    main()


    
[Python code for the model architecture]

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class GaborUNet(nn.Module):
    def __init__(self, kernel_size, in_channels=1, out_channels=1, num_orientations=8, num_scales=5):
        super(GaborUNet, self).__init__()
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_orientations = num_orientations
        self.num_scales = num_scales

        # Encoder (Contracting path)
        self.enc_conv1 = self.doubleGaborConv(in_channels, kernel_size, num_orientations, num_scales)
        self.enc_conv2 = self.doubleConv3x3(2 * num_orientations * num_scales, 32)
        self.enc_conv3 = self.doubleConv3x3(32, 64)
        self.enc_conv4 = self.doubleConv3x3(64, 128)
        self.enc_conv5 = self.doubleConv3x3(128, 256)  # New encoder layer

        self.up_conv1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)  # Adjusted for new layer
        self.dec_conv1 = self.doubleConv3x3(256, 128)  # Adjusted for new layer
        self.up_conv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec_conv2 = self.doubleConv3x3(128, 64)
        self.up_conv3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec_conv3 = self.doubleConv3x3(64, 32)
        self.up_conv4 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)  
        self.dec_conv4 = self.doubleConv3x3(2 * num_orientations * num_scales + 16, 16) # 80 + 16

        self.out_conv = nn.Conv2d(16, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc_conv1(x) 
        enc2 = self.enc_conv2(F.max_pool2d(enc1, kernel_size=2, stride=2))
        enc3 = self.enc_conv3(F.max_pool2d(enc2, kernel_size=2, stride=2))
        enc4 = self.enc_conv4(F.max_pool2d(enc3, kernel_size=2, stride=2))
        enc5 = self.enc_conv5(F.max_pool2d(enc4, kernel_size=2, stride=2))

        dec1 = self.up_conv1(enc5)
        dec1 = torch.cat((dec1, enc4), dim=1)
        dec1 = self.dec_conv1(dec1)

        dec2 = self.up_conv2(dec1)
        dec2 = torch.cat((dec2, enc3), dim=1)
        dec2 = self.dec_conv2(dec2)

        dec3 = self.up_conv3(dec2)
        dec3 = torch.cat((dec3, enc2), dim=1)
        dec3 = self.dec_conv3(dec3)

        dec4 = self.up_conv4(dec3) 
        dec4 = torch.cat((dec4, enc1), dim=1)  # Concatenation with the first encoder layer
        dec4 = self.dec_conv4(dec4)

        out = self.out_conv(dec4)

        return torch.sigmoid(out)

    def doubleGaborConv(self, in_channels, kernel_size, num_orientations, num_scales):
        return nn.Sequential(
            GaborConv2d(in_channels, kernel_size, num_orientations, num_scales),
            nn.BatchNorm2d(2 * num_orientations * num_scales), 
            nn.ReLU(inplace=True),
            GaborConv2d(2 * num_orientations * num_scales, kernel_size, num_orientations, num_scales),
            nn.BatchNorm2d(2 * num_orientations * num_scales), 
            nn.ReLU(inplace=True)
        )
    
    def doubleConv3x3(self, in_channels, out_channels, dropout_rate=0.5):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            # nn.Dropout2d(p=dropout_rate),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )


class GaborConv2d(nn.Module):
    def __init__(self, in_channels, kernel_size, num_orientations, num_scales):
        super(GaborConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = 2 * num_orientations * num_scales # 80
        self.kernel_size = kernel_size
        self.num_orientations = num_orientations
        self.num_scales = num_scales
        self.padding = kernel_size // 2
        

        # Generate Gabor filter parameters
        self.sigma, self.theta, self.Lambda, self.psi, self.gamma, self.bias = self.generate_parameters(self.out_channels // 2) # 40
        self.filter_cos = self.whole_filter(True)
        self.filter_sin = self.whole_filter(False)

    def forward(self, x):
        x_cos = F.conv2d(x, self.filter_cos, padding = self.padding, bias=self.bias)
        x_sin = F.conv2d(x, self.filter_sin, padding = self.padding, bias=self.bias)
        return torch.cat((x_cos, x_sin), 1)

    def generate_parameters(self, dim_out):
        torch.manual_seed(1)
        # Adjusted to initialize parameters more appropriately for Gabor filters
        sigma = nn.Parameter(torch.rand(dim_out, 1) * 2.0 + 0.5) # Random values between 0.5 and 2.5
        theta = nn.Parameter(torch.rand(dim_out, 1) * np.pi) # Random values between 0 and π
        Lambda = nn.Parameter(torch.rand(dim_out, 1) * 3.0 + 1.0) # Random values between 1.0 and 4.0, how Lambda is good for the detection?
        psi = nn.Parameter(torch.rand(dim_out, 1) * 2 * np.pi) # Random values between 0 and 2π
        gamma = nn.Parameter(torch.rand(dim_out, 1) * 2.0 + 0.5) # Random values between 0.5 and 2.5
        bias = nn.Parameter(torch.randn(dim_out)) # to avoid division by zero
        return sigma, theta, Lambda, psi, gamma, bias


    def whole_filter(self, cos=True):
        # Creating a tensor to hold the Gabor filters for all orientations and scales
        result = torch.zeros(self.num_orientations*self.num_scales, self.in_channels, self.kernel_size, self.kernel_size)
        for i in range(self.num_orientations):
            for j in range(self.num_scales):
                index = i * self.num_scales + j
                # Adjusting parameters for scale and orientation
                sigma = self.sigma[index] * (2.1 ** j) # Adjusting sigma for scale
                theta = self.theta[index] + i * 2 * np.pi / self.num_orientations # Adjusting theta for orientation
                Lambda = self.Lambda[index] # Keeping Lambda constant
                psi = self.psi[index] # Keeping psi constant
                gamma = self.gamma[index] # Keeping gamma constant
                # Generating the Gabor filter for each channel
                for k in range(self.in_channels):
                    result[index, k] = self.gabor_fn(sigma, theta, Lambda, psi, gamma, self.kernel_size, cos)
        return nn.Parameter(result)

    def gabor_fn(self, sigma, theta, Lambda, psi, gamma, kernel_size, cos=True):
        n = kernel_size // 2
        y, x = np.ogrid[-n:n+1, -n:n+1]
        y = torch.FloatTensor(y)
        x = torch.FloatTensor(x)

        x_theta = x * torch.cos(theta) + y * torch.sin(theta)
        y_theta = -x * torch.sin(theta) + y * torch.cos(theta)

        if cos:
            gb = torch.exp(-.5 * (x_theta ** 2 / sigma ** 2 + y_theta ** 2 / sigma ** 2 / gamma ** 2)) * torch.cos(2 * np.pi / Lambda * x_theta + psi)
        else:
            gb = torch.exp(-.5 * (x_theta ** 2 / sigma ** 2 + y_theta ** 2 / sigma ** 2 / gamma ** 2)) * torch.sin(2 * np.pi / Lambda * x_theta + psi)
        
        return gb


[Python code for data preparation]

from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision.transforms import functional as TF
from torchvision import transforms
import random
random.seed(42)
# from scipy.ndimage import binary_dilation
import numpy as np
import torch


class CenterlineDataset(Dataset):
    def __init__(self, img_dir, label_dir, transform=None, augmentation=False):
        self.img_dir = img_dir
        self.label_dir = label_dir
        # self.dilated_label_dir = dilated_label_dir
        self.transform = transform
        # self.save_dilated = save_dilated
        self.augmentation = augmentation
        # self.dilation_iterations = dilation_iterations

        self.images = sorted([img for img in os.listdir(img_dir) if img.endswith('.png')])
        self.labels = sorted([label for label in os.listdir(label_dir) if label.endswith('.png')])
        print(f"Found {len(self.images)} images and labels in the dataset.")

        # if add_poor_quality_training:
        #     self.poor_quality_img_dir = "data_poor_quality/samples"
        #     self.poor_quality_label_dir = "data_poor_quality/labels"

        #     self.poor_quality_images = sorted([img for img in os.listdir(self.poor_quality_img_dir) if img.endswith('.png')])
        #     self.poor_quality_labels = sorted([label for label in os.listdir(self.poor_quality_label_dir) if label.endswith('.png')])

        #     self.images += self.poor_quality_images
        #     self.labels += self.poor_quality_labels

        #     print(f"Added {len(self.poor_quality_images)} poor quality images and labels to the dataset.")
        #     print(f"Total images and labels in the dataset: {len(self.images)}")

        assert len(self.images) == len(self.labels), "The number of images and labels do not match!"

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index): # I/O Intensive
        img_path = os.path.join(self.img_dir, self.images[index])
        label_path = os.path.join(self.label_dir, self.labels[index])

        try:
            image = Image.open(img_path)
            # print((np.array(image) > 255).any())
            # image = image.convert("L")  # Grayscale conversion for images. It's like we do autoscaling here.
            # print the pixel intensity
            # print(np.array(image))
            # print(image.mode)

            # before convert label to grayscale, make all pixels greater than 0 to 65535
            label = Image.open(label_path)
            # label_np = np.array(label)
            # label_np[label_np > 0] = 65535
            # label = Image.fromarray(label_np.astype(np.uint16))
            # # assert all pixels are in range 0-65535
            # assert (label_np <= 65535).all(), f"Some pixels are greater than 65535, {label_path}"
            # assert (label_np >= 0).all(), f"Some pixels are less than 0, {label_path}"
            
            label_np = np.array(label)
            label_np = (label_np > 0).astype(np.uint8) * 255  # Ensure binary labels, but converts to L mode
            # label_np = label_np > 0
            # # print("The data type of the label np is*********: ", label_np.dtype)
            label = Image.fromarray(label_np)
            # print(label.mode)


            # label = label.convert("L")  # Grayscale conversion for labels
            # print(np.array(label))


            # Crop the first row and column off, adjusting the size to 64x64
            # Assuming the original size is 65x65, crop to get (1, 1, 65, 65)
            image = image.crop((1, 1, 65, 65))
            label = label.crop((1, 1, 65, 65))

            # Processing the label to be binary
            # print(np.array(label))
            # label_np = np.array(label) > 0  # Ensuring binary values, shape: (64, 64), True or False
            # label = Image.fromarray(label_np.astype(np.uint8))  # Convert back to image
            # print(np.array(label))

            # Dilate the label 
            # label_np = np.array(label) > 0 
            # dilated_label_np = binary_dilation(label_np, iterations=self.dilation_iterations).astype(np.uint8)
            # dilated_label = Image.fromarray(dilated_label_np) 

            # save the dilated label for debugging
            # if self.save_dilated:
            #     if not os.path.exists(self.dilated_label_dir):
            #         os.makedirs(self.dilated_label_dir)
            #     dilated_label_path = os.path.join(self.dilated_label_dir, f"dilated_label_{index}.png")
            #     dilated_label.save(dilated_label_path)

        except Exception as e:
            print(f"Error opening image or label at index {index}: {e}")
            return None, None, None

        # Always apply basic transformations if provided, transformed to a tensor
        if self.transform is not None:
            image = self.transform(image)
            label = transforms.ToTensor()(label) # ToTensor here works well for 8 bit
            # print(torch.max(label), torch.min(label))
            # label = self.transform(label)
            # dilated_label = self.transform(dilated_label)
            
        # image = torch.from_numpy(np.array(image, dtype=np.float32))
        # label = torch.from_numpy(np.array(label, dtype=np.float32))
        
        # if image.dim() == 2:
        #     image = image.unsqueeze(0)
        # if label.dim() == 2:
        #     label = label.unsqueeze(0)

        # image_min = torch.min(image)
        # # print("The minimum pixel intensity of the image is: ", image_min)
        # image_max = torch.max(image)
        # # print("The maximum pixel intensity of the image is: ", image_max)
        # image = (image - image_min) / (image_max - image_min)
        # # print("The data type of the image is: ", image.dtype)
        
        # label_min = torch.min(label)
        # # print("The minimum pixel intensity of the label is: ", label_min)
        # label_max = torch.max(label)
        # # print("The maximum pixel intensity of the label is: ", label_max)
        # label = (label - label_min) / (label_max - label_min)
        # # print("The data type of the label is: ", label.dtype)
                
        
                
        

        if self.augmentation:
            if self.augmentation == "rotate":
                angle = random.uniform(-180, 180)
                image = TF.rotate(image, angle)
                label = TF.rotate(label, angle)
                # dilated_label = TF.rotate(dilated_label, angle)
            if self.augmentation == "flip":
                if random.random() > 0.5: # Horizontal flip
                    image = TF.hflip(image)
                    label = TF.hflip(label)
                    # dilated_label = TF.hflip(dilated_label)
                # if random.random() > 0.5: # Vertical flip
                else:
                    image = TF.vflip(image)
                    label = TF.vflip(label)
                    # dilated_label = TF.vflip(dilated_label)
        

        return image, label


[Python code for inference, image reconstruction because we cut a big image into small patches and do inference on them]



class Normalize16BitRange(transforms.ToTensor):
    def __call__(self, pic):
        '''
            Input: PIL Image or numpy array
            Output: Tensor Normalized to [0, 1] based on the actual data range
        '''
        img = torch.from_numpy(np.array(pic, dtype=np.float32, copy=True))
        
        # mimic ToTensor having shape [C, H, W]
        if img.dim() == 2:
            img = img.unsqueeze(0)

        # Normalization based on the actual data range
        # min_val = torch.min(img)
        # max_val = torch.max(img)
        # min_val = 108
        max_val = 2000
        # pixels greater than 2826 fixed at 2826
        img = torch.where(img > max_val, torch.tensor(max_val), img)
        # pixels less than 90 fixed at 90
        # img = torch.where(img < min_val, torch.tensor(min_val), img)
        img = img / max_val  # Normalize to [0, 1]

        return img


# Patch The Input Image


def image_to_patches(image, patch_size=64, overlap=32):

    #- Splits the image into patches, with optional overlap.
    #- Input: PIL Image, uint16 
    #- Output: Numpy array of patches and their center coordinates

    patch_tuples = []
    patches = []
    img_np = np.array(image)
    stride = patch_size - overlap
    index = 0
    for y in range(0, img_np.shape[0], stride):
        for x in range(0, img_np.shape[1], stride):
            center_x = x + patch_size // 2
            center_y = y + patch_size // 2
            
            if y + patch_size <= img_np.shape[0] and x + patch_size <= img_np.shape[1]:
                patch = img_np[y:y+patch_size, x:x+patch_size]
                patches.append(patch)
                patch_tuples.append((index, (center_x, center_y), patch))
            else:  # Handle edges by padding
                patch = np.pad(img_np[y:min(y+patch_size, img_np.shape[0]), x:min(x+patch_size, img_np.shape[1])],
                               ((0, patch_size - (min(y + patch_size, img_np.shape[0]) - y)),
                                (0, patch_size - (min(x + patch_size, img_np.shape[1]) - x))),
                               mode='constant', constant_values=0)
                patches.append(patch)
                patch_tuples.append((index, (center_x, center_y), patch))
            index += 1
    return np.array(patches), patch_tuples


def model_predict_patches_raw(patches, model):
    model.eval()
    predictions = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    for patch in patches:
        patch_tensor = Normalize16BitRange()(patch).unsqueeze(0)
        patch_tensor = patch_tensor.to(device)
        with torch.no_grad():
            pred = model(patch_tensor)
        pred = pred.cpu().numpy().squeeze(0).squeeze(0)
        predictions.append(pred)
    return np.array(predictions)
 

def patches_to_image_raw(patches, image_size, patch_size=64, overlap=32):
    stride = patch_size - overlap
    reconstructed_image = np.zeros(image_size)
    count_matrix = np.zeros(image_size)
    patch_idx = 0

    for y in range(0, image_size[0], stride):
        for x in range(0, image_size[1], stride):
            # Calculate boundaries for patch application
            end_y = min(y + patch_size, image_size[0])
            end_x = min(x + patch_size, image_size[1])
            # Calculate the actual height and width to be used from the patch
            patch_height = end_y - y
            patch_width = end_x - x
            reconstructed_image[y:end_y, x:end_x] += patches[patch_idx][:patch_height, :patch_width]
            count_matrix[y:end_y, x:end_x] += 1
            patch_idx += 1
    # Avoid division by zero
    count_matrix[count_matrix == 0] = 1
    reconstructed_image /= count_matrix

    reconstructed_image[reconstructed_image > 0.5] = 255
    reconstructed_image[reconstructed_image <= 0.5] = 0
    show_image(reconstructed_image)
    return reconstructed_image


"""

In [24]:
print(summary)

Centerline Detection of Neuronal Dendrites: Deep Learning Model Workflow

Authors: [Your Names]
Date: [Current Date]

1. Introduction
This project focuses on the development of a deep learning model for centerline detection of Neuronal Dendrite membranes. The primary goal is to predict the 'backbone' of the structure with outputs that are as continuous as possible. This task falls under the category of binary segmentation, focusing on identifying the central lines within membrane structures in grayscale images.

The dataset consists of 16-bit grayscale images of size 65x65 pixels. The membrane structures have intensity values ranging from approximately 150 to 1000, whereas the background intensity ranges from approximately 110 to 150. However, it is noted that in some areas, the noise pixel values are similar to those of the membranes, which presents an additional challenge for accurate segmentation.

2. Data Preparation
2.1 Data Collection
The data used for this task was provided by e