[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/fga-0/automating_data_augmentation/blob/main/wrn_concat_raw_aug.ipynb)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
import torch.optim as optim

import utils
import importlib
importlib.reload(utils)

In [None]:
from torchvision.transforms import AutoAugment, AutoAugmentPolicy

In [None]:
from google.colab import drive
drive.mount('/content/drive')
path_to_data = "drive/MyDrive/SDD/data_augmentation/cifar10"

In [None]:
transform_augment = transforms.Compose([
    transforms.Resize(224),
    AutoAugment(AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
    # transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.CIFAR10(root=path_to_data, train=True, download=True, transform=transform)
trainset_aug = torchvision.datasets.CIFAR10(root=path_to_data, train=True, download=True, transform=transform_augment)
testset = torchvision.datasets.CIFAR10(root=path_to_data, train=False, download=True, transform=transform)

In [None]:
train_subset_aug, test_subset_aug = utils.get_subset(trainset=trainset_aug, testset=testset, percentage=0.01)

train_subset, test_subset = utils.get_subset(trainset=trainset, testset=testset, percentage=0.02)

In [None]:
from torch.utils.data import ConcatDataset

In [None]:
train_subset_raw_and_AA_CIFAR = ConcatDataset([train_subset, train_subset_aug])
train_loader_AA_CIFAR = DataLoader(train_subset_raw_and_AA_CIFAR, batch_size=32, shuffle=True, num_workers=2)

# test_loader_AA_CIFAR = DataLoader(test_subset, batch_size=32, shuffle=False, num_workers=2) # inutile d'utiliser ce test loader ?
test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

In [None]:
from torchvision.models import wide_resnet50_2

# Load Wide ResNet model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# wrn = wide_resnet50_2(num_classes=10)  # CIFAR-10 has 10 classes

# Load a pre-trained Wide ResNet-50-2 model:
wrn = wide_resnet50_2(pretrained=True)
# Freeze all model parameters except for the final layer:
for param in wrn.parameters():
    param.requires_grad = False
# Get the number of input features for the original last layer:
num_feature = wrn.fc.in_features
# Replace the final classification layer to match your dataset:
wrn.fc = nn.Linear(num_feature, 10)
# View the structure of the new final layer (optional):
print(wrn.fc)
# Move the model to the GPU for accelerated training:
wrn = wrn.to(device)

print(f"Using device: {device}")

In [None]:
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(wrn.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# optimizer = optim.SGD(wrn.fc.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
optimizer = optim.Adam(wrn.fc.parameters(), lr=0.001)


# Learning rate scheduler for better convergence
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [None]:
wrn_trained_on_raw_AND_augmented_data = utils.train_WideResNet(model=wrn, trainloader=train_loader_AA_CIFAR, num_epochs=5, batch_size=32, optimizer=optimizer, criterion=criterion, device=device, scheduler=scheduler)

In [None]:
utils.evaluate(wrn_trained_on_raw_AND_augmented_data, test_loader, device)