In [None]:
! pip install tqdm kagglehub python-dotenv opencv-python opencv-contrib-python torch torchvision scikit-learn scikit-image pandas huggingface-hub

In [None]:
import os
import requests
import zipfile
from tqdm import tqdm
import logging
from dotenv import load_dotenv
import shutil
import sys

# Load environment variables from .env file
load_dotenv()

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.info("Logger initialized")


def download_file(url, filename):
    logger.info(f"Downloading {filename} from {url}")
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024
    t = tqdm(total=total_size, unit='B', unit_scale=True)
    with open(filename, 'wb') as f:
        for data in response.iter_content(block_size):
            t.update(len(data))
            f.write(data)
    t.close()

def unzip_file(zip_path, extract_to):
    logger.info(f"Unzipping {zip_path} to {extract_to}")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    logger.info(f"Unzipped {zip_path} to {extract_to}")

def fetch_bus_images_dataset(output_path):
    # URL of the BUSI dataset
    url = 'https://scholar.cu.edu.eg/Dataset_BUSI.zip'
    filename = os.path.join(output_path, 'BUSI.zip')
    download_file(url, filename)
    unzip_file(filename, output_path)
    os.remove(filename)

def fetch_bus_dataset(output_path):
    url = 'https://data.mendeley.com/public-files/datasets/wmy84gzngw/files/b63daee9-78de-4122-8475-9b3aa22ffd64/file_downloaded'
    filename = os.path.join(output_path, 'BUS.zip')
    headers = {'User-Agent': 'Mozilla/5.0'}
    response = requests.get(url, headers=headers, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024
    t = tqdm(total=total_size, unit='B', unit_scale=True)
    with open(filename, 'wb') as f:
        for data in response.iter_content(block_size):
            t.update(len(data))
            f.write(data)
    t.close()
    unzip_file(filename, output_path)
    os.remove(filename)


def fetch_vindr_metadata(output_path):
    import kagglehub
    # os.environ["KAGGLE_CONFIG"] = os.path.join(os.path.dirname(__file__), "./kaggle/kaggle.json")
    # Download latest version
    path = kagglehub.dataset_download("truthisneverlinear/vindr-mammo-annotations")
    logger.info(f"Downloaded dataset to {path}")
    move_dataset(path, output_path)


def move_dataset(path, output_path):
    # Move the dataset to the output path
    os.makedirs(output_path, exist_ok=True)
    for item in os.listdir(path):
        src = os.path.join(path, item)
        dst = os.path.join(output_path, item)
        if os.path.exists(dst):
            if os.path.isdir(dst):
                shutil.rmtree(dst)
            else:
                os.remove(dst)
        shutil.move(src, dst)
    os.rmdir(path)
    logger.info(f"Moved dataset to {output_path}")

def fetch_vindr_dataset(output_path):
    import kagglehub
    # os.environ["KAGGLE_CONFIG"] = os.path.join(os.path.dirname(__file__), "./kaggle/kaggle.json")
    # Download latest version
    path = kagglehub.dataset_download("shantanughosh/vindr-mammogram-dataset-dicom-to-png")
    logger.info(f"Downloaded dataset to {path}")
    move_dataset(path, output_path)


def main():
    output_path = 'data'
    os.makedirs(output_path, exist_ok=True)

    # fetch_bus_dataset(output_path)
    # fetch_bus_images_dataset(output_path)
    # fetch_vindr_dataset(output_path)
    # fetch_vindr_metadata(output_path="data")


if __name__ == "__main__":
    logger.info("Script started")
    main()
    logger.info("Script finished")
    logger.info("All datasets fetched successfully")


In [None]:
import shutil

shutil.copytree('/kaggle/working/data/images_png', './data/mammo/images')

!rm -rf /kaggle/working/data/images_png

In [None]:
# prompt: check the content of a directory code and show all format

import os

def list_files_with_formats(directory):
  """Lists all files in a directory, showing their formats."""
  try:
    for filename in os.listdir(directory):
      filepath = os.path.join(directory, filename)
      if os.path.isfile(filepath):
        name, ext = os.path.splitext(filename)
        print(f"File: {filename}, Format: {ext}")
      elif os.path.isdir(filepath):
        print(f"Directory: {filename}")
        list_files_with_formats(filepath) # Recursive call for subdirectories
  except FileNotFoundError:
      print(f"Directory '{directory}' not found.")
  except Exception as e:
      print(f"An error occurred: {e}")


# Example usage:
list_files_with_formats("./data/mammo/images") # Replace with the actual path


In [None]:
import os

shutil.copy('/kaggle/input/vindr-mammo-annotations/breast-level_annotations.csv', './data/mammo/breast-level_annotations.csv')

shutil.copy('/kaggle/input/vindr-mammo-annotations/finding_annotations.csv', './data/mammo/finding_annotations.csv')

shutil.copy('/kaggle/input/vindr-mammo-annotations/metadata.csv', './data/mammo/metadata.csv')

shutil.copy('/kaggle/input/vindr-mammo-annotations/SHA256SUMS.txt', './data/mammo/SHA256SUMS.txt')


In [None]:
# prompt: code to move the content of ./data/Dataset_BUSI_with_GT/ into ./data/ultrasound/masks and ./data/ultrasound/images, but checking each directory within the main directory for the appropriate mask images to move into ./data/ultrasound/masks and normal images to move into ./data/ultrasound/images. Mask images contain "_mask"; make sure it's just moving the necessary files from the dataset Busi into ultrasound. make sure the sub-directory under the /data/Dataset_BUSI_with_GT/  is used as same subdirectory under the ultrasound/images and ultrasound/masks

import os
import shutil

def organize_images_and_masks(source_dir, images_dir, masks_dir):
    """Recursively move .png images and masks to separate directories."""
    os.makedirs(images_dir, exist_ok=True)
    os.makedirs(masks_dir, exist_ok=True)

    for root, _, files in os.walk(source_dir):
        for filename in files:
            if filename.endswith('.png'):
                src_path = os.path.join(root, filename)
                rel_path = os.path.relpath(root, source_dir)

                # Create corresponding subdirectories in images_dir and masks_dir
                image_subdir = os.path.join(images_dir, rel_path)
                mask_subdir = os.path.join(masks_dir, rel_path)
                os.makedirs(image_subdir, exist_ok=True)
                os.makedirs(mask_subdir, exist_ok=True)

                if "_mask" in filename:
                    dst_path = os.path.join(mask_subdir, filename)
                else:
                    dst_path = os.path.join(image_subdir, filename)

                shutil.move(src_path, dst_path)

if __name__ == "__main__":
    source_dir = "./data/Dataset_BUSI_with_GT"
    images_dir = './data/ultrasound/images'
    masks_dir = './data/ultrasound/masks'
    organize_images_and_masks(source_dir, images_dir, masks_dir)

In [None]:
!rm -rf /kaggle/working/data/Dataset_BUSI_with_GT

In [None]:
import os
import shutil
import pandas as pd
from sklearn.model_selection import train_test_split

# --- Configuration ---
# Set the base path for your downloaded BUSI dataset
# Assuming the structure is ./data/ultrasound/images/... and ./data/ultrasound/masks/...
BUSI_ORIGINAL_PATH = './data/ultrasound' # Adjust this path if needed
BUSI_IMAGES_PATH = os.path.join(BUSI_ORIGINAL_PATH, 'images')
BUSI_MASKS_PATH = os.path.join(BUSI_ORIGINAL_PATH, 'masks')


# Set the base path for the new organized dataset structure
ORGANIZED_DATASET_BASE_PATH = './mmibc/busi' # Specific path for BUSI
ORGANIZED_METADATA_FILE = os.path.join(ORGANIZED_DATASET_BASE_PATH, 'busi_metadata.csv') # Output metadata file

# Define split ratios (train, validation, test) - should sum to 1.0
TRAIN_RATIO = 0.8
VAL_RATIO = 0.1
TEST_RATIO = 0.1

# --- Function to Get Image Labels and Paths for BUSI ---
def get_image_labels_and_paths_busi(images_base_path, masks_base_path):
    """
    Reads BUSI dataset structure (assuming images/masks subfolders with
    benign, malignant, normal subfolders inside) and assigns binary labels
    (0 for benign/normal, 1 for malignant).
    Correctly finds mask files based on '_mask' naming convention.
    Returns a list of dictionaries: [{'image_id': '...', 'label': 0/1,
                                       'image_file_path': '...', 'mask_file_path': '...',
                                       'original_class': '...'}]
    """
    print("Processing BUSI dataset structure...")

    image_info_list = [] # List to store info for each image

    # Define the original class directories and their target binary labels
    class_directories = {
        'benign': 0,    # Map benign to binary label 0
        'normal': 0,    # Map normal to binary label 0
        'malignant': 1  # Map malignant to binary label 1
    }

    for class_name, binary_label in class_directories.items():
        class_images_dir_path = os.path.join(images_base_path, class_name)
        class_masks_dir_path = os.path.join(masks_base_path, class_name)


        if not os.path.exists(class_images_dir_path):
            print(f"Warning: Image directory not found: {class_images_dir_path}. Skipping this class.")
            continue
        # Note: We don't skip the class if mask directory is missing, just note it.
        if not os.path.exists(class_masks_dir_path):
             print(f"Warning: Mask directory not found: {class_masks_dir_path}. Segmentation masks might be missing for this class.")


        print(f"Processing class: {class_name} (Binary Label: {binary_label})")

        # List files in the class image directory
        try:
            image_files = [f for f in os.listdir(class_images_dir_path) if not f.startswith('.') and os.path.isfile(os.path.join(class_images_dir_path, f))]
        except Exception as e:
            print(f"Error listing files in {class_images_dir_path}: {e}")
            continue

        for file_name in image_files:
            image_file_path = os.path.join(class_images_dir_path, file_name)

            # Construct the expected mask filename by appending '_mask' before the extension
            name, ext = os.path.splitext(file_name)
            mask_file_name = f"{name}_mask{ext}"
            mask_file_path = os.path.join(class_masks_dir_path, mask_file_name)

            # Check if the constructed mask file path actually exists
            if not os.path.exists(mask_file_path):
                 # Handle cases like 'benign (4)_mask_1.png' if they exist alongside 'benign (4)_mask.png'
                 # This might require more sophisticated logic if there are multiple masks per image.
                 # For now, we'll just check for the primary '_mask' version.
                 # If the primary mask isn't found, check for '_mask_1', '_mask_2', etc.
                 # This is a common pattern in BUSI for multiple masks per image.
                 # Let's refine this to find *all* masks associated with an image ID.

                 associated_masks = []
                 if os.path.exists(class_masks_dir_path):
                     mask_files_in_dir = os.listdir(class_masks_dir_path)
                     # Find all files in the mask directory that start with the image name
                     # This is a more robust way to find associated masks
                     base_name = os.path.splitext(file_name)[0]
                     associated_mask_files = [
                         f for f in mask_files_in_dir
                         if f.startswith(base_name) and '_mask' in f and os.path.isfile(os.path.join(class_masks_dir_path, f))
                     ]
                     associated_masks = [os.path.join(class_masks_dir_path, f) for f in associated_mask_files]


                 if not associated_masks:
                     print(f"Warning: No mask file found for image {file_name} in {class_masks_dir_path}. Mask path will be None.")
                     mask_file_path_to_record = None # No mask found
                 elif len(associated_masks) > 1:
                      print(f"Warning: Found multiple masks for image {file_name}: {associated_masks}. Recording the first one found.")
                      mask_file_path_to_record = associated_masks[0] # Or decide how to handle multiple masks
                 else:
                      mask_file_path_to_record = associated_masks[0] # Only one mask found

            else:
                 # The primary '_mask' file was found
                 mask_file_path_to_record = mask_file_path


            # Use file_name as a simple image_id for BUSI
            image_id = file_name

            image_info_list.append({
                'image_id': image_id,
                'label': binary_label,
                'image_file_path': image_file_path,
                'mask_file_path': mask_file_path_to_record, # Include the found mask path (or None)
                'original_class': class_name # Keep original class name for metadata
            })

    print(f"Finished processing BUSI structure. Found {len(image_info_list)} relevant images with binary labels.")
    return image_info_list

# --- Function to Split Data ---
def split_busi_data(image_info_list, train_ratio, val_ratio, test_ratio):
    """
    Splits the image info list into train, validation, and test sets,
    stratifying by the binary label.
    Returns a dictionary mapping split names to lists of dictionaries.
    """
    print("Splitting BUSI data...")

    if not image_info_list:
        print("No image info found to split.")
        return {'train': [], 'validation': [], 'test': []}

    # Extract binary labels for splitting
    all_labels = [item['label'] for item in image_info_list]

    # Ensure there are enough samples and classes for stratification
    if len(set(all_labels)) < 2 or len(all_labels) < 2:
         print("Warning: Not enough samples or classes for stratification. Splitting without stratify.")
         train_info, temp_info = train_test_split(
            image_info_list, test_size=(val_ratio + test_ratio), random_state=42
        )
    else:
        # Perform the split using the list of dictionaries directly
        train_info, temp_info = train_test_split(
            image_info_list, test_size=(val_ratio + test_ratio), stratify=all_labels, random_state=42
        )

    # Calculate test_size for the second split relative to the temporary set
    test_size_temp = val_ratio + test_ratio
    test_size_final = test_ratio / test_size_temp if test_size_temp > 0 else 0

    # Extract labels from the temporary split for stratification
    temp_labels = [item['label'] for item in temp_info]

    if test_size_final > 0 and (len(set(temp_labels)) < 2 or len(temp_labels) < 2):
         print("Warning: Not enough samples or classes in temporary split for second stratification. Splitting without stratify.")
         val_info, test_info = train_test_split(
            temp_info, test_size=test_size_final, random_state=42
        )
    else:
        val_info, test_info = train_test_split(
            temp_info, test_size=test_size_final, stratify=temp_labels, random_state=42
        )

    splits = {
        'train': train_info,
        'validation': val_info,
        'test': test_info
    }

    print(f"Splitting complete. Train: {len(splits['train'])}, Val: {len(splits['validation'])}, Test: {len(splits['test'])}")
    return splits

# --- Function to Organize Split Files and Generate Metadata ---
def organize_split_files_and_generate_metadata_busi(split_data, organized_base_path, metadata_output_path):
    """
    Organizes the split data (list of dictionaries) into the target directory structure
    and generates a metadata CSV file for the organized data for BUSI.
    Copies both images and masks.
    """
    print(f"Organizing split files into: {organized_base_path}")

    all_organized_metadata = []

    # Create base organized directories
    organized_images_base = os.path.join(organized_base_path, 'images')
    organized_masks_base = os.path.join(organized_base_path, 'masks')
    os.makedirs(organized_images_base, exist_ok=True)
    os.makedirs(organized_masks_base, exist_ok=True)


    for split_name, data_list in split_data.items():
        # Create split directories within images and masks
        split_images_dir = os.path.join(organized_images_base, split_name)
        split_masks_dir = os.path.join(organized_masks_base, split_name)

        # Create class directories within each split directory
        benign_images_dir = os.path.join(split_images_dir, 'benign')
        malignant_images_dir = os.path.join(split_images_dir, 'malignant')
        benign_masks_dir = os.path.join(split_masks_dir, 'benign')
        malignant_masks_dir = os.path.join(split_masks_dir, 'malignant')


        os.makedirs(benign_images_dir, exist_ok=True)
        os.makedirs(malignant_images_dir, exist_ok=True)
        os.makedirs(benign_masks_dir, exist_ok=True)
        os.makedirs(malignant_masks_dir, exist_ok=True)


        print(f"Copying files for {split_name}...")
        for image_info in data_list:
            original_image_file_path = image_info['image_file_path']
            original_mask_file_path = image_info['mask_file_path'] # This is the path found in get_image_labels_and_paths_busi
            label = image_info['label']
            image_id = image_info['image_id'] # Use original file name as image_id
            original_class = image_info['original_class'] # Keep original class name

            # Determine the target directory for image and mask
            target_images_dir = benign_images_dir if label == 0 else malignant_images_dir
            target_masks_dir = benign_masks_dir if label == 0 else malignant_masks_dir

            # Use original image_id as the new file name for the image
            new_image_file_name = image_id
            organized_image_file_path = os.path.join(target_images_dir, new_image_file_name)

            # For the mask, use the original mask file name if available, otherwise use the image_id + _mask
            if original_mask_file_path:
                 new_mask_file_name = os.path.basename(original_mask_file_path)
            else:
                 # If no mask was found, we won't copy one, but we could define a potential name
                 # for consistency in metadata, though it won't exist on disk.
                 # Let's just record None for the organized path if original was None.
                 new_mask_file_name = None

            organized_mask_file_path = os.path.join(target_masks_dir, new_mask_file_name) if new_mask_file_name else None


            try:
                # Copy the image
                shutil.copy(original_image_file_path, organized_image_file_path)

                # Copy the mask if it exists
                if original_mask_file_path and os.path.exists(original_mask_file_path):
                     shutil.copy(original_mask_file_path, organized_mask_file_path)
                else:
                     organized_mask_file_path = None # Ensure organized path is None if mask wasn't copied


                # Add information to the metadata list
                metadata_entry = {
                    'image_id': image_id,
                    'label': label, # Binary label (0 or 1)
                    'original_class': original_class, # Original class (benign, malignant, normal)
                    'split': split_name, # This is our new split (train, val, test)
                    'organized_image_file_path': organized_image_file_path,
                    'original_image_file_path': original_image_file_path,
                    'organized_mask_file_path': organized_mask_file_path, # Path to the copied mask (or None)
                    'original_mask_file_path': original_mask_file_path # Path to the original mask (or None)
                }
                all_organized_metadata.append(metadata_entry)

            except FileNotFoundError:
                 print(f"Error: Source file not found during copy: {original_image_file_path} or {original_mask_file_path}")
            except Exception as e:
                 print(f"Error copying file {original_image_file_path} or its mask: {e}")


        print(f"Finished copying for {split_name}. Total files: {len(data_list)}")

    # Generate the metadata CSV
    if all_organized_metadata:
        metadata_df = pd.DataFrame(all_organized_metadata)
        # Ensure the output directory exists
        os.makedirs(os.path.dirname(metadata_output_path), exist_ok=True)
        metadata_df.to_csv(metadata_output_path, index=False)
        print(f"\nGenerated organized dataset metadata file: {metadata_output_path}")
    else:
        print("\nNo metadata entries to write. Metadata file not generated.")


# --- Main Execution ---
if __name__ == "__main__":
    # 1. Process BUSI dataset structure and Get Image Info
    busi_image_info_list = get_image_labels_and_paths_busi(
        BUSI_IMAGES_PATH, # Pass the images subfolder path
        BUSI_MASKS_PATH   # Pass the masks subfolder path
    )

    if busi_image_info_list:
        # 2. Split Data
        busi_splits = split_busi_data(
            busi_image_info_list,
            TRAIN_RATIO,
            VAL_RATIO,
            TEST_RATIO
        )

        # 3. Organize Split Files and Generate Metadata
        organize_split_files_and_generate_metadata_busi(
            busi_splits,
            ORGANIZED_DATASET_BASE_PATH,
            ORGANIZED_METADATA_FILE
        )
        print("\nBUSI data preparation complete.")
    else:
        print("\nBUSI data preparation failed due to issues processing dataset structure or finding images.")

    # Note: This script only handles BUSI.
    # You would run the VinDr-Mammo script separately,
    # handle KAU-BCMD, and then use a separate script to combine/upload to Hugging Face.

In [None]:
!rm -rf /kaggle/working/data/originals

In [None]:
import os
import shutil
import pandas as pd
from sklearn.model_selection import train_test_split

# --- Configuration ---
# Set the base path for your downloaded VinDr-Mammo dataset
VINDR_MAMMO_ORIGINAL_PATH = './data/mammo' # Adjust this path if needed

# Paths to the VinDr-Mammo metadata files
BREAST_LEVEL_METADATA_PATH = os.path.join(VINDR_MAMMO_ORIGINAL_PATH, 'breast-level_annotations.csv')
FINDING_METADATA_PATH = os.path.join(VINDR_MAMMO_ORIGINAL_PATH, 'finding_annotations.csv')
# METADATA_CSV_PATH = os.path.join(VINDR_MAMMO_ORIGINAL_PATH, 'metadata.csv') # Optional: for image details

# Set the base path for the new organized dataset structure
ORGANIZED_DATASET_BASE_PATH = './mmibc/vindr_mammo' # Specific path for VinDr-Mammo
ORGANIZED_METADATA_FILE = os.path.join(ORGANIZED_DATASET_BASE_PATH, 'vindr_mammo_metadata.csv') # Output metadata file

# Define split ratios for splitting the *original training data* into our train/validation sets
# The original 'test' split will be used as our test set.
TRAIN_RATIO_ORIGINAL_TRAINING = 0.8 # Ratio of original training data for our training set
VAL_RATIO_ORIGINAL_TRAINING = 0.2  # Ratio of original training data for our validation set
# Note: Original test data will be used entirely for our test set.

# --- Function to Process Metadata and Get Image Labels ---
def get_image_labels_and_paths(original_data_path, breast_metadata_path, finding_metadata_path):
    """
    Reads VinDr-Mammo metadata to determine binary labels (benign/malignant)
    and maps image IDs to file paths. Includes original split and study_id.
    Returns a list of dictionaries: [{'image_id': '...', 'label': 0/1, 'original_split': '...', 'file_path': '...', 'study_id': '...'}]
    """
    print("Processing VinDr-Mammo metadata...")

    image_info_list = [] # List to store info for each image

    # Load metadata files
    finding_df = None
    if os.path.exists(finding_metadata_path):
        finding_df = pd.read_csv(finding_metadata_path)
        # Map finding BI-RADS to binary malignancy (4, 5, 6 are malignant)
        finding_df['is_malignant_finding'] = finding_df['finding_birads'].isin(['BI-RADS 4', 'BI-RADS 5', 'BI-RADS 6']).astype(int)
    else:
        print(f"Warning: Finding metadata file not found at {finding_metadata_path}. Malignant cases might be missed.")

    breast_df = None
    if os.path.exists(breast_metadata_path):
        breast_df = pd.read_csv(breast_metadata_path)
    else:
        print(f"Warning: Breast-level metadata file not found at {breast_metadata_path}. Benign/normal cases might be missed.")

    # Combine relevant info from both dataframes
    if finding_df is not None and breast_df is not None:
        # Use breast_df as the base, as it lists all images (including normals)
        # Merge finding info onto breast info
        merged_df = pd.merge(breast_df, finding_df[['image_id', 'is_malignant_finding']], on='image_id', how='left')
        # Fill NaN malignant findings with 0 (meaning no malignant finding in this image)
        merged_df['is_malignant_finding'] = merged_df['is_malignant_finding'].fillna(0)

        # Determine final binary label (1 if any malignant finding, 0 otherwise)
        merged_df['final_label'] = merged_df['is_malignant_finding'].clip(upper=1) # Ensure label is 0 or 1

        # Now, iterate through the merged dataframe to collect image info
        for index, row in merged_df.iterrows():
            image_id = row['image_id']
            study_id = row['study_id']
            original_split = row['split']
            label = int(row['final_label']) # Ensure label is integer 0 or 1

            # Construct the full path to the image file
            # Assuming .png extension based on SHA256SUMS, adjust if necessary (e.g., .dicom)
            image_file_path = os.path.join(original_data_path, 'images', study_id, f'{image_id}.png')


            if os.path.exists(image_file_path):
                image_info_list.append({
                    'image_id': image_id,
                    'study_id': study_id,
                    'label': label,
                    'original_split': original_split,
                    'file_path': image_file_path,
                    # Add other relevant metadata if needed, e.g., 'breast_birads', 'finding_categories'
                    'breast_birads': row.get('breast_birads'), # Use .get to avoid error if column is missing
                    'breast_density': row.get('breast_density'),
                    # Note: finding_categories is a list in CSV, might need careful handling
                })
            else:
                 print(f"Warning: Image file not found for image {image_id} at {image_file_path}. Skipping.")

    elif breast_df is not None:
         print("Proceeding with breast-level metadata only. Malignant cases from findings might be missed.")
         for index, row in breast_df.iterrows():
            image_id = row['image_id']
            study_id = row['study_id']
            original_split = row['split']
            # Assume benign/normal if only breast-level metadata is available and BI-RADS is 1, 2, or 3
            label = 0 if row['breast_birads'] in ['BI-RADS 1', 'BI-RADS 2', 'BI-RADS 3'] else -1 # Use -1 for uncertain

            if label != -1:
                 image_file_path = os.path.join(original_data_path, 'images', study_id, f'{image_id}.png')
                 if os.path.exists(image_file_path):
                     image_info_list.append({
                         'image_id': image_id,
                         'study_id': study_id,
                         'label': label,
                         'original_split': original_split,
                         'file_path': image_file_path,
                         'breast_birads': row.get('breast_birads'),
                         'breast_density': row.get('breast_density'),
                     })
                 else:
                     print(f"Warning: Image file not found for image {image_id} at {image_file_path}. Skipping.")
            else:
                 print(f"Warning: Uncertain label for image {image_id} based on breast-level BI-RADS {row['breast_birads']}. Skipping.")

    elif finding_df is not None:
         print("Proceeding with finding-level metadata only. Benign/normal cases might be missed.")
         # In this case, we only have info about images with findings
         for index, row in finding_df.iterrows():
             image_id = row['image_id']
             study_id = row['study_id']
             original_split = row['split']
             label = int(row['is_malignant_finding']) # Label is 1 if malignant finding, 0 otherwise

             image_file_path = os.path.join(original_data_path, 'images', study_id, f'{image_id}.png')
             if os.path.exists(image_file_path):
                 image_info_list.append({
                     'image_id': image_id,
                     'study_id': study_id,
                     'label': label,
                     'original_split': original_split,
                     'file_path': image_file_path,
                     # Breast-level info is not available here
                 })
             else:
                 print(f"Warning: Image file not found for image {image_id} at {image_file_path}. Skipping.")


    else:
        print("Error: Neither breast-level nor finding-level metadata files were found.")


    print(f"Finished metadata processing. Found {len(image_info_list)} relevant images with labels.")
    return image_info_list

# --- Function to Split Data Based on Original Split and Ratios ---
def split_vindr_mammo_data(image_info_list, train_ratio_orig_train, val_ratio_orig_train):
    """
    Splits the image info list into train, validation, and test sets,
    respecting the original 'split' column and applying ratios to the original training data.
    Returns a dictionary mapping split names to lists of dictionaries.
    """
    print("Splitting VinDr-Mammo data...")

    # Separate based on original split
    original_training_data = [item for item in image_info_list if item['original_split'] == 'training']
    original_test_data = [item for item in image_info_list if item['original_split'] == 'test']

    print(f"Original training data count: {len(original_training_data)}")
    print(f"Original test data count: {len(original_test_data)}")

    train_split_info = []
    val_split_info = []

    # Split original training data into our train and validation sets
    if original_training_data:
        # Extract file paths and labels for splitting
        orig_train_files = [item['file_path'] for item in original_training_data]
        orig_train_labels = [item['label'] for item in original_training_data]

        # Perform the split
        # Ensure there are enough samples for stratification and splitting
        if len(set(orig_train_labels)) < 2 or len(orig_train_labels) < 2:
             print("Warning: Not enough samples or classes in original training data for stratification. Splitting without stratify.")
             files_train, files_val, labels_train, labels_val = train_test_split(
                orig_train_files, orig_train_labels,
                test_size=(val_ratio_orig_train / (train_ratio_orig_train + val_ratio_orig_train)),
                random_state=42
            )
        else:
            files_train, files_val, labels_train, labels_val = train_test_split(
                orig_train_files, orig_train_labels,
                test_size=(val_ratio_orig_train / (train_ratio_orig_train + val_ratio_orig_train)), # Calculate test_size relative to the subset
                stratify=orig_train_labels,
                random_state=42
            )

        # Reconstruct the list of dictionaries for the splits
        # Need to look up the original info based on file_path
        original_info_map = {item['file_path']: item for item in original_training_data}

        train_split_info = [original_info_map[f] for f in files_train]
        val_split_info = [original_info_map[f] for f in files_val]

    else:
        print("Warning: No original training data found. Train and validation sets will be empty.")


    # The original test data becomes our test set
    # We only need file_path and label for the organized structure
    test_split_info = [{'file_path': item['file_path'], 'label': item['label'], 'image_id': item['image_id'], 'study_id': item['study_id']} for item in original_test_data]


    splits = {
        'train': train_split_info,
        'validation': val_split_info,
        'test': test_split_info
    }

    print(f"Splitting complete. Train: {len(splits['train'])}, Val: {len(splits['validation'])}, Test: {len(splits['test'])}")
    return splits

# --- Function to Organize Split Files and Generate Metadata ---
def organize_split_files_and_generate_metadata(split_data, organized_base_path, metadata_output_path):
    """
    Organizes the split data (list of dictionaries) into the target directory structure
    and generates a metadata CSV file for the organized data.
    """
    print(f"Organizing split files into: {organized_base_path}")

    all_organized_metadata = []

    for split_name, data_list in split_data.items():
        split_dir = os.path.join(organized_base_path, split_name)
        benign_split_dir = os.path.join(split_dir, 'benign')
        malignant_split_dir = os.path.join(split_dir, 'malignant')

        os.makedirs(benign_split_dir, exist_ok=True)
        os.makedirs(malignant_split_dir, exist_ok=True)

        print(f"Copying files for {split_name}...")
        for image_info in data_list:
            original_file_path = image_info['file_path']
            label = image_info['label']
            image_id = image_info['image_id']
            study_id = image_info['study_id'] # Keep study_id for metadata

            # Determine the target directory
            target_dir = benign_split_dir if label == 0 else malignant_split_dir
            # Use original image_id as the new file name
            new_file_name = f'{image_id}.png' # Assuming .png, adjust if needed
            organized_file_path = os.path.join(target_dir, new_file_name)

            try:
                shutil.copy(original_file_path, organized_file_path)

                # Add information to the metadata list
                metadata_entry = {
                    'image_id': image_id,
                    'study_id': study_id,
                    'label': label,
                    'split': split_name, # This is our new split (train, val, test)
                    'organized_file_path': organized_file_path,
                    'original_file_path': original_file_path,
                    # Include other relevant info from original metadata if available
                    'breast_birads': image_info.get('breast_birads'),
                    'breast_density': image_info.get('breast_density'),
                    # Add path to segmentation mask if available and organized separately
                    # 'segmentation_mask_path': 'path/to/mask.png' # TODO: Add logic for mask paths
                }
                all_organized_metadata.append(metadata_entry)

            except FileNotFoundError:
                 print(f"Error: Source file not found during copy: {original_file_path}")
            except Exception as e:
                 print(f"Error copying file {original_file_path} to {organized_file_path}: {e}")


        print(f"Finished copying for {split_name}. Total files: {len(data_list)}")

    # Generate the metadata CSV
    if all_organized_metadata:
        metadata_df = pd.DataFrame(all_organized_metadata)
        # Ensure the output directory exists
        os.makedirs(os.path.dirname(metadata_output_path), exist_ok=True)
        metadata_df.to_csv(metadata_output_path, index=False)
        print(f"\nGenerated organized dataset metadata file: {metadata_output_path}")
    else:
        print("\nNo metadata entries to write. Metadata file not generated.")


# --- Main Execution ---
if __name__ == "__main__":
    # 1. Process Metadata and Get Image Info
    vindr_mammo_image_info_list = get_image_labels_and_paths(
        VINDR_MAMMO_ORIGINAL_PATH,
        BREAST_LEVEL_METADATA_PATH,
        FINDING_METADATA_PATH
    )

    if vindr_mammo_image_info_list:
        # 2. Split Data
        vindr_mammo_splits = split_vindr_mammo_data(
            vindr_mammo_image_info_list,
            TRAIN_RATIO_ORIGINAL_TRAINING,
            VAL_RATIO_ORIGINAL_TRAINING
        )

        # 3. Organize Split Files and Generate Metadata
        organize_split_files_and_generate_metadata(
            vindr_mammo_splits,
            ORGANIZED_DATASET_BASE_PATH,
            ORGANIZED_METADATA_FILE
        )
        print("\nVinDr-Mammo data preparation complete.")
    else:
        print("\nVinDr-Mammo data preparation failed due to issues processing metadata or finding images.")


In [None]:
!rm -rf /kaggle/working/data/mammo/images

# Huggingface Upload



## Multimodal Breast Cancer Imaging Dataset

This dataset combines publicly available mammography and ultrasound imaging data for the development and evaluation of multimodal deep learning models for breast cancer diagnosis, with a focus on interpretability and applicability in resource-limited settings.

## Dataset Sources

This dataset is compiled from the following publicly available sources:

* **VinDr-Mammo:** A large-scale benchmark dataset for computer-aided detection and diagnosis in full-field digital mammography.
    * **Source:** [https://doi.org/10.13026/br2v-7517](https://doi.org/10.13026/br2v-7517)
    * **Reference:** Pham, H. H., Nguyen-Trung, H., & Nguyen, H. Q. (2022). VinDr-Mammo: A large-scale benchmark dataset for computer-aided detection and diagnosis in full-field digital mammography (version 1.0.0). *PhysioNet*.

* **BUSI (Breast Ultrasound Images) Dataset:** A well-curated collection of breast ultrasound images.
    * **Source:** [https://doi.org/10.1016/j.dib.2019.104863](https://doi.org/10.1016/j.dib.2019.104863)
    * **Reference:** Al-Dhabyani, W., Gomaa, M., Khaled, H., & Fahmy, A. (2020). Dataset of breast ultrasound images. *Data in Brief*, 28, 104863.

* **KAU-BCMD (King Abdulaziz University Breast Cancer Mammogram Dataset):** Includes paired mammography and ultrasound images for a subset of cases, valuable for multimodal evaluation.
    * **Source:** [https://doi.org/10.3390/data6110111](https://doi.3390/data6110111)
    * **Reference:** Alsolami, A. S., Shalash, W., Alsaggaf, W., Ashoor, S., Refaat, H., & Elmogy, M. (2021). King Abdulaziz University Breast Cancer Mammogram Dataset (KAU-BCMD). *Data*, 6(11), 111.

## Dataset Structure

The dataset is organized into folders based on modality and data split. The proposed structure is as follows:
```
.mmibc/
├── mammo/
│   ├── train/
│   │   ├── benign/
│   │   └── malignant/
│   ├── validation/
│   │   ├── benign/
│   │   └── malignant/
│   └── test/
│       ├── benign/
│       └── malignant/
├── busi/
│   ├──images/
│   │   |── train/
│   │   │       ├── benign/
│   │   │       └── malignant/
│   │   ├── validation/
│   │   │       ├── benign/
│   │   │       └── malignant/
│   │   └── test/
│   │           ├── benign/
│   │           └── malignant/
│   ├──masks/
│   │   |── train/
│   │   │       ├── benign/
│   │   │       └── malignant/
│   │   ├── validation/
│   │   │       ├── benign/
│   │   │       └── malignant/
│   │   └── test/
│   │           ├── benign/
│   │           └── malignant/
│
└── kau_bcmd/
│    ├── train/        (Paired mammo and US images, potentially linked by ID)
│    ├── validation/   (Paired mammo and US images)
│    └── test/         (Paired mammo and US images)
```

Within each split (`train`, `validation`, `test`), images are further categorized by diagnosis (`benign`, `malignant`). The KAU-BCMD dataset, being used for multimodal evaluation, will contain paired images. The structure for KAU-BCMD might need to include a mapping or consistent naming convention to link the mammogram and ultrasound images for the same patient/case.

## Dataset Contents

* **Images:** DICOM or common image formats (PNG, JPG) for mammography and ultrasound images.
* **Annotations:** Original annotations provided with the datasets, which may include bounding boxes, lesion types, and BI-RADS assessments.
* **Labels:** Binary labels indicating the diagnosis (benign or malignant).

## Usage

This dataset is intended for training and evaluating deep learning models for breast cancer diagnosis using multimodal imaging. The structured format facilitates easy loading and processing using standard deep learning libraries and Hugging Face's `datasets` library.

## License

Please refer to the original licenses of the VinDr-Mammo, BUSI, and KAU-BCMD datasets for specific terms of use.

## Citation

Please cite the original dataset sources when using this compiled dataset in your research.

```bibtex
@article{pham2022vindr,
  title={VinDr-Mammo: A large-scale benchmark dataset for computer-aided detection and diagnosis in full-field digital mammography},
  author={Pham, Hai H and Nguyen-Trung, Hieu and Nguyen, Hoang Q},
  journal={PhysioNet},
  year={2022}
}

@article{al2020dataset,
  title={Dataset of breast ultrasound images},
  author={Al-Dhabyani, Waleed and Gomaa, Mohamed and Khaled, Hossam and Fahmy, Amr},
  journal={Data in Brief},
  volume={28},
  pages={104863},
  year={2020},
  publisher={Elsevier}
}

@article{alsolami2021king,
  title={King Abdulaziz University Breast Cancer Mammogram Dataset (KAU-BCMD)},
  author={Alsolami, Abdulrahman S and Shalash, Walid and Alsaggaf, Walid and Ashoor, Sara and Refaat, Hesham and Elmogy, Mohammed},
  journal={Data},
  volume={6},
  number={11},
  pages={111},
  year={2021},
  publisher={MDPI}
}
```

In [None]:
# # prompt: code to add the markdown content to the README.md file. write the code in full, i would add the markdown content in a different section

# from google.colab import userdata
# import os


# def add_markdown_to_readme(readme_path, markdown_content):
#     """Adds markdown content to the README.md file."""
#     try:
#         with open(readme_path, 'a') as f:  # Open in append mode
#             f.write("\n")  # Add a newline for separation
#             f.write(markdown_content)
#         print(f"Markdown content successfully added to {readme_path}")
#     except FileNotFoundError:
#         print(f"README.md file not found at {readme_path}")
#     except Exception as e:
#         print(f"An error occurred: {e}")


# if __name__ == "__main__":
#     # Markdown content to add (replace with your actual content)
#     markdown_content = """
# # Multimodal Breast Cancer Imaging Dataset

# This dataset combines publicly available mammography and ultrasound imaging data for the development and evaluation of multimodal deep learning models for breast cancer diagnosis, with a focus on interpretability and applicability in resource-limited settings.

# ## Dataset Sources

# This dataset is compiled from the following publicly available sources:

# * **VinDr-Mammo:** A large-scale benchmark dataset for computer-aided detection and diagnosis in full-field digital mammography.
#     * **Source:** [https://doi.org/10.13026/br2v-7517](https://doi.org/10.13026/br2v-7517)
#     * **Reference:** Pham, H. H., Nguyen-Trung, H., & Nguyen, H. Q. (2022). VinDr-Mammo: A large-scale benchmark dataset for computer-aided detection and diagnosis in full-field digital mammography (version 1.0.0). *PhysioNet*.

# * **BUSI (Breast Ultrasound Images) Dataset:** A well-curated collection of breast ultrasound images.
#     * **Source:** [https://doi.org/10.1016/j.dib.2019.104863](https://doi.org/10.1016/j.dib.2019.104863)
#     * **Reference:** Al-Dhabyani, W., Gomaa, M., Khaled, H., & Fahmy, A. (2020). Dataset of breast ultrasound images. *Data in Brief*, 28, 104863.

# * **KAU-BCMD (King Abdulaziz University Breast Cancer Mammogram Dataset):** Includes paired mammography and ultrasound images for a subset of cases, valuable for multimodal evaluation.
#     * **Source:** [https://doi.org/10.3390/data6110111](https://doi.3390/data6110111)
#     * **Reference:** Alsolami, A. S., Shalash, W., Alsaggaf, W., Ashoor, S., Refaat, H., & Elmogy, M. (2021). King Abdulaziz University Breast Cancer Mammogram Dataset (KAU-BCMD). *Data*, 6(11), 111.

# ## Dataset Structure

# The dataset is organized into folders based on modality and data split. The proposed structure is as follows:
# ```
# .mmibc/
# ├── mammo/
# │   ├── train/
# │   │   ├── benign/
# │   │   └── malignant/
# │   ├── validation/
# │   │   ├── benign/
# │   │   └── malignant/
# │   └── test/
# │       ├── benign/
# │       └── malignant/
# ├── busi/
# │   ├──images/
# │   │   |── train/
# │   │   │       ├── benign/
# │   │   │       └── malignant/
# │   │   ├── validation/
# │   │   │       ├── benign/
# │   │   │       └── malignant/
# │   │   └── test/
# │   │           ├── benign/
# │   │           └── malignant/
# │   ├──masks/
# │   │   |── train/
# │   │   │       ├── benign/
# │   │   │       └── malignant/
# │   │   ├── validation/
# │   │   │       ├── benign/
# │   │   │       └── malignant/
# │   │   └── test/
# │   │           ├── benign/
# │   │           └── malignant/
# │
# └── kau_bcmd/
# │    ├── train/        (Paired mammo and US images, potentially linked by ID)
# │    ├── validation/   (Paired mammo and US images)
# │    └── test/         (Paired mammo and US images)
# ```

# Within each split (`train`, `validation`, `test`), images are further categorized by diagnosis (`benign`, `malignant`). The KAU-BCMD dataset, being used for multimodal evaluation, will contain paired images. The structure for KAU-BCMD might need to include a mapping or consistent naming convention to link the mammogram and ultrasound images for the same patient/case.

# ## Dataset Contents

# * **Images:** DICOM or common image formats (PNG, JPG) for mammography and ultrasound images.
# * **Annotations:** Original annotations provided with the datasets, which may include bounding boxes, lesion types, and BI-RADS assessments.
# * **Labels:** Binary labels indicating the diagnosis (benign or malignant).

# ## Usage

# This dataset is intended for training and evaluating deep learning models for breast cancer diagnosis using multimodal imaging. The structured format facilitates easy loading and processing using standard deep learning libraries and Hugging Face's `datasets` library.

# ## License

# Please refer to the original licenses of the VinDr-Mammo, BUSI, and KAU-BCMD datasets for specific terms of use.

# ## Citation

# Please cite the original dataset sources when using this compiled dataset in your research.

# ```bibtex
# @article{pham2022vindr,
#   title={VinDr-Mammo: A large-scale benchmark dataset for computer-aided detection and diagnosis in full-field digital mammography},
#   author={Pham, Hai H and Nguyen-Trung, Hieu and Nguyen, Hoang Q},
#   journal={PhysioNet},
#   year={2022}
# }

# @article{al2020dataset,
#   title={Dataset of breast ultrasound images},
#   author={Al-Dhabyani, Waleed and Gomaa, Mohamed and Khaled, Hossam and Fahmy, Amr},
#   journal={Data in Brief},
#   volume={28},
#   pages={104863},
#   year={2020},
#   publisher={Elsevier}
# }

# @article{alsolami2021king,
#   title={King Abdulaziz University Breast Cancer Mammogram Dataset (KAU-BCMD)},
#   author={Alsolami, Abdulrahman S and Shalash, Walid and Alsaggaf, Walid and Ashoor, Sara and Refaat, Hesham and Elmogy, Mohammed},
#   journal={Data},
#   volume={6},
#   number={11},
#   pages={111},
#   year={2021},
#   publisher={MDPI}
# }
# ```
#     """

#     # Path to the README.md file in your cloned repository
#     readme_path = "/content/mmibc/README.md"
#     add_markdown_to_readme(readme_path, markdown_content)


In [None]:
# !python /content/MMIBC/src/data_handling/huggingface_upload.py

# Data Preparation Procedure

In [1]:
import os
import pandas as pd
import logging
from pathlib import Path

# Set up basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Configuration - Update these paths and details ---
# Base directory where your organized data is located
base_organized_dir = './mmibc' # Update this path

# Directories for preprocessed images within the organized structure
# Based on your provided structure: ./mmibc/busi/images/split/class/
busi_images_base_dir = os.path.join(base_organized_dir, 'busi', 'images')
# Based on your provided structure: ./mmibc/vindr_mammo/split/class/
vindr_mammo_base_dir = os.path.join(base_organized_dir, 'vindr_mammo')


# Output path for the combined unimodal metadata CSV
combined_unimodal_metadata_output_path = os.path.join(base_organized_dir, 'combined_unimodal_metadata.csv')

# Define the splits you have in your organized data
splits = ['train', 'validation', 'test']
# Define the classes you have in your organized data
classes = ['benign', 'malignant'] # Assuming binary classification

# --- Generate Metadata for Unimodal Datasets ---
unimodal_metadata_list = []

logger.info("Generating unimodal metadata from organized BUSI and VinDr-Mammo directories...")

# Process BUSI images
modality = 'ultrasound' # Or 'BUSI'
for split in splits:
    for class_name in classes:
        class_dir = os.path.join(busi_images_base_dir, split, class_name)
        if not os.path.exists(class_dir):
            logger.warning(f"Directory not found: {class_dir}. Skipping.")
            continue

        for filename in os.listdir(class_dir):
            if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                image_path = os.path.join(class_dir, filename)
                label = 1 if class_name == 'malignant' else 0 # Assuming 'malignant' is class 1

                unimodal_metadata_list.append({
                    'image_path': image_path,
                    'label': label,
                    'split': split,
                    'modality': modality,
                    'class_name': class_name,
                    'filename': filename
                    # Add patient_id if you can extract it from the filename or structure
                    # 'patient_id': extract_patient_id(filename) # You need to implement this
                })
logger.info(f"Collected {len(unimodal_metadata_list)} entries from BUSI.")


# Process VinDr-Mammo images
modality = 'mammography' # Or 'VinDr-Mammo'
for split in splits:
    for class_name in classes:
        class_dir = os.path.join(vindr_mammo_base_dir, split, class_name)
        if not os.path.exists(class_dir):
            logger.warning(f"Directory not found: {class_dir}. Skipping.")
            continue

        for filename in os.listdir(class_dir):
            if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                image_path = os.path.join(class_dir, filename)
                label = 1 if class_name == 'malignant' else 0 # Assuming 'malignant' is class 1

                unimodal_metadata_list.append({
                    'image_path': image_path,
                    'label': label,
                    'split': split,
                    'modality': modality,
                    'class_name': class_name,
                    'filename': filename
                    # Add patient_id if you can extract it from the filename or structure
                    # 'patient_id': extract_patient_id(filename) # You need to implement this
                })
logger.info(f"Collected {len(unimodal_metadata_list) - len([e for e in unimodal_metadata_list if e['modality'] == 'ultrasound'])} entries from VinDr-Mammo.")


# Create DataFrame
unimodal_metadata_df = pd.DataFrame(unimodal_metadata_list)

# --- Save the Combined Unimodal Metadata CSV ---
Path(os.path.dirname(combined_unimodal_metadata_output_path)).mkdir(parents=True, exist_ok=True) # Ensure output directory exists
unimodal_metadata_df.to_csv(combined_unimodal_metadata_output_path, index=False)

logger.info(f"Combined unimodal metadata CSV generated successfully at: {combined_unimodal_metadata_output_path}")
logger.info(f"Total unimodal samples found: {len(unimodal_metadata_df)}")

# --- Note on Multimodal Training ---
logger.warning("\n--- IMPORTANT NOTE FOR MULTIMODAL TRAINING ---")
logger.warning("This generated metadata lists UNIMODAL images (either mammography or ultrasound).")
logger.warning("The MultiModalBreastCancerDataset requires PAIRED mammogram and ultrasound images for the same patient/case in each row.")
logger.warning("To train a multimodal model, you need a dataset (like KAU-BCMD) that provides paired images, or you need to create pairs from unimodal datasets if possible (which is generally not feasible without specific pairing information).")
logger.warning("You will need a different metadata structure and potentially a different Dataset class if you intend to train unimodal models or use this data for unimodal pre-training.")



In [2]:
df = pd.read_csv("/kaggle/working/mmibc/combined_unimodal_metadata.csv")

In [3]:
df.head()

Unnamed: 0,image_path,label,split,modality,class_name,filename
0,./mmibc/busi/images/train/benign/benign (246).png,0,train,ultrasound,benign,benign (246).png
1,./mmibc/busi/images/train/benign/benign (127).png,0,train,ultrasound,benign,benign (127).png
2,./mmibc/busi/images/train/benign/benign (135).png,0,train,ultrasound,benign,benign (135).png
3,./mmibc/busi/images/train/benign/benign (313).png,0,train,ultrasound,benign,benign (313).png
4,./mmibc/busi/images/train/benign/benign (337).png,0,train,ultrasound,benign,benign (337).png


In [4]:
# Import necessary libraries
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import cv2
import logging
import random
import numpy as np # Needed for DataAugmentation and potential numpy operations
from PIL import Image

# Set up basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Assume DataAugmentation Class is Available ---
# Copy the DataAugmentation class from model/dino_nn.py into a cell above this,
# or ensure model/dino_nn.py is in your Python path and import it.
# from model.dino_nn import DataAugmentation

class DataAugmentation:
    """
    Data augmentation for DINOv2-like training.
    Creates multiple crops of different sizes.
    Can be adapted for modality-specific transforms if needed.
    """
    def __init__(self, global_crops_scale=(0.5, 1.0), local_crops_scale=(0.2, 0.5),
                 local_crops_number=8, global_size=224, local_size=96, in_chans=1):
        self.global_crops_scale = global_crops_scale
        self.local_crops_scale = local_crops_scale
        self.local_crops_number = local_crops_number
        self.global_size = global_size
        self.local_size = local_size
        self.in_chans = in_chans

        if self.in_chans == 1:
             mean = [0.485]
             std = [0.229]
        else:
             mean = [0.485, 0.456, 0.406]
             std = [0.229, 0.224, 0.225]

        # Define base transforms that are common across modalities
        # Modality-specific adjustments would happen within the __call__ method
        self.base_global_transform = transforms.Compose([
            transforms.RandomResizedCrop(global_size, scale=global_crops_scale,
                                         interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1), # Optional, potentially less suitable for medical
        ])

        self.base_local_transform = transforms.Compose([
            transforms.RandomResizedCrop(local_size, scale=local_crops_scale,
                                         interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
             # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1), # Optional, potentially less suitable for medical
        ])

        self.to_tensor_normalize = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=mean, std=std)
        ])


    def __call__(self, image: Image.Image, modality: str):
        """
        Apply different transformations to create multiple views (crops) of the same image,
        potentially with modality-specific adjustments.

        Args:
            image (PIL.Image): The input image (expected as PIL Image).
            modality (str): The modality of the image ('mammography' or 'ultrasound').

        Returns:
            list: A list of transformed image tensors (crops).
        """
        crops = []

        # --- Apply Modality-Specific Adjustments (Example) ---
        # You can add modality-specific augmentations here before the base transforms
        # if needed. For example:
        if modality == 'mammography':
            # Apply specific augmentations for mammography
            image = transforms.functional.adjust_contrast(image, contrast_factor=random.uniform(0.8, 1.2))
        elif modality == 'ultrasound':
            # Apply specific augmentations for ultrasound
            image = transforms.functional.adjust_brightness(image, brightness_factor=random.uniform(0.8, 1.2))
        # --- End Modality-Specific Adjustments ---


        # Apply global transforms (2 crops)
        for _ in range(2):
            crop = self.base_global_transform(image)
            crops.append(self.to_tensor_normalize(crop)) # Apply ToTensor and Normalize after augmentation

        # Apply local transforms (local_crops_number crops)
        for _ in range(self.local_crops_number):
            crop = self.base_local_transform(image)
            crops.append(self.to_tensor_normalize(crop)) # Apply ToTensor and Normalize after augmentation

        return crops




In [5]:
class UnimodalBCDataset(Dataset):
    """
    Custom Dataset for loading unimodal medical images for pre-training.
    Loads data based on a metadata CSV file and applies multi-crop augmentation.
    Includes modality information.
    """
    def __init__(self, metadata_path: str, transform: DataAugmentation):
        """
        Args:
            metadata_path: Path to the CSV file containing unimodal image metadata.
                           Expected columns: 'image_path', 'modality'
            transform: The DataAugmentation transform for multi-cropping.
        """
        logger.info(f"Loading unimodal dataset from metadata: {metadata_path}")
        if not os.path.exists(metadata_path):
            logger.error(f"Metadata file not found at {metadata_path}")
            raise FileNotFoundError(f"Metadata file not found at {metadata_path}")

        self.metadata_df = pd.read_csv(metadata_path)
        self.transform = transform

        # Validate required columns
        required_cols = ['image_path', 'modality']
        if not all(col in self.metadata_df.columns for col in required_cols):
            logger.error(f"Metadata CSV must contain columns: {required_cols}")
            raise ValueError(f"Metadata CSV must contain columns: {required_cols}")

        # Filter out rows with missing file paths or modality info
        initial_samples = len(self.metadata_df)
        self.metadata_df.dropna(subset=required_cols, inplace=True)
        if len(self.metadata_df) < initial_samples:
            logger.warning(f"Removed {initial_samples - len(self.metadata_df)} samples due to missing data in metadata.")


        logger.info(f"Unimodal dataset loaded with {len(self.metadata_df)} samples.")

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample_info = self.metadata_df.iloc[idx]
        img_path = sample_info['image_path']
        modality = sample_info['modality'] # Get modality information
        # Labels are not used in DINO pre-training, but Dataset requires returning something
        dummy_label = 0

        try:
            # Load image (assuming grayscale based on your model config)
            # Use cv2.IMREAD_GRAYSCALE to ensure 1 channel
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)

            if img is None:
                 logger.error(f"Could not load image: {img_path}. Returning random sample.")
                 return self.__getitem__(random.randint(0, len(self) - 1)) # Return random sample on error

            # Convert numpy array to PIL Image for torchvision transforms
            img_pil = Image.fromarray(img, mode='L') # 'L' for grayscale


        except Exception as e:
            logger.error(f"Error loading image for index {idx} ({img_path}): {e}. Returning random sample.")
            return self.__getitem__(random.randint(0, len(self) - 1))


        # Apply the DataAugmentation transform (multi-cropping)
        # Pass the modality information to the transform
        crops = self.transform(img_pil, modality)

        # Return the list of crops and a dummy label
        return crops, torch.tensor(dummy_label)


In [6]:
unimodal_metadata_csv_path = './mmibc/combined_unimodal_metadata.csv' # Update this path

# Define parameters for Data Augmentation (should match your model/pre-training config)
global_size = 224 # Base size of global crops (should match ViT input size)
local_size = 96   # Size of local crops
local_crops_number = 8 # Number of local crops
in_chans = 1 # Grayscale


In [7]:
# Create the DataAugmentation transform instance
dino_transform = DataAugmentation(
    global_size=global_size,
    local_size=local_size,
    local_crops_number=local_crops_number,
    in_chans=in_chans
)

In [8]:
# Create the unimodal dataset instance
unimodal_dataset = UnimodalBCDataset(
    metadata_path=unimodal_metadata_csv_path,
    transform=dino_transform
)

In [9]:
# Define DataLoader parameters
batch_size = 64 # Batch size for pre-training (can be larger)
num_workers = 8 # Adjust based on your system's capabilities

# Create the DataLoader instance
unimodal_dataloader = DataLoader(
    unimodal_dataset,
    batch_size=batch_size,
    shuffle=True, # Shuffle data for training
    num_workers=num_workers,
    pin_memory=True, # Pin memory for faster GPU transfer
    drop_last=True # Drop the last incomplete batch (common in pre-training)
)

# --- You can now iterate through this dataloader for unimodal pre-training ---
print(f"Number of unimodal samples in dataset: {len(unimodal_dataset)}")
print(f"Number of batches per epoch for pre-training: {len(unimodal_dataloader)}")

# Example of getting one batch
try:
    crops_batch, dummy_labels = next(iter(unimodal_dataloader))
    print(f"\nExample batch:")
    print(f"Number of crops per image: {len(crops_batch)}")
    print(f"Shape of first global crop batch: {crops_batch[0].shape}")
    print(f"Shape of second global crop batch: {crops_batch[1].shape}")
    print(f"Shape of first local crop batch: {crops_batch[2].shape}")
    print(f"Dummy labels shape: {dummy_labels.shape}")

except StopIteration:
    print("\nDataLoader is empty.")

Number of unimodal samples in dataset: 20906
Number of batches per epoch for pre-training: 326





Example batch:
Number of crops per image: 10
Shape of first global crop batch: torch.Size([64, 1, 224, 224])
Shape of second global crop batch: torch.Size([64, 1, 224, 224])
Shape of first local crop batch: torch.Size([64, 1, 96, 96])
Dummy labels shape: torch.Size([64])


In [10]:
# prompt: show tensorboard within colab.

%load_ext tensorboard
%tensorboard --logdir /content/dino_unimodal_pretraining_output


<IPython.core.display.Javascript object>

In [11]:
# --- Necessary Imports ---
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch import optim
import torch.distributed as dist
import numpy as np
import argparse
import os
import shutil # For save_checkpoint
import time # For timing
from torch.utils.tensorboard import SummaryWriter # For TensorBoard logging
import datetime # For MetricLogger eta
import pandas as pd # For metadata handling
import cv2 # For image loading
import logging # For logging
import random # For random sample loading on error
from PIL import Image # Import PIL for image handling
from pathlib import Path # For directory creation

# Set up logging (adjust level as needed)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# --- Utility Functions (Copied from utils.py) ---

class MetricLogger:
    """
    Utility class for logging metrics during training
    """
    def __init__(self, delimiter="\t"):
        self.meters = {}
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if k not in self.meters:
                self.meters[k] = SmoothedValue()
            if isinstance(v, torch.Tensor):
                v = v.item()
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        # Default behavior
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(f"{name}: {meter}")
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        # For multi-GPU training - Placeholder
        pass

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if header is not None:
            print(header)

        start_time = time.time()
        end = time.time()
        for obj in iterable:
            data_time = time.time() - end
            yield obj
            batch_time = time.time() - end
            end = time.time()
            if i % print_freq == 0:
                eta_seconds = (len(iterable) - i) * (batch_time + data_time) / 2
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    f"{header} [{i}/{len(iterable)}]\t"
                    f"eta: {eta_string}\t"
                    f"time: {batch_time:.4f}\t"
                    f"data: {data_time:.4f}\t"
                    f"{self}"
                )
            i += 1
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print(f'{header} Total time: {total_time_str} ({total_time / len(iterable):.4f}s / it)')


2025-05-19 10:58:44.797079: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747652324.819559     531 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747652324.826486     531 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [12]:
class SmoothedValue:
    """
    Track a series of values and provide access to smoothed values over a window
    """
    def __init__(self, window_size=20):
        self.window_size = window_size
        self.reset()

    def reset(self):
        self.values = []
        self.total = 0.0
        self.count = 0

    def update(self, value):
        self.values.append(value)
        self.total += value
        self.count += 1
        if len(self.values) > self.window_size:
            self.total -= self.values.pop(0)

    @property
    def median(self):
        return np.median(self.values)

    @property
    def avg(self):
        return np.mean(self.values)

    @property
    def global_avg(self):
        return self.total / self.count

    def __str__(self):
        return f"{self.global_avg:.4f} ({self.avg:.4f})"

def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    """
    Cosine scheduler with warmup for updating parameters like teacher momentum
    """
    warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_epochs * niter_per_ep)

    iters = np.arange(epochs * niter_per_ep - warmup_epochs * niter_per_ep)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule

def save_checkpoint(state, is_best, filename='checkpoint.pth', save_dir='.'):
    """
    Saves model checkpoint.
    """
    filepath = os.path.join(save_dir, filename)
    torch.save(state, filepath)
    logger.info(f"Checkpoint saved to {filepath}")
    if is_best:
        best_filepath = os.path.join(save_dir, 'model_best.pth')
        shutil.copyfile(filepath, best_filepath)
        logger.info(f"Best model saved to {best_filepath}")

In [13]:
class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
    """
    def __init__(self, drop_prob: float = 0.):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output


class PatchEmbedding(nn.Module):
    """
    2D Image to Patch Embedding with positional encoding
    Handles variable input image sizes by interpolating positional embeddings.
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        # Calculate the number of patches for the *expected* input size
        self.num_patches = (img_size // patch_size) ** 2

        # Linear projection using Conv2d
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

        # Learnable CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Learnable positional embedding for the *expected* number of patches + CLS token
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))

        # Initialize parameters
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input image tensor (B, C, H, W)

        Returns:
            torch.Tensor: Patch embeddings with CLS token and positional encoding (B, actual_num_patches + 1, embed_dim)
        """
        B, C, H, W = x.shape

        # Project patches
        x = self.proj(x)  # (B, embed_dim, H_out, W_out)
        _, _, H_out, W_out = x.shape
        actual_num_patches = H_out * W_out

        # Flatten spatial dimensions and transpose
        x = x.flatten(2).transpose(1, 2) # (B, actual_num_patches, embed_dim)

        # Interpolate positional embedding if the number of patches doesn't match the expected size
        # This handles cases where input H or W are not perfectly divisible by patch_size,
        # or when using crops of different sizes (global vs local).
        if actual_num_patches != self.num_patches:
            # Positional embedding excluding the CLS token
            pos_embed_patches = self.pos_embed[:, 1:] # (1, self.num_patches, embed_dim)
            # Reshape to spatial dimensions for interpolation
            pos_embed_spatial = pos_embed_patches.reshape(1, int(math.sqrt(self.num_patches)), int(math.sqrt(self.num_patches)), -1).permute(0, 3, 1, 2) # (1, embed_dim, sqrt(num_patches), sqrt(num_patches))
            # Interpolate to match the actual spatial dimensions after convolution
            pos_embed_interpolated = F.interpolate(pos_embed_spatial, size=(H_out, W_out), mode='bicubic', align_corners=False)
            # Reshape back to sequence format
            pos_embed_interpolated = pos_embed_interpolated.flatten(2).transpose(1, 2) # (1, actual_num_patches, embed_dim)

            # Add the interpolated positional embedding to the patches
            x = x + pos_embed_interpolated

            # Add CLS token and its corresponding positional embedding
            cls_token = self.cls_token.expand(B, -1, -1)
            # The positional embedding for the CLS token is the first element of self.pos_embed
            pos_embed_cls = self.pos_embed[:, :1] # (1, 1, embed_dim)
            cls_token = cls_token + pos_embed_cls # Add positional embedding to CLS token
            x = torch.cat((cls_token, x), dim=1) # Concatenate CLS token with patches

        else:
            # If the number of patches matches the expected size, use the original positional embedding
            # Add CLS token
            cls_token = self.cls_token.expand(B, -1, -1)
            # Concatenate CLS token with patch embeddings
            x = torch.cat((cls_token, x), dim=1) # (B, num_patches + 1, embed_dim)

            # Add positional embedding (including the CLS token's position)
            x = x + self.pos_embed


        return x


class MultiHeadAttention(nn.Module):
    """
    Multi-head Self-Attention module
    """
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5 # Scaling factor for attention scores

        # Linear projection for Query, Key, Value
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop) # Dropout for attention scores
        self.proj = nn.Linear(dim, dim) # Output projection
        self.proj_drop = nn.Dropout(proj_drop) # Dropout for output

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor (B, N, C) where N is sequence length, C is dimension

        Returns:
            tuple: Output tensor (B, N, C) and attention weights (B, num_heads, N, N)
        """
        B, N, C = x.shape

        # Project and reshape Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # Separate Q, K, V
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        # Compute attention scores
        # (B, num_heads, N, head_dim) @ (B, num_heads, head_dim, N) -> (B, num_heads, N, N)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # Apply softmax to get attention probabilities
        attn = attn.softmax(dim=-1)
        # Apply attention dropout
        attn = self.attn_drop(attn)

        # Apply attention to values
        # (B, num_heads, N, N) @ (B, num_heads, N, head_dim) -> (B, num_heads, N, head_dim)
        x = (attn @ v).transpose(1, 2) # Transpose heads and sequence length
        # Reshape back to original dimension
        x = x.reshape(B, N, C)

        # Apply output projection and dropout
        x = self.proj(x)
        x = self.proj_drop(x)

        return x, attn

class MLP(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer, etc.
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        # Two linear layers with activation and dropout
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor (B, N, in_features)

        Returns:
            torch.Tensor: Output tensor (B, N, out_features)
        """
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class LayerScale(nn.Module):
    """
    Layer scale from CaiT and for DINOv2
    """
    def __init__(self, dim, init_values=1e-5):
        super().__init__()
        # Learnable parameter gamma initialized to small values
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor (B, N, dim)

        Returns:
            torch.Tensor: Scaled output tensor
        """
        return self.gamma * x

class ViTBlock(nn.Module):
    """
    Vision Transformer Block with LayerScale and DropPath
    """
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True,
                 drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm, layer_scale_init_value=1e-5):
        super().__init__()
        # First normalization layer
        self.norm1 = norm_layer(dim)
        # Multi-Head Self-Attention module
        self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
                                    attn_drop=attn_drop, proj_drop=drop)
        # LayerScale after attention (optional)
        self.ls1 = LayerScale(dim, init_values=layer_scale_init_value) if layer_scale_init_value > 0 else nn.Identity()

        # Second normalization layer
        self.norm2 = norm_layer(dim)
        # MLP (Feed-Forward) module
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.ls2 = LayerScale(dim, init_values=layer_scale_init_value) if layer_scale_init_value > 0 else nn.Identity()

        # Stochastic depth (DropPath)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor (B, N, dim)

        Returns:
            tuple: Output tensor (B, N, dim) and attention weights (B, num_heads, N, N) from the attention block
        """
        # Attention block with residual connection, normalization, LayerScale, and DropPath
        res = x
        x_norm = self.norm1(x)
        x_attn, attn = self.attn(x_norm)
        x_attn = self.ls1(x_attn)
        x = res + self.drop_path(x_attn)

        # MLP block with residual connection, normalization, LayerScale, and DropPath
        res = x
        x_norm = self.norm2(x)
        x_mlp = self.mlp(x_norm)
        x_mlp = self.ls2(x_mlp)
        x = res + self.drop_path(x_mlp)

        return x, attn

class ViTEncoder(nn.Module):
    """
    Vision Transformer Encoder (Backbone)
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=1, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0.1, norm_layer=nn.LayerNorm, layer_scale_init_value=1e-5):
        super().__init__()

        # Patch embedding layer
        self.patch_embed = PatchEmbedding(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim
        )
        num_patches = self.patch_embed.num_patches

        # Transformer blocks (encoder layers)
        # Stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            ViTBlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
                norm_layer=norm_layer, layer_scale_init_value=layer_scale_init_value
            )
            for i in range(depth)
        ])

        # Final normalization layer
        self.norm = norm_layer(embed_dim)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input image tensor (B, C, H, W)

        Returns:
            tuple: CLS token representation (B, embed_dim), all tokens (B, num_patches + 1, embed_dim),
                   list of attention maps from each block
        """
        # Apply patch embedding
        x = self.patch_embed(x)

        # Store attention maps from each block
        attn_maps = []

        # Apply transformer blocks sequentially
        for block in self.blocks:
            x, attn = block(x)
            attn_maps.append(attn)

        # Apply final normalization
        x = self.norm(x)

        # Extract the CLS token (the first token in the sequence)
        cls_token = x[:, 0]

        # Return CLS token, all tokens, and attention maps
        return cls_token, x, attn_maps


In [14]:
class DINOLoss(nn.Module):
    """
    DINO loss from the paper "Emerging Properties in Self-Supervised Vision Transformers"
    with improvements from DINOv2
    """
    def __init__(self, out_dim, teacher_temp=0.04, student_temp=0.1, center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.teacher_temp = teacher_temp
        self.center_momentum = center_momentum
        # Register buffer for the center, initialized to zeros
        self.register_buffer("center", torch.zeros(1, out_dim))

    def forward(self, student_output, teacher_output, current_teacher_temp):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.

        Args:
            student_output (list): List of tensors, student outputs for each crop.
            teacher_output (list): List of tensors, teacher outputs for the global crops.
            current_teacher_temp (float): The current temperature for the teacher.

        Returns:
            torch.Tensor: The computed DINO loss.
        """
        # Student outputs for all crops
        student_out = [s / self.student_temp for s in student_output]

        # Teacher outputs for global crops, with current temperature and sharpening
        teacher_out = [t / current_teacher_temp for t in teacher_output]
        teacher_out = [F.softmax(t, dim=-1).detach() for t in teacher_out]

        # Center the teacher outputs
        teacher_out = [t - self.center for t in teacher_out]

        # Compute loss between global crops teacher and all crops student
        total_loss = 0
        n_crops = len(student_output)
        n_global_crops = len(teacher_output) # Should be 2 for DINO

        for s_idx in range(n_crops):
            for t_idx in range(n_global_crops):
                # Loss is cross-entropy between student crop s_idx and teacher crop t_idx
                loss = -torch.sum(teacher_out[t_idx] * F.log_softmax(student_out[s_idx], dim=-1), dim=-1).mean()
                total_loss += loss

        # Average loss over all pairs of student and teacher global crops
        total_loss /= (n_crops * n_global_crops)

        # Update center for teacher output (using only global crops)
        with torch.no_grad():
            self.update_center(torch.cat(teacher_output, dim=0)) # Update center using concatenated global crops

        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        # Calculate batch center
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)

        # All-reduce across processes for distributed training
        if dist.is_initialized():
             dist.all_reduce(batch_center)
             # Divide by total number of samples across all processes
             batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
        else:
             # If not in distributed mode, just divide by batch size
             batch_center = batch_center / len(teacher_output)


        # Update center using momentum
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)


class DINOHead(nn.Module):
    """
    Projection head used for DINO/DINOv2
    """
    def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True):
        super().__init__()

        # MLP projection layers
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, bottleneck_dim),
        )

        # Last FC layer mapping to output dimension
        self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False)

        # Option to normalize the last layer weights
        self.norm_last_layer = norm_last_layer

        # Apply weight initialization
        self.apply(self._init_weights)

        # Normalize last layer weights if requested
        if norm_last_layer:
            nn.init.constant_(self.last_layer.weight, 0) # Initialize to zero for normalization

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor (B, in_dim) - typically the CLS token representation.

        Returns:
            torch.Tensor: Output tensor (B, out_dim) after projection and normalization.
        """
        # Pass through MLP
        x = self.mlp(x)
        # Normalize the output of the bottleneck layer
        x = F.normalize(x, dim=-1, p=2)

        # Apply last layer
        if self.norm_last_layer:
            # Normalize weights before applying linear transformation
            w = self.last_layer.weight.clone()
            w = F.normalize(w, dim=1, p=2)
            x = F.linear(x, w)
        else:
            x = self.last_layer(x)

        return x

class MultiCropWrapper(nn.Module):
    """
    Wrapper for processing multiple crops through the backbone and head
    """
    def __init__(self, backbone: nn.Module, head: nn.Module):
        super().__init__()
        # Backbone network (e.g., ViTEncoder)
        self.backbone = backbone
        # Projection head (e.g., DINOHead)
        self.head = head

    def forward(self, x: list[torch.Tensor]):
        """
        Args:
            x (list): A list of image tensors, where each tensor is a different crop.

        Returns:
            list: A list of tensors, where each tensor is the output of the head for a crop.
        """
        # Process each crop through the backbone and head
        outputs = []
        for crop in x:
            # Get the CLS token from the backbone output
            # Assuming backbone returns (cls_token, all_tokens, attention_maps)
            cls_token, _, _ = self.backbone(crop)
            # Pass the CLS token through the head
            outputs.append(self.head(cls_token))

        return outputs

In [15]:
def train_one_epoch(student: nn.Module, teacher: nn.Module, train_loader: DataLoader, dino_loss: DINOLoss,
                   optimizer: optim.Optimizer, epoch: int, total_epochs: int, writer: SummaryWriter,
                   warmup_teacher_temp_epochs: int = 5, teacher_temp: float = 0.04,
                   momentum_schedule: np.ndarray = None, clip_grad: float = 0.):
    """
    One epoch of DINOv2-like training.
    """
    student.train() # Set student to training mode
    teacher.eval()  # Teacher is in evaluation mode (no gradient updates)

    metric_logger = MetricLogger()
    header = f'Epoch: [{epoch}/{total_epochs}]'

    # Adjust teacher temperature schedule
    # This schedule is applied per epoch
    teacher_temp_schedule = np.concatenate((
        np.linspace(0.07, teacher_temp, warmup_teacher_temp_epochs),
        np.ones(total_epochs - warmup_teacher_temp_epochs) * teacher_temp
    ))
    curr_teacher_temp = teacher_temp_schedule[epoch]


    for it, (images, _) in enumerate(metric_logger.log_every(train_loader, 10, header)):
        # Update weight decay and learning rate (if using a scheduler)
        # This typically happens per iteration in DINO
        it_global = len(train_loader) * epoch + it  # global iteration

        # Example of updating LR (requires an LR scheduler)
        # for i, param_group in enumerate(optimizer.param_groups):
        #     param_group["lr"] = lr_schedule[it_global] # Assuming lr_schedule is defined and accessible


        # Move images (list of crops) to gpu
        # Check if args.gpu is available and not -1 (for CPU)
        device = next(student.parameters()).device # Get device from model parameters
        images = [im.to(device, non_blocking=True) for im in images]


        # Teacher and student forward passes
        # Teacher only processes the two global views
        teacher_output = teacher(images[:2])
        # Student processes all views (global and local)
        student_output = student(images)

        # Loss computation
        loss = dino_loss(student_output, teacher_output, curr_teacher_temp)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients (optional, but common in DINO)
        if clip_grad > 0:
             torch.nn.utils.clip_grad_norm_(student.parameters(), clip_grad)

        optimizer.step()

        # EMA update of the teacher
        # Momentum is typically scheduled per iteration
        with torch.no_grad():
            m = momentum_schedule[it_global]  # momentum parameter
            # Access the underlying module if using DDP
            student_m = student.module if hasattr(student, 'module') else student
            teacher_m = teacher.module if hasattr(teacher, 'module') else teacher
            for param_q, param_k in zip(student_m.parameters(), teacher_m.parameters()):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

        # Log metrics to MetricLogger
        metric_logger.update(loss=loss.item())
        metric_logger.update(teacher_temp=curr_teacher_temp)
        metric_logger.update(momentum=m)
        # Log learning rate if using a scheduler
        # metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        # Log metrics to TensorBoard per step (optional, can be noisy)
        # if writer and (it + 1) % 10 == 0: # Log every 10 batches
        #     step = epoch * len(train_loader) + it
        #     writer.add_scalar('DINO/Loss/Train_Step', loss.item(), step)
        #     writer.add_scalar('DINO/Teacher_Temp/Train_Step', curr_teacher_temp, step)
        #     writer.add_scalar('DINO/Momentum/Train_Step', m, step)
            # if using LR scheduler: writer.add_scalar('DINO/LR/Train_Step', optimizer.param_groups[0]["lr"], step)


    # Log epoch metrics to TensorBoard
    if writer:
        writer.add_scalar('DINO/Loss/Train_Epoch', metric_logger.loss.global_avg, epoch)
        writer.add_scalar('DINO/Teacher_Temp/Train_Epoch', curr_teacher_temp, epoch)
        writer.add_scalar('DINO/Momentum/Train_Epoch', momentum_schedule[it_global], epoch) # Log final momentum of the epoch
        # if using LR scheduler: writer.add_scalar('DINO/LR/Train_Epoch', optimizer.param_groups[0]["lr"], epoch)


    # Return the metric values averaged over the epoch
    metric_logger.synchronize_between_processes()
    # print("Averaged stats:", metric_logger) # MetricLogger prints at the end of log_every
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}



In [16]:
metadata_csv_path = './mmibc/combined_unimodal_metadata.csv' # UPDATE THIS PATH
output_dir = './dino_unimodal_pretraining_output' # Output directory for logs and checkpoints
epochs = 20
batch_size = 16 # Batch size per GPU (adjust based on your hardware)
global_crops_scale = (0.5, 1.0)
local_crops_scale = (0.2, 0.5)
local_crops_number = 8
global_size = 224 # Image size for global crops (should match ViT input)
local_size = 96 # Image size for local crops
in_chans = 1 # Number of input channels (1 for grayscale)
embed_dim = 768 # ViT embedding dimension
depth = 12 # Number of ViT blocks
num_heads = 12 # Number of attention heads
mlp_ratio = 4.
qkv_bias = True
drop_rate = 0.
attn_drop_rate = 0.
drop_path_rate = 0.1
layer_scale_init_value = 1e-5
teacher_temp = 0.04
warmup_teacher_temp_epochs = 5
student_temp = 0.1
center_momentum = 0.9
learning_rate = 5e-4
weight_decay = 0.04
warmup_epochs = 10
clip_grad = 0. # Gradient clipping value (0 for no clipping)
num_workers = 8 # Number of data loading workers (adjust based on your system)
seed = 42
save_freq = 10 # Checkpoint saving frequency (epochs)

In [17]:
# --- Distributed Training Setup (Simplified for Notebook) ---
# This section allows the code to run in both single GPU/CPU and distributed environments.
# In a notebook, you typically run on a single GPU/CPU.
# If running with torch.distributed.launch, the environment variables will be set.

args = argparse.Namespace( # Create a namespace object to mimic argparse args
    metadata_path=metadata_csv_path,
    output_dir=output_dir,
    epochs=epochs,
    batch_size=batch_size,
    global_crops_scale=global_crops_scale,
    local_crops_scale=local_crops_scale,
    local_crops_number=local_crops_number,
    global_size=global_size,
    local_size=local_size,
    in_chans=in_chans,
    embed_dim=embed_dim,
    depth=depth,
    num_heads=num_heads,
    mlp_ratio=mlp_ratio,
    qkv_bias=qkv_bias,
    drop_rate=drop_rate,
    attn_drop_rate=attn_drop_rate,
    drop_path_rate=drop_path_rate,
    layer_scale_init_value=layer_scale_init_value,
    teacher_temp=teacher_temp,
    warmup_teacher_temp_epochs=warmup_teacher_temp_epochs,
    student_temp=student_temp,
    center_momentum=center_momentum,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    warmup_epochs=warmup_epochs,
    clip_grad=clip_grad,
    num_workers=num_workers,
    seed=seed,
    save_freq=save_freq,
    local_rank=0 # Default for single process
)

# Check if distributed environment variables are set
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
    args.rank = int(os.environ["RANK"])
    args.world_size = int(os.environ['WORLD_SIZE'])
    args.gpu = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(args.gpu)
    dist.init_process_group(backend="nccl", init_method="env://", world_size=args.world_size, rank=args.rank)
    logger.info(f"| distributed init: rank {args.rank}, world {args.world_size}, gpu {args.gpu}")
    torch.distributed.barrier() # Wait for all processes to synchronize
else:
    logger.info("Not using distributed training. Running on a single GPU or CPU.")
    args.rank = 0
    args.world_size = 1
    args.gpu = 0 # Assuming GPU 0 if available
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu)
    else:
        args.gpu = -1 # Indicate CPU usage


# Ensure output directory exists (only on rank 0)
if args.rank == 0:
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    # Setup TensorBoard writer (only on rank 0)
    writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 'runs'))
    logger.info(f"TensorBoard logs will be saved to: {os.path.join(args.output_dir, 'runs')}")
else:
    writer = None # Only rank 0 writes to TensorBoard



In [18]:
# --- Model Initialization ---
# Create student and teacher backbones (ViTEncoders)
# These encoders will be trained to handle unimodal data
student_backbone = ViTEncoder(
    img_size=args.global_size, # Use global crop size for backbone input size
    in_chans=args.in_chans,
    embed_dim=args.embed_dim,
    depth=args.depth,
    num_heads=args.num_heads,
    mlp_ratio=args.mlp_ratio,
    qkv_bias=args.qkv_bias,
    drop_rate=args.drop_rate,
    attn_drop_rate=args.attn_drop_rate,
    drop_path_rate=args.drop_path_rate,
    layer_scale_init_value=args.layer_scale_init_value
)
# Teacher is a copy of the student initially
teacher_backbone = copy.deepcopy(student_backbone)

# Create DINO heads for student and teacher
# Output dimension (out_dim) is typically large (e.g., 65536)
dino_out_dim = 65536 # Example output dimension
student_head = DINOHead(in_dim=args.embed_dim, out_dim=dino_out_dim)
teacher_head = DINOHead(in_dim=args.embed_dim, out_dim=dino_out_dim)

# Wrap backbones and heads in MultiCropWrapper
student = MultiCropWrapper(student_backbone, student_head)
teacher = MultiCropWrapper(teacher_backbone, teacher_head)

# Move models to GPU
if args.gpu != -1:
    student = student.cuda(args.gpu)
    teacher = teacher.cuda(args.gpu)
else:
    # Move to CPU if no GPU available
    student = student.cpu()
    teacher = teacher.cpu()


# Freeze teacher parameters
for p in teacher.parameters():
    p.requires_grad = False

# Wrap student model with DistributedDataParallel if using distributed training
if dist.is_initialized():
    student = torch.nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
    # Teacher does not need DDP as it's updated via EMA


In [19]:
# Sampler for distributed training
# Shuffle is handled by the sampler in distributed mode
sampler = torch.utils.data.distributed.DistributedSampler(unimodal_dataset, shuffle=True) if dist.is_initialized() else None

# DataLoader
data_loader = DataLoader(
    unimodal_dataset,
    sampler=sampler, # Use sampler in distributed mode
    batch_size=args.batch_size,
    shuffle=(sampler is None), # Shuffle only if not using distributed sampler
    num_workers=args.num_workers,
    pin_memory=(args.gpu != -1), # Pin memory only if using GPU
    drop_last=True, # Drop the last incomplete batch
)
logger.info(f"DataLoader created with batch size {args.batch_size} and {len(data_loader)} batches per epoch.")




In [20]:
# --- Loss Function and Optimizer ---
dino_loss = DINOLoss(
    out_dim=dino_out_dim,
    teacher_temp=args.teacher_temp,
    student_temp=args.student_temp,
    center_momentum=args.center_momentum
)
if args.gpu != -1:
    dino_loss = dino_loss.cuda(args.gpu) # Move loss to GPU
else:
    dino_loss = dino_loss.cpu() # Move loss to CPU


# Optimizer
# Parameters to optimize: student backbone and student head
optimizer = optim.AdamW(student.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

# --- Schedulers ---
# Learning rate scheduler
lr_schedule = cosine_scheduler(
    base_value=args.learning_rate * args.world_size * args.batch_size / 256., # Scale LR linearly with batch size (common practice)
    final_value=args.learning_rate * 1e-6, # Example: decay to 1e-6 of base
    epochs=args.epochs,
    niter_per_ep=len(data_loader),
    warmup_epochs=args.warmup_epochs,
)

# Momentum scheduler for teacher EMA update
momentum_schedule = cosine_scheduler(
    base_value=0.996, # Start momentum
    final_value=1.0,  # End momentum
    epochs=args.epochs,
    niter_per_ep=len(data_loader),
)


In [21]:
# prompt: #print used GPU Memory

# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA is available. Using GPU:")
    # Get the current GPU device
    device = torch.cuda.current_device()
    print(f"Device Name: {torch.cuda.get_device_name(device)}")
    print(f"Total Memory: {torch.cuda.get_device_properties(device).total_memory / 1024**3:.2f} GB")

    # Get used and free memory
    # This requires synchronizing
    torch.cuda.empty_cache() # Clear cache to get more accurate free memory
    allocated_memory = torch.cuda.memory_allocated(device)
    cached_memory = torch.cuda.memory_reserved(device)
    print(f"Allocated Memory: {allocated_memory / 1024**3:.2f} GB")
    print(f"Cached Memory:    {cached_memory / 1024**3:.2f} GB")

    # This part is more complex to get *only* memory used by your *current* process
    # without external tools or more specific CUDA APIs.
    # The allocated_memory gives memory currently allocated by PyTorch.

else:
    print("CUDA is not available. Running on CPU.")



CUDA is available. Using GPU:
Device Name: Tesla T4
Total Memory: 14.74 GB
Allocated Memory: 0.81 GB
Cached Memory:    0.88 GB


In [22]:
# prompt: clear allocated memory completely to 0GB

import gc
import torch

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

print("Memory cleared.")
if torch.cuda.is_available():
  device = torch.cuda.current_device()
  allocated_memory = torch.cuda.memory_allocated(device)
  cached_memory = torch.cuda.memory_reserved(device)
  print(f"Allocated Memory: {allocated_memory / 1024**3:.2f} GB")
  print(f"Cached Memory:    {cached_memory / 1024**3:.2f} GB")


Memory cleared.
Allocated Memory: 0.81 GB
Cached Memory:    0.88 GB


In [None]:
# --- Training Loop ---
logger.info(f"Starting DINO unimodal pre-training for {args.epochs} epochs...")

for epoch in range(args.epochs):
    # Set sampler epoch for distributed training
    if dist.is_initialized():
        data_loader.sampler.set_epoch(epoch)

    # Train one epoch
    # Pass the current learning rate from the schedule to the optimizer
    for i, param_group in enumerate(optimizer.param_groups):
         param_group["lr"] = lr_schedule[epoch * len(data_loader) + i]


    train_stats = train_one_epoch(
        student=student,
        teacher=teacher,
        train_loader=data_loader,
        dino_loss=dino_loss,
        optimizer=optimizer,
        epoch=epoch,
        total_epochs=args.epochs,
        writer=writer,
        warmup_teacher_temp_epochs=args.warmup_teacher_temp_epochs,
        teacher_temp=args.teacher_temp, # Pass base teacher temp
        momentum_schedule=momentum_schedule,
        clip_grad=args.clip_grad
    )

    # Save checkpoint (only on rank 0)
    if args.rank == 0 and (epoch + 1) % args.save_freq == 0:
         # Save student backbone state dict
         # If using DDP, access the underlying module
         student_backbone_to_save = student.module.backbone if hasattr(student, 'module') else student.backbone
         save_checkpoint({
             'epoch': epoch + 1,
             'student_backbone_state_dict': student_backbone_to_save.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(),
             'args': args, # Save training arguments
             'train_stats': train_stats # Save epoch stats
         }, is_best=False, filename=f'dino_checkpoint_epoch_{epoch+1}.pth', save_dir=args.output_dir)


# Save final student backbone weights (only on rank 0)
if args.rank == 0:
    student_backbone_to_save = student.module.backbone if hasattr(student, 'module') else student.backbone
    torch.save(student_backbone_to_save.state_dict(), os.path.join(args.output_dir, 'dinov2_unimodal_backbone.pth'))
    logger.info(f"Final pretrained unimodal student backbone saved to {os.path.join(args.output_dir, 'dinov2_unimodal_backbone.pth')}")


# Close TensorBoard writer (only on rank 0)
if args.rank == 0 and writer:
    writer.close()

# Clean up distributed training (important when running in script mode, less critical in notebook unless explicitly using distributed)
# if dist.is_initialized():
#     dist.destroy_process_group()



Epoch: [0/20]
Epoch: [0/20] [0/1306]	eta: 1:33:59	time: 6.6577	data: 1.9790	loss: 11.0904 (11.0904)	teacher_temp: 0.0700 (0.0700)	momentum: 0.9960 (0.9960)
Epoch: [0/20] [10/1306]	eta: 0:24:46	time: 2.2941	data: 0.0003	loss: 11.0904 (11.0904)	teacher_temp: 0.0700 (0.0700)	momentum: 0.9960 (0.9960)
Epoch: [0/20] [20/1306]	eta: 0:23:29	time: 2.1917	data: 0.0002	loss: 10.5622 (11.0904)	teacher_temp: 0.0667 (0.0700)	momentum: 0.9486 (0.9960)
Epoch: [0/20] [30/1306]	eta: 0:22:49	time: 2.1467	data: 0.0002	loss: 7.1551 (11.0904)	teacher_temp: 0.0452 (0.0700)	momentum: 0.6426 (0.9960)
Epoch: [0/20] [40/1306]	eta: 0:22:52	time: 2.1680	data: 0.0002	loss: 5.4099 (11.0904)	teacher_temp: 0.0341 (0.0700)	momentum: 0.4859 (0.9960)
Epoch: [0/20] [50/1306]	eta: 0:22:59	time: 2.1958	data: 0.0002	loss: 4.3492 (11.0904)	teacher_temp: 0.0275 (0.0700)	momentum: 0.3906 (0.9960)
Epoch: [0/20] [60/1306]	eta: 0:22:40	time: 2.1843	data: 0.0002	loss: 3.6362 (11.0904)	teacher_temp: 0.0230 (0.0700)	momentum: 0.3266