In [30]:
!pip install -r requirements.txt

Collecting torch-summary (from -r requirements.txt (line 15))
  Downloading torch_summary-1.4.5-py3-none-any.whl.metadata (18 kB)
Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5


In [31]:
import pandas as pd
import os
import glob
from PIL import Image
import torch
import torch.nn as nn
import math
from torchsummary import summary

# Data Preprocessing

In [None]:
# ONLY DO THIS ONCE. It took 78 minutes on my machine


input_directory = 'raw_data/jpeg/'
output_directory = 'raw_data/png16/'

if not os.path.exists(output_directory):
    os.makedirs(output_directory)

for foldername, subfolders, filenames in os.walk(input_directory):
    for filename in filenames:
        if filename.endswith('.jpg') or filename.endswith('.jpeg'):

            img_path = os.path.join(foldername, filename)
            img = Image.open(img_path)
            
            relative_path = os.path.relpath(foldername, input_directory)
            output_folder = os.path.join(output_directory, relative_path)
            if not os.path.exists(output_folder):
                os.makedirs(output_folder)
            
            png16_filename = os.path.splitext(filename)[0] + '.png'
            png16_save_path = os.path.join(output_folder, png16_filename)
            img.save(png16_save_path, format='PNG', bits=16)

print("Conversion completed.")

Conversion completed.


In [26]:
img_base_path = "C:/Users/jbber/.vscode/LightMirai/raw_data/png16/" 
csv_base_path = "C:/Users/jbber/.vscode/LightMirai/raw_data/csv/" 

def process_data(filepath, output_path=None):

    data = pd.read_csv(filepath)
    data[['file_path', 'mask_path']] = data.apply(map_image_paths, axis=1)

    processed_data = pd.DataFrame({
        'patient_id': data['patient_id'],
        'exam_id': data.groupby('patient_id').cumcount(),
        'laterality': data['left or right breast'].astype(str).str.strip().str[0].str.upper(),
        'view': data['image view'], 
        'file_path': data['file_path'], 
        'mask_path': data['mask_path'],
    })

    if output_path is not None:
        processed_data.to_csv(output_path, index=False)
        print(f"Processed data saved to {output_path}")

    return processed_data

def map_image_paths(row):

    folder_name = os.path.basename(os.path.dirname(row['ROI mask file path']))
    folder_path = os.path.join(img_base_path, folder_name)
    
    if not os.path.exists(folder_path):
        print(f"Warning: Folder {folder_path} does not exist.")
        return pd.Series([None, None])
        
    mammogram_file = next(
        (f for f in glob.glob(os.path.join(folder_path, "1-*.png")) if len(os.path.basename(f).split('-')[-1]) == 7),
        None
    )
    mask_file = next(
        (f for f in glob.glob(os.path.join(folder_path, "2-*.png")) if len(os.path.basename(f).split('-')[-1]) == 7),
        None
    )
    
    if not mammogram_file:
        print(f"Warning: No mammogram file (1-*.png) found in {folder_path}.")
    if not mask_file:
        print(f"Warning: No mask file (2-*.png) found in {folder_path}.")
    
    return pd.Series([mammogram_file, mask_file])

def merge_data(dataframe_1, dataframe_2):
    if list(dataframe_1.columns) != list(dataframe_2.columns):
        raise ValueError("DataFrames have different columns and cannot be merged.")

    merged_dataframe = pd.concat([dataframe_1, dataframe_2], ignore_index=True)
    merged_dataframe.drop_duplicates(inplace=True)

    return merged_dataframe

calc_train_data = process_data(csv_base_path + "calc_case_description_train_set.csv")
calc_test_data = process_data(csv_base_path + "calc_case_description_test_set.csv")

mass_train_data = process_data(csv_base_path + "mass_case_description_train_set.csv")
mass_test_data = process_data(csv_base_path + "mass_case_description_test_set.csv")

train_data = merge_data(calc_train_data, mass_train_data)
test_data = merge_data(calc_test_data, mass_test_data)

train_data.to_csv("clean_data/train.csv")
test_data.to_csv("clean_data/test.csv")



In [27]:
print("Training Set:\n")

print(train_data.head())
print(train_data.tail())
print(train_data.info())
print(train_data.describe(include='all'))
print(train_data.columns)
print(train_data.shape)
print(train_data.dtypes)
print(train_data.isnull().sum())
print(f"Number of duplicate rows: {len(train_data[train_data.duplicated()])}")

print("Testing Set:\n")

print(test_data.head())
print(test_data.tail())
print(test_data.info())
print(test_data.describe(include='all'))
print(test_data.columns)
print(test_data.shape)
print(test_data.dtypes)
print(test_data.isnull().sum())
print(f"Number of duplicate rows: {len(test_data[test_data.duplicated()])}")


Training Set:

  patient_id  exam_id laterality view  \
0    P_00005        0          R   CC   
1    P_00005        1          R  MLO   
2    P_00007        0          L   CC   
3    P_00007        1          L  MLO   
4    P_00008        0          L   CC   

                                           file_path  \
0  C:/Users/jbber/.vscode/LightMirai/raw_data/png...   
1  C:/Users/jbber/.vscode/LightMirai/raw_data/png...   
2  C:/Users/jbber/.vscode/LightMirai/raw_data/png...   
3  C:/Users/jbber/.vscode/LightMirai/raw_data/png...   
4  C:/Users/jbber/.vscode/LightMirai/raw_data/png...   

                                           mask_path  
0  C:/Users/jbber/.vscode/LightMirai/raw_data/png...  
1  C:/Users/jbber/.vscode/LightMirai/raw_data/png...  
2  C:/Users/jbber/.vscode/LightMirai/raw_data/png...  
3  C:/Users/jbber/.vscode/LightMirai/raw_data/png...  
4  C:/Users/jbber/.vscode/LightMirai/raw_data/png...  
     patient_id  exam_id laterality view  \
2859    P_02033        1   

In [28]:
if "mask_path" in train_data.columns:
    train_data.drop(columns=["mask_path"], inplace=True)
if "mask_path" in test_data.columns:
    test_data.drop(columns=["mask_path"], inplace=True)

train_data.drop_duplicates(inplace=True)
test_data.drop_duplicates(inplace=True)

dev_size = int(0.2 * len(train_data))

dev_data = train_data.sample(n=dev_size, random_state=42)

train_data = train_data.drop(dev_data.index)

train_data.to_csv("clean_data/train.csv")
test_data.to_csv("clean_data/test.csv")
dev_data.to_csv("clean_data/dev.csv")

# Model

## Image Encoder

In [41]:
# This implementation is based on the Github for the Mirai model: https://github.com/yala/Mirai/blob/master/onconet/models/resnet_base.py

# Here are the credits the author links for his code:
# Deep Residual Learning for Image Recognition: https://arxiv.org/abs/1512.03385
# Implementation based on PyTorch ResNet implementation: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py

class ResNet(nn.Module):
    def __init__(self, layers, input_channels=3, num_classes=2):
        super(ResNet, self).__init__()
        
        self.inplanes = 64
        self.downsampler = Downsampler(self.inplanes, input_channels)
        
        self.layers = []
        self.hidden_dim = self.inplanes
        for i, layer_blocks in enumerate(layers):
            stride = 2 if i > 0 else 1
            self.hidden_dim = min(self.hidden_dim * 2, 1024)
            layer = self._make_layer(self.hidden_dim, layer_blocks, stride)
            setattr(self, f"layer{i+1}", layer)
            self.layers.append(layer)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.hidden_dim, num_classes)

    def forward(self, x):
        x = self.downsampler(x)
        for layer in self.layers:
            x = layer(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    def _make_layer(self, planes, blocks, stride=1):
        layers = []
        for i, block in enumerate(blocks):
            if i == 0:
                layers.append(block(self.inplanes, planes, stride))
            else:
                layers.append(block(planes, planes))
            self.inplanes = planes
        return nn.Sequential(*layers)


    def _initialize_weights(self):
        """
        Initializes model weights using Kaiming He initialization.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


class Downsampler(nn.Module):
    """
    Initial downsampling layer for ResNet. Reduces input dimensions by 4x.
    """
    def __init__(self, inplanes, input_channels):
        super(Downsampler, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        return x


class ResidualBlock(nn.Module):
    """
    A standard residual block for ResNet.
    """
    def __init__(self, inplanes, planes, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # Skip connection
        self.downsample = (
            nn.Sequential(
                nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )
            if inplanes != planes or stride != 1 else None
        )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class CustomResnet(nn.Module):
    def __init__(self, block_layout):
        super(CustomResnet, self).__init__()
        layers = self.get_layers(block_layout)
        self._model = ResNet(layers)

    def forward(self, x):
        return self._model(x)

    def get_layers(self, block_layout):
        """
        Gets the layers for a ResNet given the desired layout of blocks.

        Args:
            block_layout (list): A list where each element is a list of lists.
                                Each inner list contains (block_name, num_repeats).

        Returns:
            layers (list): A list of lists of block types conforming to the block layout.
        """
        self.validate_block_layout(block_layout)

        layers = []
        for layer_layout in block_layout:
            layer = []
            for block_name, num_repeats in layer_layout:
                block = self.get_block(block_name)
                layer.extend([block] * num_repeats)
            layers.append(layer)

        return layers
    
    def get_block(self, block_name):
        """
        Maps a block name to the corresponding block class.
        Currently, only 'BasicBlock' is supported.
        """
        if block_name == "BasicBlock":
            return ResidualBlock
        raise ValueError(f"Unsupported block type: {block_name}")
    
    def validate_block_layout(self, block_layout):
        """
        Validates the block layout format.
        """
        if not isinstance(block_layout, list):
            raise ValueError("block_layout must be a list.")
        for layer_layout in block_layout:
            if not isinstance(layer_layout, list):
                raise ValueError("Each layer layout in block_layout must be a list.")
            for block_spec in layer_layout:
                if not (isinstance(block_spec, list) and len(block_spec) == 2):
                    raise ValueError("Each block spec must be a list of length 2 (block_name, num_repeats).")

block_layout = [
    [["BasicBlock", 2]],
    [["BasicBlock", 2]],
    [["BasicBlock", 2]],
    [["BasicBlock", 2]],
]
image_encoder = CustomResnet(block_layout)

input_tensor = torch.randn(1, 3, 224, 224) 
output = image_encoder(input_tensor)
print(output.shape) 

summary(image_encoder, input_size=(3, 224, 224))


torch.Size([1, 2])
Layer (type:depth-idx)                   Param #
├─ResNet: 1-1                            --
|    └─Downsampler: 2-1                  --
|    |    └─Conv2d: 3-1                  9,408
|    |    └─BatchNorm2d: 3-2             128
|    |    └─ReLU: 3-3                    --
|    |    └─MaxPool2d: 3-4               --
|    └─Sequential: 2-2                   --
|    |    └─ResidualBlock: 3-5           230,144
|    |    └─ResidualBlock: 3-6           295,424
|    └─Sequential: 2-3                   --
|    |    └─ResidualBlock: 3-7           919,040
|    |    └─ResidualBlock: 3-8           1,180,672
|    └─Sequential: 2-4                   --
|    |    └─ResidualBlock: 3-9           3,673,088
|    |    └─ResidualBlock: 3-10          4,720,640
|    └─Sequential: 2-5                   --
|    |    └─ResidualBlock: 3-11          14,686,208
|    |    └─ResidualBlock: 3-12          18,878,464
|    └─AdaptiveAvgPool2d: 2-6            --
|    └─Linear: 2-7                      

Layer (type:depth-idx)                   Param #
├─ResNet: 1-1                            --
|    └─Downsampler: 2-1                  --
|    |    └─Conv2d: 3-1                  9,408
|    |    └─BatchNorm2d: 3-2             128
|    |    └─ReLU: 3-3                    --
|    |    └─MaxPool2d: 3-4               --
|    └─Sequential: 2-2                   --
|    |    └─ResidualBlock: 3-5           230,144
|    |    └─ResidualBlock: 3-6           295,424
|    └─Sequential: 2-3                   --
|    |    └─ResidualBlock: 3-7           919,040
|    |    └─ResidualBlock: 3-8           1,180,672
|    └─Sequential: 2-4                   --
|    |    └─ResidualBlock: 3-9           3,673,088
|    |    └─ResidualBlock: 3-10          4,720,640
|    └─Sequential: 2-5                   --
|    |    └─ResidualBlock: 3-11          14,686,208
|    |    └─ResidualBlock: 3-12          18,878,464
|    └─AdaptiveAvgPool2d: 2-6            --
|    └─Linear: 2-7                       2,050
Total params