In [None]:
import os
import sys
import copy
import collections
from typing import Tuple
from pathlib import Path
from natsort import natsorted

import torch
from torch import nn
from torchvision import models
import pytorch_lightning as pl


import numpy as np
from PIL import Image
from torch.utils.data import Dataset, random_split, DataLoader, dataloader

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import cv2
from matplotlib import pyplot as plt

In [None]:
DATA_DIR = "./data/"

IMG_SIZE = 224

VAL_SIZE = 0.3
TEST_SIZE = 0.2

BATCH_SIZE = 64
LEARNING_RATE = 2e-4
EPOCHS = 50
VERSION = 0

CHECKPOINT = "./checkpoints/sot_0.ckpt"

# Utils

In [None]:
def visualize(image):
    plt.figure(figsize=(30, 30))
    plt.axis("off")
    plt.imshow(image)
    plt.show()


def plot_examples(images, bboxes=None):
    fig = plt.figure(figsize=(30, 30))
    columns = 4
    rows = 4

    for i in range(1, len(images) + 1):
        if bboxes is not None:
            img = visualize_bbox(images[i - 1], bboxes[i - 1], class_name="Elon")
        else:
            img = images[i - 1]
        fig.add_subplot(rows, columns, i)
        plt.imshow(img.astype(np.uint8))

    plt.show()


# From https://albumentations.ai/docs/examples/example_bboxes/
def visualize_bbox(img, bbox, class_name, color=(255, 0, 0), thickness=5):
    """Visualizes a single bounding box on the image"""
    x_min, y_min, x_max, y_max = map(int, bbox)
    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color, thickness)
    return img


# Data

In [None]:
Rectangle = collections.namedtuple("Rectangle", ["x", "y", "width", "height"])
Point = collections.namedtuple("Point", ["x", "y"])
Polygon = collections.namedtuple("Polygon", ["points"])

In [None]:
def convert_region(region, to):

    if to == "rectangle":

        if isinstance(region, Rectangle):
            return copy.copy(region)
        elif isinstance(region, Polygon):
            top = sys.float_info.max
            bottom = sys.float_info.min
            left = sys.float_info.max
            right = sys.float_info.min

            for point in region.points:
                top = min(top, point.y)
                bottom = max(bottom, point.y)
                left = min(left, point.x)
                right = max(right, point.x)

            return Rectangle(left, top, right - left, bottom - top)

        else:
            return None
    if to == "polygon":

        if isinstance(region, Rectangle):
            points = []
            points.append((region.x, region.y))
            points.append((region.x + region.width, region.y))
            points.append((region.x + region.width, region.y + region.height))
            points.append((region.x, region.y + region.height))
            return Polygon(points)

        elif isinstance(region, Polygon):
            return copy.copy(region)
        else:
            return None

    return None

In [None]:
def convert_to_bbox(bbox):
    if len(bbox) == 4:
        x1, y1, w, h = bbox
        x2, y2 = x1 + w, y1 + h
        return [x1, y1, x2, y2]

    elif len(bbox) > 4:
        pts = []
        for idx in range(0, len(bbox), 2):
            pts.append(Point(bbox[idx], bbox[idx + 1]))
        poly = Polygon(pts)
        rect = convert_region(poly, "rectangle")
        x1, y1, w, h = rect.x, rect.y, rect.width, rect.height
        x2, y2 = x1 + w, y1 + h
        return [x1, y1, x2, y2]

In [None]:
class Data(Dataset):
    def __init__(self, data_dir: Path = Path(DATA_DIR)):

        objects = [
            obj
            for obj in list(data_dir.glob("*"))
            if obj.is_dir() and not str(obj.name).startswith(".")
        ]

        data = []

        for obj in objects:
            img_path = obj / "color"
            annot_path = obj / "groundtruth.txt"

            images = natsorted(list(img_path.glob("*")))

            with open(str(annot_path), "r") as fl:
                annots = fl.read()
                annots = annots.split("\n")
                annots = [
                    [float(coord) for coord in annot.split(",")]
                    for annot in annots
                    if annot != ""
                ]

            annots = list(map(convert_to_bbox, annots))

            data += list(
                zip(
                    images[:-1],
                    annots[:-1],
                    images[1:],
                    annots[1:],
                    [obj.name] * (len(images) - 1),
                )
            )

        self.data = data
        self.transform_x = A.Compose(
            [
                A.RandomCropNearBBox(always_apply=True),
                A.Resize(IMG_SIZE, IMG_SIZE),
                ToTensorV2(),
            ],
            p=1.0,
        )

        self.transform_y = A.Compose(
            [
                A.RandomSizedBBoxSafeCrop(IMG_SIZE, IMG_SIZE),
                ToTensorV2(),
            ],
            p=1.0,
            bbox_params=A.BboxParams(
                format="pascal_voc", label_fields=[], min_visibility=0.3
            ),
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_x, bbox_x, image_y, bbox_y, obj = self.data[index]

        image_x = np.array(Image.open(image_x))
        bbox_x = np.array(bbox_x, dtype=np.float32)

        image_y = np.array(Image.open(image_y))
        bbox_y = np.array(bbox_y, dtype=np.float32)

        try:
            if self.transform_x:
                transformed = self.transform_x(image=image_x, cropping_bbox=bbox_x)
                image_x = transformed["image"]

            if self.transform_y:
                transformed = self.transform_y(image=image_y, bboxes=[bbox_y])
                image_y = transformed["image"]
                bbox_y = torch.tensor(transformed["bboxes"][0]).float()

            return {
                "previous_frame": image_x.float(),  # previous frame
                "current_frame": image_y.float(),  # current frame
                "bbox": bbox_y,  # target
                "name": obj,  # object name
            }
        except:
            return None

In [None]:
def collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return dataloader.default_collate(batch)


def reverse_transform(
    img: torch.Tensor,
    bbox: torch.Tensor,
    width: int = None,
    height: int = None,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Used to convert back from the processed image size and bounding box to original size and corresponding bounding box.
    """
    if width == None:
        width = img.shape[1]
    if height == None:
        height = img.shape[2]

    img = img.permute(1, 2, 0).numpy()
    bbox = bbox.numpy()

    transform = A.Compose(
        [
            A.Resize(width, height),
        ],
        p=1.0,
        bbox_params=A.BboxParams(
            format="pascal_voc", label_fields=[], min_visibility=0.3
        ),
    )

    transformed = transform(image=img, bboxes=[bbox])

    return transformed["image"], transformed["bboxes"][0]

# Model

In [None]:
class SOTModel(pl.LightningModule):
    def __init__(self, lr=LEARNING_RATE):
        super(SOTModel, self).__init__()

        self.x_cnn = nn.Sequential(
            *(list(models.resnet34(pretrained=True).children())[:-1])
        )
        self.y_cnn = nn.Sequential(
            *(list(models.resnet34(pretrained=True).children())[:-1])
        )

        self.flatten = nn.Flatten()

        self.fc = nn.Sequential(
            nn.Linear(512 * 2, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 4),
        )

        self.lr = lr
        self.loss = nn.MSELoss()

    def forward(self, previous_frame, current_frame):
        x_feature = self.x_cnn(previous_frame)
        y_feature = self.y_cnn(current_frame)

        x_feature = self.flatten(x_feature)
        y_feature = self.flatten(y_feature)

        features = torch.cat([x_feature, y_feature], dim=1)

        return self.sigmoid_scale(self.fc(features), 0, IMG_SIZE)

    def training_step(self, batch, batch_idx):
        target = batch["bbox"]

        out = self(batch["previous_frame"], batch["current_frame"])
        return self.loss(out, target)

    def validation_step(self, batch, batch_idx):
        target = batch["bbox"]

        out = self(batch["previous_frame"], batch["current_frame"])
        self.log("val_mse", self.loss(out, target), on_step=True, on_epoch=True)

    def test_step(self, batch, batch_idx):
        target = batch["bbox"]

        out = self(batch["previous_frame"], batch["current_frame"])
        self.log("test_mse", self.loss(out, target), on_step=True, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.lr)

    @staticmethod
    def sigmoid_scale(x, lo, hi):
        return torch.sigmoid(x) * (hi - lo) + lo

In [None]:
def eval():
    model = SOTModel()
    model = model.load_from_checkpoint(CHECKPOINT).eval()
    print("Model loaded successfully...")

    ds = Data()
    split_idx = int(len(ds) * VAL_SIZE)
    indices = list(range(len(ds)))

    train_indices, val_indices = indices[:split_idx], indices[split_idx:]
    train_sampler, val_sampler = SubsetRandomSampler(
        train_indices
    ), SubsetRandomSampler(val_indices)

    val_dl = DataLoader(
        ds,
        batch_size=8,
        sampler=val_sampler,
        num_workers=4,
        collate_fn=collate,
        shuffle=False,
    )

    val_batch = next(iter(val_dl))

    with torch.no_grad():
        out = model(val_batch["previous_frame"], val_batch["current_frame"])

    org_imgs = []
    pred_imgs = []
    org_bboxes = []
    pred_bboxes = []

    imgs = []
    bboxes = []

    for idx in range(len(out)):
        org_img, org_bbox = reverse_transform(
            val_batch["current_frame"][idx],
            val_batch["bbox"][idx],
            480,
            720,
        )

        pred_img, pred_bbox = reverse_transform(
            val_batch["current_frame"][idx],
            out[idx],
            480,
            720,
        )

        imgs += [org_img, pred_img]
        bboxes += [org_bbox, pred_bbox]

    plot_examples(imgs, bboxes)

# Execution

In [None]:

ds = Data()
val_sz = int(len(ds) * VAL_SIZE)
test_sz = int(len(ds) * TEST_SIZE)
train_sz = len(ds) - val_sz - test_sz

train_ds, val_ds, test_ds = random_split(ds, [train_sz, val_sz, test_sz])

train_dl = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=os.cpu_count(),
    collate_fn=collate,
)
val_dl = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    num_workers=os.cpu_count(),
    collate_fn=collate,
)
test_dl = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    num_workers=os.cpu_count(),
    collate_fn=collate,
)

model = SOTModel()
trainer = pl.Trainer(
    default_root_dir="logs",
    gpus=(1 if torch.cuda.is_available() else 0),
    max_epochs=EPOCHS,
    precision=16,
    logger=pl.loggers.TensorBoardLogger("logs/", name="sot", version=VERSION),
)

trainer.fit(
    model,
    train_dataloader=train_dl,
    val_dataloaders=val_dl,
)
trainer.test(test_dataloaders=test_dl)

trainer.save_checkpoint(f"checkpoints/sot_{VERSION}.ckpt")