diff --git a/examples/resnet/prepare_model_data.py b/examples/resnet/prepare_model_data.py index 87581ee00e..1496148e1e 100644 --- a/examples/resnet/prepare_model_data.py +++ b/examples/resnet/prepare_model_data.py @@ -56,12 +56,13 @@ def update_lr(optimizer, lr): def prepare_model(num_epochs=1, models_dir="models", data_dir="data"): + seed = 0 # seed everything to 0 for reproducibility, https://pytorch.org/docs/stable/notes/randomness.html - random.seed(0) - np.random.seed(0) - torch.manual_seed(0) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) # the following are needed only for GPU - torch.cuda.manual_seed(0) + torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False diff --git a/examples/resnet/user_script.py b/examples/resnet/user_script.py index 82f3ff3851..6278c9f5d2 100644 --- a/examples/resnet/user_script.py +++ b/examples/resnet/user_script.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- + import torch import torchmetrics from onnxruntime.quantization.calibrate import CalibrationDataReader @@ -17,6 +18,15 @@ # Common Dataset # ------------------------------------------------------------------------- +seed = 0 +# seed everything to 0 for reproducibility, https://pytorch.org/docs/stable/notes/randomness.html +# do not set random seed and np.random.seed for aml test, since it will cause aml job name conflict +torch.manual_seed(seed) +# the following are needed only for GPU +torch.cuda.manual_seed(seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + class CIFAR10DataSet: def __init__( @@ -31,10 +41,15 @@ def __init__( def setup(self, stage: str): transform = transforms.Compose( - [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()] + [ + transforms.Pad(4), + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32), + transforms.ToTensor(), + ] ) self.train_dataset = CIFAR10(root=self.train_path, train=True, transform=transform, download=False) - self.val_dataset = CIFAR10(root=self.vld_path, train=True, transform=transform, download=False) + self.val_dataset = CIFAR10(root=self.vld_path, train=False, transform=transform, download=False) class PytorchResNetDataset(Dataset): @@ -71,8 +86,7 @@ def post_process(output): def create_dataloader(data_dir, batch_size, *args, **kwargs): cifar10_dataset = CIFAR10DataSet(data_dir) - _, val_set = torch.utils.data.random_split(cifar10_dataset.val_dataset, [49000, 1000]) - return DataLoader(PytorchResNetDataset(val_set), batch_size=batch_size, drop_last=True) + return DataLoader(PytorchResNetDataset(cifar10_dataset.val_dataset), batch_size=batch_size, drop_last=True) # ------------------------------------------------------------------------- @@ -83,11 +97,17 @@ def create_dataloader(data_dir, batch_size, *args, **kwargs): class ResnetCalibrationDataReader(CalibrationDataReader): def __init__(self, data_dir: str, batch_size: int = 16): super().__init__() - self.iterator = iter(create_dataloader(data_dir, batch_size)) + self.iterator = iter(create_train_dataloader(data_dir, batch_size)) + self.sample_counter = 500 def get_next(self) -> dict: + if self.sample_counter <= 0: + return None + try: - return {"input": next(self.iterator)[0].numpy()} + item = {"input": next(self.iterator)[0].numpy()} + self.sample_counter -= 1 + return item except Exception: return None @@ -161,8 +181,7 @@ def create_qat_config(): def create_train_dataloader(data_dir, batchsize, *args, **kwargs): cifar10_dataset = CIFAR10DataSet(data_dir) - train_dataset, _ = torch.utils.data.random_split(cifar10_dataset.train_dataset, [40000, 10000]) - return DataLoader(PytorchResNetDataset(train_dataset), batch_size=batchsize, drop_last=True) + return DataLoader(PytorchResNetDataset(cifar10_dataset.train_dataset), batch_size=batchsize, drop_last=True) # -------------------------------------------------------------------------