In [9]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm


### hyper_parameter

In [10]:
batch_size = 32
test_batch_size = 1000
iters = 10000
lr = 0.01
momentum = 0.9
alpha = 1.0
xi = 10.0
eps = 1.0
ip = 1
workers = 8 
seed = 1
log_interval = 100

### vat.py

In [11]:
import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F


@contextlib.contextmanager
def _disable_tracking_bn_stats(model):

    def switch_attr(m):
        if hasattr(m, 'track_running_stats'):
            m.track_running_stats ^= True
            
    model.apply(switch_attr)
    yield
    model.apply(switch_attr)


def _l2_normalize(d):
    d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
    d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
    return d


class VATLoss(nn.Module):

    def __init__(self, xi=10.0, eps=1.0, ip=1):
        """VAT loss
        :param xi: hyperparameter of VAT (default: 10.0)
        :param eps: hyperparameter of VAT (default: 1.0)
        :param ip: iteration times of computing adv noise (default: 1)
        """
        super(VATLoss, self).__init__()
        self.xi = xi
        self.eps = eps
        self.ip = ip

    def forward(self, model, x):
        with torch.no_grad():
            pred = F.softmax(model(x), dim=1)

        # prepare random unit tensor
        d = torch.rand(x.shape).sub(0.5).to(x.device)
        d = _l2_normalize(d)

        with _disable_tracking_bn_stats(model):
            # calc adversarial direction
            for _ in range(self.ip):
                d.requires_grad_()
                pred_hat = model(x + self.xi * d)
                logp_hat = F.log_softmax(pred_hat, dim=1)
                adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
                adv_distance.backward()
                d = _l2_normalize(d.grad)
                model.zero_grad()
    
            # calc LDS
            r_adv = d * self.eps
            pred_hat = model(x + r_adv)
            logp_hat = F.log_softmax(pred_hat, dim=1)
            lds = F.kl_div(logp_hat, pred, reduction='batchmean')

        return lds

### data_utils.py 

In [67]:
import numpy as np
from sklearn.preprocessing import LabelBinarizer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import sampler
from torchvision import datasets
from torchvision import transforms


class SimpleDataset(Dataset):

    def __init__(self, x, y, transform):
        self.x = x
        self.y = y
        self.transform = transform

    def __getitem__(self, index):
        img = self.x[index]
        if self.transform is not None:
            img = self.transform(img)
        target = self.y[index]
        return img, target

    def __len__(self):
        return len(self.x)


class InfiniteSampler(sampler.Sampler):

    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __iter__(self):
        while True:
            order = np.random.permutation(self.num_samples)
            for i in range(self.num_samples):
                yield order[i]

    def __len__(self):
        return None


def get_iters(
        dataset='CIFAR10', root_path='.', data_transforms=None,
        n_labeled=4000, valid_size=1000,
        l_batch_size=32, ul_batch_size=128, test_batch_size=256,
        workers=0, pseudo_label=None):
    
    train_path = f'{root_path}/data/{dataset}/train/'
    test_path = f'{root_path}/data/{dataset}/test/'

    if dataset == 'CIFAR10':
        train_dataset = datasets.CIFAR10(train_path, download=True, train=True, transform=None)
        test_dataset = datasets.CIFAR10(test_path, download=True, train=False, transform=None)
    elif dataset == 'CIFAR100':
        train_dataset = datasets.CIFAR100(train_path, download=True, train=True, transform=None)
        test_dataset = datasets.CIFAR100(test_path, download=True, train=False, transform=None)
    else:
        raise ValueError

    if data_transforms is None:
        data_transforms = {
            'train': transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]),
            'eval': transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]),
        }

    x_train, y_train = train_dataset.data, np.array(train_dataset.targets)
    x_test, y_test = test_dataset.data, np.array(test_dataset.targets)

    randperm = np.random.permutation(len(x_train))
    labeled_idx = randperm[:n_labeled]
    validation_idx = randperm[n_labeled:n_labeled + valid_size]
    unlabeled_idx = randperm[n_labeled + valid_size:]

    x_labeled = x_train[labeled_idx]
    x_validation = x_train[validation_idx]
    x_unlabeled = x_train[unlabeled_idx]

    y_labeled = y_train[labeled_idx]
    y_validation = y_train[validation_idx]
    if pseudo_label is None:
        y_unlabeled = y_train[unlabeled_idx]
    else:
        assert isinstance(pseudo_label, np.ndarray)
        y_unlabeled = pseudo_label
    
    data_iterators = {
        'labeled': iter(DataLoader(
            SimpleDataset(x_labeled, y_labeled, data_transforms['train']),
            batch_size=l_batch_size, num_workers=workers,
            sampler=InfiniteSampler(len(x_labeled)),
        )),
        'unlabeled': iter(DataLoader(
            SimpleDataset(x_unlabeled, y_unlabeled, data_transforms['train']),
            batch_size=ul_batch_size, num_workers=workers,
            sampler=InfiniteSampler(len(x_unlabeled)),
        )),
        'make_pl': iter(DataLoader(
            SimpleDataset(x_unlabeled, y_unlabeled, data_transforms['eval']),
            batch_size=ul_batch_size, num_workers=workers, shuffle=False
        )),
        'val': iter(DataLoader(
            SimpleDataset(x_validation, y_validation, data_transforms['eval']),
            batch_size=len(x_validation), num_workers=workers, shuffle=False
        )),
        'test': iter(DataLoader(
            SimpleDataset(x_test, y_test, data_transforms['eval']),
            batch_size=test_batch_size, num_workers=workers, shuffle=False
        ))
    }

    return data_iterators


### utils.py

In [68]:
from collections import OrderedDict
import logging
from pathlib import Path
import torch


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, top_k=(1,)):
    """Computes the precision@k for the specified values of k"""
    max_k = max(top_k)
    batch_size = target.size(0)

    _, pred = output.topk(max_k, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in top_k:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))

    if len(res) == 1:
        res = res[0]

    return res

### vat.py

In [99]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3)
        self.fc1 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(-1, 128)
        x = self.fc1(x)
        return x
    
def train(model, device, data_iterators, optimizer):
    model.train()
    for i in tqdm(range(1000)):
        
        # reset
        if i % 100 == 0:
            ce_losses = AverageMeter()
            vat_losses = AverageMeter()
            prec1 = AverageMeter()
        
        x_l, y_l = next(data_iterators['labeled'])
        x_ul, _ = next(data_iterators['unlabeled'])

        x_l, y_l = x_l.to(device), y_l.to(device)
        x_ul = x_ul.to(device)

        optimizer.zero_grad()

        vat_loss = VATLoss(xi=10.0 , eps=1.0, ip=1)
        cross_entropy = nn.CrossEntropyLoss()

        lds = vat_loss(model, x_ul)
        output = model(x_l)
        classification_loss = cross_entropy(output, y_l)
        loss = classification_loss + 1.0 * lds
        loss.backward()
        optimizer.step()

        acc = accuracy(output, y_l)
        ce_losses.update(classification_loss.item(), x_l.shape[0])
        vat_losses.update(lds.item(), x_ul.shape[0])
        prec1.update(acc.item(), x_l.shape[0])

        if i % 100 == 0:
            print(f'\nIteration: {i}\t'
                  f'CrossEntropyLoss {ce_losses.val:.4f} ({ce_losses.avg:.4f})\t'
                  f'VATLoss {vat_losses.val:.4f} ({vat_losses.avg:.4f})\t'
                  f'Prec@1 {prec1.val:.3f} ({prec1.avg:.3f})')
            


### main

In [103]:
device = torch.device('mps')

data_iterators = get_iters(
        root_path='.',
        l_batch_size=32,
        ul_batch_size=32,
        test_batch_size=1000,
        workers=0
    )

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

train(model, device, data_iterators, optimizer)

Files already downloaded and verified
Files already downloaded and verified


  0%|          | 5/10000 [00:00<13:21, 12.47it/s]  


Iteration: 0	CrossEntropyLoss 2.3205 (2.3205)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


  1%|          | 106/10000 [00:03<04:18, 38.26it/s]


Iteration: 100	CrossEntropyLoss 2.2977 (2.2977)	VATLoss 0.0000 (0.0000)	Prec@1 21.875 (21.875)


  2%|▏         | 208/10000 [00:05<04:17, 38.04it/s]


Iteration: 200	CrossEntropyLoss 2.2995 (2.2995)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


  3%|▎         | 308/10000 [00:08<04:11, 38.52it/s]


Iteration: 300	CrossEntropyLoss 2.3130 (2.3130)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


  4%|▍         | 404/10000 [00:11<03:57, 40.44it/s]


Iteration: 400	CrossEntropyLoss 2.3075 (2.3075)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


  5%|▌         | 506/10000 [00:13<04:24, 35.94it/s]


Iteration: 500	CrossEntropyLoss 2.2973 (2.2973)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


  6%|▌         | 608/10000 [00:16<04:06, 38.06it/s]


Iteration: 600	CrossEntropyLoss 2.3258 (2.3258)	VATLoss 0.0000 (0.0000)	Prec@1 0.000 (0.000)


  7%|▋         | 705/10000 [00:19<03:51, 40.17it/s]


Iteration: 700	CrossEntropyLoss 2.3067 (2.3067)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


  8%|▊         | 806/10000 [00:21<04:01, 38.09it/s]


Iteration: 800	CrossEntropyLoss 2.3106 (2.3106)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


  9%|▉         | 906/10000 [00:24<04:10, 36.37it/s]


Iteration: 900	CrossEntropyLoss 2.3134 (2.3134)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 10%|█         | 1008/10000 [00:27<03:46, 39.71it/s]


Iteration: 1000	CrossEntropyLoss 2.3017 (2.3017)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 11%|█         | 1108/10000 [00:29<03:52, 38.22it/s]


Iteration: 1100	CrossEntropyLoss 2.2992 (2.2992)	VATLoss 0.0000 (0.0000)	Prec@1 18.750 (18.750)


 12%|█▏        | 1208/10000 [00:32<03:35, 40.77it/s]


Iteration: 1200	CrossEntropyLoss 2.3048 (2.3048)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 13%|█▎        | 1306/10000 [00:34<03:42, 39.05it/s]


Iteration: 1300	CrossEntropyLoss 2.2991 (2.2991)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 14%|█▍        | 1407/10000 [00:37<03:44, 38.34it/s]


Iteration: 1400	CrossEntropyLoss 2.3091 (2.3091)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 15%|█▌        | 1506/10000 [00:40<03:48, 37.24it/s]


Iteration: 1500	CrossEntropyLoss 2.3128 (2.3128)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 16%|█▌        | 1605/10000 [00:42<03:33, 39.32it/s]


Iteration: 1600	CrossEntropyLoss 2.2993 (2.2993)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 17%|█▋        | 1705/10000 [00:45<03:17, 42.06it/s]


Iteration: 1700	CrossEntropyLoss 2.3068 (2.3068)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 18%|█▊        | 1808/10000 [00:47<03:46, 36.16it/s]


Iteration: 1800	CrossEntropyLoss 2.3430 (2.3430)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 19%|█▉        | 1907/10000 [00:50<03:35, 37.53it/s]


Iteration: 1900	CrossEntropyLoss 2.2868 (2.2868)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 20%|██        | 2009/10000 [00:52<03:23, 39.32it/s]


Iteration: 2000	CrossEntropyLoss 2.3134 (2.3134)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 21%|██        | 2105/10000 [00:55<03:23, 38.82it/s]


Iteration: 2100	CrossEntropyLoss 2.3035 (2.3035)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 22%|██▏       | 2205/10000 [00:57<03:23, 38.34it/s]


Iteration: 2200	CrossEntropyLoss 2.3196 (2.3196)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 23%|██▎       | 2305/10000 [01:00<03:22, 37.99it/s]


Iteration: 2300	CrossEntropyLoss 2.3022 (2.3022)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 24%|██▍       | 2408/10000 [01:03<03:08, 40.31it/s]


Iteration: 2400	CrossEntropyLoss 2.2988 (2.2988)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 25%|██▌       | 2507/10000 [01:05<03:13, 38.81it/s]


Iteration: 2500	CrossEntropyLoss 2.2971 (2.2971)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 26%|██▌       | 2609/10000 [01:08<03:08, 39.16it/s]


Iteration: 2600	CrossEntropyLoss 2.3063 (2.3063)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 27%|██▋       | 2706/10000 [01:11<03:02, 40.04it/s]


Iteration: 2700	CrossEntropyLoss 2.2956 (2.2956)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 28%|██▊       | 2805/10000 [01:13<03:10, 37.76it/s]


Iteration: 2800	CrossEntropyLoss 2.3281 (2.3281)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 29%|██▉       | 2908/10000 [01:16<03:02, 38.89it/s]


Iteration: 2900	CrossEntropyLoss 2.3257 (2.3257)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 30%|███       | 3008/10000 [01:18<03:04, 37.80it/s]


Iteration: 3000	CrossEntropyLoss 2.3044 (2.3044)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 31%|███       | 3108/10000 [01:21<02:51, 40.11it/s]


Iteration: 3100	CrossEntropyLoss 2.2908 (2.2908)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 32%|███▏      | 3205/10000 [01:23<02:49, 40.00it/s]


Iteration: 3200	CrossEntropyLoss 2.3052 (2.3052)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 33%|███▎      | 3307/10000 [01:26<02:42, 41.31it/s]


Iteration: 3300	CrossEntropyLoss 2.3129 (2.3129)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 34%|███▍      | 3408/10000 [01:29<02:43, 40.24it/s]


Iteration: 3400	CrossEntropyLoss 2.2923 (2.2923)	VATLoss 0.0000 (0.0000)	Prec@1 21.875 (21.875)


 35%|███▌      | 3509/10000 [01:31<02:46, 38.88it/s]


Iteration: 3500	CrossEntropyLoss 2.3032 (2.3032)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 36%|███▌      | 3605/10000 [01:34<02:33, 41.57it/s]


Iteration: 3600	CrossEntropyLoss 2.2980 (2.2980)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 37%|███▋      | 3708/10000 [01:36<02:48, 37.36it/s]


Iteration: 3700	CrossEntropyLoss 2.2998 (2.2998)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 38%|███▊      | 3806/10000 [01:39<02:44, 37.73it/s]


Iteration: 3800	CrossEntropyLoss 2.3025 (2.3025)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 39%|███▉      | 3906/10000 [01:42<02:30, 40.41it/s]


Iteration: 3900	CrossEntropyLoss 2.3101 (2.3101)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 40%|████      | 4005/10000 [01:44<02:39, 37.60it/s]


Iteration: 4000	CrossEntropyLoss 2.3010 (2.3010)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 41%|████      | 4107/10000 [01:47<02:32, 38.57it/s]


Iteration: 4100	CrossEntropyLoss 2.3126 (2.3126)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 42%|████▏     | 4207/10000 [01:49<02:23, 40.31it/s]


Iteration: 4200	CrossEntropyLoss 2.2839 (2.2839)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 43%|████▎     | 4306/10000 [01:52<02:20, 40.49it/s]


Iteration: 4300	CrossEntropyLoss 2.3169 (2.3169)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 44%|████▍     | 4407/10000 [01:54<02:25, 38.54it/s]


Iteration: 4400	CrossEntropyLoss 2.2931 (2.2931)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 45%|████▌     | 4508/10000 [01:57<02:16, 40.22it/s]


Iteration: 4500	CrossEntropyLoss 2.2791 (2.2791)	VATLoss 0.0000 (0.0000)	Prec@1 25.000 (25.000)


 46%|████▌     | 4606/10000 [02:00<02:27, 36.54it/s]


Iteration: 4600	CrossEntropyLoss 2.3039 (2.3039)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 47%|████▋     | 4706/10000 [02:02<02:21, 37.35it/s]


Iteration: 4700	CrossEntropyLoss 2.3384 (2.3384)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 48%|████▊     | 4806/10000 [02:05<02:13, 38.91it/s]


Iteration: 4800	CrossEntropyLoss 2.3152 (2.3152)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 49%|████▉     | 4907/10000 [02:07<02:21, 35.91it/s]


Iteration: 4900	CrossEntropyLoss 2.3009 (2.3009)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 50%|█████     | 5008/10000 [02:10<02:09, 38.52it/s]


Iteration: 5000	CrossEntropyLoss 2.2931 (2.2931)	VATLoss 0.0000 (0.0000)	Prec@1 18.750 (18.750)


 51%|█████     | 5109/10000 [02:13<02:02, 40.00it/s]


Iteration: 5100	CrossEntropyLoss 2.3010 (2.3010)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 52%|█████▏    | 5205/10000 [02:15<02:01, 39.61it/s]


Iteration: 5200	CrossEntropyLoss 2.2974 (2.2974)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 53%|█████▎    | 5307/10000 [02:18<02:03, 38.06it/s]


Iteration: 5300	CrossEntropyLoss 2.2992 (2.2992)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 54%|█████▍    | 5407/10000 [02:20<02:07, 36.07it/s]


Iteration: 5400	CrossEntropyLoss 2.3020 (2.3020)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 55%|█████▌    | 5506/10000 [02:23<01:56, 38.72it/s]


Iteration: 5500	CrossEntropyLoss 2.2992 (2.2992)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 56%|█████▌    | 5607/10000 [02:26<01:52, 39.10it/s]


Iteration: 5600	CrossEntropyLoss 2.3098 (2.3098)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 57%|█████▋    | 5709/10000 [02:28<01:49, 39.19it/s]


Iteration: 5700	CrossEntropyLoss 2.3086 (2.3086)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 58%|█████▊    | 5807/10000 [02:31<01:47, 39.13it/s]


Iteration: 5800	CrossEntropyLoss 2.3169 (2.3169)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 59%|█████▉    | 5906/10000 [02:33<01:39, 41.20it/s]


Iteration: 5900	CrossEntropyLoss 2.2712 (2.2712)	VATLoss 0.0000 (0.0000)	Prec@1 25.000 (25.000)


 60%|██████    | 6006/10000 [02:36<01:42, 38.92it/s]


Iteration: 6000	CrossEntropyLoss 2.2954 (2.2954)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 61%|██████    | 6105/10000 [02:38<01:35, 40.71it/s]


Iteration: 6100	CrossEntropyLoss 2.3041 (2.3041)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 62%|██████▏   | 6206/10000 [02:41<01:39, 38.31it/s]


Iteration: 6200	CrossEntropyLoss 2.3097 (2.3097)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 63%|██████▎   | 6305/10000 [02:43<01:33, 39.50it/s]


Iteration: 6300	CrossEntropyLoss 2.3079 (2.3079)	VATLoss 0.0000 (0.0000)	Prec@1 0.000 (0.000)


 64%|██████▍   | 6408/10000 [02:46<01:29, 40.21it/s]


Iteration: 6400	CrossEntropyLoss 2.3100 (2.3100)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 65%|██████▌   | 6505/10000 [02:49<01:25, 40.66it/s]


Iteration: 6500	CrossEntropyLoss 2.2967 (2.2967)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 66%|██████▌   | 6605/10000 [02:51<01:32, 36.51it/s]


Iteration: 6600	CrossEntropyLoss 2.2998 (2.2998)	VATLoss 0.0000 (0.0000)	Prec@1 0.000 (0.000)


 67%|██████▋   | 6705/10000 [02:54<01:20, 40.68it/s]


Iteration: 6700	CrossEntropyLoss 2.3076 (2.3076)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 68%|██████▊   | 6808/10000 [02:57<01:19, 40.06it/s]


Iteration: 6800	CrossEntropyLoss 2.2808 (2.2808)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 69%|██████▉   | 6907/10000 [02:59<01:24, 36.57it/s]


Iteration: 6900	CrossEntropyLoss 2.3087 (2.3087)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 70%|███████   | 7007/10000 [03:02<01:17, 38.72it/s]


Iteration: 7000	CrossEntropyLoss 2.3154 (2.3154)	VATLoss 0.0000 (0.0000)	Prec@1 0.000 (0.000)


 71%|███████   | 7105/10000 [03:04<01:19, 36.54it/s]


Iteration: 7100	CrossEntropyLoss 2.3023 (2.3023)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 72%|███████▏  | 7207/10000 [03:07<01:14, 37.61it/s]


Iteration: 7200	CrossEntropyLoss 2.2992 (2.2992)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 73%|███████▎  | 7307/10000 [03:10<01:09, 38.52it/s]


Iteration: 7300	CrossEntropyLoss 2.3015 (2.3015)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 74%|███████▍  | 7408/10000 [03:12<01:11, 36.50it/s]


Iteration: 7400	CrossEntropyLoss 2.3010 (2.3010)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 75%|███████▌  | 7508/10000 [03:15<01:05, 37.77it/s]


Iteration: 7500	CrossEntropyLoss 2.3041 (2.3041)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 76%|███████▌  | 7609/10000 [03:18<00:59, 40.11it/s]


Iteration: 7600	CrossEntropyLoss 2.3147 (2.3147)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 77%|███████▋  | 7707/10000 [03:20<00:58, 38.88it/s]


Iteration: 7700	CrossEntropyLoss 2.3088 (2.3088)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 78%|███████▊  | 7807/10000 [03:23<00:58, 37.26it/s]


Iteration: 7800	CrossEntropyLoss 2.2970 (2.2970)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 79%|███████▉  | 7905/10000 [03:25<00:53, 39.45it/s]


Iteration: 7900	CrossEntropyLoss 2.3210 (2.3210)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 80%|████████  | 8004/10000 [03:28<00:53, 37.54it/s]


Iteration: 8000	CrossEntropyLoss 2.3204 (2.3204)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 81%|████████  | 8106/10000 [03:31<00:50, 37.71it/s]


Iteration: 8100	CrossEntropyLoss 2.3139 (2.3139)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 82%|████████▏ | 8205/10000 [03:33<00:46, 38.37it/s]


Iteration: 8200	CrossEntropyLoss 2.3055 (2.3055)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 83%|████████▎ | 8307/10000 [03:36<00:44, 38.12it/s]


Iteration: 8300	CrossEntropyLoss 2.3096 (2.3096)	VATLoss 0.0000 (0.0000)	Prec@1 0.000 (0.000)


 84%|████████▍ | 8406/10000 [03:38<00:43, 36.94it/s]


Iteration: 8400	CrossEntropyLoss 2.2967 (2.2967)	VATLoss 0.0000 (0.0000)	Prec@1 18.750 (18.750)


 85%|████████▌ | 8506/10000 [03:41<00:38, 38.51it/s]


Iteration: 8500	CrossEntropyLoss 2.3216 (2.3216)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 86%|████████▌ | 8607/10000 [03:44<00:37, 36.67it/s]


Iteration: 8600	CrossEntropyLoss 2.3025 (2.3025)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 87%|████████▋ | 8709/10000 [03:46<00:32, 40.29it/s]


Iteration: 8700	CrossEntropyLoss 2.3188 (2.3188)	VATLoss 0.0000 (0.0000)	Prec@1 3.125 (3.125)


 88%|████████▊ | 8808/10000 [03:49<00:30, 39.32it/s]


Iteration: 8800	CrossEntropyLoss 2.3004 (2.3004)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 89%|████████▉ | 8906/10000 [03:51<00:29, 36.55it/s]


Iteration: 8900	CrossEntropyLoss 2.3107 (2.3107)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 90%|█████████ | 9005/10000 [03:54<00:26, 37.98it/s]


Iteration: 9000	CrossEntropyLoss 2.2901 (2.2901)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 91%|█████████ | 9105/10000 [03:57<00:23, 37.56it/s]


Iteration: 9100	CrossEntropyLoss 2.3068 (2.3068)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 92%|█████████▏| 9206/10000 [03:59<00:20, 38.17it/s]


Iteration: 9200	CrossEntropyLoss 2.3062 (2.3062)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 93%|█████████▎| 9308/10000 [04:02<00:17, 40.63it/s]


Iteration: 9300	CrossEntropyLoss 2.3307 (2.3307)	VATLoss 0.0000 (0.0000)	Prec@1 0.000 (0.000)


 94%|█████████▍| 9404/10000 [04:04<00:14, 40.09it/s]


Iteration: 9400	CrossEntropyLoss 2.2987 (2.2987)	VATLoss 0.0000 (0.0000)	Prec@1 12.500 (12.500)


 95%|█████████▌| 9506/10000 [04:07<00:13, 37.38it/s]


Iteration: 9500	CrossEntropyLoss 2.2851 (2.2851)	VATLoss 0.0000 (0.0000)	Prec@1 15.625 (15.625)


 96%|█████████▌| 9605/10000 [04:10<00:10, 38.42it/s]


Iteration: 9600	CrossEntropyLoss 2.3178 (2.3178)	VATLoss 0.0000 (0.0000)	Prec@1 6.250 (6.250)


 97%|█████████▋| 9708/10000 [04:12<00:07, 39.75it/s]


Iteration: 9700	CrossEntropyLoss 2.3060 (2.3060)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 98%|█████████▊| 9806/10000 [04:15<00:05, 37.99it/s]


Iteration: 9800	CrossEntropyLoss 2.3071 (2.3071)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


 99%|█████████▉| 9907/10000 [04:17<00:02, 37.45it/s]


Iteration: 9900	CrossEntropyLoss 2.3026 (2.3026)	VATLoss 0.0000 (0.0000)	Prec@1 9.375 (9.375)


100%|██████████| 10000/10000 [04:20<00:00, 38.42it/s]


In [104]:
def test(model, device, data_iterators):
    model.eval()
    correct = 0
    
    with torch.no_grad():
        for index , (x, y) in enumerate(data_iterators['test']):
            print(index , x.shape)
            with torch.no_grad():
                x, y = x.to(device), y.to(device)
                outputs = model(x)
            correct += torch.eq(outputs.max(dim=1)[1], y).detach().cpu().float().sum()
        test_acc = correct / len(data_iterators['test']._dataset) * 100

    print(f'\nTest Accuracy: {test_acc:.4f}%\n')

In [105]:
test(model, device, data_iterators)

0 torch.Size([1000, 3, 32, 32])
1 torch.Size([1000, 3, 32, 32])
2 torch.Size([1000, 3, 32, 32])
3 torch.Size([1000, 3, 32, 32])
4 torch.Size([1000, 3, 32, 32])
5 torch.Size([1000, 3, 32, 32])
6 torch.Size([1000, 3, 32, 32])
7 torch.Size([1000, 3, 32, 32])
8 torch.Size([1000, 3, 32, 32])
9 torch.Size([1000, 3, 32, 32])

Test Accuracy: 10.0000%



In [98]:
data_iterators['test']._next_index

<bound method _BaseDataLoaderIter._next_index of <torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x177434490>>