In [7]:
import torch
from simclr import SimCLR
from torchvision import transforms
from torchvision.models import resnet18, resnet50
import torchvision
from types import SimpleNamespace
from linear_evaluation import compute_metrics, inference
from tqdm import tqdm

In [8]:
args = SimpleNamespace()
args.model = 'resnet50'
args.training_method = 'simclr'
args.dataset = 'cifar10'
args.batch_size = 32
args.image_size = 224
args.projection_dim = 64
args.ckpt_path = "/home/levscaut/SimCLR/save_resnet50_supervised/checkpoint_{epoch}.tar"

In [9]:
test_dataset = torchvision.datasets.CIFAR10(download=True, train=False, root="datasets", transform=transforms.Compose([
    torchvision.transforms.Resize(size=args.image_size),
    transforms.ToTensor(),
    ]))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
print('# of test data:', len(test_dataset))

Files already downloaded and verified
# of test data: 10000


In [10]:
model_dict = {
    'resnet18': resnet18,
    'resnet50': resnet50
}
model = model_dict[args.model](pretrained=False, num_classes=len(test_dataset.classes))

In [11]:
state_dict = torch.load(args.ckpt_path)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [15]:
model.eval()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
count = 0
all_metrics = {
    'acc': 0,
    'precision': 0,
    'recall': 0,
    'f1': 0,
}
for i, (images, labels) in tqdm(enumerate(test_loader), total=len(test_loader)):
    images = images.to(device)
    labels = labels.to(device)
    pred = model(images)
    res = compute_metrics(pred, labels)
    for key in all_metrics.keys():
        all_metrics[key] += res[key] * len(labels)
    count += len(labels)
    
for key in all_metrics.keys():
    all_metrics[key] /= count
print(all_metrics)
    

100%|██████████| 313/313 [00:19<00:00, 15.95it/s]

{'acc': tensor(0.9123, device='cuda:0'), 'precision': tensor(0.9072, device='cuda:0'), 'recall': tensor(0.9058, device='cuda:0'), 'f1': tensor(0.8932, device='cuda:0')}





100:
{'acc': tensor(0.9153, device='cuda:0'), 'precision': tensor(0.9054, device='cuda:0'), 'recall': tensor(0.9047, device='cuda:0'), 'f1': tensor(0.8925, device='cuda:0')}