In [None]:
import os
import time
import argparse

import torchvision
import torch
import torch.nn as nn

from util import AverageMeter
from encoder import SmallAlexNet
from align_uniform import align_loss, uniform_loss_prelog
from tqdm import tqdm
from collections import defaultdict
import copy

import matplotlib.pyplot as plt


class TwoAugUnsupervisedDatasetLbl(torch.utils.data.Dataset):
    r"""Returns two augmentation and no labels."""

    def __init__(self, dataset, transform, lblmap=None):
        self.dataset = dataset
        self.transform = transform
        self.lblmap = copy.deepcopy(lblmap)

    def __getitem__(self, index):
        image, lbl = self.dataset[index]
        lbl2return = lbl if self.lblmap is None else self.lblmap[lbl]
        return self.transform(image), self.transform(image), lbl2return

    def __len__(self):
        return len(self.dataset)

def parse_option():
    parser = argparse.ArgumentParser('STL-10 Representation Learning with Alignment and Uniformity Losses')

    parser.add_argument('--align_w', type=float, default=1, help='Alignment loss weight')
    parser.add_argument('--unif_w', type=float, default=1, help='Uniformity loss weight')
    parser.add_argument('--align_alpha', type=float, default=2, help='alpha in alignment loss')
    parser.add_argument('--unif_t', type=float, default=2, help='t in uniformity loss')

    parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
    parser.add_argument('--epochs', type=int, default=400, help='Number of training epochs')
    parser.add_argument('--iter', type=int, default=0, help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=None,
                        help='Learning rate. Default is linear scaling 0.12 per 256 batch size')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='Learning rate decay rate')
    parser.add_argument('--lr_decay_epochs', default=[155, 170, 185], nargs='*', type=int,
                        help='When to decay learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='L2 weight decay')
    parser.add_argument('--feat_dim', type=int, default=128, help='Feature dimensionality')

    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loader workers to use')
    parser.add_argument('--log_interval', type=int, default=40, help='Number of iterations between logs')
    parser.add_argument('--gpus', default=[0], nargs='*', type=int,
                        help='List of GPU indices to use, e.g., --gpus 0 1 2 3')

    parser.add_argument('--data_folder', type=str, default='./data', help='Path to data')
    parser.add_argument('--result_folder', type=str, default='./results', help='Base directory to save model')

    opt = parser.parse_args("")

    if opt.lr is None:
        opt.lr = 0.12 * (opt.batch_size / 256)

    opt.gpus = list(map(lambda x: torch.device('cuda', x), opt.gpus))

    opt.save_folder = os.path.join(
        opt.result_folder,
        f"cifar100_{opt.epochs}_4_sideinformation_align{opt.align_w:g}alpha{opt.align_alpha:g}_unif{opt.unif_w:g}t{opt.unif_t:g}_iter{opt.iter}"
    )
    os.makedirs(opt.save_folder, exist_ok=True)

    return opt


opt = parse_option()

In [None]:
transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(32, scale=(0.08, 1)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        torchvision.transforms.RandomGrayscale(p=0.2),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (0.44087801806139126, 0.42790631331699347, 0.3867879370752931),
            (0.26826768628079806, 0.2610450402318512, 0.26866836876860795),
        ),
    ])

old_lbls = list(range(10))
labels_2_keep = [0,1]
# labels_2_keep = [0,1,2,3]

old2new = {}
count = 0
for old_lbl in old_lbls:
    if old_lbl in labels_2_keep: 
        old2new[old_lbl] = count
        count += 1

for old_lbl in old_lbls:
    if old_lbl not in labels_2_keep: 
        old2new[old_lbl] = count

new_lbls = list(range(count+1))

In [11]:
def get_data_loader(opt):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(32, scale=(0.08, 1)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        torchvision.transforms.RandomGrayscale(p=0.2),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (0.44087801806139126, 0.42790631331699347, 0.3867879370752931),
            (0.26826768628079806, 0.2610450402318512, 0.26866836876860795),
        ),
    ])
    dataset = TwoAugUnsupervisedDatasetLbl(
        torchvision.datasets.STL10(opt.data_folder, 'train', download=True), 
        transform=transform, 
        lblmap=old2new )
    
    return torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, num_workers=opt.num_workers,
                                       shuffle=True, pin_memory=True)


print(f'Optimize: {opt.align_w:g} * loss_align(alpha={opt.align_alpha:g}) + {opt.unif_w:g} * loss_uniform(t={opt.unif_t:g})')

torch.cuda.set_device(opt.gpus[0])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

encoder = SmallAlexNet(feat_dim=opt.feat_dim, cifar=True).to(opt.gpus[0])

optim = torch.optim.Adam(encoder.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, gamma=opt.lr_decay_rate,
                                                 milestones=opt.lr_decay_epochs)

loader = get_data_loader(opt)
align_meter = AverageMeter('align_loss')
unif_meter = AverageMeter('uniform_loss')
loss_meter = AverageMeter('total_loss')
it_time_meter = AverageMeter('iter_time')

for epoch in range(opt.epochs):
    align_meter.reset()
    unif_meter.reset()
    loss_meter.reset()
    it_time_meter.reset()
    t0 = time.time()
    for ii, (im_x, im_y, lbl) in enumerate(loader):
        optim.zero_grad()
        x, y = encoder(torch.cat([im_x.to(opt.gpus[0]), im_y.to(opt.gpus[0])])).chunk(2)
        
        align_loss_val = align_loss(x, y, alpha=opt.align_alpha)
        # group according to new_lbls

        z = torch.cat( [x, y])
        lbl_z = torch.cat([lbl, lbl])
        unif_losses = torch.cat([uniform_loss_prelog(z[lbl_z==new_lbl]) for new_lbl in new_lbls])
        unif_loss_val = torch.log( torch.mean(unif_losses) )
        
        loss = align_loss_val * opt.align_w + unif_loss_val * opt.unif_w
        align_meter.update(align_loss_val, x.shape[0])
        unif_meter.update(unif_loss_val)
        loss_meter.update(loss, x.shape[0])
        loss.backward()
        optim.step()
        it_time_meter.update(time.time() - t0)
        if ii % opt.log_interval == 0:
            print(f"Epoch {epoch}/{opt.epochs}\tIt {ii}/{len(loader)}\t" +
                  f"{align_meter}\t{unif_meter}\t{loss_meter}\t{it_time_meter}")
        t0 = time.time()
    scheduler.step()

ckpt_file = os.path.join(opt.save_folder, 'encoder.pth')
torch.save(encoder.state_dict(), ckpt_file)
print(f'Saved to {ckpt_file}')

Optimize: 1 * loss_align(alpha=2) + 1 * loss_uniform(t=2)
Files already downloaded and verified
Epoch 0/400	It 0/20	align_loss 1.298860 (1.298860)	uniform_loss -2.720523 (-2.720523)	total_loss -1.421663 (-1.421663)	iter_time 2.359052 (2.359052)
Epoch 1/400	It 0/20	align_loss 0.323790 (0.323790)	uniform_loss -0.629418 (-0.629418)	total_loss -0.305628 (-0.305628)	iter_time 0.556997 (0.556997)
Epoch 2/400	It 0/20	align_loss 0.487992 (0.487992)	uniform_loss -1.008581 (-1.008581)	total_loss -0.520589 (-0.520589)	iter_time 0.594481 (0.594481)
Epoch 3/400	It 0/20	align_loss 0.610227 (0.610227)	uniform_loss -1.310382 (-1.310382)	total_loss -0.700156 (-0.700156)	iter_time 0.594522 (0.594522)
Epoch 4/400	It 0/20	align_loss 0.673948 (0.673948)	uniform_loss -1.651645 (-1.651645)	total_loss -0.977697 (-0.977697)	iter_time 0.570518 (0.570518)
Epoch 5/400	It 0/20	align_loss 0.702533 (0.702533)	uniform_loss -1.835233 (-1.835233)	total_loss -1.132699 (-1.132699)	iter_time 0.576284 (0.576284)
Epoch 6/40

In [25]:
"""
    Here we  do the linear evaluation, the old labels are provided to the linear objective as one hot
"""
import time
import argparse

import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F

from util import AverageMeter
from encoder import SmallAlexNet
MODEL_2_LOAD = "./results/base_200_sideinformation_align1alpha2_unif1t2_iter0/encoder.pth"
MODEL_2_LOAD = "./results/base_200_4_sideinformation_align1alpha2_unif1t2_iter0/encoder.pth"
MODEL_2_LOAD="./results/manual_labels_align1alpha2_unif1t2_iter0/encoder.pth"
MODEL_2_LOAD="./results/base_cifar100_400_resize_transform_align1alpha2_unif1t2_iter0/encoder.pth"
MODEL_2_LOAD="./results/cifar100_400_4_sideinformation_align1alpha2_unif1t2_iter0/encoder.pth"
MODEL_2_LOAD="./results/base_cifar100_400_resize_transform_align1alpha2_unif1t2_iter0/encoder.pth"

USE_MOD_LBL = True
encoder = SmallAlexNet(feat_dim=opt.feat_dim,cifar=True).to(opt.gpus[0])
encoder.load_state_dict(torch.load(MODEL_2_LOAD))

<All keys matched successfully>

In [26]:
def parse_option():
    parser = argparse.ArgumentParser('STL-10 Representation Learning with Alignment and Uniformity Losses')

    parser.add_argument('--encoder_checkpoint', type=str, help='Encoder checkpoint to evaluate', default=MODEL_2_LOAD)
    parser.add_argument('--feat_dim', type=int, default=128, help='Encoder feature dimensionality')
    parser.add_argument('--layer_index', type=int, default=-2, help='Evaluation layer')

    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--lr_decay_rate', type=float, default=0.2, help='Learning rate decay rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='When to decay learning rate')

    parser.add_argument('--num_workers', type=int, default=6, help='Number of data loader workers to use')
    parser.add_argument('--log_interval', type=int, default=40, help='Number of iterations between logs')
    parser.add_argument('--gpus', default=[0], nargs='*', type=int,
                        help='List of GPU indices to use, e.g., --gpus 0 1 2 3')

    parser.add_argument('--data_folder', type=str, default='./data', help='Path to data')

    opt = parser.parse_args("")

    if opt.lr is None:
        opt.lr = 0.12 * (opt.batch_size / 256)

    opt.gpu = torch.device('cuda', opt.gpus[0])
    opt.lr_decay_epochs = list(map(int, opt.lr_decay_epochs.split(',')))

    return opt


class DatasetModifiedLbl(torch.utils.data.Dataset):
    r"""Returns two augmentation and no labels."""

    def __init__(self, dataset, lblmap=None):
        self.dataset = dataset
        self.lblmap = copy.deepcopy(lblmap)

    def __getitem__(self, index):
        image, lbl = self.dataset[index]
        lbl2return = lbl if self.lblmap is None else self.lblmap[lbl]
        return image, lbl2return

    def __len__(self):
        return len(self.dataset)

In [27]:
class DatasetModifiedLblandLbl(torch.utils.data.Dataset):
    r"""Returns two augmentation and no labels."""

    def __init__(self, dataset, lblmap):
        self.dataset = dataset
        self.lblmap = copy.deepcopy(lblmap)

    def __getitem__(self, index):
        image, lbl = self.dataset[index]
        return image, self.lblmap[lbl], lbl

    def __len__(self):
        return len(self.dataset)

In [28]:
def get_data_loaders(opt):
    train_transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(32, scale=(0.08, 1)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (0.44087801806139126, 0.42790631331699347, 0.3867879370752931),
            (0.26826768628079806, 0.2610450402318512, 0.26866836876860795),
        ),
        torchvision.transforms.RandomHorizontalFlip()
    ])
    val_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(70),
        torchvision.transforms.CenterCrop(32),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (0.44087801806139126, 0.42790631331699347, 0.3867879370752931),
            (0.26826768628079806, 0.2610450402318512, 0.26866836876860795),
        ),
    ])
    train_dataset = DatasetModifiedLblandLbl( torchvision.datasets.STL10(opt.data_folder, 'train', download=True, transform=train_transform), lblmap=old2new)
    val_dataset =  DatasetModifiedLblandLbl( torchvision.datasets.STL10(opt.data_folder, 'test', transform=val_transform), lblmap=old2new)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size,
                                               num_workers=opt.num_workers, shuffle=True, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=opt.batch_size,
                                             num_workers=opt.num_workers, pin_memory=True)
    return train_loader, val_loader


def validate_comb(opt, encoder, classifier, val_loader):
    correct = 0
    with torch.no_grad():
        for images, labels_mod, labels_act in val_loader:
            pred = classifier(torch.cat( (encoder(images.to(opt.gpus[0]), layer_index=opt.layer_index).flatten(1), torch.nn.functional.one_hot(labels_mod.to(opt.gpus[0]), num_classes=len(labels_2_keep)+1)), dim=1)).argmax(dim=1)
            correct += (pred.cpu() == labels_act).sum().item()
    return correct / len(val_loader.dataset)

def validate(opt, encoder, classifier, val_loader):
    correct = 0
    with torch.no_grad():
        for images, labels_mod, labels_act in val_loader:
            pred = classifier( encoder(images.to(opt.gpus[0]), layer_index=opt.layer_index).flatten(1) ).argmax(dim=1)
            correct += (pred.cpu() == labels_act).sum().item()
    return correct / len(val_loader.dataset)

In [29]:
opt.gpu=opt.gpus[0]

In [30]:
opt = parse_option()

opt.gpu=opt.gpus[0]
torch.cuda.set_device(opt.gpus[0])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

encoder.eval()
train_loader, val_loader = get_data_loaders(opt)

with torch.no_grad():
    sample, _ = train_loader.dataset.dataset[0]
    eval_numel = encoder(sample.unsqueeze(0).to(opt.gpus[0]), layer_index=opt.layer_index).numel()
print(f'Feature dimension: {eval_numel}')


classifier = nn.Linear(eval_numel, 10).to(opt.gpus[0]) if not USE_MOD_LBL else nn.Linear( eval_numel + len(labels_2_keep) + 1,10).to(opt.gpus[0])

optim = torch.optim.Adam(classifier.parameters(), lr=opt.lr, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, gamma=opt.lr_decay_rate,
                                                 milestones=opt.lr_decay_epochs)

val_accs = []
loss_meter = AverageMeter('loss')
it_time_meter = AverageMeter('iter_time')
for epoch in tqdm(range(opt.epochs)):
    loss_meter.reset()
    it_time_meter.reset()
    t0 = time.time()
    for ii, (images, labels_mod, labels) in enumerate(train_loader):
        optim.zero_grad()
        with torch.no_grad():
            feats = encoder(images.to(opt.gpus[0]), layer_index=opt.layer_index).flatten(1)

        if USE_MOD_LBL:
            logits = classifier(torch.cat( (feats, torch.nn.functional.one_hot(labels_mod.to(opt.gpus[0]), num_classes=len(labels_2_keep)+1 )),dim=1))
        else:
            logits = classifier(feats)
        
        loss = F.cross_entropy(logits, labels.to(opt.gpus[0]))
        loss_meter.update(loss, images.shape[0])
        loss.backward()
        optim.step()
        it_time_meter.update(time.time() - t0)
        if ii % opt.log_interval == 0:
            print(f"Epoch {epoch}/{opt.epochs}\tIt {ii}/{len(train_loader)}\t{loss_meter}\t{it_time_meter}")
        t0 = time.time()
    scheduler.step()
    val_acc = validate_comb(opt,encoder,classifier,val_loader) if USE_MOD_LBL else validate(opt, encoder, classifier, val_loader) 
    val_accs.append(val_acc)
    print(f"Epoch {epoch}/{opt.epochs}\tval_acc {val_acc*100:.4g}%")
print(f"Best validation accuracy {max(val_accs)}")

Files already downloaded and verified
Feature dimension: 4096


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 0/100	It 0/40	loss 2.317669 (2.317669)	iter_time 0.643054 (0.643054)


  1%|          | 1/100 [00:01<02:55,  1.77s/it]

Epoch 0/100	val_acc 51.59%
Epoch 1/100	It 0/40	loss 1.347377 (1.347377)	iter_time 0.140231 (0.140231)


  2%|▏         | 2/100 [00:03<02:24,  1.47s/it]

Epoch 1/100	val_acc 54.2%
Epoch 2/100	It 0/40	loss 1.179149 (1.179149)	iter_time 0.158372 (0.158372)


  3%|▎         | 3/100 [00:04<02:14,  1.39s/it]

Epoch 2/100	val_acc 55.16%
Epoch 3/100	It 0/40	loss 1.074468 (1.074468)	iter_time 0.188010 (0.188010)


  4%|▍         | 4/100 [00:05<02:09,  1.35s/it]

Epoch 3/100	val_acc 57.35%
Epoch 4/100	It 0/40	loss 1.077220 (1.077220)	iter_time 0.138690 (0.138690)


  5%|▌         | 5/100 [00:06<02:05,  1.32s/it]

Epoch 4/100	val_acc 57.66%
Epoch 5/100	It 0/40	loss 1.094635 (1.094635)	iter_time 0.247242 (0.247242)


  6%|▌         | 6/100 [00:08<02:06,  1.34s/it]

Epoch 5/100	val_acc 59.33%
Epoch 6/100	It 0/40	loss 0.907721 (0.907721)	iter_time 0.168980 (0.168980)


  7%|▋         | 7/100 [00:09<02:04,  1.34s/it]

Epoch 6/100	val_acc 59.41%
Epoch 7/100	It 0/40	loss 0.979732 (0.979732)	iter_time 0.129158 (0.129158)


  8%|▊         | 8/100 [00:10<02:00,  1.31s/it]

Epoch 7/100	val_acc 59.4%
Epoch 8/100	It 0/40	loss 0.831983 (0.831983)	iter_time 0.167577 (0.167577)


  9%|▉         | 9/100 [00:12<01:58,  1.30s/it]

Epoch 8/100	val_acc 60.45%
Epoch 9/100	It 0/40	loss 0.927104 (0.927104)	iter_time 0.186994 (0.186994)


 10%|█         | 10/100 [00:13<01:56,  1.29s/it]

Epoch 9/100	val_acc 60.35%
Epoch 10/100	It 0/40	loss 0.841688 (0.841688)	iter_time 0.120916 (0.120916)


 11%|█         | 11/100 [00:14<01:54,  1.29s/it]

Epoch 10/100	val_acc 61.19%
Epoch 11/100	It 0/40	loss 0.828138 (0.828138)	iter_time 0.176086 (0.176086)


 12%|█▏        | 12/100 [00:16<01:55,  1.31s/it]

Epoch 11/100	val_acc 61.27%
Epoch 12/100	It 0/40	loss 0.907684 (0.907684)	iter_time 0.170757 (0.170757)


 13%|█▎        | 13/100 [00:17<01:54,  1.32s/it]

Epoch 12/100	val_acc 62%
Epoch 13/100	It 0/40	loss 0.797824 (0.797824)	iter_time 0.175897 (0.175897)


 14%|█▍        | 14/100 [00:18<01:55,  1.34s/it]

Epoch 13/100	val_acc 62.18%
Epoch 14/100	It 0/40	loss 0.786929 (0.786929)	iter_time 0.153615 (0.153615)


 15%|█▌        | 15/100 [00:20<01:51,  1.32s/it]

Epoch 14/100	val_acc 62.29%
Epoch 15/100	It 0/40	loss 0.755933 (0.755933)	iter_time 0.194819 (0.194819)


 16%|█▌        | 16/100 [00:21<01:50,  1.32s/it]

Epoch 15/100	val_acc 62.65%
Epoch 16/100	It 0/40	loss 0.743983 (0.743983)	iter_time 0.126901 (0.126901)


 17%|█▋        | 17/100 [00:22<01:48,  1.30s/it]

Epoch 16/100	val_acc 62.4%
Epoch 17/100	It 0/40	loss 0.856573 (0.856573)	iter_time 0.152627 (0.152627)


 18%|█▊        | 18/100 [00:23<01:46,  1.29s/it]

Epoch 17/100	val_acc 63.39%
Epoch 18/100	It 0/40	loss 0.996817 (0.996817)	iter_time 0.149961 (0.149961)


 19%|█▉        | 19/100 [00:25<01:44,  1.29s/it]

Epoch 18/100	val_acc 63.4%
Epoch 19/100	It 0/40	loss 0.748335 (0.748335)	iter_time 0.167104 (0.167104)


 20%|██        | 20/100 [00:26<01:42,  1.28s/it]

Epoch 19/100	val_acc 63.76%
Epoch 20/100	It 0/40	loss 0.665994 (0.665994)	iter_time 0.129084 (0.129084)


 21%|██        | 21/100 [00:27<01:40,  1.27s/it]

Epoch 20/100	val_acc 63.71%
Epoch 21/100	It 0/40	loss 0.765240 (0.765240)	iter_time 0.165035 (0.165035)


 22%|██▏       | 22/100 [00:28<01:39,  1.28s/it]

Epoch 21/100	val_acc 63.85%
Epoch 22/100	It 0/40	loss 0.785096 (0.785096)	iter_time 0.151135 (0.151135)


 23%|██▎       | 23/100 [00:30<01:38,  1.28s/it]

Epoch 22/100	val_acc 63.76%
Epoch 23/100	It 0/40	loss 0.697388 (0.697388)	iter_time 0.141931 (0.141931)


 24%|██▍       | 24/100 [00:31<01:37,  1.28s/it]

Epoch 23/100	val_acc 64.84%
Epoch 24/100	It 0/40	loss 0.618923 (0.618923)	iter_time 0.162297 (0.162297)


 25%|██▌       | 25/100 [00:32<01:35,  1.28s/it]

Epoch 24/100	val_acc 64.34%
Epoch 25/100	It 0/40	loss 0.878883 (0.878883)	iter_time 0.129265 (0.129265)


 26%|██▌       | 26/100 [00:34<01:33,  1.27s/it]

Epoch 25/100	val_acc 64.15%
Epoch 26/100	It 0/40	loss 0.698757 (0.698757)	iter_time 0.156944 (0.156944)


 27%|██▋       | 27/100 [00:35<01:32,  1.26s/it]

Epoch 26/100	val_acc 64.69%
Epoch 27/100	It 0/40	loss 0.676458 (0.676458)	iter_time 0.142217 (0.142217)


 28%|██▊       | 28/100 [00:36<01:31,  1.27s/it]

Epoch 27/100	val_acc 64.26%
Epoch 28/100	It 0/40	loss 0.614755 (0.614755)	iter_time 0.132313 (0.132313)


 29%|██▉       | 29/100 [00:37<01:29,  1.26s/it]

Epoch 28/100	val_acc 64.76%
Epoch 29/100	It 0/40	loss 0.643382 (0.643382)	iter_time 0.161837 (0.161837)


 30%|███       | 30/100 [00:39<01:29,  1.28s/it]

Epoch 29/100	val_acc 65.66%
Epoch 30/100	It 0/40	loss 0.673461 (0.673461)	iter_time 0.129699 (0.129699)


 31%|███       | 31/100 [00:40<01:29,  1.29s/it]

Epoch 30/100	val_acc 65.33%
Epoch 31/100	It 0/40	loss 0.540164 (0.540164)	iter_time 0.156794 (0.156794)


 32%|███▏      | 32/100 [00:41<01:28,  1.30s/it]

Epoch 31/100	val_acc 65.54%
Epoch 32/100	It 0/40	loss 0.606209 (0.606209)	iter_time 0.155382 (0.155382)


 33%|███▎      | 33/100 [00:43<01:28,  1.32s/it]

Epoch 32/100	val_acc 65.4%
Epoch 33/100	It 0/40	loss 0.667767 (0.667767)	iter_time 0.158487 (0.158487)


 34%|███▍      | 34/100 [00:44<01:25,  1.30s/it]

Epoch 33/100	val_acc 65.54%
Epoch 34/100	It 0/40	loss 0.588110 (0.588110)	iter_time 0.164415 (0.164415)


 35%|███▌      | 35/100 [00:45<01:24,  1.30s/it]

Epoch 34/100	val_acc 65.49%
Epoch 35/100	It 0/40	loss 0.628757 (0.628757)	iter_time 0.190547 (0.190547)


 36%|███▌      | 36/100 [00:46<01:22,  1.30s/it]

Epoch 35/100	val_acc 66.3%
Epoch 36/100	It 0/40	loss 0.691093 (0.691093)	iter_time 0.131629 (0.131629)


 37%|███▋      | 37/100 [00:48<01:20,  1.27s/it]

Epoch 36/100	val_acc 66%
Epoch 37/100	It 0/40	loss 0.723042 (0.723042)	iter_time 0.130217 (0.130217)


 38%|███▊      | 38/100 [00:49<01:18,  1.27s/it]

Epoch 37/100	val_acc 65.77%
Epoch 38/100	It 0/40	loss 0.723599 (0.723599)	iter_time 0.127830 (0.127830)


 39%|███▉      | 39/100 [00:50<01:16,  1.26s/it]

Epoch 38/100	val_acc 65.64%
Epoch 39/100	It 0/40	loss 0.694763 (0.694763)	iter_time 0.137837 (0.137837)


 40%|████      | 40/100 [00:51<01:14,  1.25s/it]

Epoch 39/100	val_acc 65.56%
Epoch 40/100	It 0/40	loss 0.586016 (0.586016)	iter_time 0.226323 (0.226323)


 41%|████      | 41/100 [00:53<01:14,  1.27s/it]

Epoch 40/100	val_acc 65.49%
Epoch 41/100	It 0/40	loss 0.681377 (0.681377)	iter_time 0.133758 (0.133758)


 42%|████▏     | 42/100 [00:54<01:13,  1.27s/it]

Epoch 41/100	val_acc 65.96%
Epoch 42/100	It 0/40	loss 0.587108 (0.587108)	iter_time 0.181894 (0.181894)


 43%|████▎     | 43/100 [00:55<01:12,  1.27s/it]

Epoch 42/100	val_acc 66.24%
Epoch 43/100	It 0/40	loss 0.677060 (0.677060)	iter_time 0.173625 (0.173625)


 44%|████▍     | 44/100 [00:57<01:11,  1.28s/it]

Epoch 43/100	val_acc 66.46%
Epoch 44/100	It 0/40	loss 0.560449 (0.560449)	iter_time 0.173075 (0.173075)


 45%|████▌     | 45/100 [00:58<01:11,  1.30s/it]

Epoch 44/100	val_acc 65.77%
Epoch 45/100	It 0/40	loss 0.548829 (0.548829)	iter_time 0.144733 (0.144733)


 46%|████▌     | 46/100 [00:59<01:09,  1.28s/it]

Epoch 45/100	val_acc 66.24%
Epoch 46/100	It 0/40	loss 0.594334 (0.594334)	iter_time 0.178492 (0.178492)


 47%|████▋     | 47/100 [01:00<01:07,  1.28s/it]

Epoch 46/100	val_acc 66.09%
Epoch 47/100	It 0/40	loss 0.408554 (0.408554)	iter_time 0.130188 (0.130188)


 48%|████▊     | 48/100 [01:02<01:05,  1.26s/it]

Epoch 47/100	val_acc 65.88%
Epoch 48/100	It 0/40	loss 0.546549 (0.546549)	iter_time 0.186718 (0.186718)


 49%|████▉     | 49/100 [01:03<01:04,  1.27s/it]

Epoch 48/100	val_acc 66.99%
Epoch 49/100	It 0/40	loss 0.559753 (0.559753)	iter_time 0.196218 (0.196218)


 50%|█████     | 50/100 [01:04<01:03,  1.26s/it]

Epoch 49/100	val_acc 66.38%
Epoch 50/100	It 0/40	loss 0.603024 (0.603024)	iter_time 0.135448 (0.135448)


 51%|█████     | 51/100 [01:05<01:02,  1.27s/it]

Epoch 50/100	val_acc 67.11%
Epoch 51/100	It 0/40	loss 0.630170 (0.630170)	iter_time 0.151525 (0.151525)


 52%|█████▏    | 52/100 [01:07<01:01,  1.27s/it]

Epoch 51/100	val_acc 66.22%
Epoch 52/100	It 0/40	loss 0.589007 (0.589007)	iter_time 0.136746 (0.136746)


 53%|█████▎    | 53/100 [01:08<00:59,  1.27s/it]

Epoch 52/100	val_acc 66.38%
Epoch 53/100	It 0/40	loss 0.489609 (0.489609)	iter_time 0.130348 (0.130348)


 54%|█████▍    | 54/100 [01:09<00:58,  1.26s/it]

Epoch 53/100	val_acc 66.83%
Epoch 54/100	It 0/40	loss 0.639273 (0.639273)	iter_time 0.140299 (0.140299)


 55%|█████▌    | 55/100 [01:11<00:57,  1.27s/it]

Epoch 54/100	val_acc 66.97%
Epoch 55/100	It 0/40	loss 0.630125 (0.630125)	iter_time 0.136673 (0.136673)


 56%|█████▌    | 56/100 [01:12<00:55,  1.26s/it]

Epoch 55/100	val_acc 66.51%
Epoch 56/100	It 0/40	loss 0.640424 (0.640424)	iter_time 0.184545 (0.184545)


 57%|█████▋    | 57/100 [01:13<00:54,  1.27s/it]

Epoch 56/100	val_acc 66.41%
Epoch 57/100	It 0/40	loss 0.451903 (0.451903)	iter_time 0.137591 (0.137591)


 58%|█████▊    | 58/100 [01:14<00:53,  1.27s/it]

Epoch 57/100	val_acc 66.61%
Epoch 58/100	It 0/40	loss 0.540828 (0.540828)	iter_time 0.127867 (0.127867)


 59%|█████▉    | 59/100 [01:16<00:52,  1.28s/it]

Epoch 58/100	val_acc 66.71%
Epoch 59/100	It 0/40	loss 0.462482 (0.462482)	iter_time 0.137296 (0.137296)


 60%|██████    | 60/100 [01:17<00:51,  1.28s/it]

Epoch 59/100	val_acc 66.69%
Epoch 60/100	It 0/40	loss 0.599504 (0.599504)	iter_time 0.162778 (0.162778)


 61%|██████    | 61/100 [01:18<00:50,  1.28s/it]

Epoch 60/100	val_acc 66.71%
Epoch 61/100	It 0/40	loss 0.509995 (0.509995)	iter_time 0.130297 (0.130297)


 62%|██████▏   | 62/100 [01:20<00:48,  1.28s/it]

Epoch 61/100	val_acc 66.89%
Epoch 62/100	It 0/40	loss 0.561208 (0.561208)	iter_time 0.215965 (0.215965)


 63%|██████▎   | 63/100 [01:21<00:47,  1.29s/it]

Epoch 62/100	val_acc 66.95%
Epoch 63/100	It 0/40	loss 0.606518 (0.606518)	iter_time 0.130065 (0.130065)


 64%|██████▍   | 64/100 [01:22<00:46,  1.28s/it]

Epoch 63/100	val_acc 67.01%
Epoch 64/100	It 0/40	loss 0.536094 (0.536094)	iter_time 0.154241 (0.154241)


 65%|██████▌   | 65/100 [01:23<00:44,  1.27s/it]

Epoch 64/100	val_acc 66.96%
Epoch 65/100	It 0/40	loss 0.543198 (0.543198)	iter_time 0.146537 (0.146537)


 66%|██████▌   | 66/100 [01:25<00:43,  1.27s/it]

Epoch 65/100	val_acc 66.94%
Epoch 66/100	It 0/40	loss 0.574198 (0.574198)	iter_time 0.145601 (0.145601)


 67%|██████▋   | 67/100 [01:26<00:41,  1.26s/it]

Epoch 66/100	val_acc 67.04%
Epoch 67/100	It 0/40	loss 0.471654 (0.471654)	iter_time 0.152508 (0.152508)


 68%|██████▊   | 68/100 [01:27<00:40,  1.26s/it]

Epoch 67/100	val_acc 67.15%
Epoch 68/100	It 0/40	loss 0.546541 (0.546541)	iter_time 0.132300 (0.132300)


 69%|██████▉   | 69/100 [01:28<00:39,  1.27s/it]

Epoch 68/100	val_acc 67.17%
Epoch 69/100	It 0/40	loss 0.585553 (0.585553)	iter_time 0.123107 (0.123107)


 70%|███████   | 70/100 [01:30<00:37,  1.26s/it]

Epoch 69/100	val_acc 67.25%
Epoch 70/100	It 0/40	loss 0.528602 (0.528602)	iter_time 0.160890 (0.160890)


 71%|███████   | 71/100 [01:31<00:36,  1.26s/it]

Epoch 70/100	val_acc 67.2%
Epoch 71/100	It 0/40	loss 0.599262 (0.599262)	iter_time 0.207910 (0.207910)


 72%|███████▏  | 72/100 [01:32<00:35,  1.27s/it]

Epoch 71/100	val_acc 67.27%
Epoch 72/100	It 0/40	loss 0.615647 (0.615647)	iter_time 0.166551 (0.166551)


 73%|███████▎  | 73/100 [01:33<00:33,  1.26s/it]

Epoch 72/100	val_acc 67.19%
Epoch 73/100	It 0/40	loss 0.526935 (0.526935)	iter_time 0.126176 (0.126176)


 74%|███████▍  | 74/100 [01:35<00:32,  1.24s/it]

Epoch 73/100	val_acc 67.34%
Epoch 74/100	It 0/40	loss 0.517050 (0.517050)	iter_time 0.196997 (0.196997)


 75%|███████▌  | 75/100 [01:36<00:31,  1.26s/it]

Epoch 74/100	val_acc 67.15%
Epoch 75/100	It 0/40	loss 0.535559 (0.535559)	iter_time 0.130949 (0.130949)


 76%|███████▌  | 76/100 [01:37<00:30,  1.25s/it]

Epoch 75/100	val_acc 67.12%
Epoch 76/100	It 0/40	loss 0.503226 (0.503226)	iter_time 0.156032 (0.156032)


 77%|███████▋  | 77/100 [01:38<00:28,  1.25s/it]

Epoch 76/100	val_acc 67.25%
Epoch 77/100	It 0/40	loss 0.533539 (0.533539)	iter_time 0.148157 (0.148157)


 78%|███████▊  | 78/100 [01:40<00:27,  1.25s/it]

Epoch 77/100	val_acc 67.34%
Epoch 78/100	It 0/40	loss 0.559428 (0.559428)	iter_time 0.154558 (0.154558)


 79%|███████▉  | 79/100 [01:41<00:26,  1.25s/it]

Epoch 78/100	val_acc 67.35%
Epoch 79/100	It 0/40	loss 0.616487 (0.616487)	iter_time 0.143486 (0.143486)


 80%|████████  | 80/100 [01:42<00:25,  1.25s/it]

Epoch 79/100	val_acc 67.42%
Epoch 80/100	It 0/40	loss 0.488645 (0.488645)	iter_time 0.127620 (0.127620)


 81%|████████  | 81/100 [01:43<00:23,  1.26s/it]

Epoch 80/100	val_acc 67.38%
Epoch 81/100	It 0/40	loss 0.515724 (0.515724)	iter_time 0.170518 (0.170518)


 82%|████████▏ | 82/100 [01:45<00:22,  1.27s/it]

Epoch 81/100	val_acc 67.46%
Epoch 82/100	It 0/40	loss 0.625685 (0.625685)	iter_time 0.156350 (0.156350)


 83%|████████▎ | 83/100 [01:46<00:21,  1.28s/it]

Epoch 82/100	val_acc 67.47%
Epoch 83/100	It 0/40	loss 0.552087 (0.552087)	iter_time 0.171879 (0.171879)


 84%|████████▍ | 84/100 [01:47<00:20,  1.30s/it]

Epoch 83/100	val_acc 67.41%
Epoch 84/100	It 0/40	loss 0.641250 (0.641250)	iter_time 0.172864 (0.172864)


 85%|████████▌ | 85/100 [01:49<00:19,  1.29s/it]

Epoch 84/100	val_acc 67.41%
Epoch 85/100	It 0/40	loss 0.424724 (0.424724)	iter_time 0.169298 (0.169298)


 86%|████████▌ | 86/100 [01:50<00:18,  1.31s/it]

Epoch 85/100	val_acc 67.44%
Epoch 86/100	It 0/40	loss 0.629648 (0.629648)	iter_time 0.210585 (0.210585)


 87%|████████▋ | 87/100 [01:51<00:17,  1.34s/it]

Epoch 86/100	val_acc 67.42%
Epoch 87/100	It 0/40	loss 0.625604 (0.625604)	iter_time 0.170725 (0.170725)


 88%|████████▊ | 88/100 [01:53<00:15,  1.31s/it]

Epoch 87/100	val_acc 67.41%
Epoch 88/100	It 0/40	loss 0.661369 (0.661369)	iter_time 0.131286 (0.131286)


 89%|████████▉ | 89/100 [01:54<00:14,  1.30s/it]

Epoch 88/100	val_acc 67.49%
Epoch 89/100	It 0/40	loss 0.514093 (0.514093)	iter_time 0.180659 (0.180659)


 90%|█████████ | 90/100 [01:55<00:12,  1.30s/it]

Epoch 89/100	val_acc 67.47%
Epoch 90/100	It 0/40	loss 0.544140 (0.544140)	iter_time 0.201684 (0.201684)


 91%|█████████ | 91/100 [01:56<00:11,  1.29s/it]

Epoch 90/100	val_acc 67.46%
Epoch 91/100	It 0/40	loss 0.461486 (0.461486)	iter_time 0.162254 (0.162254)


 92%|█████████▏| 92/100 [01:58<00:10,  1.29s/it]

Epoch 91/100	val_acc 67.56%
Epoch 92/100	It 0/40	loss 0.594819 (0.594819)	iter_time 0.130205 (0.130205)


 93%|█████████▎| 93/100 [01:59<00:08,  1.27s/it]

Epoch 92/100	val_acc 67.54%
Epoch 93/100	It 0/40	loss 0.631848 (0.631848)	iter_time 0.133519 (0.133519)


 94%|█████████▍| 94/100 [02:00<00:07,  1.28s/it]

Epoch 93/100	val_acc 67.53%
Epoch 94/100	It 0/40	loss 0.467117 (0.467117)	iter_time 0.159647 (0.159647)


 95%|█████████▌| 95/100 [02:02<00:06,  1.28s/it]

Epoch 94/100	val_acc 67.54%
Epoch 95/100	It 0/40	loss 0.410187 (0.410187)	iter_time 0.160288 (0.160288)


 96%|█████████▌| 96/100 [02:03<00:05,  1.27s/it]

Epoch 95/100	val_acc 67.44%
Epoch 96/100	It 0/40	loss 0.446972 (0.446972)	iter_time 0.134611 (0.134611)


 97%|█████████▋| 97/100 [02:04<00:03,  1.27s/it]

Epoch 96/100	val_acc 67.46%
Epoch 97/100	It 0/40	loss 0.482620 (0.482620)	iter_time 0.161258 (0.161258)


 98%|█████████▊| 98/100 [02:05<00:02,  1.27s/it]

Epoch 97/100	val_acc 67.42%
Epoch 98/100	It 0/40	loss 0.446620 (0.446620)	iter_time 0.173109 (0.173109)


 99%|█████████▉| 99/100 [02:07<00:01,  1.27s/it]

Epoch 98/100	val_acc 67.49%
Epoch 99/100	It 0/40	loss 0.483180 (0.483180)	iter_time 0.170703 (0.170703)


100%|██████████| 100/100 [02:08<00:00,  1.28s/it]

Epoch 99/100	val_acc 67.46%
Best validation accuracy 0.675625



