# Benchmarking

This tutorial benchmarks the performance of various sampling strategies, with and without caching.

## Imports

In [None]:
import os
import sys
import time
from typing import Tuple

sys.path.append("../..")

from torch.utils.data import DataLoader

from torchgeo.datasets import ChesapeakeDE, NAIP
from torchgeo.samplers import RandomGeoSampler, GridGeoSampler, RandomBatchGeoSampler


ROOT = "/mnt/blobfuse/adam-scratch"

## Timing function

In [None]:
def time_epoch(dataloader: DataLoader) -> Tuple[float, int]:
    tic = time.time()
    i = 0
    for _ in dataloader:
        i += 1
    toc = time.time()
    return toc - tic, i

## RandomGeoSampler

In [None]:
for cache in [False, True]:
    chesapeake = ChesapeakeDE(os.path.join(ROOT, "chesapeake", "DE"), cache=cache)
    naip = NAIP(os.path.join(ROOT, "naip"), crs=chesapeake.crs, res=chesapeake.res, cache=cache)
    dataset = chesapeake + naip
    sampler = RandomGeoSampler(naip.index, size=1000, length=888)
    dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)
    duration, count = time_epoch(dataloader)
    print(duration, count)

296.582683801651 74
54.20210099220276 74


## GridGeoSampler

In [None]:
for cache in [False, True]:
    chesapeake = ChesapeakeDE(os.path.join(ROOT, "chesapeake", "DE"), cache=cache)
    naip = NAIP(os.path.join(ROOT, "naip"), crs=chesapeake.crs, res=chesapeake.res, cache=cache)
    dataset = chesapeake + naip
    sampler = GridGeoSampler(naip.index, size=1000, stride=500)
    dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)
    duration, count = time_epoch(dataloader)
    print(duration, count)

391.90197944641113 74
118.0611424446106 74


## RandomBatchGeoSampler

In [None]:
for cache in [False, True]:
    chesapeake = ChesapeakeDE(os.path.join(ROOT, "chesapeake", "DE"), cache=cache)
    naip = NAIP(os.path.join(ROOT, "naip"), crs=chesapeake.crs, res=chesapeake.res, cache=cache)
    dataset = chesapeake + naip
    sampler = RandomBatchGeoSampler(naip.index, size=1000, batch_size=12, length=888)
    dataloader = DataLoader(dataset, batch_sampler=sampler)
    duration, count = time_epoch(dataloader)
    print(duration, count)

230.51380324363708 74
53.99923872947693 74
