In [1]:
import pathlib
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from torch import nn
import tqdm
import copy

In [2]:
import timm

## Setup

In [7]:
# Create a pytorch dataset
data_dir = pathlib.Path('/data/tiny-imagenet/')
image_count = len(list(data_dir.glob('**/*.JPEG')))
CLASS_NAMES = np.array([item.name for item in (data_dir / 'train').glob('*')])
print('Discovered {} images'.format(image_count))

# Create the training data generator
batch_size = 32
im_height = 64
im_width = 64

# data_transforms = transforms.Compose([
#     transforms.ToTensor(),
#     # transforms.Normalize((0, 0, 0), tuple(np.sqrt((255, 255, 255)))),
# ])

data_transforms = transforms.Compose([
    transforms.Resize((224,224)),
#     transforms.RandomCrop(64, padding=8),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
train_set = torchvision.datasets.ImageFolder(data_dir / 'train', data_transforms)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                           shuffle=True, num_workers=4, pin_memory=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Discovered 120000 images


## Domain Adaptation

In [8]:
# ## Test time batchnorm update with no prior

def update_bn_params(model, val_loader, num_bn_updates, gpu):
    val_loader = torch.utils.data.DataLoader(val_loader.dataset,
                                             batch_size=val_loader.batch_size,
                                             shuffle=True, num_workers=val_loader.num_workers)
    def use_test_statistics(module):
        if isinstance(module, nn.BatchNorm2d):
            module.train()
    model = copy.deepcopy(model)
    model.eval()
    model.apply(use_test_statistics)
    print("Updating BN params (num updates:{})".format(num_bn_updates))
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            if i<num_bn_updates:
                images = images.cuda(gpu, non_blocking=True)
                output = model(images)
    print("Done.")
    return model


## Test Particular Model
Run to load particular model onto computer

In [9]:
model = timm.create_model('inception_resnet_v2', pretrained=True)
num_ftrs = model.classif.in_features
model.fc =  nn.Sequential(
                  nn.Dropout(0.4),
                  nn.Linear(num_ftrs, 1024), 
                  nn.ReLU(),
                  nn.Linear(1024, 256),
                  nn.ReLU(),
                  nn.Linear(256, 200))
checkpoint = torch.load('results/inceptionresnetv2_biggerFCN_linear_epoch0.pt')
model.load_state_dict(checkpoint['net'])
model = model.to(device)
# model = torchvision.models.resnet50(pretrained=True)
# num_ftrs = model.fc.in_features

# model.fc =  nn.Sequential(
#                   nn.Linear(num_ftrs, 256), 
#                   nn.ReLU(),
#                   nn.Linear(256, 256),
#                   nn.ReLU(),
#                   nn.Linear(256, 200))
# checkpoint = torch.load('resnet50-pretrainedaugmix-20epochs-nofreeze-aug.pt')
# model.load_state_dict(checkpoint['net'])
# model = model.to(device)


Run below to evaluate on validation set

In [10]:
validation_set = torchvision.datasets.ImageFolder(data_dir / 'val', transform_test)
val_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size,
                                           shuffle=True, num_workers=4, pin_memory=True)


# model = update_bn_params(model, val_loader, 32, device)
model.eval()
all_preds = []
all_labels = []
all_losses = []
with torch.no_grad():
    index = 0
    for batch in tqdm.tqdm(val_loader):
        inputs = batch[0]
        targets = batch[1]
        targets = targets.cuda()
        inputs = inputs.cuda()
        preds = model(inputs)
        loss = nn.CrossEntropyLoss()(preds, targets)
        all_losses.append(loss.cpu())
        all_preds.append(preds.cpu())
        all_labels.append(targets.cpu())

100%|██████████| 313/313 [00:51<00:00,  6.12it/s]


## Top 1 Accuracy

In [11]:
top_preds = [x.argsort(dim=-1)[:,-1:].squeeze() for x in all_preds]
correct = 0
for idx, batch_preds in enumerate(top_preds):
    correct += torch.eq(all_labels[idx], batch_preds).sum()
accuracy = correct.item() / (32 * len(all_labels))
print(f"Top 1 Validation Accuracy: {accuracy}")

Top 1 Validation Accuracy: 0.6638378594249201


## Top 3 Accuracy

In [12]:
top_preds = [x.argsort(dim=-1)[:,-3:] for x in all_preds]
correct = 0
for idx, batch_preds in enumerate(top_preds):
    correct += torch.eq(all_labels[idx], batch_preds[:,0:1].squeeze()).sum()
    correct += torch.eq(all_labels[idx], batch_preds[:,1:2].squeeze()).sum()

    correct += torch.eq(all_labels[idx], batch_preds[:,2:3].squeeze()).sum()

accuracy = correct.item() / (32 * len(all_labels))
print(f"Top 3 Validation Accuracy: {accuracy}")

Top 3 Validation Accuracy: 0.8087060702875399


In [None]:
## with BN update: top 1: 39.3% top 3: 57.96%
## without BN update