In [18]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())

import torchvision
print(torchvision.__version__)


2.3.0+cu121
True
0.18.0+cu121


In [19]:
from os import listdir
from pathlib import Path

def listdir_nohidden(path: Path | str) -> list[str]:
    """
    List files in a directory excluding hidden files.
    """
    return [file for file in listdir(path) if not file.startswith(".")]

def lazy_stoi(s: str) -> int:
    """
    Convert a string to an integer.
    """
    return int(s) if s.replace(",", "").isdigit() else s

def lazy_stof(s: str) -> float:
    """
    Convert a string to a float.
    """
    return float(s) if s.replace(",", "").replace(".", "", 1).isdigit() else s


In [20]:
import os
import requests
from pathlib import Path
from zipfile import ZipFile
from typing import Optional

import gdown
from tqdm import tqdm

def download_file_from_gdrive_gdown(
        url: str,
        file_name: Path | str,
        overwrite: bool = False,
        postprocess: callable = None
    ):
    print(f"Downloading {url} to {file_name}")
    # check if file exists or has been unzipped
    is_file = os.path.exists(file_name)
    is_unzipped = os.path.exists(file_name.removesuffix('.zip'))
    if not overwrite and (is_file or is_unzipped):
        raise FileExistsError(f"{file_name} exists.")

    # Make parent directory
    if not os.path.isdir(os.path.dirname(file_name)):
        os.makedirs(os.path.dirname(file_name))

    gdown.download(url, output=str(file_name))

    if postprocess:
        try:
            postprocess(file_name)
        except Exception as e:
            os.remove(file_name)
            raise e

def download_folder_from_gdrive_gdown(
    url: str,
    file_name: Path | str,
    overwrite: bool = False,
    postprocess: callable = None
):
    print(f"Downloading {url} to {file_name}")
    # Check if folder exists
    if not overwrite and os.path.isdir(file_name) and listdir_nohidden(file_name):
        raise FileExistsError(f"{file_name} exists and is not empty.")

    # Check if folder exists as a file
    if not overwrite and os.path.exists(file_name):
        raise FileExistsError(f"{file_name} exists and is not a directory.")

    # Make parent directory
    if not os.path.isdir(file_name):
        os.makedirs(file_name)

    gdown.download_folder(url, output=str(file_name))

    if postprocess:
        try:
            postprocess(file_name)
        except Exception as e:
            os.removedirs(file_name)
            raise e

def download_file_from_dropbox(
    url: str, file_name: Path | str, postprocess: Optional[callable] = None,
    dry_run: bool = False
):
    print(f"Downloading {url} to {file_name}.")

    if dry_run:
        return

    # Make parent directory
    parent = os.path.dirname(file_name)
    if not os.path.isdir(parent):
        os.makedirs(parent)

    # headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
    r = requests.get(url, stream=True)

    total_size = int(r.headers.get('content-length', 0))
    block_size = 1024

    with tqdm(total=total_size, unit='B', unit_scale=True) as t:
        with open(file_name, 'wb') as f:
            for chunk in r.iter_content(chunk_size=block_size):
                if chunk:
                    t.update(len(chunk))
                    f.write(chunk)

    if postprocess:
        try:
            postprocess(file_name)
        except Exception as e:
            os.remove(file_name)
            raise e

def unzip_contents(path: Path | str):
    # Check if we are directly given a zip file
    if os.path.exists(path) and os.path.splitext(path)[1] == ".zip":
        print(f"Unzipping {path}")
        with ZipFile(path, "r") as zip_ref:
            zip_ref.extractall(os.path.dirname(path))
        os.remove(path)

        unzip_contents(path.removesuffix(".zip"))

    # Otherwise, search directory for zip files
    elif os.path.isdir(path):
        print(f"Searching {path} for zip files...")
        for d in tqdm(listdir_nohidden(path), desc=f"Unzipping {os.path.basename(path)}"):
            zip_path = os.path.join(path, d)
            if zip_path.endswith(".zip"):
                print(f"Unzipping {d}")
                with ZipFile(zip_path, "r") as zip_ref:
                    zip_ref.extractall(path)
                os.remove(zip_path)

            # Recursively unzip the contents
                unzip_contents(path.removesuffix(".zip"))
            elif os.path.isdir(zip_path):
                unzip_contents(zip_path)


In [22]:
import os
import json

from enum import Enum, auto
from zipfile import ZipFile
from pathlib import Path
from typing import Optional



import torch
import torchvision.transforms.functional as F

from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import verify_str_arg
from torch import tensor, Tensor
from torchvision.tv_tensors._bounding_boxes import BoundingBoxes
from torchvision.utils import draw_bounding_boxes

# Download links for the DOTA dataset

DOTA_TRAIN_HBB_URL = (
    "https://www.dropbox.com/scl/fi/k5e9wdfdu4qppyz283nss/"
    "train_hbb.zip?rlkey=wrlr3fpqk8x02r5xzph8uuedk&st=kqaeuv3x&dl=1"
)
DOTA_VAL_HBB_URL = (
    "https://www.dropbox.com/scl/fi/1820a7bhat8b5esv73u6h/"
    "val_hbb.zip?rlkey=0ekae5kjspsq68cbkww8qjkl3&st=3z3ij651&dl=1"
)

DOTA_TRAIN_OBB_URL = (
    "https://www.dropbox.com/scl/fi/zu3p9wqzlu86v0kuu4v0p/"
    "train.zip?rlkey=bwih2x8xd3zpldj7l1s4owy9a&st=zrisaio3&dl=1"
)
DOTA_VAL_OBB_URL = (
    "https://www.dropbox.com/scl/fi/k3d45d22iz1op3gifazw4/"
    "val.zip?rlkey=zv7fgnkf3yqztj93cztzsfsgd&st=ghtu1ntz&dl=1"
)

# Dataset defaults

DEFAULT_DOTA_PATH = Path(os.getcwd()) / "data" / "dota"

DATASET_DICT = {
    ("train", "hbb"): {
        "url": DOTA_TRAIN_HBB_URL,
        "base_dir": os.path.join("hbb", "train_hbb")  # Adjusted to the correct folder name
    },
    ("val", "hbb"): {
        "url": DOTA_VAL_HBB_URL,
        "base_dir": os.path.join("hbb", "val_hbb")  # Make sure this is correct
    },
    ("train", "obb"): {
        "url": DOTA_TRAIN_OBB_URL,
        "base_dir": os.path.join("obb", "train")
    },
    ("val", "obb"): {
        "url": DOTA_VAL_OBB_URL,
        "base_dir": os.path.join("obb", "val") 
    }
}

IMAGES_DIRNAME = "images"
LABELS_DIRNAME = "labels"

class Label(Enum):
    """
    Enum class for the labels in the DOTA dataset. Automatically assigns
    integers to each label.
    """
    large_vehicle = auto()
    small_vehicle = auto()
    plane = auto()
    ship = auto()
    storage_tank = auto()
    baseball_diamond = auto()
    tennis_court = auto()
    basketball_court = auto()
    ground_track_field = auto()
    harbor = auto()
    bridge = auto()
    helicopter = auto()
    roundabout = auto()
    soccer_ball_field = auto()
    swimming_pool = auto()
    container_crane = auto()
    airport = auto()
    helipad = auto()

class Target(dict):
    """
    Dict subclass that mostly acts the same but has a few extra methods for
    convenience and overwrites __str__ and __repr__ to give a more informative
    string representation.
    """
    def __init__(self, boxes = [], labels = [], difficult = [],  **kwargs):
        super().__init__(**kwargs)
        self["boxes"] = boxes
        self["labels"] = labels
        self["difficult"] = difficult

    def __str__(self):
        return json.dumps(
            {
                "attributes": [
                    k for k in self.keys()
                    if k not in ["boxes", "labels", "difficult"]
                ],
                "features": ["boxes", "labels", "difficult"],
                "n_features": len(self['boxes'])
            }
        )

    def __repr__(self):
        return self.__str__()

    def add_attribute(self, key: str, value: str | int | float):
        self[key] = value

    def add_box(self, box: list[float], label: str, difficult: int):
        self["boxes"].append(box)
        self["labels"].append(label)
        self["difficult"].append(difficult)

class DOTA(VisionDataset):
    """
    Class for the DOTA dataset. The dataset is split into train and val
    sets, and the annotations can be either horizontal bounding boxes (hbb)
    or oriented bounding boxes (obb).

    Once downloaded, indexing the dataset will return a tuple of an image
    and a target dictionary. The target dictionary contains the following
    keys:
        - boxes: A BoundingBoxes object containing the bounding boxes.
        - labels: A tensor containing the labels for each bounding box.
        - difficult: A tensor containing the difficulty of each bounding box.
    and any other annotation keys that are present in the target file.
    """
    def __init__(
        self,
        root: Path | str = DEFAULT_DOTA_PATH,
        split: str = "train",
        annotation_type: Optional[str] = "hbb",
        to_tensor: Optional[bool] = True,
        transforms: Optional[callable] = None,
        download: Optional[bool] = True
    ):
        """
        Initialises the DOTA dataset. Gives you either the train or val split
        with either horizontal bounding boxes (hbb) or oriented bounding boxes
        and optional conversion to tensors (required for transforms).

        Args:
            root (Path | str): The root directory to save the dataset. Defaults
                to <current working directory>/data/dota.
            split (str): The split of the dataset to use. Either "train" or
                "val". Defaults to "train".
            annotation_type (str): The type of annotation to use. Either "hbb"
                or "obb". Defaults to "hbb".
            to_tensor (bool): Whether to convert the images and targets to
                tensors. Defaults to True.
            transforms (callable): A callable transform to apply to the images
                and targets. The transform should take in an image and a target
                dictionary and return the transformed image and target. Defaults
                to None. If not none, to_tensor is set to True.
            download (bool): Whether to download the dataset if it doesn't
                exist. Defaults to True.
        """
        super().__init__(root, transforms)

        self.split = verify_str_arg(split, "split", ("train", "val"))
        self.annotation_type = verify_str_arg(
            annotation_type, "annotation_type", ("hbb", "obb")
        )

        # Transforms require the targets to be tensors
        self.to_tensor = to_tensor
        if transforms:
            self.to_tensor = True
        self.transforms = transforms

        # Cannot do transforms or tensors for obb annotations
        if self.annotation_type == "obb":
            self.to_tensor = False
            self.transforms = None

        # Set up the download
        dataset_dict = DATASET_DICT[(self.split, self.annotation_type)]

        self.url = dataset_dict["url"]
        file_path = Path(root) / dataset_dict["base_dir"]

        if download:
            self.download(
                self.url, file_path, postprocess=self.unzip_contents
            )
        if not self.val_files(file_path):
            msg = (
                "The files in the directory are not in the correct format. ",
                "Consider setting download=True to download the files."
            )
            raise ValueError(msg)

        image_dir = file_path / IMAGES_DIRNAME
        label_dir = file_path / LABELS_DIRNAME
        self.images = sorted([
            os.path.join(image_dir, img) for img in listdir_nohidden(image_dir)
        ])
        self.targets = sorted([
            os.path.join(label_dir, tgt) for tgt in listdir_nohidden(label_dir)
        ])

        if len(self.images) != len(self.targets):
            raise ValueError(
                "Number of images and labels do not match.",
                f"Images: {len(self.images)}",
                f"Labels: {len(self.targets)}"
            )

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image and label
        img = Image.open(self.images[idx]).convert("RGB")
        target = self.parse_dota_targets(self.targets[idx])

        # Convert to tensors
        if self.to_tensor:
            img, target = self.to_tensors(img, target)

        # Apply transforms
        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target

    @staticmethod
    def to_tensors(img: Image, target: Target) -> tuple[Tensor, Target]:
        """
        Convert images and targets to tensors.
        """
        img = F.to_tensor(img)
        target = dict(
            boxes=BoundingBoxes(
                target["boxes"],
                format="XYXY",
                canvas_size=img.shape[1:],
                dtype=img.dtype,
                device=img.device
            ),
            labels=tensor([x.value for x  in target["labels"]]),
            difficult=tensor(target["difficult"])
        )
        return img, target

    def draw_bounding_boxes(self, idx: int, width: int = 5) -> Image.Image:
        """
        Draw the bounding boxes on the image at index idx with the given width.
        """
        img, target = self.__getitem__(idx)

        # guarantee we get tensors
        if not self.to_tensor:
            img, target = self.to_tensors(img, target)

        uint8_img = (img * 255).to(torch.uint8)
        return F.to_pil_image(
            draw_bounding_boxes(uint8_img, target["boxes"], width=width)
        )

    @staticmethod
    def val_files(file_path: Path) -> bool:
        """
        Validate that the necessary files are in the correct structure.
        
        Args:
            file_path (Path): The directory path to validate.
            
        Returns:
            bool: True if the validation is successful, False otherwise.
        """
        if not os.path.isdir(file_path):
            print("Directory does not exist:", file_path)
            return False
    
        image_dir = file_path / IMAGES_DIRNAME
        label_dir = file_path / LABELS_DIRNAME
    
        # Check directories exist
        if not os.path.isdir(image_dir) or not os.path.isdir(label_dir):
            print("Required subdirectories do not exist.")
            return False
    
        # Check that all files in images have a corresponding file in labels
        image_files = {os.path.splitext(f)[0]: f for f in os.listdir(image_dir) if f.endswith(".png")}
        label_files = {os.path.splitext(f)[0]: f for f in os.listdir(label_dir) if f.endswith(".txt")}
    
        if not image_files:
            print("No image files found in", image_dir)
            return False
    
        if not label_files:
            print("No label files found in", label_dir)
            return False
    
        # Ensure each image has a corresponding label file
        for base_name in image_files:
            if base_name not in label_files:
                print("Missing label for image:", base_name)
                return False
    
        return True

    @staticmethod
    def unzip_contents(path: Path | str, delete: Optional[bool] = True):
        """
        Directly unzip a single zip file and into a directory with the same name
        """
        dir_path = Path(path).parent
        if os.path.isfile(path) and os.path.splitext(path)[1] == ".zip":
            print(f"Unzipping {path}")
            with ZipFile(path, "r") as zip_ref:
                zip_ref.extractall(dir_path)
            zip_dir = dir_path / listdir_nohidden(dir_path)[0]
            os.rename(zip_dir, path.removesuffix(".zip"))
            if delete:
                os.remove(path)

    @staticmethod
    def download(url: str, path: Path | str, **kwargs: Optional[dict]):
        """
        Download the zip file from dropbox and unzip it into path if path
        doesn't have the correct directory structure.
        """
        if not DOTA.val_files(path):
            download_file_from_dropbox(
                url, f"{path}.zip", **kwargs
            )

    def parse_dota_targets(self, target: str) -> Target:
        """
        Parse the DOTA target file into a Target object.
        """
        res = Target()
        with open(target, "r") as f:
            lines = f.readlines()
            for line in lines:
                ws_split = line.split()
                if len(ws_split) == 1:
                    l0, l1 = ws_split[0].split(":")
                    res.add_attribute(lazy_stof(l0), lazy_stof(l1))
                elif len(ws_split) == 10:
                    if self.annotation_type == "hbb":
                        # if annotation type is hbb, then the box is in the form
                        # [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]
                        box = [float(ws_split[i]) for i in [0, 1, 2, 5]]
                    else:
                        box = [float(x) for x in ws_split[:-2]]
                    label = ws_split[-2].replace("-", "_")
                    label = getattr(Label, label)
                    difficult = int(ws_split[-1])
                    res.add_box(box, label, difficult)
        return res


In [23]:
import torch
from torch.utils.data import DataLoader
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def create_model():
    num_classes = len(Label) + 1  # +1 for the background class
    model = fasterrcnn_resnet50_fpn(pretrained=True)

    # Replace the classifier with a new one that has the right number of classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

model = create_model()
# Setup device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [33]:
import torch
from torchvision.transforms import v2

transforms = v2.Compose([
    v2.ToImage(),
    v2.RandomResizedCrop(300),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomRotation(180),
    v2.SanitizeBoundingBoxes(),
    v2.ToDtype(torch.float32)
])

In [35]:
# Data Loader
def collate_fn(batch):
    return tuple(zip(*batch))
    
train_dataset = DOTA(split='train', annotation_type='hbb', transforms=transforms)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

In [48]:
val_dataset = DOTA(split='val', annotation_type='hbb', transforms=transforms)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

In [37]:
# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

# Learning rate scheduler
from torch.optim.lr_scheduler import StepLR
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

# Training Loop
def train_one_epoch(model, optimizer, data_loader, device):
    model.train()
    running_loss = 0.0
    for images, targets in data_loader:
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        running_loss += losses.item()
    return running_loss / len(data_loader)

# Validation Function
def validate(model, data_loader, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for images, targets in data_loader:
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            running_loss += losses.item()
    return running_loss / len(data_loader)

In [49]:
# Main training and validation loop
num_epochs = 1
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, optimizer, val_loader, device)
    val_loss = validate(model, val_loader, device)
    print(f"Epoch {epoch+1}, Training loss: {train_loss}, Validation loss: {val_loss}")
    scheduler.step()

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.89 GiB. GPU 

In [50]:
accumulation_steps = 4  # Accumulate gradients over 4 forward passes

for epoch in range(num_epochs):
    model.zero_grad()
    for i, (images, targets) in enumerate(val_loader):
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values()) / accumulation_steps

        losses.backward()
        
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            model.zero_grad()

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.79 GiB. GPU 

In [51]:
torch.cuda.memory_summary(device=None, abbreviated=False)

