# TODO - describe this file

In [None]:
%pip install pillow
%pip install numpy<2.0
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
%pip install matplotlib
%pip install pycocotools
%pip install gdown

from IPython.display import clear_output

clear_output(wait=False)

print("ALL DEPENDENCIES INSTALLED")

In [35]:
# IMPORTS
import copy
import io
import json
import math
import os
import random
import shutil
import sys
import time
from contextlib import redirect_stdout
from typing import Optional, Any

# Third-Party Libraries
import gdown
import numpy as np
import PIL.Image

from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

#Torch
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset

# Torchvision
import torchvision.models.detection as tv_detection
import torchvision.transforms.v2 as v2
import torchvision.tv_tensors as tv_tensors




In [36]:
# DATASET UTILS
def load_annotations(file: str) -> dict[str, Any]:
    with open(file, 'r') as annot_file:
        dataset = json.load(annot_file)
    return dataset

class AleketDataset(Dataset):

    NUM_TO_CLASSES = {'placeholder': 0, 'healthy': 1, 'nec': 2}
    CLASSES_TO_NUM = {0: 'placeholder', 1: 'healthy', 2: 'nec'}

    def __init__(self,
                 dataset_dir: str,
                 transforms: Optional[nn.Module] = None) -> None:
        self.img_dir = f"{dataset_dir}/imgs"
        self.dataset = load_annotations(f"{dataset_dir}/dataset.json")

        self.default_transforms = v2.Compose([v2.ToDtype(torch.float32, scale=True)])
        self.train = True
        self.transforms = transforms

    def __len__(self):
        return len(self.dataset["imgs"])

    def __getitem__(self, idx: int):
        img_path = f"{self.img_dir}/{self.dataset['imgs'][idx]}.jpeg"
        img = PIL.Image.open(img_path).convert("RGB")

        annots = self.dataset["annotations"][idx]
        labels, bboxes = annots["category_id"], annots["boxes"]

        img = tv_tensors.Image(img, dtype=torch.uint8)

        img = img[:3, :, :]

        wt = img.shape[-1]
        ht = img.shape[-2]

        labels = torch.as_tensor(labels)
        bboxes = tv_tensors.BoundingBoxes(bboxes, format="XYXY", canvas_size=(wt, ht))

        if self.train and self.transforms is not None:
            img, bboxes, labels = self.transforms(img, bboxes, labels)
            
        area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
        iscrowd = torch.zeros((bboxes.shape[0],), dtype=torch.int64)

        img, bboxes, labels = self.default_transforms(img, bboxes, labels)
        target = {"boxes": bboxes, "labels": labels, "area": area, "image_id": idx, "iscrowd": iscrowd}

        return img, target

In [37]:
# COCO METRICS UTILS

def convert_to_coco_api(dataset: Dataset):

    ds = {"images": [], "categories": [], "annotations": []}
    categories = set()
    ann_id = 1

    for idx in range(len(dataset)):
        img, targets = dataset[idx]
        img_id = targets["image_id"]

        img_entry = {}
        img_entry["id"] = img_id
        img_entry["height"] = img.shape[-2]
        img_entry["width"] = img.shape[-1]

        ds["images"].append(img_entry)

        bboxes = targets["boxes"].clone()
        bboxes[:, 2:] -= bboxes[:, :2]  # xyxy to xywh (coco format)

        bboxes = bboxes.tolist()
        labels = targets["labels"].tolist()
        areas = targets["area"].tolist()
        iscrowd = targets["iscrowd"].tolist()

        num_objs = len(bboxes)

        for i in range(num_objs):
            ann = {"image_id": img_id, "bbox": bboxes[i],
                   "category_id": labels[i], "area": areas[i],
                   "iscrowd": iscrowd[i], "id": ann_id}
            categories.add(labels[i])
            ds["annotations"].append(ann)
            ann_id += 1

    ds["categories"] = [{"id": i} for i in sorted(categories)]  # TODO add names

    with redirect_stdout(io.StringIO()):
        coco_ds = COCO()
        coco_ds.dataset = ds
        coco_ds.createIndex()
    return coco_ds


def stats_dict(stats: np.ndarray):
    #  According to https://cocodataset.org/#detection-eval
    return {
        "AP": stats[0],
        "AP@0.5": stats[1],
        "AP@0.75": stats[2],
        "AP small": stats[3],
        "AP medium": stats[4],
        "AP large": stats[5],
        "AR max=1": stats[6],
        "AR max=10": stats[7],
        "AR max=100": stats[8],
        "AR small":stats[9],
        "AR medium":stats[10],
        "AR large":stats[11]
    }


class CocoEvaluator:
    def __init__(self, gt_dataset):
        if isinstance(gt_dataset, Dataset):
            gt_dataset = convert_to_coco_api(gt_dataset)

        self.coco_gt = copy.deepcopy(gt_dataset)
        self.coco_dt = []

        self.img_ids = set()

    def append(self, predictions):
        for image_id, prediction in predictions.items():
            if len(prediction) == 0:
                continue
            if image_id in self.img_ids:
                raise Exception(f"image with an id: {image_id} is already got predicted, ids must be unique")
            self.img_ids.add(image_id)

            boxes = prediction["boxes"].clone()
            boxes[:, 2:] -= boxes[:, :2]
            boxes = boxes.tolist()
            scores = prediction["scores"].tolist()
            labels = prediction["labels"].tolist()

            self.coco_dt.extend([
                {
                    "image_id": image_id,
                    "category_id": labels[k],
                    "bbox": box,
                    "score": scores[k],
                }
                for k, box in enumerate(boxes)
            ]
            )

    def eval(self):
        if not self.coco_dt:
            print("NO PREDICTIONS")
            return stats_dict(np.zeros(12))
        with redirect_stdout(io.StringIO()):
            coco_dt = COCO.loadRes(self.coco_gt, self.coco_dt)  # type: ignore
            coco = COCOeval(self.coco_gt, coco_dt, iouType="bbox")
            coco.evaluate()
            coco.accumulate()
        coco.summarize()
        return stats_dict(coco.stats)


In [46]:
# TRAIN AND EVAL UTILS
def log_metric(
    batch: int,
    total_batches: int,
    time_passed: Optional[float] = None,
    losses_dict: Optional[dict] = None,
    total_loss: Optional[float] = None,
):

    td = f" time: {time_passed}" if time_passed is not None else ""
    loss_msg = f"Loss: {total_loss} " if total_loss is not None else ""
    loss_dict = (
        f'loss_classifier: {losses_dict["loss_classifier"]} loss_box_reg: {losses_dict["loss_box_reg"]} loss_objectness: {losses_dict["loss_objectness"]} loss_rpn_box_reg: {losses_dict["loss_rpn_box_reg"]}'
        if losses_dict is not None
        else ""
    )

    print(f"[{batch}/{total_batches}] " + loss_msg + loss_dict + td)


def filter_predictions_by_conf(predictions: list, conf_thresh: float):
    filtered_predictions = []
    for prediction in predictions:
        boxes = prediction["boxes"]
        labels = prediction["labels"]
        scores = prediction["scores"]

        keep_indices = torch.where(scores > conf_thresh)[0]

        filtered_boxes = boxes[keep_indices]
        filtered_labels = labels[keep_indices]
        filtered_scores = scores[keep_indices]

        filtered_prediction = {
            "boxes": filtered_boxes,
            "labels": filtered_labels,
            "scores": filtered_scores,
        }
        filtered_predictions.append(filtered_prediction)
    return filtered_predictions


def train_one_epoch(
    model: tv_detection.FasterRCNN,
    optimizer: torch.optim.Optimizer,
    dataloader: DataLoader,
    device: str,
    epoch: int,
    printFreq: int = 1,
) -> float:
    model.train()
    size = len(dataloader)
    loss_values = torch.zeros(size, dtype=torch.float32)
    t = time.time()

    for batch_num, (imgs, targets) in enumerate(dataloader):
        imgs = [img.to(device) for img in imgs]
        targets = [
            {
                k: v.to(device) if isinstance(v, torch.Tensor) else v
                for k, v in t.items()
            }
            for t in targets
        ]  # {bboxes: ..., labels:... }

        losses = model(imgs, targets)
        loss = sum(loss for loss in losses.values())
        loss_values[batch_num] = loss.item()
        if not math.isfinite(loss):
            print(f"Loss is {loss.item()}, stopping training")
            print(losses)
            sys.exit(1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if printFreq != 0 and batch_num % printFreq == 0:
            t = time.time() - t
            log_metric(
                batch_num,
                size,
                time_passed=t,
                losses_dict=losses,
                total_loss=loss.item(),
            )
            t = time.time()

    epoch_loss = loss_values.mean().item()

    print(f"Epoch {epoch} loss: {epoch_loss}")
    return epoch_loss


def evaluate(
    model: tv_detection.FasterRCNN,
    dataloader: DataLoader,
    device: str,
    conf_thresh: float = 0.1,
    printFreq: int = 1,
) -> Any:

    size = len(dataloader)
    model.eval()

    t = time.time()
    coco_eval = CocoEvaluator(dataloader.dataset)

    for batch_num, (imgs, targets) in enumerate(dataloader):

        imgs = [img.to(device) for img in imgs]
        targets = [
            {
                k: v.to(device) if isinstance(v, torch.Tensor) else v
                for k, v in t.items()
            }
            for t in targets
        ]

        with torch.no_grad():
            predictions = model(imgs)
            predictions = filter_predictions_by_conf(predictions, conf_thresh)
            res = {
                target["image_id"]: output
                for target, output in zip(targets, predictions)
            }
            coco_eval.append(res)

        if printFreq != 0 and batch_num % printFreq == 0:
            t = time.time() - t
            log_metric(batch_num, size, time_passed=t)
            t = time.time()

    stats = coco_eval.eval()
    return stats

In [None]:
# PATH VARIABLES
def get_dataset(save_dir: str):
    patched_dataset_gdrive_id = ""  # FIXME
    if not os.path.exists(save_dir):
        gdown.download(id=patched_dataset_gdrive_id, output="_temp_.zip")
        shutil.unpack_archive("_temp_.zip", save_dir)
        os.remove("_temp_.zip")
    print(f"DATASET {save_dir} LOADED")
    return save_dir


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"USING {DEVICE}")
DATASET_ROOT = get_dataset("dataset_patched/")
MODEL_DIR = "result/"
LAST_MODEL_PATH = f"{MODEL_DIR}/last_model.pth"
BEST_MODEL_PATH = f"{MODEL_DIR}/best_model.pth"

In [44]:
# TRAIN SETTINGS
SEED = 1

LOAD_BEST = False
BATCH_SIZE = 8
EPOCHS = 500
TEST_FRACTION = 0.2
DATASET_FRACTION = 1
DATALOADER_WORKERS = 8
LR = 0.001

TRANSFORMS = v2.Compose(
    [
        v2.RandomHorizontalFlip(0.5),
        v2.RandomVerticalFlip(0.5),
        v2.RandomPerspective(0.1, p=0.1),
    ]
)


def get_model(num_classes) -> tv_detection.FasterRCNN:
    if LOAD_BEST and os.path.exists(BEST_MODEL_PATH):
        model = torch.load(BEST_MODEL_PATH, weights_only=False)
    else:
        model = tv_detection.fasterrcnn_resnet50_fpn_v2(
            weights="DEFAULT"
        )
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = (
            tv_detection.faster_rcnn.FastRCNNPredictor(
                in_features, num_classes
            )
        )
    return model.to(DEVICE)

In [None]:
# main
def set_seed():
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    pass


def save_model(model, is_best):
    os.makedirs(MODEL_DIR, exist_ok=True)
    torch.save(model, LAST_MODEL_PATH)
    if is_best:
        torch.save(model, BEST_MODEL_PATH)


def collate_fn(batch):
    return tuple(zip(*batch))


def main():

    set_seed()
    model = get_model(3)

    dataset = AleketDataset(DATASET_ROOT, transforms=TRANSFORMS)

    size = int(len(dataset) * DATASET_FRACTION)
    indices = torch.randperm(size).tolist()

    test_size = int(size * TEST_FRACTION)
    train_dataset = Subset(dataset, indices[:-test_size])
    val_dataset = Subset(dataset, indices[-test_size:])

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=collate_fn,
        num_workers=DATALOADER_WORKERS,
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=collate_fn,
        num_workers=DATALOADER_WORKERS,
    )

    optimizer = optim.AdamW(params=model.parameters(), lr=LR)
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=1, end_factor=0.1, total_iters=15
    )
    best_ap = -1
    for epoch in range(EPOCHS):
        print(f"Epoch: {epoch}/{EPOCHS}\tLearning rate: {lr_scheduler.get_last_lr()}")
        dataset.train = True
        train_one_epoch(model, optimizer, train_dataloader, DEVICE, epoch, printFreq=20)

        print(f"Evaluating: ")
        dataset.train = False
        eval_stats = evaluate(
            model, val_dataloader, DEVICE, conf_thresh=0, printFreq=100
        )
        ap = eval_stats["AP"]
        save_model(model, ap > best_ap)
        best_ap = max(best_ap, ap)
        print(f"AP: {ap}")
        lr_scheduler.step()


if __name__ == "__main__":
    main()