-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocessing.py
75 lines (61 loc) · 2.12 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
The image preprocessor script
Author: Abdelkarim eljandoubi
date: Nov 2023
"""
from transformers import AutoImageProcessor
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
import torch
def preprocessing(model_checkpoint):
"""
From the model checkpoint, it returns two callable that trait train and validation datasets
"""
# load the image preprocessor
image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
# define the normalization
normalize = Normalize(
mean=image_processor.image_mean,
std=image_processor.image_std)
# define the train preporcessing pipeline
train_transforms = Compose(
[
RandomResizedCrop(image_processor.size["height"]),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
# define the validation preporcessing pipeline
val_transforms = Compose(
[
Resize(image_processor.size["height"]),
CenterCrop(image_processor.size["height"]),
ToTensor(),
normalize,
]
)
def preprocess_train(example_batch: dict) -> dict:
"""Apply train_transforms across a batch."""
example_batch["pixel_values"] = [train_transforms(image.convert("RGB"))
for image in example_batch["image"]]
return example_batch
def preprocess_val(example_batch: dict) -> dict:
"""Apply val_transforms across a batch."""
example_batch["pixel_values"] = [val_transforms(image.convert("RGB"))
for image in example_batch["image"]]
return example_batch
return preprocess_train, preprocess_val, image_processor
def collate_fn(examples: dict) -> dict:
"""Data Collactor"""
pixel_values = torch.stack([example["pixel_values"]
for example in examples])
labels = torch.tensor([example['label'] for example in examples])
return {"pixel_values": pixel_values, 'labels': labels}