In [None]:
# Mount to Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Define project folder
FOLDERNAME = 'Colab\ Notebooks/flowers102'

%cd drive/MyDrive/$FOLDERNAME

In [None]:

# Define device
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# This function calculates mean and std of the dataset
def compute_mean_std(loader):
    mean = 0.
    std = 0.
    num_total_images = 0
    
    for images, _ in loader:
        # images: N x C x H x W (4D)
        num_batch_samples = images.size(0)  # N
        images = images.reshape(num_batch_samples, images.size(1), -1)  # Reshape to 3D: N x C x (H*W)
        
        # Mean calculation
        # First calculate mean for each image in batch and channel
        batch_mean = images.mean(dim=-1)  # N x C
        # Then sum over all images in batch
        mean += batch_mean.sum(dim=0)  # C
        
        # Std calculation
        # First calculate std for each image in batch and channel
        batch_std = images.std(dim=-1)  # N x C
        # Then sum over all images in batch
        std += batch_std.sum(dim=0)  # C
        
        num_total_images += num_batch_samples
    
    mean = mean / num_total_images
    std = std / num_total_images
    
    return mean, std

# Define a basic transform just for mean/std calculation
basic_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load training dataset with basic transform just for mean/std calculation
train_dataset = datasets.Flowers102(root='./train',
    split='train',
    transform=basic_transform,
    download=True
)

# Calculate mean and std of the training dataset
loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
mean, std = compute_mean_std(loader)
print(f"Dataset mean: {mean}")
print(f"Dataset std: {std}")

In [None]:
from torchvision.transforms import v2

# Define transforms for the dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    v2.AutoAugment(policy=v2.AutoAugmentPolicy.IMAGENET),  # AutoAugment with ImageNet policy
    transforms.ToTensor(),
    transforms.Normalize(
        mean=mean.tolist(),
        std=std.tolist()
    )
])

# Load datasets with transforms
train_dataset = datasets.Flowers102(root='./train',
    split='train',
    transform=transform,
    download=True
)
val_dataset = datasets.Flowers102(
    root='./val',
    split='val',
    transform=transform,
    download=True
)
test_dataset = datasets.Flowers102(
    root='./test',
    split='test',
    transform=transform,
    download=True
)

