In [None]:
from skimage import io
import os
import os.path as osp
from torchvision.transforms import Compose
import pandas as pd
import matplotlib.pyplot as plt
import torch
import matplotlib.patches as patches
from PIL import Image

In [None]:
from main_dino import DataAugmentationDINO

In [None]:
config = 'configs/config.yaml'
data_path = '/checkpoint/bojanowski/data/DINO4CELL/2022_02_09_single_cell/fixed_size_masked_single_cells/'
index_im = 1

In [None]:
data = pd.read_csv('/checkpoint/bojanowski/data/DINO4CELL/2022_02_09_single_cell/fixed_size_masked_single_cells.csv')

In [None]:
transform = DataAugmentationDINO(config = config)

In [None]:
# remove to tensor and normalize
transform.global_transfo1 = Compose(transform.global_transfo1.transforms[:-1])
transform.global_transfo2 = Compose(transform.global_transfo2.transforms[:-1])
transform.local_transfo = Compose(transform.local_transfo.transforms[:-1])

In [None]:
path =osp.join(data_path,data.iloc[index_im]['file'].split('/')[-1])

In [None]:
im = io.imread(path)

## Full image

In [None]:
plt.figure(figsize=(8,8))
plt.imshow(im)

## First Global data augmentation

In [None]:
fig, axs = plt.subplots(12,12,figsize=(50,50))
for i in range(12):
    for j in range(12):
        axs[i,j].imshow(transform.global_transfo1(im).permute(1,2,0).numpy())
plt.show()

## Second Global data augmentation

In [None]:
fig, axs = plt.subplots(12,12,figsize=(50,50))
for i in range(12):
    for j in range(12):
        axs[i,j].imshow(transform.global_transfo2(im).permute(1,2,0).numpy())
plt.show()

## Local data augmentation without brightness and contrast

In [None]:
fig, axs = plt.subplots(12,12,figsize=(50,50))
for i in range(12):
    for j in range(12):
        axs[i,j].imshow(transform.local_transfo(im).permute(1,2,0).numpy())
plt.show()

## Global Crop

In [None]:
before_rrc = Compose(transform.global_transfo1.transforms[:4])
rrc = transform.global_transfo1.transforms[4]

In [None]:
img = io.imread(path)
img = before_rrc(img)
if not isinstance(img, torch.Tensor): img = transforms.ToTensor()(img)


# Create figure and axes
fig, ax = plt.subplots(figsize = (10,10))

# Display the image
ax.imshow(img.permute(1,2,0).numpy())

color = ['r','g','b','pink','orange']
for _ in range(20):
    i, j, h, w = rrc.get_params(img, rrc.scale, rrc.ratio)
    # Create a Rectangle patch
    rect = patches.Rectangle((j, i), w, h, linewidth=1, edgecolor=color[_%5], facecolor='none')
    
    # Add the patch to the Axes
    ax.add_patch(rect)

plt.show()

## Local Crop

In [None]:
before_rrc = Compose(transform.local_transfo.transforms[:4])
rrc = transform.local_transfo.transforms[4]

In [None]:
rrc.scale = [0.05,0.4]

In [None]:
img = io.imread(path)
img = before_rrc(img)
if not isinstance(img, torch.Tensor): img = transforms.ToTensor()(img)


# Create figure and axes
fig, ax = plt.subplots(figsize = (10,10))

# Display the image
ax.imshow(img.permute(1,2,0).numpy())

color = ['r','g','b','pink','orange']
for _ in range(20):
    i, j, h, w = rrc.get_params(img, rrc.scale, rrc.ratio)
    # Create a Rectangle patch
    rect = patches.Rectangle((j, i), w, h, linewidth=1, edgecolor=color[_%5], facecolor='none')
    
    # Add the patch to the Axes
    ax.add_patch(rect)

plt.show()

## Remove Brightness and contrast

In [None]:
transform.global_transfo1 = Compose(transform.global_transfo1.transforms[:5])
transform.global_transfo2 = Compose(transform.global_transfo2.transforms[:5])
transform.local_transfo = Compose(transform.local_transfo.transforms[:5])

## First Global data augmentation without brightness and contrast

In [None]:
fig, axs = plt.subplots(12,12,figsize=(50,50))
for i in range(12):
    for j in range(12):
        axs[i,j].imshow(transform.global_transfo1(im).permute(1,2,0).numpy())
plt.show()

## Local data augmentation

In [None]:
fig, axs = plt.subplots(12,12,figsize=(50,50))
for i in range(12):
    for j in range(12):
        axs[i,j].imshow(transform.local_transfo(im).permute(1,2,0).numpy())
plt.show()

## Remove RRC

In [None]:
transform.global_transfo1 = Compose(transform.global_transfo1.transforms[:4])
transform.global_transfo2 = Compose(transform.global_transfo2.transforms[:4])
transform.local_transfo = Compose(transform.local_transfo.transforms[:4])

In [None]:
transform.global_transfo1

## First Global data augmentation without brightness and contrast

In [None]:
fig, axs = plt.subplots(12,12,figsize=(50,50))
for i in range(12):
    for j in range(12):
        tensor_im = transform.global_transfo1(im)
        print((i,j))
        print(tensor_im.mean(dim=(-1,-2)))
        print(tensor_im.max())
        print(tensor_im.min())
        axs[i,j].imshow(tensor_im.permute(1,2,0).numpy())
plt.show()