In [None]:
# run these in the command line before executing this notebook
# ! pip install fastai; 
# ! pip install ipykernel torch torchaudio torchvision

# NOTE: Start with this article to understand!!
# https://walkwithfastai.com/Binary_Segmentation

In [None]:
from fastai import *
from fastai.vision.all import *
from IPython.display import clear_output, DisplayHandle

torch.cuda.is_available()

def update_patch(self, obj):
    clear_output(wait=True)
    self.display(obj)
DisplayHandle.update = update_patch

In [None]:
project_directory = Path.cwd().joinpath('../')
train_path = project_directory.joinpath('data/mitochondria_data/training')
# print(train_path)

test_path = project_directory.joinpath('data/mitochondria_data/testing')
# print(test_path)

test_image_paths = test_path / "images"
init_image_paths = train_path / "images"
test_mask_paths = test_path / "masks"
init_mask_paths = train_path / "masks"

def get_union_of_directories(dir1, dir2):
    # List of Path objects
    directories = [dir1, dir2]

    # List comprehension to get all .tif files from both directories
    res = [file for directory in directories for file in directory.glob('*.tif')]
    return res

image_files = get_union_of_directories(test_image_paths, init_image_paths)
mask_files = get_union_of_directories(test_mask_paths, init_mask_paths)

print(len(image_files))
print(len(mask_files))


In [None]:
input_image_size = Image.open(image_files[0]).size
print(f'Size of an image: {input_image_size}')
square_size = input_image_size[0]


print(np.unique(Image.open(image_files[1])))
print(np.unique(Image.open(mask_files[1])))

print(image_files[1])
print(mask_files[1])


In [None]:
# Define a function to get the mask file path from an image file path
import matplotlib.pyplot as plt
# Now, our mask isn't set up how fastai expects, in which the mask points are not all in a row. We need to change this:
# # We'll do this through an n_codes function. What this will do is run through our masks and build a set based on the unique values 
# present in our masks. 
# From there we will build a dictionary that will replace our points once we load in the image
def n_codes(fnames, is_partial=True):
  "Gather the codes from a list of `fnames`, full file paths"
  vals = set()
  if is_partial:
    random.shuffle(fnames)
    fnames = fnames[:10]
  for fname in fnames:
    msk = np.array(PILMask.create(fname))
    for val in np.unique(msk):
      if val not in vals:
        vals.add(val)
  vals = list(vals)
  p2c = dict()
  for i,val in enumerate(vals):
    p2c[i] = vals[i]
  return p2c

p2c = n_codes(mask_files)

print(p2c)


In [None]:
image_file = image_files[0]
image_file.parent.parent.parent

In [None]:
def get_mask_file(image_file, p2c):
    # this is the base path
    base_path = image_file.parent.parent.parent
    # get training or testing from here
    first_name = re.findall(string=image_file.name, pattern=r"^[training|testing|]*")[0]
    # get the sample number
    nums = re.findall(string=image_file.name, pattern=r"\d+_\d+_\d+")[0]  # remove from list
    # put the whole thing together
    str_name = f'{first_name}_groundtruth_' + nums + image_file.suffix
    # attach it to the correct path
    mask_path = (base_path / first_name / 'masks' / str_name)
    # convert to an array (mask)
    msk = np.array(PILMask.create(mask_path))
    mx = np.max(msk)
    # find all the possible values in the mask (0,255)
    for i, val in enumerate(p2c):
        msk[msk==p2c[i]] = val
    return PILMask.create(msk)


def get_y(o): 
    return get_mask_file(o, p2c)

fig, ax = plt.subplots(1, 2, figsize=(5, 5))
im = PILImage.create(image_files[0])
im.show(ax[0])
ax[0].set_title("Image")

msk = get_y(image_files[0])
msk.show(ax[1])
ax[1].set_title("Mask")


In [None]:
def show_mask(img_fn):
    img = PILImage.create(img_fn)
    msk = PILMask.create(get_mask_file(img_fn, p2c))

    fig, ax = plt.subplots(1, 2, figsize=(5, 5))

    # Show image on left axis
    img.show(ax=ax[0])

    # Show mask on right axis
    msk.show(ax=ax[1], alpha=1)

    print(f"Unique values in the mask: {np.unique(np.array(msk))}")


# Show a few masks and their unique values
for image_path in image_files[:5]:
    show_mask(image_path)


In [None]:

# Note: get_image_files takes a path object. If you already ahve a list of paths, then just pass that as a lambda function
# set up the datablock 
mitos = DataBlock(blocks=(ImageBlock, MaskBlock(codes=np.array(['not_mito', 'mito']))),
                  splitter=RandomSplitter(valid_pct=0.2, seed=42),
                  get_items=lambda x: image_files,
                  get_y=get_y, #get_mask_file,  # Use the custom getter function for masks
                  item_tfms=[RandomResizedCrop(512, min_scale=0.3)],  # this is super important - upscale the crop at each batch randomly
                  batch_tfms=[*aug_transforms(size=224,
                                              flip_vert=True,
                                              max_rotate=30,
                                              min_zoom=0.8,
                                              max_zoom=1.15,
                                              max_warp=0.3)],
                  n_inp=1,
                  )

batch_size = 16
dls = mitos.dataloaders(image_files,  bs=batch_size)
mitos.summary(image_files, bs=batch_size)


In [None]:
dls.show_batch(max_n=batch_size, vmin=0, vmax=1, figsize=(batch_size/2,batch_size/2))

In [None]:
x, y = dls.one_batch()
print("Input shape:", x.shape)
print("Target shape:", y.shape)


In [None]:
opt=ranger
learn = unet_learner(dls, resnet34, metrics=Dice, opt_func=opt)

# learn.summary()

In [None]:
# find a good learning rate 
lr_min, lr_steep, lr_valley, lr_slide = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
lr_min, lr_steep, lr_valley, lr_slide

In [None]:
# short test to see if we're good

# # Pretty much always gonna use the default, valley, but useful to see the other suggestions
# learn.fit_one_cycle(3, lr_valley)
# learn.show_results(max_n=3, figsize=(2,3))


# pick from the graph above or use to make your own LR
my_lr = 1e-3

n_epochs = 50
# optimizer 
opt = ranger
# feed the model the dataloader and the backbone e.g. resnet, with its metrics and optimizer
learn = unet_learner(dls, resnet34, metrics=Dice, opt_func=opt)
# fine tune it

print(f"""Learning rate = {my_lr}
      Epochs = {n_epochs}""")
print(f'Employing loss function: {learn.loss_func}')


In [None]:
# fine tune the model
learn.fine_tune(n_epochs, my_lr,
                cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.001, patience=5))

In [None]:
learn.recorder.plot_loss()

In [None]:
#1 This is the training loss at the end of the last epoch. 
# The training loss measures the difference between the model's predictions 
# and the true target values. A lower value indicates that the model is
# performing better on the training data.

#2 This is the value of the Dice coefficient (or Sørensen–Dice coefficient) metric 
# at the end of the last epoch. The Dice coefficient is a performance metric 
# commonly used for image segmentation tasks, measuring the similarity between two sets. 
# In this case, it compares the predicted segmentation mask and the ground truth mask.
# The Dice coefficient ranges from 0 to 1, where a higher value indicates better performance
# (a value of 1 means the predicted mask and ground truth mask are identical).

rec_vals = learn.recorder.values[-1]
print(len(rec_vals))
print(f'Training loss: {rec_vals[0]}')
print(f'Validation loss: {rec_vals[1]}')
print(f'Dice Coef: {rec_vals[2]}')


In [None]:
learn.show_results(max_n=6,figsize=(2,6))


In [None]:
# save the model 
os.makedirs("../segmentation_model_dir", exist_ok =True)
fname = f"dynamic_unet_seg_model-e{n_epochs}_b{batch_size}.pkl"
print(fname)

output_file = project_directory.joinpath(Path("segmentation_model_dir").joinpath(fname))

# saves the whole model, not just the weights
learn.export(output_file)


In [4]:
import gc 
import torch
# # Delete the objects
# del learn
# del dls
# del mitos


# Call the garbage collector
gc.collect()

# clear the GPU cache
torch.cuda.empty_cache()
