# Adapters for Transformer -- CV Edition
## Setup

In [None]:
!pip install torchvision==0.14
!pip install matplotlib
!pip install "squirrel-core[torch]"
!pip install "squirrel-datasets-core[torchvision]"

!pip install transformers

### Download evaluation dataset

Download the dataset from 

https://drive.google.com/drive/folders/1bMct3K76RTjycmkmMr8RPw8axRKqz2JD?usp=share_link

and unzip into the working directory (`./squirrel_middleburry_patched`)

### Download training dataset
Download the Sintel dataset from

http://files.is.tue.mpg.de/jwulff/sintel/MPI-Sintel-stereo-training-20150305.zip

and unzip into the working directory (`./Sintel`).
The folder name must match exactly, as the dataloader will only match this way.

## Code

In [2]:
import torch
import torch.utils.data as tud
from torch.utils.data._utils.collate import default_collate as torch_default_collate

import torchvision.transforms as tr

import matplotlib.pyplot as plt

from squirrel.driver import MessagepackDriver
from squirrel_datasets_core.driver import TorchvisionDriver

Matplotlib is building the font cache; this may take a moment.


### Prepare dataloaders

In [3]:
def get_dataloader_eval(batch_size: int) -> tud.DataLoader:
    """Dataloader to load evaluation/test dataset."""
    
    url = "./squirrel_middlebury_patched"  # path to unzipped data folder containing *.gz files
    # Get iterator from driver
    driver = MessagepackDriver(url)
    it = driver.get_iter()
    
    #############################
    ## YOUR PREPROCESSING HERE ##
    preprocess = tr.Compose([
        lambda x: x
    ])
    #############################

    dataset = (
        it
        .map(preprocess)
        .batched(batch_size, torch_default_collate, drop_last_if_not_full=False)
        .to_torch_iterable()
    )
    return tud.DataLoader(dataset, shuffle=None, batch_size=None)

In [4]:
def get_dataloader_train(batch_size: int, shuffe_size: int = 100, num_workers:int = 0) -> tud.DataLoader:
    """Dataloader to Sintel training data."""
    # Path to folder containing the `Sintel` folder previously donwloaded.
    url = "./"
    
    driver = TorchvisionDriver("SintelStereo", url=url)
    it = driver.get_iter()
   
    dataset = (
        it
        .shuffle(shuffe_size)
        .split_by_worker_pytorch()
        #############################################################
        ### YOUR PREPROCESSING, COLLATING, AUGMENTATION, ETC. HERE ##
        #############################################################
        .batched(batch_size, torch_default_collate, drop_last_if_not_full=True)
        .to_torch_iterable()
    )
    return tud.DataLoader(dataset, shuffle=None, batch_size=None, num_workers=num_workers)

### Sanity-check data loaders

In [11]:
batch_size = 32
dl_eval = get_dataloader_eval(batch_size)

In [12]:
for i, d in enumerate(dl_eval):

    img_l = d["img_l"][0]
    img_r = d["img_r"][0]

    fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)
    ax[0].imshow(img_l.permute(1, 2, 0).numpy())
    ax[1].imshow(img_r.permute(1, 2, 0).numpy())
    ax[0].set_title(f"{img_l.shape}, {img_l.dtype}, {img_l.min()}, {img_l.max()}")
    fig.tight_layout()
    
    if i == 2:
        break

## Your Code

### Baseline Evaluation

### Adapter Fine Tuning