In [1]:

import argparse
import os
import torch
import numpy as np
import yaml
from trainer import Trainer


# np.random.seed(0)
# torch.manual_seed(42)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
# torch.set_float32_matmul_precision('high')

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.set_float32_matmul_precision('highest')



config = yaml.safe_load(
r"""
data:
    dataset: old_noisy_cifar10
    noise_type: symmetric
    noise_rate: 0.5
    random_seed: 42

model:
    architecture: resnet34
    num_classes: 10

trainer:
    num_workers: 4
    batch_size: 128
"""
)

device = torch.device("cuda:0")


## Load model
import models

model = models.get_model(**config["model"])
# model.load_state_dict(torch.load("checkpoints/best.pth", map_location='cpu'))

trainer = Trainer(
    model=model,
    config=config['trainer'],
    device=device,
)

In [2]:
## Load dataset
import datasets

train_dataset, test_dataset = datasets.get_dataset(**config["data"])

dataset = train_dataset
dataset.transform = datasets.get_transform('none', dataset)
print(dataset)

Files already downloaded and verified
True noise rate: 0.4501
Files already downloaded and verified
Dataset OldNoisyCIFAR10
    Number of datapoints: 50000
    Root location: /dev/shm/data/
    Split: Train


In [3]:
## Run inference
results = trainer.inference(dataset)
results

  0%|          | 0/391 [00:00<?, ?it/s]

{'logits': tensor([[ 0.0089, -0.0057,  0.0202,  ...,  0.0048,  0.0517, -0.0003],
         [ 0.0134, -0.0040,  0.0246,  ...,  0.0046,  0.0453,  0.0060],
         [ 0.0183,  0.0039,  0.0278,  ..., -0.0040,  0.0541,  0.0219],
         ...,
         [ 0.0157, -0.0049,  0.0215,  ...,  0.0040,  0.0522,  0.0201],
         [ 0.0081,  0.0050,  0.0199,  ..., -0.0157,  0.0447,  0.0129],
         [ 0.0130,  0.0014,  0.0211,  ..., -0.0030,  0.0491,  0.0077]]),
 'target': tensor([4, 6, 9,  ..., 5, 8, 1]),
 'target_gt': tensor([6, 9, 9,  ..., 9, 1, 1])}

In [5]:
results['logits'].shape

torch.Size([50000, 10])

In [None]:
## Save results
# torch.save(results, "results.pth")