Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About The Experiment on CIFAR10 #3

Open
mengxianghan123 opened this issue Oct 25, 2022 · 1 comment
Open

About The Experiment on CIFAR10 #3

mengxianghan123 opened this issue Oct 25, 2022 · 1 comment

Comments

@mengxianghan123
Copy link

Thanks for your GREAT work!! And the released code really helps a lot!!
But when I tried to replicate the experimental results on CIFAR10, I failed and only got 0.35 accuracy.
It might because of the inappropriate hyper-parameter setting, or my misunderstanding on other experimental details on CIFAR10.
I've tried different beta values (0.1, 1,5, 10) and d values (5, 10). And for the feature extraction on CIFAR10, here's my implementation:

@torch.no_grad()
def CIFAR10_features():
    cifar10 = torchvision.datasets.CIFAR10(root='/home/mxh/datasets/CIFAR10',train=True,download=False)
    model = torchvision.models.resnet50(weights='DEFAULT')
    model.eval()
    model.fc= torch.nn.Identity()
    to_tensor = transforms.ToTensor()
    img_list = []
    label_list = []
    for idx in tqdm(range(len(cifar10))):
        img,label = cifar10[idx]
        img = to_tensor(img).unsqueeze(0)
        img_list.append(model(img))
        label_list.append(label)
    cifar10 = torchvision.datasets.CIFAR10(root='/home/mxh/datasets/CIFAR10',train=False,download=False)
    for idx in tqdm(range(len(cifar10))):
        img,label = cifar10[idx]
        img = to_tensor(img).unsqueeze(0)
        img_list.append(model(img))
        label_list.append(label)
    img_list = torch.cat(img_list, dim=0).numpy()
    label_list = np.array(label_list)
    data = {'data':img_list, 'label':label_list}
    np.save("/home/mxh/codes/EDESC/data/CIFAR10/cifar.npy", data)

Could you please give me some advice on the replication of CIFAR10? It would be extremely helpful!! Thanks a lot!!

@mengxianghan123
Copy link
Author

After adding a pre-processing step, the ACC can reach 0.457 for now. But there is still a gap between 0.627 which is reported in the paper. Could you please leave more details? It would be very helpful!

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
@torch.no_grad()
def CIFAR10_features():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
    cifar10 = torchvision.datasets.CIFAR10(root='/home/mxh/datasets/CIFAR10',train=True,download=False,transform=transform)
    trainloader = torch.utils.data.DataLoader(cifar10, batch_size=2048, shuffle=True, num_workers=2)
    model = torchvision.models.resnet50(weights='DEFAULT')
    model.eval()
    model.fc= torch.nn.Identity()
    img_list = []
    label_list = []
    for img,label in trainloader:
        img_list.append(model(img))
        label_list.append(label)
    cifar10 = torchvision.datasets.CIFAR10(root='/home/mxh/datasets/CIFAR10',train=False,download=False,transform=transform)
    testloader = torch.utils.data.DataLoader(cifar10, batch_size=2048, shuffle=True, num_workers=2)
    for img,label in testloader:
        img_list.append(model(img))
        label_list.append(label)
    img_list = torch.cat(img_list, dim=0).numpy()
    label_list = torch.cat(label_list, dim=0).numpy()
    data = {'data':img_list, 'label':label_list}
    np.save("/home/mxh/codes/EDESC-pytorch-master/data/CIFAR10/cifar.npy", data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant