Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a data augmentation method PatchMix for Contrastive Learning and ViT #1873

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added tests/data/0/0.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/0/1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/0/2.JPEG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/0/3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/0/4.JPEG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/0/5.JPEG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/0/6.JPEG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/0/7.JPEG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 48 additions & 0 deletions tests/test_patchmix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from torchvision.transforms.functional import to_pil_image
import torch
from timm.data.patchmix import PatchMix


def cpu_and_cuda():
import pytest # noqa

return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))


def needs_cuda(test_func):
import pytest # noqa

return pytest.mark.needs_cuda(test_func)


@needs_cuda
@pytest.mark.parametrize("batch_size ", (4, 7))
@pytest.mark.parametrize("prob", (1.0, 0.5, 0.0))
@pytest.mark.parametrize("mix_num", (1, 2, 3))
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_patchmix(batch_size, prob, mix_num, device):
data_set = datasets.ImageFolder(
root='data/',
transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]),
)
data_loader = DataLoader(dataset=data_set, batch_size=batch_size, num_workers=4, shuffle=True)

patchmix = PatchMix(10, prob, mix_num, 16)

for images, _ in data_loader:
b, c, w, h = images.shape
images = images.to(device)
target = torch.arange(batch_size).to(device)
org_img = images.permute(1, 0, 2, 3).reshape(c, b * w, h)
mix_img, mo_target, mm_target = patchmix(images, target)
mix_img = mix_img.permute(1, 0, 2, 3).reshape(c, b * w, h)
result = torch.cat([org_img, mix_img], dim=-1)
to_pil_image(result).save(f"bs{batch_size}_p{prob}_n{mix_num}_{device}.png")
print(target)
print(mo_target)
print(mm_target)
break
122 changes: 122 additions & 0 deletions timm/data/patchmix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""PatchMix

Papers:
Inter-Instance Similarity Modeling for Contrastive Learning (https://arxiv.org/abs/2306.12243)

Code Reference:
PatchMix: https://github.com/visresearch/patchmix
"""
import torch
import random
import numpy as np


def one_hot(x, num_classes, on_value=1.0, off_value=0.0):
return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value)


def random_indexes(size):
forward_indexes = np.arange(size)
np.random.shuffle(forward_indexes)
backward_indexes = np.argsort(forward_indexes)
return forward_indexes, backward_indexes


def take_indexes(x, indexes):
return torch.gather(x, 1, indexes.unsqueeze(-1).repeat(1, 1, x.shape[-1]))


class PatchMix:
"""PatchMix that applies different params to whole batch

Args:
num_classes (int): the number of categories in contrastive learning, it is generally the sum of batch sizes across all nodes.
prob (float): The probability of performing patch mix. Default is ``0.5``.
mix_num (int): The number of original images included in each set of mix images. Default is ``2``.
patch_size (int): size of image patch. Default is ``16``.
smoothing (float): coefficient of label smoothing Default is ``0.``.
"""

def __init__(self, num_classes, prob=0.5, mix_num=2, patch_size=16, smoothing=0.0):
super().__init__()
self.prob = prob
self.mix_num = mix_num
self.patch_size = patch_size
self.smoothing = smoothing
self.num_classes = num_classes

def _shuffle(self, x):
b, l = x.shape[:2]
indexes = random_indexes(l)
forward_indexes = torch.as_tensor(indexes[0], dtype=torch.long).to(x.device)
forward_indexes = forward_indexes.repeat(b, 1)
backward_indexes = torch.as_tensor(indexes[1], dtype=torch.long).to(x.device)
backward_indexes = backward_indexes.repeat(b, 1)
x = take_indexes(x, forward_indexes)
return x, forward_indexes, backward_indexes

def _mix(self, x, m):
b, l, c = x.shape
s = l // m
d = b * m
l_ = int(s * m)
# get the image sequence that needs to be mixed, and drop the last.
mix_x = x[:, :l_]
mix_x = mix_x.reshape(d, s, c)
# generate the mix index for mixing patch group.
ids = torch.arange(d, device=x.device)
mix_indexes = (ids + ids % m * m) % d
mix_x = torch.gather(mix_x, 0, mix_indexes.repeat(s, c, 1).permute(-1, 0, 1))
mix_x = mix_x.reshape(b, l_, c)
x[:, :l_] = mix_x
# generate the mix index for mixing target.
ids = torch.arange(b, device=x.device).view(-1, 1)
m2o_indexes = (ids + torch.arange(m, device=x.device)) % b
m2m_indexes = ((ids - m + 1) + torch.arange(m * 2 - 1, device=x.device) + b) % b
return x, m2o_indexes, m2m_indexes

def __call__(self, x, target):
"""
img (Tensor): Image to be mixed.
target (Tensor): target for contrastive learning.

Returns:
x (Tensor): mixed image.
m2o_target (Tensor): target between mixed images and original images in infoNCE loss.
m2m_target (Tensor): target between mixed images and mixed images in infoNCE loss.
"""
b, c, h, w = x.shape
m = self.mix_num
# We only use patch mix when m is greater than 1
use_mix = random.random() < self.prob and m > 1
if use_mix:
p = self.patch_size
n_h = h // p
n_w = w // p
# b c (w p1) (h p2) -> b (w h) (c p1 p2)
x = x.reshape(b, c, n_h, p, n_w, p).permute(0, 2, 4, 1, 3, 5).reshape(b, n_h * n_w, c * p * p)
x, _, backward_indexes = self._shuffle(x)
x, m2o_indexes, m2m_indexes = self._mix(x, m)
x = take_indexes(x, backward_indexes)
# b (w h) (c p1 p2) -> b c (w p1) (h p2)
x = x.reshape(b, n_h, n_w, c, p, p).permute(0, 3, 1, 4, 2, 5).reshape(b, c, n_h * p, n_w * p)
else:
m = 1
m2o_indexes = target.view(-1, 1)
m2m_indexes = m2o_indexes

# get mixed target for mix-to-org loss and mix-to-mix loss
m2o_target = target[m2o_indexes]
m2m_target = target[m2m_indexes]

off_value = self.smoothing / self.num_classes
true_num = m2o_target.shape[1]
on_value = (1.0 - self.smoothing) / true_num + off_value
m2o_target = one_hot(m2o_target, self.num_classes, on_value, off_value)

ids = torch.arange(m2m_target.shape[1], device=x.device)
weights = 1.0 - torch.abs(m - ids - 1) / m
on_value = (1.0 - self.smoothing) * weights / m + off_value
m2m_target = one_hot(m2m_target, self.num_classes, on_value.expand([m2m_target.shape[0], -1]), off_value)

return x, m2o_target, m2m_target