In [1]:
#get training dataset
import fiftyone as fo
from fiftyone.utils.data.importers import ImageSegmentationDirectoryImporter
from pathlib import Path

data_path = Path("datasets/optotaxis")

imfolder = data_path/"images";#'C:\\Users\\miner\\OneDrive - University of North Carolina at Chapel Hill\\Bear Lab\\optotaxis calibration\\data\\segmentation_iteration_testing\\iter4\\round1\\images'
maskfolder = data_path/"masks";#Path('C:/Users/miner/OneDrive - University of North Carolina at Chapel Hill/Bear Lab/optotaxis calibration/data/segmentation_iteration_testing/iter4/round1/masks')
imp = ImageSegmentationDirectoryImporter(data_path=imfolder,labels_path=maskfolder)
dataset = fo.Dataset.from_importer(imp)

import fiftyone.utils.random as four

four.random_split(dataset, {"train": 0.8, "val": 0.2})

 100% |███████████████████| 56/56 [78.7ms elapsed, 0s remaining, 711.6 samples/s]  


In [2]:
dataset

Name:        2025.04.20.22.40.30
Media type:  image
Num samples: 56
Persistent:  False
Tags:        []
Sample fields:
    id:               fiftyone.core.fields.ObjectIdField
    filepath:         fiftyone.core.fields.StringField
    tags:             fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)
    metadata:         fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.ImageMetadata)
    created_at:       fiftyone.core.fields.DateTimeField
    last_modified_at: fiftyone.core.fields.DateTimeField
    ground_truth:     fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Segmentation)

In [None]:
import json
from imageio.v3 import imread
import numpy as np
from skimage.exposure import rescale_intensity

meanstdpath = data_path/"meanstd.json"
if meanstdpath.exists():
    with open(meanstdpath,"r") as f:
        MEAN,STD = json.load(f)
else:
    from running_stats import RunningStats
    def get_im_mean_std(dataset:fo.Dataset):
        stats = RunningStats(n=np.array(0),m=np.array(0),s=np.array(0))
        for samp in dataset.iter_samples(progress=True):
            im = imread(samp.filepath)

            #image transforms - obviated by segmentation pipeline
            im = rescale_intensity(im)
            if im.dtype != np.uint8:
                im = np.astype(im/255,np.uint8)
            
            stats += im;

        return stats.mean, stats.std
    MEAN,STD = get_im_mean_std(dataset)
    print(MEAN,STD)
    meanstdpath = data_path/"meanstd.json"
    with open(meanstdpath,"w") as f:
        json.dump((MEAN,STD),f)

In [None]:
import numpy as np
import torch
from PIL import Image
from imageio.v3 import imread
# from bidict import bidict


class FOTorchSegmentationDataset(torch.utils.data.Dataset):
    """A class to construct a PyTorch dataset from a FiftyOne dataset containing segmantic segmentation masks.
    
    Args:
        fiftyone_dataset: a FiftyOne dataset or view that will be used for 
            training or testing
        transforms (None): a list of PyTorch transforms to apply to images 
            and targets when loading
        gt_field ("ground_truth"): the name of the field in fiftyone_dataset 
            that contains the desired labels to load
        classes (None): a list of class strings that are used to define the 
            mapping between class names and indices. If None, it will use 
            all classes present in the given fiftyone_dataset.
    """

    def __init__(
        self,
        fiftyone_dataset:fo.Dataset,
        classes:list[str]|str,
        transforms=None,
        gt_field="ground_truth",
    ):
        self.samples = fiftyone_dataset
        self.transforms = transforms
        self.gt_field = gt_field

        self.img_paths = self.samples.values("filepath")
        self.mask_paths = self.samples.values(f"{gt_field}.mask_path")
        self.classes = [classes] if isinstance(classes,str) else classes

        if self.classes[0] != "background":
            self.classes = ["background"] + self.classes

        self.labels_map = {i: c for i, c in enumerate(self.classes)}
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = imread(img_path)

        img = np.stack([img,img,img],axis=-1)

        mask_path = self.mask_paths[idx]
        mask = imread(mask_path)
        mask = mask > 0 #reads mask 0,255 otherwise


        if self.transforms is not None:
            t = self.transforms(image=img, mask=mask)
            timg = t["image"]
            tmask = t["mask"]

        return timg, tmask, img, mask

    def __len__(self):
        return len(self.img_paths)

    def get_classes(self):
        return self.classes

In [6]:
import albumentations as A
transform = A.Normalize(mean=MEAN, std=STD)
train_dataset = FOTorchSegmentationDataset(dataset.match_tags("train"),classes="cytoplasm",transforms=transform)
test_dataset = FOTorchSegmentationDataset(dataset.match_tags("val"),classes="cytoplasm",transforms=transform)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
len(test_dataset)

11

In [None]:
#turn training dataset into something mask2former can see
from transformers import MaskFormerImageProcessor

# Create a preprocessor
mask2processor = MaskFormerImageProcessor(reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)

##mask2PROCESSOR USAGE: Call the preprocessor on a *batch* of images and segmentation maps. Images are RGB, masks are segmentation #s.
# eg. mask2processor(image_batch,mask_batch)

In [None]:
import numpy as np
from torch.utils.data import DataLoader

def collate_fn(batch):
    inputs = list(zip(*batch))
    images = inputs[0]
    segmentation_maps = inputs[1]
    # this function pads the inputs to the same size,
    # and creates a pixel mask
    # actually padding isn't required here since we are cropping
    # print(np.array(images).shape)
    batch = mask2processor( #creates a dict with various inputs to the model. Key (get it) arguments: "pixel_values" - raw image input. "mask_labels" - per-class binary masks. "class_labels" - integer label per mask
        images,
        segmentation_maps=segmentation_maps,
        return_tensors="pt",
    )

    batch["original_images"] = inputs[2]
    batch["original_segmentation_maps"] = inputs[3]
    
    return batch

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

In [10]:
from transformers import MaskFormerForInstanceSegmentation
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-ade",
                                                          id2label=train_dataset.labels_map,
                                                          ignore_mismatched_sizes=True)

Some weights of MaskFormerForInstanceSegmentation were not initialized from the model checkpoint at facebook/maskformer-swin-base-ade and are newly initialized because the shapes did not match:
- class_predictor.weight: found shape torch.Size([151, 256]) in the checkpoint and torch.Size([3, 256]) in the model instantiated
- class_predictor.bias: found shape torch.Size([151]) in the checkpoint and torch.Size([3]) in the model instantiated
- criterion.empty_weight: found shape torch.Size([151]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
import evaluate

metric = evaluate.load("mean_iou")

In [12]:
#Fine-Tuning Loop (Semantic)
import torch
from tqdm.auto import tqdm

debug_cpu = False
device = torch.device("cuda" if not debug_cpu and torch.cuda.is_available() else "cpu")
print("torch device:",device)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

running_loss = 0.0
num_samples = 0


torch device: cuda


In [None]:
for epoch in tqdm(range(100)):
  print("Epoch:", epoch)
  model.train()
  for idx, batch in enumerate(tqdm(train_dataloader)):
      # Reset the parameter gradients
      optimizer.zero_grad()

      # print("mask:",batch["mask_labels"])
      # print("class:",batch["class_labels"])

      # Forward pass
      outputs = model(
          pixel_values=batch["pixel_values"].to(device,torch.float),
          mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
          class_labels=[labels.to(device) for labels in batch["class_labels"]],
      )

      # Backward propagation
      loss = outputs.loss
      loss.backward()

      batch_size = batch["pixel_values"].size(0)
      running_loss += loss.item()
      num_samples += batch_size

      if idx % 100 == 0:
        print("Loss:", running_loss/num_samples)

      # Optimization
      optimizer.step()

  model.eval()
  for idx, batch in enumerate(tqdm(test_dataloader)):
    if idx > 5:
      break

    pixel_values = batch["pixel_values"]
    
    # Forward pass
    with torch.no_grad():
      outputs = model(pixel_values=pixel_values.to(device))

    # get original images
    original_images = batch["original_images"]
    target_sizes = [(image.shape[0], image.shape[1]) for image in original_images]
    # predict segmentation maps
    predicted_segmentation_maps = mask2processor.post_process_semantic_segmentation(outputs,
                                                                                  target_sizes=target_sizes)

    # get ground truth segmentation maps
    ground_truth_segmentation_maps = batch["original_segmentation_maps"]

    metric.add_batch(references=ground_truth_segmentation_maps, predictions=predicted_segmentation_maps)
  
  # NOTE this metric outputs a dict that also includes the mIoU per category as keys
  # so if you're interested, feel free to print them as well
  print("Mean IoU:", metric.compute(num_labels = 1, ignore_index = 0)['mean_iou'])

Epoch: 0


  0%|          | 0/23 [00:00<?, ?it/s]

Loss: 1.1711831092834473


100%|██████████| 23/23 [09:56<00:00, 25.94s/it]
100%|██████████| 6/6 [00:28<00:00,  4.77s/it]
  area_label = np.histogram(label, bins=num_labels, range=(0, num_labels - 1))[0]
  all_acc = total_area_intersect.sum() / total_area_label.sum()
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  metrics["mean_iou"] = np.nanmean(iou)
  metrics["mean_accuracy"] = np.nanmean(acc)


Mean IoU: nan
Epoch: 1


  4%|▍         | 1/23 [00:26<09:39, 26.35s/it]

Loss: 0.5246903249557983


  4%|▍         | 1/23 [00:53<19:36, 53.48s/it]


KeyboardInterrupt: 

In [14]:
model.eval()
for idx, batch in enumerate(tqdm(test_dataloader)):
    if idx > 5:
      break

    pixel_values = batch["pixel_values"]
    
    # Forward pass
    with torch.no_grad():
      outputs = model(pixel_values=pixel_values.to(device,torch.float))

    # get original images
    original_images = batch["original_images"]
    target_sizes = [(image.shape[0], image.shape[1]) for image in original_images]
    # predict segmentation maps
    predicted_segmentation_maps = mask2processor.post_process_semantic_segmentation(outputs,
                                                                                  target_sizes=target_sizes)

    # get ground truth segmentation maps
    ground_truth_segmentation_maps = batch["original_segmentation_maps"]

    metric.add_batch(references=ground_truth_segmentation_maps, predictions=predicted_segmentation_maps)

100%|██████████| 6/6 [00:33<00:00,  5.57s/it]


In [15]:
print("Mean IoU:", metric.compute(num_labels = 2, ignore_index = 0)['mean_iou'])

Mean IoU: 1.0


  area_label = np.histogram(label, bins=num_labels, range=(0, num_labels - 1))[0]
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label


In [16]:
np.unique(batch["original_segmentation_maps"])

array([False,  True])

In [17]:
predicted_segmentation_maps

[tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0')]

In [18]:
torch.unique(predicted_segmentation_maps[0])

tensor([1], device='cuda:0')