In [37]:
!pip install imgaug



In [38]:
import os
import cv2
import imageio
import numpy as np
import imgaug as ia
from imgaug import augmenters as iaa 
import matplotlib.pyplot as plt
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
%matplotlib inline
ia.seed(1)

In [39]:
# Load the dataset
image_dir = "./sea-turtles-1-test/all/images/"
label_dir = "./sea-turtles-1-test/all/labels/"
image_files = os.listdir(image_dir)

In [40]:
# Only load the first 20 images with annotations
image_files = [image_file for image_file in image_files if os.path.exists(os.path.join(label_dir, image_file.replace(".jpg", ".txt")))]

In [41]:
# Select the first 20 images
image_files = image_files[:20]

In [42]:
class Box_imgaug():
    def __init__(self,x1,x2,y1,y2, class_id):
        self.x1=x1
        self.x2=x2
        self.y1=y1
        self.y2=y2
        self.class_id=int(class_id)

class Box_yolo():
    def __init__(self,x_center,y_center,width,height, class_id):
        self.x_center=x_center
        self.y_center=y_center
        self.width=width
        self.height=height
        self.class_id=class_id



In [43]:
def coords_imgaug(annotation, img_w, img_h):
    # Extract the bounding box coordinates
    class_id, center_x, center_y, b_width, b_height = annotation
    x1 = (center_x - (b_width/2))  *img_w
    x2 = (center_x + (b_width/2))  *img_w
    y1 = (center_y - (b_height/2)) *img_h
    y2 = (center_y + (b_height/2)) *img_h
    
    return Box_imgaug(x1,x2,y1,y2, class_id)

def coords_yolo(x1,x2,y1,y2, class_id, w, h):
    b_width  = abs(x1,x2)
    b_height = abs(y1-y2)
    x_center = min(x1,x2) + b_width
    y_center = min(y1,y2) + b_height

    return Box_yolo(x_center/w,y_center/h, b_width/w, b_height/h, class_id)

In [44]:
seq = iaa.Sequential([
    iaa.Crop(px=(1, 16), keep_size=False),
    iaa.Fliplr(0.5),
    #iaa.GaussianBlur(sigma=(0, 3.0))
], random_order=True)

In [46]:
for image_file in image_files:
    # Load the image
    image_path = os.path.join(image_dir, image_file)
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    h, w, c = image.shape

    # Load the corresponding annotation file
    annotation_file = os.path.join(label_dir, image_file.replace('.jpg', '.txt'))
    with open(annotation_file, 'r') as file:
        list_bbs = []
        for line in file:

            # Assuming each line of the annotation file contains [class, xmin, ymin, xmax, ymax]
            annotation = [float(coord) for coord in line.strip().split()]
            print('img ',image_path,' - label = ', annotation)
            
            bbox_imgaug = coords_imgaug(annotation=annotation, img_w=w, img_h=h)

            bb = BoundingBox(
                    x1=bbox_imgaug.x1,
                    x2=bbox_imgaug.x2,
                    y1=bbox_imgaug.y1,
                    y2=bbox_imgaug.y2,
                    label=bbox_imgaug.class_id
                    )

            list_bbs.append(bb)
        bbs = BoundingBoxesOnImage(list_bbs, shape=image.shape)

    #print(annotation)
    #plt.imshow(bbs.draw_on_image(image, size=2))

    image_aug, bbs_aug = seq(image=image, bounding_boxes=bbs)
    cv2.imwrite(image_dir+'aug_'+image_file, image_aug) 
    result = map(coords_yolo,bbs_aug)
    print(list(result))
    
    plt.imshow(bbs_aug.draw_on_image(image_aug))

img  ./sea-turtles-1-test/all/images/001.jpg  - label =  [0.0, 0.68359375, 0.30625, 0.01796875, 0.03984375]
img  ./sea-turtles-1-test/all/images/001.jpg  - label =  [0.0, 0.6828125, 0.4015625, 0.01484375, 0.03046875]
img  ./sea-turtles-1-test/all/images/001.jpg  - label =  [0.0, 0.65, 0.51328125, 0.01796875, 0.03046875]
img  ./sea-turtles-1-test/all/images/001.jpg  - label =  [0.0, 0.634375, 0.55390625, 0.01875, 0.0328125]
img  ./sea-turtles-1-test/all/images/001.jpg  - label =  [0.0, 0.659375, 0.61484375, 0.021875, 0.03359375]
img  ./sea-turtles-1-test/all/images/001.jpg  - label =  [0.0, 0.6296875, 0.74609375, 0.01328125, 0.0328125]
img  ./sea-turtles-1-test/all/images/001.jpg  - label =  [0.0, 0.634375, 0.8953125, 0.01796875, 0.0296875]


TypeError: coords_yolo() missing 6 required positional arguments: 'x2', 'y1', 'y2', 'class_id', 'w', and 'h'