Skip to content

Commit

Permalink
Upgrade Python and Pytorch versions for some examples (#1906)
Browse files Browse the repository at this point in the history
  • Loading branch information
tenzen-y committed Jun 28, 2022
1 parent 9ee8fda commit cfa2d84
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# TODO (tenzen-y): Upgrade Python version and Pytorch version
FROM python:3.7-slim
FROM python:3.9-slim

ENV TARGET_DIR /opt/darts-cnn-cifar10

Expand Down
6 changes: 3 additions & 3 deletions examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def virtual_step(self, train_x, train_y, xi, w_optim):
gradients = torch.autograd.grad(loss, self.model.getWeights())

# Do virtual step (Update gradient)
# Below opeartions do not need gradient tracking
# Below operations do not need gradient tracking
with torch.no_grad():
# dict key is not the value, but the pointer. So original network weight have to
# be iterated also.
for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients):
m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum
vw.copy_(w - xi * (m + g + self.w_weight_decay * w))
vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))

# Sync alphas
for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()):
Expand Down Expand Up @@ -85,7 +85,7 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
# Update final gradient = dalpha - xi * hessian
with torch.no_grad():
for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian):
alpha.grad = da - xi * h
alpha.grad = da - torch.FloatTensor(xi) * h

def compute_hessian(self, dws, train_x, train_y):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch==1.0.0
torchvision==0.2.1
Pillow==6.2.2
torch==1.11.0
torchvision==0.12.0
Pillow>=9.1.1
4 changes: 2 additions & 2 deletions examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,15 @@ def main():
best_top1 = 0.

for epoch in range(num_epochs):
lr_scheduler.step()
lr = lr_scheduler.get_lr()[0]
lr = lr_scheduler.get_last_lr()

model.print_alphas()

# Training
print(">>> Training")
train(train_loader, valid_loader, model, architect, w_optim, alpha_optim,
lr, epoch, num_epochs, device, w_grad_clip, print_step)
lr_scheduler.step()

# Validation
print("\n>>> Validation")
Expand Down
4 changes: 2 additions & 2 deletions examples/v1beta1/trial-images/darts-cnn-cifar10/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def update(self, val, n=1):


def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)

Expand All @@ -53,7 +53,7 @@ def accuracy(output, target, topk=(1,)):

res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
res.append(correct_k.mul_(1.0 / batch_size))

return res
Expand Down
3 changes: 1 addition & 2 deletions examples/v1beta1/trial-images/pytorch-mnist/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# TODO (tenzen-y): Upgrade Python version and Pytorch version
FROM python:3.7-slim
FROM python:3.9-slim

ADD examples/v1beta1/trial-images/pytorch-mnist /opt/pytorch-mnist
WORKDIR /opt/pytorch-mnist
Expand Down
6 changes: 3 additions & 3 deletions examples/v1beta1/trial-images/pytorch-mnist/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cloudml-hypertune==0.1.0.dev6
torch==1.0.0
torchvision==0.2.1
Pillow==6.2.2
torch==1.11.0
torchvision==0.12.0
Pillow>=9.1.1
2 changes: 1 addition & 1 deletion hack/verify-yamllint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ if [ -z "$(command -v yamllint)" ]; then
fi

echo 'Running yamllint'
yamllint -d "{extends: default, rules: {line-length: disable}}" examples/* manifests/*
yamllint -d "{extends: default, rules: {line-length: disable}}" examples/* manifests/*

0 comments on commit cfa2d84

Please sign in to comment.