In [3]:
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from PIL import Image
import requests
import torch
from utils.dataset_utils import CadisDataset,Cataract101Dataset
import torch.optim as optim
import numpy as np
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torch.nn.functional as F
import evaluate
from torch.utils.tensorboard import SummaryWriter
import os

In [None]:
!pip install evaluate
!pip install tensorboard

In [6]:
out_dir="outputs/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
%load_ext tensorboard
%tensorboard --logdir '{out_dir}'runs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 2817), started 0:08:32 ago. (Use '!kill 2817' to kill it.)

# Initializing Datasets and Device

In [16]:
cataract_101_dataset=Cataract101Dataset(root_folder="data/cataract-101", split="train")
total_size=len(cataract_101_dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size
generator1=torch.Generator().manual_seed(42)
cataract_101_train_dataset, cataract_101_val_dataset, cataract_101_test_dataset = random_split(cataract_101_dataset, [train_size, val_size, test_size],generator=generator1)

cadis_train_dataset = CadisDataset(root_folder="data/cadis", split="train")
cadis_val_dataset = CadisDataset(root_folder="data/cadis", split="val")
cadis_test_dataset = CadisDataset(root_folder="data/cadis", split="test")
len(cadis_train_dataset),len(cataract_101_train_dataset),len(cadis_val_dataset),len(cataract_101_val_dataset),len(cadis_test_dataset),len(cataract_101_test_dataset)

# TODO: merge A+B, rand(A)+B train and val sets

(3584, 674, 540, 84, 614, 85)

In [8]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cpu


# Initializing The M2F Model, Configs and The Dataloaders

In [9]:
num_classes=len(cadis_train_dataset.categories) # including background
image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic",ignore_index=255, reduce_labels=True)
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic",num_labels=num_classes-1,ignore_mismatched_sizes=True)

Some weights of Mask2FormerForUniversalSegmentation were not initialized from the model checkpoint at facebook/mask2former-swin-large-ade-semantic and are newly initialized because the shapes did not match:
- class_predictor.bias: found shape torch.Size([151]) in the checkpoint and torch.Size([23]) in the model instantiated
- class_predictor.weight: found shape torch.Size([151, 256]) in the checkpoint and torch.Size([23, 256]) in the model instantiated
- criterion.empty_weight: found shape torch.Size([151]) in the checkpoint and torch.Size([23]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# Training
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
learning_rate_multiplier=0.1
backbone_lr=LEARNING_RATE*learning_rate_multiplier
weight_decay=0.5
#dice = Dice(average='micro')

#lambda_CE=5.0
#lambda_dice=5.0
metric = evaluate.load("mean_iou")
encoder_params=[param for name, param in model.named_parameters() if name.startswith("model.pixel_level_module.encoder")]
decoder_params=[param for name, param in model.named_parameters() if name.startswith("model.pixel_level_module.decoder")]
transformer_params=[param for name, param in model.named_parameters() if name.startswith("model.transformer_module")]
optimizer = optim.AdamW([{'params': encoder_params, 'lr': backbone_lr},\
                         {'params': decoder_params}, \
                            {'params':transformer_params}], \
                                lr=LEARNING_RATE,weight_decay=weight_decay)

scheduler = optim.lr_scheduler.PolynomialLR(optimizer,total_iters=NUM_EPOCHS,power=0.9)
"""
or 
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = MultiStepLR(
        optimizer, milestones=50, gamma=0.1, verbose=True
    )
"""

#CE_weight = torch.ones(num_classes)*2.0
#CE_weight[0] = 0.1

# Dataloading
N_WORKERS = 4
BATCH_SIZE = 16
SHUFFLE = True
DROP_LAST = True

In [17]:
# Option 1: model pretrained on ADE20K, finetune on only A, test A and B
train_loader_A = DataLoader(cadis_train_dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=N_WORKERS, drop_last=DROP_LAST, pin_memory=True)
val_loader_A=DataLoader(cadis_val_dataset,batch_size=2,shuffle=False,num_workers=N_WORKERS,drop_last=DROP_LAST)
test_loader_A=DataLoader(cadis_test_dataset,batch_size=1,shuffle=False,num_workers=N_WORKERS,drop_last=False)
test_loader_B=DataLoader(cataract_101_test_dataset,batch_size=1,shuffle=False,num_workers=N_WORKERS,drop_last=False)
model_name="m2f_cadis"
val_dataset_name="cadis"

# Option 2: model pretrained on A, finetune on  A+B, test A and B
"""
train_loader_fully_merged = DataLoader(fully_merged_train_dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=N_WORKERS, drop_last=DROP_LAST, pin_memory=True)
val_loader_fully_merged=DataLoader(fully_merged_val_dataset,batch_size=2,shuffle=False,num_workers=N_WORKERS,drop_last=DROP_LAST)
test_loader_A=DataLoader(cadis_test_dataset,batch_size=1,shuffle=False,num_workers=N_WORKERS,drop_last=False)
test_loader_B=DataLoader(cataract_101_test_dataset,batch_size=1,shuffle=False,num_workers=N_WORKERS,drop_last=False)
model_name="m2f_cadis+cataract101"
val_dataset_name="cadis+cataract101"
"""

# Option 3: model pretrained on A, finetune on B, test A and B
"""
train_loader_B = DataLoader(cataract_101_train_dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=N_WORKERS, drop_last=DROP_LAST, pin_memory=True)
val_loader_B=DataLoader(cataract_101_val_dataset,batch_size=2,shuffle=False,num_workers=N_WORKERS,drop_last=DROP_LAST)
test_loader_A=DataLoader(cadis_test_dataset,batch_size=1,shuffle=False,num_workers=N_WORKERS,drop_last=False)
test_loader_B=DataLoader(cataract_101_test_dataset,batch_size=1,shuffle=False,num_workers=N_WORKERS,drop_last=False)
model_name="m2f_cataract101"
val_dataset_name="cataract101"
"""

# Option 3: model pretrained on A, finetune on rand(A)+B, test A and B
"""
train_loader_replayA_B = DataLoader(replayA_B_train_dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=N_WORKERS, drop_last=DROP_LAST, pin_memory=True)
val_loader_replayA_B=DataLoader(replayA_B_val_dataset,batch_size=2,shuffle=False,num_workers=N_WORKERS,drop_last=DROP_LAST)
test_loader_A=DataLoader(cadis_test_dataset,batch_size=1,shuffle=False,num_workers=N_WORKERS,drop_last=False)
test_loader_B=DataLoader(cataract_101_test_dataset,batch_size=1,shuffle=False,num_workers=N_WORKERS,drop_last=False)
model_name="m2f_replayCadis+cataract101"
val_dataset_name="replayCadis+cataract101"
"""

train_loader=train_loader_A
val_loader=val_loader_A
test_loaders=[test_loader_A,test_loader_B]

# Train

In [None]:
writer = SummaryWriter(log_dir=out_dir)

best_val_metric=-np.inf

model_dir=out_dir+"models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir=model_dir+f"{model_name}/best_models/"
if not os.path.exists(model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir=model_dir+f"{model_name}/final_model/"
if not os.path.exists(model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    
    for images, masks in train_loader:
        optimizer.zero_grad()
        images = images.to(device)
        inputs = image_processor(images, return_tensors="pt",do_rescale=False)
        masks = masks.to(device)
        outputs = model(**inputs)
        loss=outputs.loss
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)

        target_sizes = [(image.shape[0], image.shape[1]) for image in images]
        pred_maps = image_processor.post_process_semantic_segmentation(
            outputs, target_sizes=target_sizes
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    train_epoch_miou = metric.compute(num_labels=num_classes, ignore_index=255, reduce_labels=True)['mean_iou']
    train_epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {train_epoch_loss:.4f}')
    writer.add_scalar(f'Loss/train_{model_name}', train_epoch_loss, epoch+1)
    writer.add_scalar(f'mIoU/train_{model_name}', train_epoch_miou, epoch+1)
    print(f'mIoU/train_{model_name}:{train_epoch_miou}, Epoch:{epoch+1}')

    if epoch>0 and epoch %5 ==0:
        running_loss_val=0.
        model.eval()
        with torch.no_grad():

            for val_images,val_masks in val_loader:
                val_images = val_images.to(device)
                val_inputs = image_processor(val_images, return_tensors="pt",do_rescale=False)
                val_masks = val_masks.to(device)
                outputs_val = model(**val_inputs)
                val_loss=outputs_val.loss
                running_loss_val += val_loss.item() * val_images.size(0)
                target_sizes = [(image.shape[0], image.shape[1]) for image in val_images]
                pred_maps = image_processor.post_process_semantic_segmentation(outputs_val, target_sizes=target_sizes)
                metric.add_batch(references=val_masks, predictions=pred_maps)
                
        val_epoch_loss = running_loss_val / len(val_loader.dataset)
        val_epoch_miou = metric.compute(num_labels=num_classes, ignore_index=255, reduce_labels=True)['mean_iou']

        print(f'Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {val_epoch_loss:.4f}')
        writer.add_scalar(f'mIoU/val_{val_dataset_name}', val_epoch_miou, epoch+1)
        writer.add_scalar(f'Loss/val_{val_dataset_name}', val_epoch_loss, epoch+1)
        print(f'mIoU/val_{model_name}:{val_epoch_miou}, Epoch:{epoch+1}')

        if val_epoch_miou>best_val_metric:
            best_val_metric=val_epoch_miou
            model.save_pretrained(best_model_dir)


    scheduler.step()
    


# Save final model.
model.save_pretrained(final_model_dir)
print('TRAINING COMPLETE')


# Test

In [None]:
with torch.no_grad():
    model.eval()
    for test_loader in test_loaders:
        
        # TODO: test dataset names to store them in tensorboard

        for test_images,test_masks in test_loader:
            test_images = test_images.to(device)
            test_inputs = image_processor(test_images, return_tensors="pt",do_rescale=False)
            test_masks = test_masks.to(device)
            outputs_test = model(**test_inputs)
            target_sizes = [(image.shape[0], image.shape[1]) for image in test_images]
            pred_maps = image_processor.post_process_semantic_segmentation(outputs_test, target_sizes=target_sizes)
            metric.add_batch(references=test_masks, predictions=pred_maps)
                
        test_miou = metric.compute(num_labels=num_classes, ignore_index=255, reduce_labels=True)['mean_iou']

        writer.add_scalar(f'mIoU/test_{test_dataset_name}', test_miou)
        print(f'mIoU/test_{test_dataset_name}:', test_miou)


writer.close()

In [31]:
url = (
    "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
)
image=Image.open(requests.get(url, stream=True).raw)
image.size

(683, 512)

In [33]:

from torchvision import transforms
transform=transforms.Compose([transforms.ToTensor()])
image=transform(image)
image

tensor([[[0.3373, 0.3333, 0.3333,  ..., 0.3059, 0.3059, 0.3059],
         [0.3373, 0.3333, 0.3333,  ..., 0.3059, 0.3059, 0.3059],
         [0.3373, 0.3333, 0.3333,  ..., 0.3059, 0.3059, 0.3059],
         ...,
         [0.3490, 0.3608, 0.3608,  ..., 0.4471, 0.4471, 0.4510],
         [0.3686, 0.3725, 0.3686,  ..., 0.4588, 0.4627, 0.4667],
         [0.3686, 0.3686, 0.3569,  ..., 0.4588, 0.4627, 0.4667]],

        [[0.5412, 0.5373, 0.5373,  ..., 0.5059, 0.5059, 0.5059],
         [0.5412, 0.5373, 0.5373,  ..., 0.5059, 0.5059, 0.5059],
         [0.5412, 0.5373, 0.5373,  ..., 0.5059, 0.5059, 0.5059],
         ...,
         [0.3765, 0.3882, 0.3882,  ..., 0.4353, 0.4353, 0.4392],
         [0.3922, 0.3961, 0.3922,  ..., 0.4471, 0.4510, 0.4549],
         [0.3922, 0.3922, 0.3804,  ..., 0.4471, 0.4510, 0.4549]],

        [[0.7294, 0.7255, 0.7333,  ..., 0.7294, 0.7294, 0.7294],
         [0.7294, 0.7255, 0.7333,  ..., 0.7294, 0.7294, 0.7294],
         [0.7294, 0.7255, 0.7333,  ..., 0.7294, 0.7294, 0.