In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
from torchvision import models

# Set the device to GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Import Models
Load the VGG-16, ResNet18 and MobileNetV2 models which were created with weights learned from ImageNet, then adusted using VinDr-Mammo.

In [None]:
# Mount google drive to access the pre-trained models
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Define paths to model state files from VinDr-Mammo training
vgg16_model_path = 'path/to/vgg16_model.pth'
resnet18_model_path = 'path/to/resnet18_model.pth'
mobilenetv2_model_path = 'path/to/mobilenetv2_model.pth'


In [None]:
# Initialize models without pre-trained weights from ImageNet or VinDr-Mammo
# Adapt for binary classification

# VGG-16
vgg16_model = models.vgg16(weights=None)
vgg16_model.classifier[6] = nn.Linear(4096, 2)

#ResNet-18
resnet18_model = models.resnet18(weights=None)
resnet18_model.fc = nn.Linear(resnet18_model.fc.in_features, 2)

#MobileNetV2
mobilenetv2_model = models.mobilenet_v2(weights=None)
mobilenetv2_model.classifier[1] = nn.Linear(1280, 2)

#MobileNetV2FixedWeights
mobilenetv2_fix_birads_model = models.mobilenet_v2(weights=None)
mobilenetv2_fix_birads_model.classifier[1] = nn.Linear(1280, 2)

In [None]:
# Load models with weights from state files, ensuring that the model is also loaded for the correct device (should be GPU)
vgg16_model.load_state_dict(torch.load(vgg16_model_path, map_location=device))
resnet18_model.load_state_dict(torch.load(resnet18_model_path, map_location=device))
mobilenetv2_model.load_state_dict(torch.load(mobilenetv2_model_path, map_location=device))

In [None]:
# Move the models to the GPU
vgg16_model.to(device)
resnet18_model.to(device)
mobilenetv2_model.to(device)

In [None]:
# Check if the models were successfully moved to the GPU
print(next(vgg16_model.parameters()).device)
print(next(resnet18_model.parameters()).device)
print(next(mobilenetv2_model.parameters()).device)

## Create Datset and DataLoader for test data
Before setting the model to evaluation mode we will first need to create the dataset and dataloader required by the model to accept our VinDr-Mammo test data.

In [None]:
# Install google cloud storage package if you haven't already
!pip install google-cloud-storage==2.0.0

In [None]:
# Upload GCS key to file system
from google.colab import files
uploaded = files.upload()

In [None]:
# Use a service account key for long-life credentials
import os
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/content/my-gcs-key.json"  # Replace with your service account key path

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from google.cloud import storage
from PIL import Image
import pandas as pd
from io import BytesIO

client = storage.Client()
bucket_name = 'vindr-mammo-dataset'  # Replace with your bucket name
bucket = client.bucket(bucket_name)

# Define the custom dataset class
class VindrMammoDataset(Dataset):
    def __init__(self, bucket, dataframe, transform=None):
        self.bucket = bucket        # Google Cloud Storage bucket
        self.dataframe = dataframe  # Dataframe containing image filenames and bi-rads labels
        self.transform = transform  # Transform for data augmentation

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_id = self.dataframe.iloc[idx, 0]
        breast_birads = self.dataframe.iloc[idx, 1]

        # Extract the number from the BI-RADS rating
        birads_value = int(breast_birads.split()[-1])
        # Map BI-RADS values to binary classes
        if birads_value in [1, 2, 3]:  # Benign
            label = 0
        else:  # Malignant (BI-RADS 4, 5)
            label = 1

        # Concatenate the path to the image file in GCS bucket
        img_path = f"images/{image_id}.png"

        # Load the image from the GCS bucket
        blob = self.bucket.blob(img_path)
        image_data = blob.download_as_bytes()
        image = Image.open(BytesIO(image_data)).convert('RGB') # Ensure that it's RGB

        # Apply transformations (if there are any)
        if self.transform:
            image = self.transform(image)

        return image, label


# Define the image transformations for dynamic preprocessing as data is loaded
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Ensure that images are 256x256
    transforms.ToTensor(),  # Convert images to PyTorch tensors
])

# Load the finding_annotations.csv file from GCS
csv_blob = bucket.blob("finding_annotations.csv")
csv_data = csv_blob.download_as_text()
annotations_df = pd.read_csv(BytesIO(csv_data.encode()))

# Filter the DataFrame for the test sets
test_df = annotations_df[annotations_df['split'] == 'test']

# Create a new DataFrame with only the necessary columns
test_df = test_df[['image_id', 'breast_birads']]

# Reset the index for the DataFrame
test_df = test_df.reset_index(drop=True)

# Create the datasets
test_dataset = VindrMammoDataset(bucket=bucket, dataframe=test_df, transform=transform)

# Create the data loader
valid_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

## Evaluate MobileNetV2
For testing purposes we will first evaluate MobileNetV2 because it has reduced training times.

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from tqdm import tqdm  # Import tqdm for progress bar

# Set model to evaluation mode
mobilenetv2_fix_birads_model.eval()

# Initialize lists to store predictions and true labels
all_preds = []
all_labels = []

# Disable gradient computation
with torch.no_grad():
    # Wrap the valid_loader with tqdm for progress bar
    for images, labels in tqdm(valid_loader, desc="Evaluating", total=len(valid_loader)):
        images, labels = images.to(device), labels.to(device)

        # Make predictions
        outputs = mobilenetv2_fix_birads_model(images)
        _, preds = torch.max(outputs, 1)

        # Store the predictions and true labels
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Compute the confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot the confusion matrix
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Benign", "Malignant"], yticklabels=["Benign", "Malignant"])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix for mobilenetv2_fix_birads_model')
plt.show()
