# Experiment to see if VISSL helps with segmentation

1) Take a ResNet from torchvision and SSL with some pretext on balloons for few iterations. Export model for Detectron2
2) Take a ResNet from torchvision and SSL with some pretext on balloons for MANY iterations. Export model for Detectron2
3) Take same ResNet from torchvision, train for x epocs on balloons, evaluate on balloons
4) Take the ResNet from SSL(1), train for x epocs on balloons, evaluate on balloons
5) Take the ResNet from SSL(2), train for x epocs on balloons, evaluate on balloons

Ideally, 3 ~= 4 and 4 > 5. But it will be nice if 4 > 5 and 3 > 4. If not . . .

## Results
**SPOILER ALERT**

Performance of torcvhision backbone is less than the performance of model_zoo backbone. Something like AP=80 versus 50. 

The performance of slightly SSL trained torchvision backbone is slightly worse than original one

The performance of intensly SSL trained torchvision backbone is **much worse** than original one. Guesses about why, in blog post.

### Some path setup and download pretrained weights

In [None]:
from pathlib import Path
import shutil
import os

# Setup some constants and paths
MODELS_FOLDER = Path("../models")
BALLONS_FOLDER = Path("../balloon")

ResNet_base_weights = "resnet50-11ad3fa6.pth"
ResNet_light_ssl = "res50_vissl_small_count.torch"
ResNet_large_ssl = "res50_vissl_large_count.torch"

resnet_torchvision_fname = Path(MODELS_FOLDER) / "resnet50-11ad3fa6.pth"
resnet_torchvision_url = "https://download.pytorch.org/models/resnet50-11ad3fa6.pth"
resnet_tv2D2_fname = Path(MODELS_FOLDER) / "resnet50-tv-D2.pkl"
    
# These are rooky numbers. Feel free to bump them! Esp SSL_large_iter_count and D2_iter_count. But on my "SOTA" GTX 980Ti this is it.
SSL_small_iter_count = 5
SSL_large_iter_count = 10 
D2_iter_count = 100

# Get a torchvision ResNet50 checkpoint weights


In [None]:
if not Path(resnet_tv2D2_fname).exists():
    # Download the torchvision weights
    import requests
    with requests.get(resnet_torchvision_url, stream=True) as r:
        with open(resnet_torchvision_fname, 'wb') as f:
            shutil.copyfileobj(r.raw, f)
    # Convert to Detectron2
    ! python convert-torchvision-to-d2.py {resnet_torchvision_fname} {resnet_tv2D2_fname}

# ViSSL training

In [None]:
def generate_balloons_config_for_vissl():
    crt_path = Path.cwd()
    configs_path = crt_path / "configs"/ "config"
    configs_path.mkdir(exist_ok=True, parents=True)
    init_file_path = crt_path / "configs" / "__init__.py"
    with open(init_file_path, "wt") as f:
        f.write("")

    dataset_catalog_js = crt_path / "configs" / "config" / "dataset_catalog.json"
    dataset_catalog_js.unlink(missing_ok=True)

    # We will override settings from command line, later
    !wget -q -O configs/config/quick_1gpu_resnet50_simclr.yaml https://dl.fbaipublicfiles.com/vissl/tutorials/configs/quick_1gpu_resnet50_simclr.yaml

    # https://vissl.readthedocs.io/en/v0.1.5/getting_started.html
    dataset_name = "balloons_train_full"
    json_data = {
            dataset_name: {
                "train": [os.path.join(BALLONS_FOLDER,"trainmodel_name.parent"), os.path.join(BALLONS_FOLDER,"train/via_region_data.json") ],
            }
        }
    from vissl.utils.io import save_file
    save_file(json_data, "configs/config/dataset_catalog.json")

    # Run only once to register the config.
    from vissl.data.dataset_catalog import VisslDatasetCatalog
    if dataset_name not in VisslDatasetCatalog.list():
        VisslDatasetCatalog.register_data(name="balloons_train_full", data_dict=json_data["balloons_train_full"])
    print(f"Known datasets: {VisslDatasetCatalog.list()}")
    print(VisslDatasetCatalog.get("balloons_train_full"))

def del_checkpoint_folder():
    crt_path = Path.cwd()
    checpoint_dir = crt_path / "checkpoints"
    shutil.rmtree(checpoint_dir, ignore_errors=True)


def train_vissl_from_existing_model(model_file, no_epochs=5, data_limit=61, destination_D4_model_name="resnet50_fined_vissl2detectron.torch",
                                    checpoint_folder="./checkpoints", model_folder=MODELS_FOLDER):
    ! python3 run_distributed_engines.py hydra.verbose=true  \
      config.DATA.TRAIN.DATASET_NAMES=[balloons_train_full]  \
      config.DATA.TRAIN.DATA_SOURCES=[disk_folder]           \
      config.DATA.TRAIN.DATA_PATHS=[{BALLONS_FOLDER}] \
      config=quick_1gpu_resnet50_simclr                               \
      config.MODEL.WEIGHTS_INIT.PARAMS_FILE={model_file}   \
      config.MODEL.WEIGHTS_INIT.APPEND_PREFIX="trunk._feature_blocks."   \
      config.MODEL.WEIGHTS_INIT.STATE_DICT_KEY_NAME=''   \
      config.MODEL.TRUNK.RESNETS.DEPTH=50                \
      config.CHECKPOINT.DIR={checpoint_folder}   \
      config.DATA.TRAIN.DATA_LIMIT={data_limit} \
      config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4 \
      config.OPTIMIZER.num_epochs={no_epochs} \
      config.CHECKPOINT.CHECKPOINT_FREQUENCY=20 \
      config.TEST_EVERY_NUM_EPOCH=20 \
      config.DISTRIBUTED.NUM_PROC_PER_NODE=1 \
      config.CHECKPOINT.AUTO_RESUME=false \
      +config.TENSORBOARD_SETUP.USE_TENSORBOARD=false

    chkpoints = Path(checpoint_folder)
    model_names  = sorted(chkpoints.glob("model_final*.torch"))
    assert len(model_names) > 0, "No model was exported. Maybe errors while training?"
    model_name = model_names[0]
    out_vi2d2_name = model_folder / destination_D4_model_name
    print(f"Converting {model_name}  to {out_vi2d2_name}")
    ! python convert_vissl_to_detectron2.py                   \
        --input_model_file {model_name}                       \
        --output_model {out_vi2d2_name}   \
        --weights_type torch                                  \
        --state_dict_key_name classy_state_dict
    # copy the stdout logs
    src_log_file = Path(checpoint_folder) / "stdout.json"
    dst_log_file = out_vi2d2_name.with_suffix(".json")
    shutil.copyfile(src_log_file, dst_log_file)
    
generate_balloons_config_for_vissl()

### Train with VISSL, a Resnet for one epoch

In [None]:
del_checkpoint_folder()
train_vissl_from_existing_model(str(MODELS_FOLDER / ResNet_base_weights), no_epochs=SSL_small_iter_count, data_limit=4, destination_D4_model_name=ResNet_light_ssl)

### Train with VISSL, a Resnet for some epochs

In [None]:
del_checkpoint_folder()
train_vissl_from_existing_model(str(MODELS_FOLDER / ResNet_base_weights), no_epochs=SSL_large_iter_count, data_limit=61, destination_D4_model_name=ResNet_large_ssl)

# Detectron 2

In [None]:
%matplotlib widget
import matplotlib
from matplotlib import pyplot as plt

import torch
import torchvision
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
import shutil
import numpy as np
import os, json, cv2, random
import pprint
from pathlib import Path

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.structures import BoxMode
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader


## Setup balloon dataset for the Detectron2


In [None]:
balloon_dir = BALLONS_FOLDER
detecton_trian_limit = 20

def get_balloon_dicts(img_dir, limit_data=0):
    json_file = os.path.join(img_dir, "via_region_data.json")
    with open(json_file) as f:
        imgs_anns = json.load(f)

    dataset_dicts = []
    for idx, v in enumerate(imgs_anns.values()):
        if idx > limit_data and limit_data > 0:
            print(f"Limiting data loading at idx {idx}")
            break
        record = {}
        
        filename = os.path.join(img_dir, v["filename"])
        height, width = cv2.imread(filename).shape[:2]
        
        record["file_name"] = filename
        record["image_id"] = idx
        record["height"] = height
        record["width"] = width
      
        annos = v["regions"]
        objs = []
        for _, anno in annos.items():
            assert not anno["region_attributes"]
            anno = anno["shape_attributes"]
            px = anno["all_points_x"]
            py = anno["all_points_y"]
            poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
            poly = [p for x in poly for p in x]

            obj = {
                "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": [poly],
                "category_id": 0,
            }
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)
    return dataset_dicts

for d in ["train", "val"]:
    try:
        name = "balloon_" + d
        DatasetCatalog.get(name)
        del DatasetCatalog[name]
    except Exception as e:
        # print(f"Exception while retrieving {name}: {e}")
        pass
    
for d in ["train", "val"]:
    DatasetCatalog.register("balloon_" + d, lambda d=d: get_balloon_dicts(os.path.join(balloon_dir, d), detecton_trian_limit if d == "train" else -1  ))  
    MetadataCatalog.get("balloon_" + d).set(thing_classes=["balloon"])

print(MetadataCatalog.get("balloon_train"))
print(MetadataCatalog.get("balloon_val"))


In [None]:
def get_detectron2_configuration(batch_size=2, num_iterations=4):
    # Start from known ResNet segmentation setup:
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
    cfg.DATASETS.TRAIN = ("balloon_train",)
    cfg.DATASETS.TEST = ("balloon_val",)
    cfg.DATALOADER.NUM_WORKERS = 4
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo
    cfg.SOLVER.IMS_PER_BATCH = batch_size
    cfg.SOLVER.BASE_LR = 0.0025  # pick a good LR
    cfg.SOLVER.MAX_ITER = num_iterations    # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
    cfg.SOLVER.STEPS = []        # do not decay learning rate
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset (default: 512)
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
    return cfg

def clear_output_folders(out_dir):
    shutil.rmtree(out_dir, ignore_errors=True)
    os.makedirs(out_dir, exist_ok=True)
    
def evaluate_model_on_balloons(cfg, predictor=None, model_file="model_final.pth"):
    print(f"Evaluating {cfg.OUTPUT_DIR} {model_file}")
    meta_for_val = MetadataCatalog.get('balloon_val')
    print(f"Metadata for balloon_val dataset: {meta_for_val}")
    cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, model_file)  
    del_at_end = False
    if predictor is None:
        predictor = DefaultPredictor(cfg)
        del_at_end = True
    evaluator = COCOEvaluator("balloon_val", output_dir=MODELS_FOLDER, allow_cached_coco=False)
    val_loader = build_detection_test_loader(cfg, "balloon_val")
    inference_on_dataset(predictor.model, val_loader, evaluator)
    if del_at_end:
        del predictor
    del evaluator

## Finetune genuine Detectron2 ResNet and evaluate on balloons

In [None]:
cfg = get_detectron2_configuration(2, D2_iter_count)
clear_output_folders(cfg.OUTPUT_DIR)
# pprint.pprint(cfg)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()
evaluate_model_on_balloons(cfg, trainer)
del trainer

## Finetune ResNet torchvision model and evaluate on Balloons

1) Download the torchvision model
2) Convert to Detectron2
3) Load, on top of default pretrained config/weights.

Steps 1) and 2) were done in the beginning of this notebook.

Check https://github.com/facebookresearch/detectron2/blob/main/tools/convert-torchvision-to-d2.py

In [None]:
cfg = get_detectron2_configuration(2, D2_iter_count)
clear_output_folders(cfg.OUTPUT_DIR)

cfg.MODEL.PIXEL_MEAN = [123.675, 116.280, 103.530]
cfg.MODEL.PIXEL_STD = [58.395, 57.120, 57.375]
cfg.MODEL.RESNETS.STRIDE_IN_1X1 = False
cfg.INPUT.FORMAT = "RGB"

trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
print("Loading weigths on top of existing config")
trainer.checkpointer.load(str(resnet_tv2D2_fname))
trainer.train()
evaluate_model_on_balloons(cfg, trainer)
del trainer

## Finetune on top of lightly trained SSL model and evaluate

In [None]:
cfg = get_detectron2_configuration(2, D2_iter_count)
clear_output_folders(cfg.OUTPUT_DIR)

cfg.MODEL.PIXEL_MEAN = [123.675, 116.280, 103.530]
cfg.MODEL.PIXEL_STD = [58.395, 57.120, 57.375]
cfg.MODEL.RESNETS.STRIDE_IN_1X1 = False
cfg.INPUT.FORMAT = "RGB"

trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
print("Loading weigths on top of existing config")
trainer.checkpointer.load(str(MODELS_FOLDER / ResNet_light_ssl))
trainer.train()
evaluate_model_on_balloons(cfg, trainer)
del trainer

## Finetune on top of heavy trained SSL model and evaluate

In [None]:
cfg = get_detectron2_configuration(2, D2_iter_count)
clear_output_folders(cfg.OUTPUT_DIR)

cfg.MODEL.PIXEL_MEAN = [123.675, 116.280, 103.530]
cfg.MODEL.PIXEL_STD = [58.395, 57.120, 57.375]
cfg.MODEL.RESNETS.STRIDE_IN_1X1 = False
cfg.INPUT.FORMAT = "RGB"

trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
print("Loading weigths on top of existing config")
trainer.checkpointer.load(str(MODELS_FOLDER / ResNet_large_ssl))
trainer.train()
evaluate_model_on_balloons(cfg, trainer)
del trainer