## Setup

In [2]:
import mlflow

  """Entry point for launching an IPython kernel.


In [3]:
# path_uri = '/content/gdrive/MyDrive/University/FEA_Internship/boxshrink/ML_FLOW'
import os
from getpass import getpass

os.environ['MLFLOW_TRACKING_USERNAME'] = input('Enter your DAGsHub username: ')
os.environ['MLFLOW_TRACKING_PASSWORD'] = getpass('Enter your DAGsHub access token: ')
os.environ['MLFLOW_TRACKING_PROJECTNAME'] = input('Enter your DAGsHub project name: ')
mlflow.set_registry_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME'] 
                        + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')
mlflow.set_tracking_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME'] 
                        + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')

In [3]:
from mlflow.tracking import MlflowClient

print("Tracking Location: {}".format(mlflow.get_tracking_uri()))
print("Registry Location: {}".format(mlflow.get_registry_uri()))

Tracking Location: https://dagshub.com/michaelgroeger/testing.mlflow
Registry Location: https://dagshub.com/michaelgroeger/testing.mlflow


In [4]:
# Import custom function
from tools import (
    return_files_in_directory,
    decode_segmap,
    get_classes_from_mask,
    rgb_to_mask,
    visualize,
    return_batch_information,
    flatten,
    human_sort
    )

from dataset import Colonoscopy_Dataset

In [5]:
# Import libraries
import torch
import os
import cv2
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import compute_unary, unary_from_softmax

from torchmetrics import JaccardIndex
import time

from skimage.color import rgb2gray
from skimage.filters import sobel
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

from torch.utils.data import DataLoader, Subset
from torch.nn import CrossEntropyLoss
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, ToPILImage
from torch.optim.lr_scheduler import StepLR, ExponentialLR
import copy
from tifffile import imread

## Training Parameters

In [12]:
import config

## Data

In [13]:
# load repo with data if it is not exists
if not os.path.exists(DATA_DIR):
    print("Couldn't find data")
else:
    print("Found data")

Found data


In [14]:
image_files = return_files_in_directory(DATA_DIR + "/Original", ".tif")
mask_files = return_files_in_directory(DATA_DIR + "/Ground Truth", ".tif")
box_files = return_files_in_directory(DATA_DIR + "/boxmasks", ".png")
sp_crf_files = return_files_in_directory(DATA_DIR + "rapid_boxshrink", ".png")
embedding_files = return_files_in_directory(DATA_DIR + "robust_boxshrink", ".png")

In [16]:
human_sort(image_files)
human_sort(mask_files)
human_sort(box_files)
# human_sort(crf_files)
# human_sort(sp_crf_files)
# human_sort(embedding_files)

NameError: name 'crf_files' is not defined

In [None]:
from sklearn.model_selection import train_test_split
# TODO: Check if the random state is persistant across restarts
if TRAINING_INPUT == "Boxes":
  X_train, X_test, y_train, y_test = train_test_split(image_files, box_files, test_size=0.1, random_state=1)
elif TRAINING_INPUT == "rapid_boxshrink":
  X_train, X_test, y_train, y_test = train_test_split(image_files, sp_crf_files, test_size=0.1, random_state=1)
elif TRAINING_INPUT == "robust_boxshrink":
  X_train, X_test, y_train, y_test = train_test_split(image_files, embedding_files, test_size=0.1, random_state=1)
else:
  X_train, X_test, y_train, y_test = train_test_split(image_files, mask_files, test_size=0.1, random_state=1)

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.11111, random_state=1) # 0.1111 x 0.9 = 0.1

In [None]:
TRAINING_INPUT

'embedding_masks'

In [None]:
# Eval on ground truth masks
if TRAINING_INPUT == "Boxes":
    y_val = [i.replace("boxmasks", 'Ground Truth').replace('png', 'tif') for i in y_val]
    y_test = [i.replace("boxmasks", 'Ground Truth').replace('png', 'tif') for i in y_test]
elif TRAINING_INPUT == "rapid_boxshrink":
    y_val = [i.replace("INPUT_PATH", 'Ground Truth').replace('png', 'tif') for i in y_val]
    y_test = [i.replace("INPUT_PATH", 'Ground Truth').replace('png', 'tif') for i in y_test]
elif TRAINING_INPUT == "robust_boxshrink":
    y_val = [i.replace("INPUT_PATH", 'Ground Truth').replace('png', 'tif') for i in y_val]
    y_test = [i.replace("INPUT_PATH", 'Ground Truth').replace('png', 'tif') for i in y_test]

In [None]:
VALIDATE_TRAINING_TARGETS = '_'.join(y_train[0].split('/')[-4:])
VALIDATE_TRAINING_TARGETS

'colonoscopy_crf_masks_adjusted_embeddings_350s_0t_it_03_mot_005_iter_1_the_005_thsp_0_foreground_embeddings_ns250_th01_sp_crf_masks_pbsxy2525_pbsrb10_pgsxy5_519.png'

In [None]:
VALIDATE_VAL_TARGETS = '_'.join(y_val[0].split('/')[-4:])
VALIDATE_VAL_TARGETS

'datasets_colonoscopy_Ground Truth_529.tif'

In [None]:
VALIDATE_TEST_TARGETS = '_'.join(y_test[0].split('/')[-4:])
VALIDATE_TEST_TARGETS

'datasets_colonoscopy_Ground Truth_494.tif'

In [None]:
train_dataset = Colonoscopy_Dataset(
    X_train, 
    y_train,
    limit_dataset_size=256
)

test_dataset = Colonoscopy_Dataset(
    X_test, 
    y_test,
    # limit_dataset_size=64
)

val_dataset = Colonoscopy_Dataset(
    X_val, 
    y_val,
    # limit_dataset_size=64
)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

## Training setup

In [None]:
import segmentation_models_pytorch as smp

# create segmentation model with pretrained encoder
if DECODER == 'Unet':
    model = smp.Unet(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS, 
        classes=len(CLASSES), 
        activation=ACTIVATION,
        )
    
elif DECODER == 'DeepLabV3+':
    model = smp.Unet(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS, 
        classes=len(CLASSES), 
        activation=ACTIVATION,
        )
        
if OPTIMIZER == "SGD":
  optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
if OPTIMIZER == "Adam":
  optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Learning rate scheduling
if LEARNING_RATE_SCHEDULING == True and SCHEDULE_TYPE == "STEP":
  scheduler = StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA, verbose=True)
elif LEARNING_RATE_SCHEDULING == True and SCHEDULE_TYPE == "EXPONENTIAL":
  STEP_SIZE = 'Not needed'
  scheduler = ExponentialLR(optimizer, gamma=GAMMA, verbose=True)
elif LEARNING_RATE_SCHEDULING == False:
  STEP_SIZE = 'No scheduling'
  GAMMA = 'No scheduling'
  SCHEDULE_TYPE = 'No scheduling'

Adjusting learning rate of group 0 to 1.0000e-04.


In [None]:
%%capture


label_colors = np.array(
    [(0,0,0), (128,128,128)]
)

jaccard = JaccardIndex(num_classes=len(CLASSES), reduction='elementwise_mean').to(device)



def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


# Setup date for model name
from datetime import date
today = date.today()
datestring = today.strftime("%Y-%m-%d")
# Instantiate accuracy and loss tracker
best_train_loss = float('inf')
best_valid_loss = float('inf')
best_valid_iou = 0
best_train_iou = 0

# Build dataframe to collect loss and metric data
df_train = pd.DataFrame(columns=['epoch', 'loss', 'avg_loss', 'mean_iou'])
df_val = pd.DataFrame(columns=['epoch', 'loss', 'avg_loss', 'mean_iou'])
# # Determine column types train
# Dummy entry to prevent visualization bug that large values are plotted as zero
if LOSS == "CrossEntropyLoss":
    criterion = CrossEntropyLoss()
    criterion_double = CrossEntropyLoss()
model.to(device)

## Additional Visualization 

In [None]:
from pytorch_grad_cam.utils.image import show_cam_on_image

# useful function from prior experiments
# used to show image alongside label or prediction mask
def show_grad_cam_on_img(org_image, cam):
    rgb_img = np.float32(org_image) / 255
    cam_rgb = show_cam_on_image(rgb_img, cam, use_rgb=True)
    return cam_rgb

## Train Loop

In [None]:
# Write params to ML Flow
# Start measuring time to train
# End MLflow run if there is one
mlflow.end_run()

# Log params for ML flow
params = {'state': STATE,
          'encoder': ENCODER,
          'decoder': DECODER,
          'activation': ACTIVATION,
          'weights': ENCODER_WEIGHTS,
          'batch_size': train_loader.batch_size, 
          'optimizer_name':OPTIMIZER, 
          'learning_rate': optimizer.defaults['lr'],
          'learning_rate_scheduler': SCHEDULE_TYPE,
          'step_size': STEP_SIZE,
          'gamma': GAMMA,
          'weight_decay': optimizer.defaults['weight_decay'], 
          'train_dataset_size': len(train_loader.dataset), 
          'valid_dataset_size': len(val_loader.dataset),
          'eval_on_test_set': EVAL_ON_MASKS,
          'training_input': TRAINING_INPUT, 
          'mask_occupancy_threshold': MASK_OCCUPANCY_THRESHOLD,
          'iou_threshold': IOU_THRESHOLD,
          'export_best_model': EXPORT_BEST_MODEL,
          'epochs': N_EPOCHS,
          'y_train': VALIDATE_TRAINING_TARGETS,
          'y_test': VALIDATE_TEST_TARGETS,
          'y_val': VALIDATE_VAL_TARGETS,
          }

# Build the run
number_of_runs_in_experiment = len(mlflow.search_runs())
run_name = "_".join([STATE, DECODER, ENCODER, str(number_of_runs_in_experiment+1)])

# with mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name):
with mlflow.start_run(run_name=run_name):
    # Write params to ML Flow
    mlflow.log_params(params)
    torch.backends.cudnn.benchmark = True
    train_start_time = time.time()
    train_iou_score = torch.tensor([0])
    early_stopped = 0 
    for epoch in range(START_EPOCH, N_EPOCHS):
        model.train()
        batch, running_epoch_iou, running_epoch_loss = 0, 0.0, 0.0
        with tqdm(train_loader, unit="batch") as tepoch:
            for train_inputs, train_labels, train_org_images in tepoch:
                batch += 1
                optimizer.zero_grad(set_to_none=True)
                tepoch.set_description(f"Epoch {epoch}")
                train_inputs, train_labels, train_org_images = train_inputs.to(device), train_labels.to(device), train_org_images.to(device)
                # forward
                train_outputs = model(train_inputs).to(device)
                out_max = torch.argmax(train_outputs, dim=1, keepdim=True)[: , -1, :, :].cpu().detach().numpy()
                # Backward
                train_loss = criterion(train_outputs, train_labels)
                train_loss.backward()
                if epoch % PER_X_EPOCH == 0 and batch % PER_X_BATCH == 0:
                    return_batch_information(train_org_images, out_max, train_labels, 1, CLASSES, label_colors=label_colors)
                optimizer.step()
                model.eval()
                train_iou_score =  jaccard(train_outputs, train_labels).to(device).item()
                model.train()
                running_epoch_iou += train_iou_score
                train_loss = float(train_loss.item())
                running_epoch_loss += train_loss
                # print statistics
                tepoch.set_postfix(phase="Training", loss=train_loss, iou=train_iou_score, epoch_iou = running_epoch_iou / batch, epoch_loss = running_epoch_loss / batch)
            train_mean_epoch_iou, train_mean_epoch_loss = running_epoch_iou / batch, running_epoch_loss / batch
        if best_train_loss > train_mean_epoch_loss:
          best_train_loss = train_mean_epoch_loss
        if best_train_iou < train_mean_epoch_iou:
          best_train_iou = train_mean_epoch_iou
        # Save results to dataframe
        if epoch == 0:
          train_row = {'epoch': int(epoch), 'loss': float(train_mean_epoch_loss), 'avg_loss': float(train_mean_epoch_loss),'mean_iou': float(train_mean_epoch_iou)}
        else:
          # Get moving average
          train_avg = df_train['loss'].ewm(com=0.99).mean()
          train_row = {'epoch': int(epoch), 'loss': float(train_loss), 'avg_loss': train_avg[(epoch-1)],'mean_iou': train_mean_epoch_iou}
        
        # Send logs to ML flow
        mlflow.log_metric(key="train_loss", value=train_mean_epoch_loss, step=epoch)
        mlflow.log_metric(key="train_iou", value=train_mean_epoch_iou, step=epoch)
        df_train = df_train.append(train_row, ignore_index=True)
        # Decay Learning Rate at x steps
        if LEARNING_RATE_SCHEDULING == True:
            scheduler.step()
        # Delete variables to free memory
        del running_epoch_iou, running_epoch_loss, train_loss, train_iou_score

    ### Running validation loop
        batch, running_epoch_iou, running_epoch_loss = 0, 0.0, 0.0
        model.eval()
        with torch.no_grad():
            with tqdm(val_loader, unit="batch") as tepoch:
                for val_inputs, val_labels, val_org_images in tepoch:
                    batch += 1
                    tepoch.set_description(f"Epoch {epoch}")
                    val_inputs, val_labels, val_org_images = val_inputs.to(device), val_labels.to(device), val_org_images.to(device)
                    # forward 
                    val_outputs = model(val_inputs)
                    # Collect metrics
                    val_iou_score = jaccard(val_outputs, val_labels).item()
                    val_loss = criterion(val_outputs, val_labels).item()
                    # Collect data for dataframe
                    running_epoch_iou += val_iou_score
                    running_epoch_loss += val_loss
                    # print statistics
                    tepoch.set_postfix(phase="Validation", loss=val_loss, iou=val_iou_score, epoch_iou = running_epoch_iou / batch, epoch_loss = running_epoch_loss / batch)
            val_mean_epoch_iou, val_mean_epoch_loss = running_epoch_iou / batch, running_epoch_loss / batch
            # Save results to dataframe
            if epoch == 0:
              val_row = {'epoch': int(epoch), 'loss': float(val_mean_epoch_loss), 'avg_loss': float(val_mean_epoch_loss),'mean_iou': val_mean_epoch_iou}
            else:
              val_avg = df_val['loss'].ewm(com=0.99).mean()
              val_row = {'epoch': int(epoch), 'loss': float(val_loss), 'avg_loss': val_avg[(epoch-1)],'mean_iou': val_mean_epoch_iou}
            df_val = df_val.append(val_row, ignore_index=True)
            if best_valid_loss > val_mean_epoch_loss:
              # Update best metrics
              best_valid_loss = val_mean_epoch_loss
            if best_valid_iou < val_mean_epoch_iou: 
              best_valid_iou = val_mean_epoch_iou
              best_model = copy.deepcopy(model)
              if epoch > 2 and EXPORT_BEST_MODEL == True:
                model_name = "_".join([datestring, STATE, MODE, DECODER, OPTIMIZER, LOSS, ENCODER, str(len(train_dataset)), "images", LOSS, "loss", str(best_valid_loss).replace(".","_"), "iou", str(val_mean_epoch_iou), "epoch", str(epoch), ".pth"])
                # save model
                path = os.path.join(BEST_MODEL_DIR, model_name)
                torch.save(best_model.state_dict(), path)
                print(f'Model saved! Name is {model_name}')          
            if epoch % PER_X_EPOCH_PLOT == 0:
              plt.plot(df_train['epoch'], df_train['avg_loss'], label = "Train Loss")
              plt.plot(df_val['epoch'], df_val['avg_loss'], label = "Valid Loss")
              plt.plot(df_val['epoch'], df_val['mean_iou'], label = "Mean IoU")
              plt.legend()
              plt.title('Performance')
              plot = plt.gcf()
              plt.show()
            train_df_name = "_".join([datestring, "train", MODE, DECODER, OPTIMIZER, LOSS, ENCODER, str(len(train_dataset)), "images", ".csv"])
            valid_df_name = "_".join([datestring, "valid", MODE, DECODER, OPTIMIZER, LOSS, ENCODER, str(len(train_dataset)), "images", ".csv"])
            df_train.to_csv(os.path.join(EXPORT_CSV_DIR, train_df_name))
            df_val.to_csv(os.path.join(EXPORT_CSV_DIR, valid_df_name))
            mlflow.log_metric(key="valid_loss", value=val_mean_epoch_loss, step=epoch)
            mlflow.log_metric(key="valid_iou", value=val_mean_epoch_iou, step=epoch)
            if epoch > 5:
                last_runs = df_train['loss'][-5:]
                # Get min and max of that window
                min_loss_last_runs = last_runs.min()
                max_loss_last_runs = last_runs.max()
                difference = max_loss_last_runs - min_loss_last_runs
                if difference < 0.001:
                  print("Stopped Training because it doesn't improve anymore.")
                  train_end_time = time.time()
                  # Get minutes and seconds to write to ML flow
                  train_mins, train_secs = epoch_time(train_start_time, train_end_time)
                  mlflow.log_param("train_time", f"{train_mins} min, {train_secs} sec")
                  # Get run information and return to window
                  run = mlflow.active_run()
                  print("run_id: {}; status: {}".format(run.info.run_id, run.info.status))
                  # End run and get status
                  mlflow.log_metric(key="best_valid_loss", value=best_valid_loss)
                  mlflow.log_metric(key="best_train_loss", value=best_train_loss)
                  mlflow.log_metric(key="best_valid_iou", value=best_valid_iou)
                  mlflow.log_metric(key="best_train_iou", value=best_train_iou)
                  batch, running_epoch_iou, running_epoch_loss = 0, 0.0, 0.0
                  best_model.eval()
                  with torch.no_grad():
                      with tqdm(test_loader, unit="batch") as tepoch:
                          for test_inputs, test_labels, test_org_images in tepoch:
                              batch += 1
                              tepoch.set_description(f"Epoch {epoch}")
                              test_inputs, test_labels, test_org_images = test_inputs.to(device), test_labels.to(device), test_org_images.to(device)
                              # forward 
                              test_outputs = best_model(test_inputs)
                              # Collect metrics
                              test_iou_score = jaccard(test_outputs, test_labels).item()
                              test_loss = criterion(test_outputs, test_labels).item()
                              # Collect data for dataframe
                              running_epoch_iou += test_iou_score
                              running_epoch_loss += test_loss
                              # print statistics
                              tepoch.set_postfix(phase="Validation", loss=test_loss, iou=test_iou_score, epoch_iou = running_epoch_iou / batch, epoch_loss = running_epoch_loss / batch)
                      test_mean_epoch_iou, test_mean_epoch_loss = running_epoch_iou / batch, running_epoch_loss / batch
                      mlflow.log_metric(key="test_iou", value=test_mean_epoch_iou)
                      mlflow.log_metric(key="test_loss", value=test_mean_epoch_loss)
                  print("")
                  print(f"Performance on test set: {val_mean_epoch_iou} IoU and {val_mean_epoch_loss} Loss")
                  mlflow.log_params({'best_model': model_name})
                  mlflow.end_run()
                  run = mlflow.get_run(run.info.run_id)
                  print(f"Training time was {train_mins, train_secs}")
                  print("run_id: {}; status: {}".format(run.info.run_id, run.info.status))
                  print("--")

                  # Check for any active runs
                  print("Active run: {}".format(mlflow.active_run()))
                  early_stopped += 1
                  break
    if early_stopped == 0:
        train_end_time = time.time()
        # Get minutes and seconds to write to ML flow
        train_mins, train_secs = epoch_time(train_start_time, train_end_time)
        mlflow.log_param("train_time", f"{train_mins} min, {train_secs} sec")
        # Get run information and return to window
        run = mlflow.active_run()
        print("run_id: {}; status: {}".format(run.info.run_id, run.info.status))
        # End run and get status
        mlflow.log_metric(key="best_valid_loss", value=best_valid_loss)
        mlflow.log_metric(key="best_train_loss", value=best_train_loss)
        mlflow.log_metric(key="best_valid_iou", value=best_valid_iou)
        mlflow.log_metric(key="best_train_iou", value=best_train_iou)
        batch, running_epoch_iou, running_epoch_loss = 0, 0.0, 0.0
        best_model.eval()
        with torch.no_grad():
            with tqdm(test_loader, unit="batch") as tepoch:
                for test_inputs, test_labels, test_org_images in tepoch:
                    batch += 1
                    tepoch.set_description(f"Epoch {epoch}")
                    test_inputs, test_labels, test_org_images = test_inputs.to(device), test_labels.to(device), test_org_images.to(device)
                    # forward 
                    test_outputs = best_model(test_inputs)
                    # Collect metrics
                    test_iou_score = jaccard(test_outputs, test_labels).item()
                    test_loss = criterion(test_outputs, test_labels).item()
                    # Collect data for dataframe
                    running_epoch_iou += test_iou_score
                    running_epoch_loss += test_loss
                    # print statistics
                    tepoch.set_postfix(phase="Validation", loss=test_loss, iou=test_iou_score, epoch_iou = running_epoch_iou / batch, epoch_loss = running_epoch_loss / batch)
            test_mean_epoch_iou, test_mean_epoch_loss = running_epoch_iou / batch, running_epoch_loss / batch
            mlflow.log_metric(key="test_iou", value=test_mean_epoch_iou)
            mlflow.log_metric(key="test_loss", value=test_mean_epoch_loss)
            print("")
            print(f"Performance on test set: {test_mean_epoch_iou} IoU and {test_mean_epoch_loss} Loss")
        mlflow.log_params({'best_model': model_name})
        mlflow.end_run()
        run = mlflow.get_run(run.info.run_id)
        print(f"Training time was {train_mins, train_secs}")
        print("run_id: {}; status: {}".format(run.info.run_id, run.info.status))
        print("--")
    
        # Check for any active runs
        print("Active run: {}".format(mlflow.active_run()))

In [None]:
# Play an audio beep. Any audio URL will do.
from google.colab import output
output.eval_js('new Audio("https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg").play()')

## Embeddings

In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
from torchvision.models.segmentation import deeplabv3_resnet50
import torch
import torch.functional as F
import numpy as np
import requests
import cv2
import torchvision
from PIL import Image
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam import GradCAM

# A model wrapper that gets a resnet model and returns the features before the fully connected layer.
class ResnetFeatureExtractor(torch.nn.Module):
    def __init__(self, model):
        super(ResnetFeatureExtractor, self).__init__()
        self.model = model
        self.feature_extractor = torch.nn.Sequential(*list(self.model.children())[:-1])
                
    def __call__(self, x):
        return self.feature_extractor(x)[:, :, 0, 0]
        
resnet = torchvision.models.resnet50(pretrained=True)
resnet.eval()
feature_extract_model = ResnetFeatureExtractor(resnet)

In [None]:
cos = torch.nn.CosineSimilarity(dim=0)
def get_cosine_sim_score(feat_1, feat_2, cosine_fct=cos):
    return (torch.sum(cos(feat_1.squeeze(), feat_2.squeeze())))

In [None]:
%%capture
feature_extract_model.to(device)

In [None]:
def get_bbox_coordinates_one_box(tensor): 
    all_x, all_y = (tensor.squeeze() == 1).nonzero(as_tuple=True)
    smallest_x, smallest_y = torch.min(all_x).item(), torch.min(all_y).item()
    largest_x, largest_y = torch.max(all_x).item(), torch.max(all_y).item()
    return (smallest_y, smallest_x), (largest_y, largest_x)

def get_foreground_background_embeddings(argmax_prediction_per_class, org_img, train_input, threshold, N_SEGMENTS, class_indx=1,  compactness=10, sigma=1, start_label=1, device=device, model=feature_extract_model):
    # get superpixels
    org_img = org_img.cpu().detach().numpy()
    all_superpixels_mask = torch.from_numpy(slic(org_img, n_segments=N_SEGMENTS, compactness=compactness, sigma=sigma, start_label=start_label))
    visualize_superpixels(all_superpixels_mask.numpy(), img=org_img)
    hadamard = all_superpixels_mask.to(device) * argmax_prediction_per_class.to(device)
    overlap = (hadamard / class_indx).type(torch.IntTensor)
    # Instantiate base mask
    base_mask = torch.zeros(overlap.shape)
    # Get numbers to list, start from second element because first is 0 
    relevant_superpixels = torch.unique(overlap).int().tolist()[1:]
    relevant_superpixels_thresholded = []
    for superpixel in relevant_superpixels:
          temp = overlap.clone()
          org = all_superpixels_mask.clone()
        #   # Check how many are non-zero in superpixel mask
          temp[temp != superpixel] = 0
          org[org != superpixel] = 0
          # Check how many are non-zero in overlap
          # Determine share of pixels
          share = torch.count_nonzero(temp).item() / torch.count_nonzero(org).item()
          # Add superpixel as ones to base mask if share is over threshold
          if share > threshold:
            # bring org values to one
            relevant_superpixels_thresholded.append(superpixel)
    background_superpixels = [i.item() for i in torch.unique(all_superpixels_mask) if i not in relevant_superpixels_thresholded]
    foreground_embeddings = torch.zeros([len(relevant_superpixels_thresholded), 2048])
    background_embeddings = torch.zeros([len(background_superpixels), 2048])
    for i, superpixel in enumerate(relevant_superpixels_thresholded):
      all_superpixels_mask_tmp = all_superpixels_mask.clone()
      all_superpixels_mask_tmp[all_superpixels_mask_tmp != superpixel] = 0
      all_superpixels_mask_tmp[all_superpixels_mask_tmp>0] = 1
      s,l = get_bbox_coordinates_one_box(all_superpixels_mask_tmp)
      print(f"Org shape: {torch.Tensor(org_img).permute(2,0,1).shape}")
      print(f"Input shape: {train_input.shape}")
      base = torch.Tensor(org_img).permute(2,0,1).clone().cpu() / 255#train_input.clone().cpu()
      base_aspm = base.clone()

      base_aspm[0,:,:] = base_aspm[0,:,:] * all_superpixels_mask_tmp
      base_aspm[1,:,:] = base_aspm[1,:,:] * all_superpixels_mask_tmp
      base_aspm[2,:,:] = base_aspm[2,:,:] * all_superpixels_mask_tmp
      cut = base_aspm[:,s[1]:l[1],s[0]:l[0]].unsqueeze(0).to(device)
      visualize(cut=cut.cpu().squeeze(0).permute(2,1,0))
      with torch.no_grad():
        feat_foreground_sp = feature_extract_model(cut)
        foreground_embeddings[i,:] = feat_foreground_sp
    for i, superpixel in enumerate(background_superpixels):
      all_superpixels_mask_tmp = all_superpixels_mask.clone()
      all_superpixels_mask_tmp[all_superpixels_mask_tmp != superpixel] = 0
      all_superpixels_mask_tmp[all_superpixels_mask_tmp>0] = 1
      s,l = get_bbox_coordinates_one_box(all_superpixels_mask_tmp)
      base = train_input.clone().cpu()
      base_aspm = base.clone()

      base_aspm[0,:,:] = base_aspm[0,:,:] * all_superpixels_mask_tmp
      base_aspm[1,:,:] = base_aspm[1,:,:] * all_superpixels_mask_tmp
      base_aspm[2,:,:] = base_aspm[2,:,:] * all_superpixels_mask_tmp
      cut = base_aspm[:,s[1]:l[1],s[0]:l[0]].unsqueeze(0).to(device)
      with torch.no_grad():
        feat_background_sp = model(cut)
        background_embeddings[i,:] = feat_background_sp
    return foreground_embeddings, background_embeddings, relevant_superpixels_thresholded, all_superpixels_mask


In [None]:
## Get mean embeddings

# foreground_embeddings = torch.zeros([len(train_dataset), 2048])
# backgound_embeddings = torch.zeros([len(train_dataset), 2048])
# counter = 0
# for epoch in range(0, 1):
#   batch = 0
#   with tqdm(train_loader, unit="batch") as tepoch:
#     for train_inputs, train_labels, train_org_images in tepoch:
#         batch += 1
#         tepoch.set_description(f"Epoch {epoch}")
#         train_inputs, train_labels, train_org_images = train_inputs.to(device), train_labels.to(device), train_org_images.to(device)
#         for i in range(0,train_inputs.shape[0],1):
#           i = 7
#           N_SEGMENTS = 250
#           threshold = 0.10
#           vis = train_org_images[i].squeeze(0).permute(0,1,2).cpu()
#           print(vis.shape)
#           print(f"train_input: {train_inputs.shape}")
#           print(f"train_org_images: {train_org_images.shape}")
#           visualize(img=vis)
#           embed_f, embed_b, rs, aspm = get_foreground_background_embeddings(train_labels[i], train_org_images[i], train_inputs[i], N_SEGMENTS=N_SEGMENTS, threshold=threshold)
#           break
#         break
#   break
          # mean_f = torch.mean(embed_f, dim=0)
          # mean_b = torch.mean(embed_b, dim=0)

          # foreground_embeddings[counter,:] += mean_f.cpu()
          # backgound_embeddings[counter,:] += mean_b.cpu()
          # counter += 1
          # break
          # if counter == 477:
          #   print("if-break")
          #   # torch.save(foreground_embeddings, 'PATH_TO_EMBEDDINGS')
          #   # torch.save(backgound_embeddings, 'PATH_TO_EMBEDDINGS')
          #   break
        # torch.save(foreground_embeddings, 'PATH_TO_EMBEDDINGS')
        # torch.save(backgound_embeddings, 'PATH_TO_EMBEDDINGS')

f = torch.load('PATH_TO_MEAN_EMBEDDING_VECTOR')
b = torch.load('PATH_TO_MEAN_EMBEDDING_VECTOR')

In [None]:
def assign_foreground_sp(cosine_fct, mean_foreground_embedding, mean_background_embedding, relevant_superpixels_thresholded, foreground_embeddings, threshold):
    close_f_foreground_embeddings = []
    for i in range(foreground_embeddings.shape[0]):
        f_cos = cosine_fct(mean_foreground_embedding, foreground_embeddings[i])
        b_cos = cosine_fct(mean_background_embedding, foreground_embeddings[i])
        diff = abs(b_cos - f_cos)
        if (f_cos > b_cos or f_cos == b_cos) and diff <= threshold:
          close_f_foreground_embeddings.append(relevant_superpixels_thresholded[i])
    return close_f_foreground_embeddings

def scan_outer_boundary(superpixel_mask, flatten_list_fct=flatten):
    # We will collect those superpixels which are on the outer boundary so they 
    # have a zero neighbour.
    superpixel_mask_refined = superpixel_mask.clone()
    outer_superpixel_rows, outer_superpixel_cols, outer_all = [], [], []
    tuples = torch.nonzero(superpixel_mask_refined)
    rows = torch.unique(tuples[:,0])
    columns = torch.unique(tuples[:,1])
    # scan over rows 
    for i in rows:
      current_row = superpixel_mask_refined[i, :]
      unique_non_zeroed_row = torch.unique(current_row[current_row.nonzero(as_tuple=True)], sorted=False)
      first_superpixel = unique_non_zeroed_row[-1]
      last_superpixel = unique_non_zeroed_row[0]
      outer_superpixel_rows.append(first_superpixel.item())
      outer_superpixel_rows.append(last_superpixel.item())
    # scan over columns
    for i in columns:
      current_column = superpixel_mask_refined[:, i]
      unique_non_zeroed_column = torch.unique(current_column[current_column.nonzero(as_tuple=True)], sorted=False)
      first_superpixel = unique_non_zeroed_column[-1]
      last_superpixel = unique_non_zeroed_column[0]
      outer_superpixel_cols.append(first_superpixel.item())
      outer_superpixel_cols.append(last_superpixel.item())
    outer_all.append(outer_superpixel_cols)
    outer_all. append(outer_superpixel_rows)
    outer_all = flatten(outer_all)
    return list(set(outer_all))

def create_embedding_mask(train_label, 
                          train_org_image, 
                          train_input, 
                          N_SEGMENTS, 
                          threshold_embedding=0, 
                          threshold_closeness=0, 
                          mean_foreground_embedding=mean_f, 
                          mean_background_embedding=mean_b, 
                          cosine_function=get_cosine_sim_score, 
                          assign_label_based_on_closeness=assign_foreground_sp, 
                          get_foreground_background_embeddings_function=get_foreground_background_embeddings, 
                          scan_outer_pixels=True,
                          scan_outer_superpixels_function=scan_outer_boundary,
                          postprocess_crf=True,
                          iter=1):
    foreground_embeddings, background_embeddings, relevant_superpixels, all_superpixels_mask = get_foreground_background_embeddings_function(train_label, train_org_image, train_input, N_SEGMENTS=N_SEGMENTS, threshold=threshold_embedding)
    for i in range(0, iter, 1):
      not_in_relevant_superpixels = [i for i in torch.unique(all_superpixels_mask) if i not in relevant_superpixels]
      embedding_mask_relevant_superpixels = all_superpixels_mask.clone()
      for i in not_in_relevant_superpixels:
          embedding_mask_relevant_superpixels[embedding_mask_relevant_superpixels == i] = 0
      if scan_outer_pixels == True:
          outer_superpixels = scan_outer_superpixels_function(embedding_mask_relevant_superpixels)
          indexes_outer = [relevant_superpixels.index(i) for i in outer_superpixels]
          outer_foreground_embedding = foreground_embeddings[indexes_outer, :]
          close_foreground_outer_superpixels = assign_foreground_sp(cosine_function, mean_foreground_embedding, mean_background_embedding, relevant_superpixels_thresholded=outer_superpixels, foreground_embeddings=outer_foreground_embedding, threshold=threshold_closeness)
          to_be_dropped = [i for i in outer_superpixels if i not in close_foreground_outer_superpixels]
          relevant_superpixels = [i for i in relevant_superpixels if i not in to_be_dropped]
          not_in_relevant_superpixels = [i for i in torch.unique(all_superpixels_mask) if i not in relevant_superpixels]
          for i in not_in_relevant_superpixels:
              embedding_mask_relevant_superpixels[embedding_mask_relevant_superpixels == i] = 0
    embedding_mask_relevant_superpixels[embedding_mask_relevant_superpixels > 0] = 1
    if postprocess_crf == True:
      embedding_mask_relevant_superpixels = crf(train_org_image, embedding_mask_relevant_superpixels)
    else:
      embedding_mask_relevant_superpixels[embedding_mask_relevant_superpixels > 0] = 1
      if postprocess_crf == True:
          embedding_mask_relevant_superpixels = crf(train_org_image, embedding_mask_relevant_superpixels)

    return embedding_mask_relevant_superpixels, outer_superpixels, embedding_mask_relevant_superpixels, all_superpixels_mask


In [None]:
def export_superpixel_embedding_masks(dataset, export_path="PATH", img_transform=img_transform, show_grad_cam_on_img=show_grad_cam_on_img):
      images = dataset.X
      masks = dataset.Y
      for i, _ in tqdm(enumerate(images)):
          # load image
          img = torch.tensor(imread(_))
          # load mask
          if ".tif" in masks[i]:
            mask = torch.tensor(imread(masks[i])).long()
          elif ".png" in masks[i]:
            mask = torch.Tensor(np.array(Image.open(masks[i]))).long()
          mask[mask>0] = 1    
          embedding_mask_outer_fct = create_embedding_mask(mask, 
                                                     img, 
                                                     img_transform(imread(_)), 
                                                     N_SEGMENTS=200, 
                                                     iter=1)
          embedding_mask_outer_fct = pass_pseudomask_or_ground_truth(mask.to(device), embedding_mask_outer_fct.to(device), iou_threshold=0.5, mask_occupancy_threshold=0.05)
          embedding_mask_outer_fct = pass_pseudomask_or_ground_truth(mask.to(device), embedding_mask_outer_fct)
          overlay_1 = show_grad_cam_on_img(img, embedding_mask_outer_fct.cpu())
          visualize(img=img.cpu(), gt_box=mask.cpu(), embedding_mask_outer_fct=embedding_mask_outer_fct.cpu(), overlay_1=overlay_1)  
          embedding_mask_outer_fct = Image.fromarray(np.uint8(embedding_mask_outer_fct.cpu().detach() * 255) , 'L')
          output_path_mask = (
          export_path + "/" + _.split('/')[-1]
          ).replace("tif", "png")
          embedding_mask_outer_fct.save(output_path_mask, quality=100, subsampling=0)
# export_superpixel_embedding_masks(test_dataset)