Prepare the Environment
First, you need to import the libraries you'll need. You can also import them as you find them necessary.

Task 5.7.1: Import the libraries you'll need. You can update this cell and re-run it as you discover more things later.

In [1]:
# Import the libraries that you need
from pathlib import Path

import medigan
import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid
from tqdm.notebook import tqdm



Since GPUs are available on your machine, make sure you handle placing the tensors on the proper device.

Task 5.7.2: Check the availability of GPUs on this machine and determine the correct device name. Store the device name in the variable device.

In [2]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print(f"Using {device} device.")


Using cuda device.


Getting Images
We don't have a nice collection of images to work with. Instead, we'll need to have Medigan generate them for us. Let's find one that does what we need.

Task 5.7.3: Find a GAN in Medigan that produces mammogram images with the roi ("region of interest") marked, and save its id to model_id. You'll need to create the entry-point for Medigan as well.

In [12]:
# Create the connection to the Medigan generators
generators = medigan.Generators()

# Find the models that match what we want
values = ["mammogram", "roi"]
models = generators.find_matching_models_by_values(values)
model_id = models[0].model_id

print(model_id)


00004_PIX2PIX_MMG_MASSES_W_MASKS


You should have only gotten one model back. Let's make sure this is doing what we want.

Task 5.7.4: Get the configuration for the model you found using the model_id. Save the model configuration to the model_config variable.

In [13]:
model_config = generators.get_config_by_id(model_id=model_id)

print(f"Model keys: {model_config.keys()}")


Model keys: dict_keys(['execution', 'selection', 'description'])


As we saw before, the configurations have a lot of information in them. Let's only look at the parts we're interested in.

Task 5.7.5: Select the generates, tags, height, and width keys from the selection key in the model configuration. Save the result to model_info, a dictionary that maps the selected keys to their values.

In [14]:
vals = ["generates", "tags", "height", "width"]
model_info = {x: model_config["selection"][x] for x in vals}

model_info

{'generates': ['regions of interest',
  'ROI',
  'mammograms',
  'patches',
  'full-field digital mammograms'],
 'tags': ['Mammogram',
  'Mammography',
  'Digital Mammography',
  'Full field Mammography',
  'Full-field Mammography',
  'pix2pix',
  'Pix2Pix',
  'Mass segmentation',
  'Breast lesion'],
 'height': 256,
 'width': 256}

Our next step is to generate the data. We'll need to define places where we'll store this data.

Task 5.7.6: Create a path to the directory output/sample_mammogram. Use Pathlib.

In [15]:
output_dir = Path("output")
sample_dir = output_dir / "sample_mammogram"

# Create the directory with mkdir
sample_dir.mkdir(parents=True, exist_ok=True)

print(sample_dir)


output\sample_mammogram


If something goes wrong, you can delete the whole directory and start over:

In [9]:
!rm -Rf output/

'rm' is not recognized as an internal or external command,
operable program or batch file.


What do mammogram images look like? Let's generate some.
Task 5.7.7: Generate 4 images of mammograms with region of interest using our selected GAN. Save them to sample_dir.

In [16]:
generators.generate(
    model_id=model_id,
    num_samples=4,
    output_path=sample_dir
)


  0%|          | 0/1 [00:00<?, ?it/s]ERROR:root:Error while trying to initialize pix2pix: No module named ':'
ERROR:root:Error while trying to generate 4 images with model models/00004_PIX2PIX_MMG_MASSES_W_MASKS/pix2pix_mask_to_mass_model.pth: No module named ':'
  0%|          | 0/1 [00:00<?, ?it/s]
ERROR:root:00004_PIX2PIX_MMG_MASSES_W_MASKS: Error while trying to generate images with model models/00004_PIX2PIX_MMG_MASSES_W_MASKS/pix2pix_mask_to_mass_model.pth: No module named ':'


ModuleNotFoundError: No module named ':'

Task 5.7.8: Finish up the missing parts of the function view_images. Invoke the function and store the resulting image in the variable sample_images.

In [None]:
def view_images(directory, num_images=4, glob_rule="*.jpg"):
    """Displays a sample of images in the given directory
    They'll display in rows of 4 images
    - directory: which directory to look for images
    - num_images: how many images to display (default 4, for one row)
    - glob_rule: argument to glob to filter images (default "*" selects all)"""

    image_list = list(directory.glob(glob_rule))  # REMOVERHS
    num_samples = min(num_images, len(image_list))
    images = [read_image(str(f)) for f in sorted(image_list)[:num_samples]]  # REMOVERHS
    grid = make_grid(images, nrow=4, pad_value=255.0)
    return torchvision.transforms.ToPILImage()(grid)


sample_images = view_images(sample_dir)
sample_images

If we're going to use these generated images to train our model, we'll need a data loader. Medigan can provide one for us directly. We'll make both a training set and a validation set.

Task 5.7.9: Use Medigan to create a data loader of training data. Make 50 images, in batches of 4, with shuffling turned on. Don't forget to set prefetch_factor=None. This may take a few minutes to run.

In [None]:
train_dataloader = generators.get_as_torch_dataloader(
    model_id=model_id, num_samples=50, batch_size=4, shuffle=True, prefetch_factor=None
)

sample_batch = next(iter(train_dataloader))
print(f"Training data loader with keys: {sample_batch.keys()}")


Task 5.7.10: Use Medigan to create a data loader of validation data. This time only create 
 images and don't shuffle. All other settings should be the same.

In [None]:
val_dataloader = generators.get_as_torch_dataloader(
    model_id=model_id, num_samples=30, batch_size=4, shuffle=False, prefetch_factor=None
)

val_batch = next(iter(train_dataloader))
shape = val_batch["sample"].shape
dtype = val_batch["sample"].dtype
print(f"Validation image with data shape {shape} and type {dtype}")


This data isn't quite what we need for PyTorch. This is particularly apparent with the mask.

Task 5.7.11: Get the shape and type for the mask component of the val_batch.

In [None]:
shape = val_batch["mask"].shape
dtype = val_batch["mask"].dtype

print(f"Validation mask with data shape {shape} and type {dtype}")


We'll need to fix this, we need the images to be [3, 256, 256] and the mask to be [1, 256, 256], and both to have type float32. This function converts the type and adds the channels.

In [None]:
def convert_to_torch_image(tensor, color=False):
    tensor_float = tensor.type(torch.float32)
    grayscale = tensor_float.unsqueeze(1)
    if color:
        return grayscale.repeat(1, 3, 1, 1)
    else:
        return grayscale

Task 5.7.12: Run this function on the mask component of val_batch and get the new shape and type.

In [None]:
mask_converted = convert_to_torch_image(val_batch["mask"])

shape = mask_converted.shape
dtype = mask_converted.dtype

print(f"Validation mask with data shape {shape} and type {dtype}")

Task 5.7.13: Run this function on the sample component of val_batch and get the new shape and type. You'll need to specify color=True to get RGB images.

In [None]:
sample_converted = convert_to_torch_image(val_batch["sample"], color=True)

shape = sample_converted.shape
dtype = sample_converted.dtype

print(f"Validation mask with data shape {shape} and type {dtype}")


Creating a Model
Now that we have our data, we'll want to train a model. This is a segmentation problem, and we found a good pre-trained model for that in one of the lessons. Let's use that one.

Task 5.7.14: Load the pre-trained deeplabv3_resnet50 model. Use the COCO_WITH_VOC_LABELS_V1 weights.

In [None]:
pretrained_weights = (
    torchvision.models.segmentation.DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1
)

model = torchvision.models.segmentation.deeplabv3_resnet50(weights=pretrained_weights)

print("Model components:")
for name, part in model.named_children():
    print("\t" + name)

As before, we'll need to replace the final layer with one that does what we need. But first we should see what we get from the model. This model gives a dictionary with two outputs: out and aux. We only want out

Task 5.7.15: Run the model on the sample part of our sample_batch, and get the shape of the out part of the result. You'll need to convert the data to the correct format.

In [None]:
sample_converted = convert_to_torch_image(sample_batch["sample"], True)
model_result = model(sample_converted)
model_out = model_result["out"]
out_shape = model_out.shape

out_shape


This doesn't match our masks. It's the right height and width, but the wrong number of channels. We'll replace that last layer. It's a convolution, but not the one we need.

In [None]:
model.classifier[-1]

Task 5.7.16: Replace the last layer in the classifier with a convolution that gives the correct output shape to match our mask.

In [None]:
new_final_layer = torch.nn.Conv2d(256, 1, kernel_size=(1, 1))
model.classifier[-1] = new_final_layer

new_out = model(sample_converted)["out"]
print(f"New model output shape: {new_out.shape}")
print(f"Mask shape: {mask_converted.shape}")


We're also going to need a loss function and an optimizer for when we train. We'll use the same BCEWithLogitsLoss we used in the lesson, and an Adam optimizer.
Task 5.7.17: Create the loss function and the optimizer. Save them to loss_fun and opt respectively.

In [None]:
loss_fun = torch.nn.BCEWithLogitsLoss()
opt = torch.optim.Adam(params=model.parameters())


opt

Training
We'll need to train the model we just created. It'll be quite similar to training we've done in the past, but slightly adjusted due to the need to adjust the image and mask shapes, and getting the out part of the output.

We'll build up a few functions to put this together. First, we'll deal with calculating the loss. The function below outlines this, but the details are missing.

Task 5.7.18: Fill in the missing parts of the compute_loss function.

In [None]:
def compute_loss(batch, model, loss_fun):
    # Extract the sample and mask from the batch
    sample = batch["sample"]
    mask = batch["mask"]

    # Convert the sample and mask to the correct shape and type
    sample_correct = convert_to_torch_image(sample, color=True)
    mask_correct = convert_to_torch_image(mask)

    # move the sample and mask to the GPU
    sample_gpu = sample_correct.to(device)
    mask_gpu = mask_correct.to(device)

    # Run the model on the sample and select the classifier (out key)
    output = model(sample_gpu)["out"]

    # Compute the loss
    loss = loss_fun(output, mask_gpu)

    return loss


We'll run this on the sample batch to make sure it's working. Note we need the model on the GPU as well.

In [None]:
model.to(device)

compute_loss(sample_batch, model, loss_fun)

With that working, we can build the training for one epoch. We'll loop over the data loader and step our model for each batch. We'll also compute the validation loss.

Task 5.7.19: Fill in the missing parts of the function.

In [None]:
def train_epoch(model, train_dataloader, val_dataloader, loss_fun, opt):
    model.train()

    # Training part
    train_loss = 0.0
    train_count = 0
    for batch in tqdm(train_dataloader):
        # zero the gradients on the optimizer
        opt.zero_grad()

        # compute the loss for the batch
        loss = compute_loss(batch, model, loss_fun)

        # Compute the backward part of the loss and step the optimizer
        loss.backward()
        opt.step()

        train_loss += loss.item()
        train_count += 1

    # Validation part
    val_loss = 0.0
    val_count = 0
    for batch in tqdm(val_dataloader):
        # compute the loss for each batch
        loss = compute_loss(batch, model, loss_fun)

        val_loss += loss.item()
        val_count += 1

    return train_loss / train_count, val_loss / val_count


Let's check this worked by running one epoch. It'll return the two losses.

In [None]:
train_epoch(model, train_dataloader, val_dataloader, loss_fun, opt)

We're ready to go. We can call this in a loop to train our model.
Task 5.7.20: Load the pretrained model.

We have trained the model for 9 more epochs. Now it's your time, load the model that we have saved in the file model_trained.pth.

In [None]:
model = torch.load('model_trained.pth').to(device)


Testing the Model
Let's see how well we did. First, we'll need some new data to test on.

Task 5.7.21: Use Medigan to create a data loader of test data. This time only create 
 images and don't shuffle. All other settings should be the same as our earlier loaders.

In [None]:
test_dataloader = generators.get_as_torch_dataloader(
    model_id=model_id, num_samples=8, batch_size=4, shuffle=False, prefetch_factor=None
)

test_batch = next(iter(test_dataloader))

print(f"Data loader images in batches of {test_batch['sample'].size(0)}")


Task 5.7.22: Run the model on the test_batch, and save the out to test_result. You'll need to first convert the sample to the correct shape and type, and move it to the GPU.

In [None]:
corrected_sample = convert_to_torch_image(test_batch["sample"], color=True)
corrected_sample = corrected_sample.to(device)

test_result = model(corrected_sample)["out"]

test_result.shape


One more thing, our result was never put through an activation function. We need it to be image-like, which we can get by applying the sigmoid.

Task 5.7.23: Apply the sigmoid function to the test_result. Save the output to test_mask_model.

In [None]:
test_mask_model = torch.sigmoid(test_result)


And finally, we can plot our results and see how we did. This function expects tensors that are already in the right shape [batch_size, channels, height, width].



In [None]:
def plot_images_from_tensor(tensor):
    grid = make_grid(tensor, nrow=4, pad_value=1.0)
    return torchvision.transforms.ToPILImage()(grid)

Task 5.7.24: Call the plot function to look at the input sample, the mask, and our final result. How did we do?

In [None]:
# Plot the sample part of the test_batch
sample_test_batch_plot = plot_images_from_tensor(convert_to_torch_image(test_batch["sample"]))
sample_test_batch_plot

# Plot the mask part of the test_batch
mask_test_batch_plot = plot_images_from_tensor(convert_to_torch_image(test_batch["mask"]))
mask_test_batch_plot

# Plot the result of the model running
model_result_plot = plot_images_from_tensor(test_mask_model)
model_result_plot


How did your model do? You can try to train longer to see if you get better results, or adjust the learning rate in Adam, or pull a new test batch to see a broader range of results.