In [8]:

import argparse
import os
import torch
import numpy as np
import yaml
from trainer import Trainer
import wandb


# np.random.seed(0)
# torch.manual_seed(42)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
# torch.set_float32_matmul_precision('high')

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.set_float32_matmul_precision('highest')



config = yaml.safe_load(
r"""
data:
    dataset: old_noisy_cifar10
    noise_type: symmetric
    noise_rate: 0.5
    random_seed: 42

model:
    architecture: resnet34
    num_classes: 10

trainer:
    num_workers: 4
    batch_size: 128
"""
)

device = torch.device("cuda:0")


## Load model
import models

model = models.get_model(**config["model"])

wandb_run_id = "2cvre1ep"

def load_checkpoint(name="model_199.pth"):
    checkpoint = wandb.restore(name, run_path=f"hyounguk-shon/noisy-label/{wandb_run_id}", replace=False, root='./temp')
    model.load_state_dict(torch.load(checkpoint.name, map_location="cpu"))
    print(f"Loaded checkpoint: {checkpoint.name}")

load_checkpoint("model_199.pth")

trainer = Trainer(
    model=model,
    config=config['trainer'],
    device=device,
)

Loaded checkpoint: ./temp/model_199.pth


In [9]:
## Load dataset
import datasets

train_dataset, test_dataset = datasets.get_dataset(**config["data"])

dataset = train_dataset
dataset.transform = datasets.get_transform('none', dataset)
print(dataset)

Files already downloaded and verified
True noise rate: 0.4501
Files already downloaded and verified
Dataset OldNoisyCIFAR10
    Number of datapoints: 50000
    Root location: /dev/shm/data/
    Split: Train


In [10]:
## Run inference
results = trainer.inference(dataset)
results

{'logits': tensor([[-0.5065, -1.4981, -4.9593,  ..., -1.0978,  1.9800, -2.1671],
         [-3.7119, -3.9473,  1.2110,  ...,  0.0594, -9.7636, 17.0981],
         [-3.4245, -2.2070, -4.2968,  ..., -4.9050,  2.7233, 13.2970],
         ...,
         [15.5436, -5.0751, -3.2363,  ..., -0.0580, -0.1786,  0.4635],
         [ 3.4131, 16.4452, -1.4655,  ..., -1.4378, -3.9767, -6.4341],
         [-2.5247, 13.6361,  1.0942,  ...,  2.2141,  5.3886,  0.7559]]),
 'target': tensor([4, 6, 9,  ..., 5, 8, 1]),
 'target_gt': tensor([6, 9, 9,  ..., 9, 1, 1])}

In [11]:
(results['logits'].max(-1).indices == results['target_gt']).float().mean()

tensor(0.5511)

In [18]:
## Run multiple inferences to aggregate results over random augmentations

from collections import defaultdict
import tqdm

dataset.transform = datasets.get_transform('autoaugment', dataset)

n_repeat = 20

many_results = defaultdict(list)
for _ in tqdm.trange(n_repeat):
    results = trainer.inference(dataset)
    for k, v in results.items():
        many_results[k].append(v)
many_results = {k: torch.stack(v, dim=1) for k, v in many_results.items()} # shape is (n_samples, n_repeat, ...)
many_results

100%|██████████| 20/20 [02:28<00:00,  7.43s/it]


{'logits': tensor([[[  3.1371,  -7.2377,  -3.8441,  ...,  -5.2848,   0.5542,   2.5441],
          [  1.0813,  -2.6185,  -0.9506,  ...,  -3.1691,  -2.2103,  -2.6586],
          [  4.0872,  -5.1103,   3.2242,  ...,   1.0411,  -0.2065,  -3.1807],
          ...,
          [ -0.5065,  -1.4981,  -4.9593,  ...,  -1.0978,   1.9800,  -2.1671],
          [ -2.5059,  -1.8346,  -6.5934,  ...,   0.4153,   2.7932,  -0.9255],
          [  2.7689,   1.0047,  -6.2310,  ...,   5.8537,   0.7352,   3.2017]],
 
         [[ -3.9351,  -4.2672,   1.2846,  ...,  -2.4506,  -4.8029,  13.6650],
          [ -3.7119,  -3.9473,   1.2110,  ...,   0.0594,  -9.7636,  17.0981],
          [ -3.5406,  -3.8940,   1.7314,  ...,  -0.4538,  -9.5749,  16.9001],
          ...,
          [ -3.7119,  -3.9473,   1.2110,  ...,   0.0594,  -9.7636,  17.0981],
          [  4.7519,   4.2068,  -6.3249,  ...,   2.9229, -10.9448,  16.8736],
          [  2.0416,   2.5977,   0.1642,  ...,  -3.4261,  -5.0222,   4.4338]],
 
         [[ -3.424

In [21]:
(many_results['logits'].mean(1).max(-1).indices == many_results['target_gt'][:,0]).float().mean()

tensor(0.5801)

In [None]:
## Save results
# torch.save(results, "results.pth")