Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regarding inference with pretrained weights. #53

Open
sorobedio opened this issue Jan 9, 2023 · 0 comments
Open

Regarding inference with pretrained weights. #53

sorobedio opened this issue Jan 9, 2023 · 0 comments

Comments

@sorobedio
Copy link

sorobedio commented Jan 9, 2023

Hello.
i am currently working withe checkpoint of model query with dataset cifar10-valid. I am unable to reproduce the results in the architecture information using the same dataloader.

here is part of the code used
`
import os
import argparse
import numpy as np
import pandas as pd
import torchvision.datasets as dset
import torchvision.transforms as transforms

from tqdm import tqdm
import torch.nn as nn
import torch

from nats_bench import create
from nas_201_api import NASBench201API as API
from nas_201_api import ResultsCount
from models import get_cell_based_tiny_net

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

root = '../../../../Datasets/NASBench/'
batch_size=32
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
test_data = dset.CIFAR10(root, train=False, transform=test_transform, download=True)

testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
shuffle=False, num_workers=4)

def test():
correct = 0
total = 0

with torch.no_grad():
    for data in tqdm(testloader):
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        _, outputs = model(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

if name == 'main':
base_nasdir = '../../../../Datasets/NASBench/NATS-tss-v1_0-3ffb9-full/'
api = create(base_nasdir, 'tss', fast_mode=True, verbose=False)

config = api.get_net_config(0, 'cifar10-valid')
model = get_cell_based_tiny_net(config)
params = api.get_net_param(0, 'cifar10-valid', seed=777, hp="200")
model.load_state_dict(params)

)
model = model.to(device)
test()

outpus 75%

expected results should be in range below
82.092,
81.616
82.240
`
i wonder if my procedure is correct.
Is there a better way to reprouce the result given by

results = api.query_by_index(0, 'cifar10-valid', hp="200") for seed, result in results.items(): vacc = result.get_eval('x-valid')['accuracy'] tacc = result.get_eval('ori-test')['accuracy']

thank you in advance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant