## WIDER Dataset

In [8]:
# !pip install --upgrade torch --quiet

In [2]:
import json
import io
import boto3
from torch.utils.data import Dataset
from PIL import Image
from torchvision.transforms import transforms, RandAugment
import torch

class DataSet(Dataset):
    def __init__(self, ann_files, augs, img_size, dataset):
        self.dataset = dataset
        self.ann_files = ann_files
        self.augment = self.augs_function(augs, img_size)
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
            ]
        )
        self.anns = []
        self.s3_client = boto3.client('s3')  # Initialize the S3 client
        self.load_anns()
        print(self.augment)

        if self.dataset == "wider":
            self.transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
                ]
            )

    def augs_function(self, augs, img_size):            
        t = []
        if 'randomflip' in augs:
            t.append(transforms.RandomHorizontalFlip())
        if 'ColorJitter' in augs:
            t.append(transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0))
        if 'resizedcrop' in augs:
            t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)))
        if 'RandAugment' in augs: # need to review RandAugment()
            t.append(RandAugment())
            # t.append(transforms.RandomApply([
            #     transforms.RandomRotation(degrees=10),
            #     transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            #     transforms.RandomPerspective(distortion_scale=0.05)
            # ], p=0.5))

        t.append(transforms.Resize((img_size, img_size)))
    
        return transforms.Compose(t)


    def load_anns(self):
        self.anns = []
        for ann_file in self.ann_files:
            bucket, key = self.parse_s3_path(ann_file)
            response = self.s3_client.get_object(Bucket=bucket, Key=key)
            json_data = json.loads(response['Body'].read())
            self.anns += json_data

    def __len__(self):
        return len(self.anns)

    def __getitem__(self, idx):
        idx = idx % len(self)
        ann = self.anns[idx]
        
        bucket, key = self.parse_s3_path(ann["img_path"])
        response = self.s3_client.get_object(Bucket=bucket, Key=key)
        img = Image.open(io.BytesIO(response['Body'].read())).convert("RGB")

        if self.dataset == "wider":
            x, y, w, h = ann['bbox']
            img = img.crop([x, y, x+w, y+h])
        
        img = self.augment(img)
        img = self.transform(img)
        
        message = {
            "img_path": ann['img_path'],
            "target": torch.tensor(ann['target']),
            "img": img
        }
        return message

    @staticmethod
    def parse_s3_path(s3_path):
        if not s3_path.startswith("s3://"):
            raise ValueError(f"Invalid S3 path: {s3_path}")
        s3_path = s3_path[5:]
        bucket, key = s3_path.split('/', 1)
        return bucket, key


In [3]:
from torch.utils.data import DataLoader


def load_wider(batch_size=64):

    # Define transformations and augmentations
    train_augs = ['randomflip', 'ColorJitter', 'resizedcrop', 'RandAugment']
    test_augs = []  
    img_size = 256 
    
    train_dataset = DataSet(
        ann_files=['s3://210bucket/wider_attribute_annotation/wider_attribute_trainval.json'],  
        augs=train_augs,
        img_size=img_size,
        dataset='wider'
    )
    
    test_dataset = DataSet(
        ann_files=['s3://210bucket/wider_attribute_annotation/wider_attribute_test.json'], 
        augs=test_augs,
        img_size=img_size,
        dataset='wider'
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


In [4]:
trainloader, testloader  = load_wider()

Compose(
    RandomHorizontalFlip(p=0.5)
    ColorJitter(brightness=(0.5, 1.5), contrast=(0.5, 1.5), saturation=(0.5, 1.5), hue=None)
    RandomResizedCrop(size=(256, 256), scale=(0.7, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=warn)
    RandAugment(num_ops=2, magnitude=9, num_magnitude_bins=31, interpolation=InterpolationMode.NEAREST, fill=None)
    Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn)
)
Compose(
    Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn)
)
