Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/resnet/prepare_model_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,13 @@ def update_lr(optimizer, lr):


def prepare_model(num_epochs=1, models_dir="models", data_dir="data"):
seed = 0
# seed everything to 0 for reproducibility, https://pytorch.org/docs/stable/notes/randomness.html
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# the following are needed only for GPU
torch.cuda.manual_seed(0)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Expand Down
35 changes: 27 additions & 8 deletions examples/resnet/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import torch
import torchmetrics
from onnxruntime.quantization.calibrate import CalibrationDataReader
Expand All @@ -17,6 +18,15 @@
# Common Dataset
# -------------------------------------------------------------------------

seed = 0
# seed everything to 0 for reproducibility, https://pytorch.org/docs/stable/notes/randomness.html
# do not set random seed and np.random.seed for aml test, since it will cause aml job name conflict
torch.manual_seed(seed)
# the following are needed only for GPU
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class CIFAR10DataSet:
def __init__(
Expand All @@ -31,10 +41,15 @@ def __init__(

def setup(self, stage: str):
transform = transforms.Compose(
[transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()]
[
transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor(),
]
)
self.train_dataset = CIFAR10(root=self.train_path, train=True, transform=transform, download=False)
self.val_dataset = CIFAR10(root=self.vld_path, train=True, transform=transform, download=False)
self.val_dataset = CIFAR10(root=self.vld_path, train=False, transform=transform, download=False)


class PytorchResNetDataset(Dataset):
Expand Down Expand Up @@ -71,8 +86,7 @@ def post_process(output):

def create_dataloader(data_dir, batch_size, *args, **kwargs):
cifar10_dataset = CIFAR10DataSet(data_dir)
_, val_set = torch.utils.data.random_split(cifar10_dataset.val_dataset, [49000, 1000])
return DataLoader(PytorchResNetDataset(val_set), batch_size=batch_size, drop_last=True)
return DataLoader(PytorchResNetDataset(cifar10_dataset.val_dataset), batch_size=batch_size, drop_last=True)


# -------------------------------------------------------------------------
Expand All @@ -83,11 +97,17 @@ def create_dataloader(data_dir, batch_size, *args, **kwargs):
class ResnetCalibrationDataReader(CalibrationDataReader):
def __init__(self, data_dir: str, batch_size: int = 16):
super().__init__()
self.iterator = iter(create_dataloader(data_dir, batch_size))
self.iterator = iter(create_train_dataloader(data_dir, batch_size))
self.sample_counter = 500

def get_next(self) -> dict:
if self.sample_counter <= 0:
return None

try:
return {"input": next(self.iterator)[0].numpy()}
item = {"input": next(self.iterator)[0].numpy()}
self.sample_counter -= 1
return item
except Exception:
return None

Expand Down Expand Up @@ -161,8 +181,7 @@ def create_qat_config():

def create_train_dataloader(data_dir, batchsize, *args, **kwargs):
cifar10_dataset = CIFAR10DataSet(data_dir)
train_dataset, _ = torch.utils.data.random_split(cifar10_dataset.train_dataset, [40000, 10000])
return DataLoader(PytorchResNetDataset(train_dataset), batch_size=batchsize, drop_last=True)
return DataLoader(PytorchResNetDataset(cifar10_dataset.train_dataset), batch_size=batchsize, drop_last=True)


# -------------------------------------------------------------------------
Expand Down