# Instance Segmentation
In this notebook generates all the test data required for the test with instance segmentation that requires test data.

In [None]:
# Imports
import pickle
import shutil

from icevision.all import *
import icedata

from icevision_dashboards.data import InstanceSegmentationResultsDataset

## Setup data

In [None]:
# Load the Fridge dataset
path = icedata.pennfudan.load_data()
# parse the data
parser = icedata.pennfudan.parser(data_dir=path)
# we just want to have a look at the data so we don't split the data
train_records, valid_records = parser.parse()
# get the class map
class_map = train_records[0].detection.class_map

## Train model for data generation

In [None]:
# Define transforms
train_tfms = tfms.A.Adapter(
    [*tfms.A.aug_tfms(size=384, presize=512), tfms.A.Normalize()]
)
valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(384), tfms.A.Normalize()])

In [None]:
train_ds = Dataset(train_records, train_tfms)
valid_ds = Dataset(valid_records, valid_tfms)

In [None]:
model_type = models.mmdet.mask_rcnn
backbone = model_type.backbones.resnet50_fpn_1x(pretrained=True)

In [None]:
model = model_type.model(backbone=backbone, num_classes=len(class_map))
metrics = [COCOMetric(metric_type=COCOMetricType.mask)]

In [None]:
train_dl = model_type.train_dl(train_ds, batch_size=2, num_workers=4, shuffle=True)
valid_dl = model_type.valid_dl(valid_ds, batch_size=2, num_workers=4, shuffle=False)

In [None]:
class LightModel(model_type.lightning.ModelAdapter):
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-4)

light_model = LightModel(model, metrics=metrics)

In [None]:
trainer = pl.Trainer(max_epochs=15, gpus=1)
trainer.fit(light_model, train_dl, valid_dl)

## Create preds and sampels files

In [None]:
_ = model.to("cuda:0")

In [None]:
_ = light_model.to("cuda:0")

In [None]:
light_model.device

In [None]:
samples, losses_stats = model_type.interp.get_losses(model, valid_ds)

In [None]:
dl = model_type.interp.infer_dl(valid_ds, batch_size=2)
preds = model_type.interp.predict_from_dl(model=model, infer_dl=dl, keep_images=True)

In [None]:
def get_updated_mask(pred):
    stacked_mask_array = np.stack([entry.data[0,:,:] for entry in pred.detection.masks])
    new_mask_array = MaskArray(stacked_mask_array)
    new_mask = new_mask_array.to_erles(None, None)
    return new_mask

In [None]:
def get_compnent(components, component_type):
    for component in components:
        if isinstance(component, component_type):
            return component
    else:
        return None

In [None]:
def remove_image(components):
    for entry in list(components):
        if isinstance(entry, FilepathRecordComponent) or isinstance(entry, ImageRecordComponent):
            entry.img = None

In [None]:
def cleanup_preds(preds):
    new_preds = deepcopy(preds)
    for pred in new_preds:
        new_mask = get_updated_mask(pred)
        get_compnent(list(pred.pred.detection.components), InstanceMasksRecordComponent).set_masks([new_mask])
        remove_image(pred.ground_truth.common.components)
        remove_image(pred.pred.common.components)
    return new_preds

In [None]:
# remove all the data not required (image) and convert the masks to a smaller data fromat
clean_preds = cleanup_preds(preds)
pickle.dump(clean_preds, open("test_data/instance_segmentation_preds.pkl", "wb"))
preds = pickle.load(open("test_data/instance_segmentation_preds.pkl", "rb"))

In [None]:
def cleanup_samples(samples):
    new_samples = deepcopy(samples)
    for sample in new_samples:
        remove_image(sample.components)
    return new_samples

In [None]:
# remove all the data not required (image)
clean_samples = cleanup_samples(samples)
pickle.dump(clean_samples, open("test_data/instance_segmentation_samples.pkl", "wb"))
samples = pickle.load(open("test_data/instance_segmentation_samples.pkl", "rb"))

In [None]:
# first create a dataset that can be consumed by the analysis dashboard
valid_result_ds = InstanceSegmentationResultsDataset.init_from_preds_and_samples(preds, samples, class_map=class_map)
valid_result_ds.save("test_data/instance_segmentation_result_ds_valid.dat")

## Cleanup

In [None]:
shutil.rmtree("checkpoints/")
shutil.rmtree("lightning_logs/")