Skip to content

Commit

Permalink
test(resnet9): update to 91.6% model artifact
Browse files Browse the repository at this point in the history
  • Loading branch information
joennlae committed Oct 10, 2023
1 parent 021a663 commit 645bfd8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/resnet9_validation.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
name: ResNet9 - 90%+ accuracy
name: ResNet9 - 91.6%+ accuracy
on: [push]

jobs:
changes:
uses: ./.github/workflows/filter.yaml
resnet9:
name: ResNet9 - 90%+ accuracy
name: ResNet9 - 91.6%+ accuracy
needs: changes
if: ${{ needs.changes.outputs.algorithmic == 'true' }}
runs-on: ubuntu-latest
Expand Down
47 changes: 24 additions & 23 deletions src/python/test/model_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
from halutmatmul.modules import HalutConv2d, HalutLinear


model_name_file = "resnet9-best-int8.pth"


def download_model(
path: str,
is_ci: bool = False,
url: str = "https://iis-people.ee.ethz.ch/~janniss/resnet9-best.pth",
url: str = f"https://iis-people.ee.ethz.ch/~janniss/{model_name_file}",
) -> None:
# Streaming, so we can iterate over the response.
# pylint: disable=missing-timeout
Expand All @@ -26,7 +29,7 @@ def download_model(

dl = 0
print("Download ResNet-9 CIFAR-10 luts, thresholds and dims")
with open(path + "/" + "resnet9-best.pth", "wb") as f:
with open(path + "/" + model_name_file, "wb") as f:
for data in r.iter_content(block_size):
f.write(data)
dl += len(data)
Expand All @@ -42,9 +45,7 @@ def download_model(
download_model(tmpdirname, is_ci=False)

script_dir = tmpdirname
state_dict = torch.load(
script_dir + "/" + "resnet9-best" + ".pth", map_location="cpu"
)
state_dict = torch.load(script_dir + "/" + model_name_file, map_location="cpu")

(
model_name,
Expand All @@ -55,36 +56,36 @@ def download_model(
args_checkpoint,
halut_modules,
checkpoint,
) = load_model(script_dir + "/" + "resnet9-best" + ".pth")
) = load_model(script_dir + "/" + model_name_file)
print(model)
print(halut_modules)

# model.to("cuda")
# criterion = nn.CrossEntropyLoss()
# evaluate(
# model,
# criterion=criterion,
# data_loader=data_loader_val,
# device="cuda",
# )

# int8 quantized model
model_int8 = torch.ao.quantization.quantize_dynamic(
model.to("cuda")
criterion = nn.CrossEntropyLoss()
evaluate(
model,
{HalutConv2d, HalutLinear},
dtype=torch.qint8,
criterion=criterion,
data_loader=data_loader_val,
device="cuda",
)

model_int8.to("cpu")
criterion = nn.CrossEntropyLoss()
# int8 quantized model
# model_int8 = torch.ao.quantization.quantize_dynamic(
# model,
# {HalutConv2d, HalutLinear},
# dtype=torch.qint8,
# )

# model_int8.to("cpu")
# criterion = nn.CrossEntropyLoss()

acc1, acc5, loss = evaluate(
model_int8,
model,
criterion=criterion,
data_loader=data_loader_val,
device="cpu",
)

assert acc1 > 0.90
assert acc1 > 0.916
assert acc5 > 0.99
assert loss < 0.5

0 comments on commit 645bfd8

Please sign in to comment.