In [1]:
import os
import random
import time
import json
import warnings 
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from utils import label_accuracy_score
import cv2

import numpy as np
import pandas as pd

# 전처리를 위한 라이브러리
from pycocotools.coco import COCO
import torchvision
import torchvision.transforms as transforms

import albumentations as A
from albumentations import *
from albumentations.pytorch import ToTensorV2
from albumentations.augmentations import functional as F

# 시각화를 위한 라이브러리
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

plt.rcParams['axes.grid'] = False

print('pytorch version: {}'.format(torch.__version__))
print('GPU 사용 가능 여부: {}'.format(torch.cuda.is_available()))

print(torch.cuda.get_device_name(0))
print(torch.cuda.device_count())

device = "cuda" if torch.cuda.is_available() else "cpu"   # GPU 사용 가능 여부에 따라 device 정보 저장

pytorch version: 1.4.0
GPU 사용 가능 여부: True
Tesla P40
1


In [2]:
batch_size = 16   # Mini-batch size
num_epochs = 20
learning_rate = 0.0001

In [3]:
# seed 고정
random_seed = 21
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
# torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

In [4]:
%matplotlib inline

dataset_path = '/opt/ml/input/data'
anns_file_path = dataset_path + '/' + 'train.json'

# Read annotations
with open(anns_file_path, 'r') as f:
    dataset = json.loads(f.read())

categories = dataset['categories']
anns = dataset['annotations']
imgs = dataset['images']
nr_cats = len(categories)
nr_annotations = len(anns)
nr_images = len(imgs)

# Load categories and super categories
cat_names = []
super_cat_names = []
super_cat_ids = {}
super_cat_last_name = ''
nr_super_cats = 0
for cat_it in categories:
    cat_names.append(cat_it['name'])
    super_cat_name = cat_it['supercategory']
    # Adding new supercat
    if super_cat_name != super_cat_last_name:
        super_cat_names.append(super_cat_name)
        super_cat_ids[super_cat_name] = nr_super_cats
        super_cat_last_name = super_cat_name
        nr_super_cats += 1

In [1]:
# Count annotations
cat_histogram = np.zeros(nr_cats,dtype=int)
for ann in anns:
    cat_histogram[ann['category_id']] += 1

# Initialize the matplotlib figure
f, ax = plt.subplots(figsize=(5,5))

# Convert to DataFrame
df = pd.DataFrame({'Categories': cat_names, 'Number of annotations': cat_histogram})
df = df.sort_values('Number of annotations', 0, False)

# Plot the histogram
plt.title("category distribution of train set ")
plot_1 = sns.barplot(x="Number of annotations", y="Categories", data=df, label="Total", color="b")

In [6]:
# category labeling 
sorted_temp_df = df.sort_index()

# background = 0 에 해당되는 label 추가 후 기존들을 모두 label + 1 로 설정
sorted_df = pd.DataFrame(["Backgroud"], columns = ["Categories"])
sorted_df = sorted_df.append(sorted_temp_df, ignore_index=True)

In [2]:
# class (Categories) 에 따른 index 확인 (0~11 : 총 12개)
sorted_df

In [8]:
category_names = list(sorted_df.Categories)

def get_classname(classID, cats):
    for i in range(len(cats)):
        if cats[i]['id']==classID:
            return cats[i]['name']
    return "None"

class CustomDataLoader(Dataset):
    """COCO format"""
    def __init__(self, data_dir, mode = 'train', transform = None):
        super().__init__()
        self.mode = mode
        self.transform = transform
        self.coco = COCO(data_dir)
        
    def __getitem__(self, index: int):
        # dataset이 index되어 list처럼 동작
        image_id = self.coco.getImgIds(imgIds=index)
        image_infos = self.coco.loadImgs(image_id)[0]
        
        # cv2 를 활용하여 image 불러오기
        paths = os.path.join(dataset_path, image_infos['file_name'])
        images = cv2.imread(os.path.join(dataset_path, image_infos['file_name']))
        images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.uint8)
#         images /= 255.0
        
        if (self.mode in ('train', 'val')):
            ann_ids = self.coco.getAnnIds(imgIds=image_infos['id'])
            anns = self.coco.loadAnns(ann_ids)

            # Load the categories in a variable
            cat_ids = self.coco.getCatIds()
            cats = self.coco.loadCats(cat_ids)

            # masks : size가 (height x width)인 2D
            # 각각의 pixel 값에는 "category id + 1" 할당
            # Background = 0
            masks = np.zeros((image_infos["height"], image_infos["width"]))
            # Unknown = 1, General trash = 2, ... , Cigarette = 11
            for i in range(len(anns)):
                className = get_classname(anns[i]['category_id'], cats)
                pixel_value = category_names.index(className)
                masks = np.maximum(self.coco.annToMask(anns[i])*pixel_value, masks)
            masks = masks.astype(np.uint8)

            # transform -> albumentations 라이브러리 활용
            if self.transform is not None:
                transformed = self.transform(image=images, mask=masks)
                images = transformed["image"]
                masks = transformed["mask"]
            
            return paths,images, masks, image_infos
        
        if self.mode == 'test':
            # transform -> albumentations 라이브러리 활용
            if self.transform is not None:
                transformed = self.transform(image=images)
                images = transformed["image"]
            
            return paths, images, image_infos
    
    
    def __len__(self) -> int:
        # 전체 dataset의 size를 return
        return len(self.coco.getImgIds())

In [9]:
train_path = dataset_path + '/train_all.json'

def collate_fn(batch):
    return tuple(zip(*batch))

train_transform = A.Compose([
                            ToTensorV2()
                            ])

train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=1,
                                           shuffle=True,
                                           num_workers=4,
                                           collate_fn=collate_fn)

loading annotations into memory...
Done (t=4.28s)
creating index...
index created!


In [10]:
img_dir = []

for step, (paths,images, masks, _) in enumerate(train_loader):
    img_dir.append(paths[0])

In [3]:
from tqdm.notebook import tqdm
from PIL import Image
print(img_dir[0])

Image.open(img_dir[0])

# Normalize작업 - test set정보를 이용

In [4]:
test_path = dataset_path + '/test.json'

def collate_fn(batch):
    return tuple(zip(*batch))

test_transform = A.Compose([
                            ToTensorV2()
                            ])

test_dataset = CustomDataLoader(data_dir=test_path, mode='test', transform=test_transform)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                           batch_size=1,
                                           shuffle=True,
                                           num_workers=4,
                                           collate_fn=collate_fn)

img_dir = []

for step, (paths,images, masks) in enumerate(test_loader):
    img_dir.append(paths[0])
    

img_info = dict(heights=[], widths=[], means=[], stds=[])


for path in tqdm(img_dir):
    img = np.array(Image.open(path))
    h, w, _ = img.shape
    img_info['heights'].append(h)
    img_info['widths'].append(w)
    img_info['means'].append(img.mean(axis=(0,1)))
    img_info['stds'].append(img.std(axis=(0,1)))


In [5]:
print(f'Total number of people is {len(df)}')
print(f'Total number of images is {len(df) * 7}')

print(f'Minimum height for dataset is {np.min(img_info["heights"])}')
print(f'Maximum height for dataset is {np.max(img_info["heights"])}')
print(f'Average height for dataset is {int(np.mean(img_info["heights"]))}')
print(f'Minimum width for dataset is {np.min(img_info["widths"])}')
print(f'Maximum width for dataset is {np.max(img_info["widths"])}')
print(f'Average width for dataset is {int(np.mean(img_info["widths"]))}')

print(f'RGB Mean: {np.mean(img_info["means"], axis=0) / 255.}')
print(f'RGB Standard Deviation: {np.mean(img_info["stds"], axis=0) / 255.}')

In [6]:
train_path = dataset_path + '/train_all.json'

def collate_fn(batch):
    return tuple(zip(*batch))

train_transform = A.Compose([
    Normalize(
        mean=(0.46009655,0.43957878,0.41827092), 
        std=(0.2108204,0.20766491,0.21656131), 
        max_pixel_value=255.0, 
        p=1.0),
    ToTensorV2()
])

train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=1,
                                           shuffle=True,
                                           num_workers=4,
                                           collate_fn=collate_fn)

In [16]:
temp_images = []
for step, (paths,images, masks, _) in enumerate(train_loader):
    if step == 4:
        break 
    temp_images.append(images[0])
    
    
# temp_images

In [7]:
torch.max(temp_images[0])

In [8]:
torch.min(temp_images[0])

In [9]:
# train.json / validation.json / test.json 디렉토리 설정
train_path = dataset_path + '/train.json'

# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

train_transform = A.Compose([
                            ToTensorV2()
                            ])

# create own Dataset 2
# train dataset
train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)

# DataLoader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=4,
                                           collate_fn=collate_fn)

# train_loader의 output 결과(image 및 mask) 확인
for _,imgs, masks, image_infos in train_loader:
    image_infos = image_infos[0]
    temp_images = imgs
    temp_masks = masks
    
    break

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(temp_images[0].permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [10]:
# train.json / validation.json / test.json 디렉토리 설정
train_path = dataset_path + '/train.json'

# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

train_transform = A.Compose([
    CLAHE(p=1.0),
    RandomResizedCrop(256,256),
    ToTensorV2()
])

# create own Dataset 2
# train dataset
train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)

# DataLoader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=4,
                                           collate_fn=collate_fn)

# train_loader의 output 결과(image 및 mask) 확인
for _,imgs, masks, image_infos in train_loader:
    image_infos = image_infos[0]
    temp_images = imgs
    temp_masks = masks
    
    break

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 12))

ax1.imshow(temp_images[0].permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

ax2.imshow(temp_masks[0])
ax2.grid(False)

plt.show()

In [11]:
# train.json / validation.json / test.json 디렉토리 설정
train_path = dataset_path + '/train.json'

# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

train_transform = A.Compose([
    VerticalFlip(p=1.0),
    ToTensorV2()
])

# create own Dataset 2
# train dataset
train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)

# DataLoader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=4,
                                           collate_fn=collate_fn)

# train_loader의 output 결과(image 및 mask) 확인
for _,imgs, masks, image_infos in train_loader:
    image_infos = image_infos[0]
    temp_images = imgs
    temp_masks = masks
    
    break

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 12))

ax1.imshow(temp_images[0].permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

ax2.imshow(temp_masks[0])
ax2.grid(False)

plt.show()

In [12]:
# train.json / validation.json / test.json 디렉토리 설정
train_path = dataset_path + '/train.json'

# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

train_transform = A.Compose([
    Blur(),
    ToTensorV2()
])

# create own Dataset 2
# train dataset
train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)

# DataLoader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=4,
                                           collate_fn=collate_fn)

# train_loader의 output 결과(image 및 mask) 확인
for _,imgs, masks, image_infos in train_loader:
    image_infos = image_infos[0]
    temp_images = imgs
    temp_masks = masks
    
    break

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(temp_images[0].permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [13]:
# train.json / validation.json / test.json 디렉토리 설정
train_path = dataset_path + '/train.json'

# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

train_transform = A.Compose([
    ChannelDropout(),
    ToTensorV2()
])

# create own Dataset 2
# train dataset
train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)

# DataLoader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=4,
                                           collate_fn=collate_fn)

# train_loader의 output 결과(image 및 mask) 확인
for _,imgs, masks, image_infos in train_loader:
    image_infos = image_infos[0]
    temp_images = imgs
    temp_masks = masks
    
    break

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(temp_images[0].permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [14]:
# train_transform = A.Compose([
    ChannelDropout(),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [15]:
train_transform = A.Compose([
    ChannelShuffle(),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [46]:
#CLAHE는 작동안함 - 추후 unint8포멧으로 변경하여 해결
train_transform = A.Compose([
    CLAHE(p =1.0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [16]:
train_transform = A.Compose([
    CoarseDropout(max_holes=20, max_height=20,p =1.0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [17]:
train_transform = A.Compose([
    ColorJitter(p =1.0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [18]:
train_transform = A.Compose([
    Cutout(num_holes=10, max_h_size=30, max_w_size=30,p =1.0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [19]:
train_transform = A.Compose([
    Downscale(p =1.0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [20]:
train_transform = A.Compose([
    Flip(p =1.0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [21]:
train_transform = A.Compose([
    GaussianBlur(p =1.0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [22]:
train_transform = A.Compose([
    GaussianBlur(p =1.0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]
print(type(images))
print(images.shape)

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [23]:
images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0

images = augmentations.functional.add_snow(images,1,1.5)
images = torch.tensor(images)
print(images.shape)


fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))
ax1.imshow(images)
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [24]:
images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0

rain_drops = ((100,100), (200,200))

images = augmentations.functional.add_rain(
    images,
    slant = 10,
    drop_length = 10,
    drop_width = 3,
    rain_drops = rain_drops,
    blur_value = 1, 
    brightness_coefficient = 1.0,
    drop_color = (200,200,200)
)
images = torch.tensor(images)
print(images.shape)


fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))
ax1.imshow(images)
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [25]:
train_transform = A.Compose([
    ToGray(p = 1.0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [None]:
train_transform = A.Compose([
    Cutout(num_holes=10, 
                        max_h_size=int(.1 * 400), max_w_size=int(.1 * 400), 
                        p=1.0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

# GridMask Test

In [26]:
class GridMask(DualTransform):
    def __init__(self, num_grid=3, fill_value=0, rotate=0, mode=0, always_apply=False, p=0.5):
        super(GridMask, self).__init__(always_apply, p)
        if isinstance(num_grid, int):
            num_grid = (num_grid, num_grid)
        if isinstance(rotate, int):
            rotate = (-rotate, rotate)
        self.num_grid = num_grid
        self.fill_value = fill_value
        self.rotate = rotate
        self.mode = mode
        self.masks = None
        self.rand_h_max = []
        self.rand_w_max = []

    def init_masks(self, height, width):
        if self.masks is None:
            self.masks = []
            n_masks = self.num_grid[1] - self.num_grid[0] + 1
            for n, n_g in enumerate(range(self.num_grid[0], self.num_grid[1] + 1, 1)):
                grid_h = height / n_g
                grid_w = width / n_g
                this_mask = np.ones((int((n_g + 1) * grid_h), int((n_g + 1) * grid_w))).astype(np.uint8)
                for i in range(n_g + 1):
                    for j in range(n_g + 1):
                        this_mask[
                             int(i * grid_h) : int(i * grid_h + grid_h / 2),
                             int(j * grid_w) : int(j * grid_w + grid_w / 2)
                        ] = self.fill_value
                        if self.mode == 2:
                            this_mask[
                                 int(i * grid_h + grid_h / 2) : int(i * grid_h + grid_h),
                                 int(j * grid_w + grid_w / 2) : int(j * grid_w + grid_w)
                            ] = self.fill_value
                
                if self.mode == 1:
                    this_mask = 1 - this_mask

                self.masks.append(this_mask)
                self.rand_h_max.append(grid_h)
                self.rand_w_max.append(grid_w)

    def apply(self, image, mask, rand_h, rand_w, angle, **params):
        h, w = image.shape[:2]
        mask = F.rotate(mask, angle) if self.rotate[1] > 0 else mask
        mask = mask[:,:,np.newaxis] if image.ndim == 3 else mask
        image *= mask[rand_h:rand_h+h, rand_w:rand_w+w].astype(image.dtype)
        return image

    def get_params_dependent_on_targets(self, params):
        img = params['image']
        height, width = img.shape[:2]
        self.init_masks(height, width)

        mid = np.random.randint(len(self.masks))
        mask = self.masks[mid]
        rand_h = np.random.randint(self.rand_h_max[mid])
        rand_w = np.random.randint(self.rand_w_max[mid])
        angle = np.random.randint(self.rotate[0], self.rotate[1]) if self.rotate[1] > 0 else 0

        return {'mask': mask, 'rand_h': rand_h, 'rand_w': rand_w, 'angle': angle}

    @property
    def targets_as_params(self):
        return ['image']

    def get_transform_init_args_names(self):
        return ('num_grid', 'fill_value', 'rotate', 'mode')

train_transform = A.Compose([
    GridMask(num_grid=3, p=1, rotate=0),
    ToTensorV2()
])

images = cv2.imread("/opt/ml/input/data/batch_02_vt/1500.jpg")
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0


transformed = train_transform(image=images)
images = transformed["image"]

fig, (ax1) = plt.subplots(nrows=1, ncols=1, figsize=(12, 12))

ax1.imshow(images.permute([1,2,0]))
ax1.grid(False)

plt.show()