In [1]:
%load_ext autoreload
%autoreload 2
import sys
import numpy as np
sys.path.insert(0, '../')

import torch
from PIL import Image

from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler

from mnist_train import Net
from utils.gpu_utils import restrict_GPU_pytorch
from utils.imagenet_utils import accuracy
from utils.dataloading_utils import MyIter, MyLoader
from new_tta_models import ImageW, ImageS, ImageWS, Original, StandardTTA
import torchvision.transforms.functional as F
from cnn_finetune import make_model
from tqdm.notebook import tqdm

restrict_GPU_pytorch('1')
np.random.seed(42)

Using GPU:0
Using GPU:1


# Loading pre-trained model

In [2]:
def get_flowers_model(model_name):
    m_name = model_name
    if model_name == 'MobileNetV2':
        m_name = 'mobilenet_v2'
    if model_name == 'inceptionv3':
        m_name = 'inception_v3'
    model = make_model(
                    m_name,
                    pretrained=True,
                    num_classes=n_classes,
                    input_size=(224, 224),
                )
    model.load_state_dict(torch.load('../saved_models/flowers102/' + m_name+ '.pth'))
    return model
n_classes = 102
model = get_flowers_model('resnet18')
model.eval()

ResNetWrapper(
  (_features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_sta

In [3]:
class Crop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size, pos=4):
        assert isinstance(output_size, (int, tuple))
        self.pos = pos
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, img):
        return F.five_crop(img, self.output_size)[self.pos]

class Scale(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size, pct=1):
        assert isinstance(output_size, (int, tuple))
        self.pct = pct
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, img):
        h, w = img.size
        new_h, new_w = int(self.pct*h), int(self.pct*w)
        img = F.center_crop(img, (new_h, new_w))
        img = F.resize(img, self.output_size, Image.BILINEAR)
        return img
    
class HFlip(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, img):
        img = F.hflip(img)
        img = F.center_crop(img, self.output_size)
        return img

# Defining dataloaders 

In [15]:
# Figure out val/test index split for these 
def get_datapath(dataset_name):
    if dataset_name == 'flowers102':
        data_path = '../datasets/flowers102/test'
    return data_path

def get_augmentation_transform(augmentation):
    # These could be compositions too
    output_size = (224, 224)
    aug_transform_map = {'hflip': HFlip(output_size),
                         'crop_tl': Crop(output_size, 0),
                         'crop_tr': Crop(output_size, 1),
                         'crop_b1': Crop(output_size, 2),
                         'crop_br': Crop(output_size, 3),
                         'orig': Crop(output_size, 4),
                         'scale_1.04': Scale(output_size, 1.04),
                         'scale_1.10': Scale(output_size, 1.10)}
    return aug_transform_map[augmentation]   

def get_dataloader(dataset_name, augmentation, idxs, batch_size):
    datapath = get_datapath(dataset_name)
    if dataset_name == 'flowers102':
        image_size = 256
        crop_size = 224
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        d_transforms = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            get_augmentation_transform(augmentation),
            transforms.ToTensor(),
            normalize
        ])
    dataset = datasets.ImageFolder(root=datapath, transform=d_transforms)
    dataloader = torch.utils.data.DataLoader(torch.utils.data.Subset(dataset, idxs), 
                                             batch_size=batch_size,
                                             num_workers=0, shuffle=False)
        
    # look up augmentation append it to transforms
    return dataloader 

def get_dataloaders(dataset_name, augmentations, bs):
    # Write val idxs
    # Write test idxs
    all_idxs = list(range(6149))
    np.random.shuffle(all_idxs)
    val_idxs = all_idxs[:int(len(all_idxs)/2)]
    test_idxs = all_idxs[int(len(all_idxs)/2):]
    train_dl_list = [get_dataloader(dataset_name, aug, val_idxs, bs) for aug in augmentations]
    test_dl_list = [get_dataloader(dataset_name, aug, test_idxs, bs) for aug in augmentations]
    return MyLoader(train_dl_list), MyLoader(test_dl_list)
n_augs = 8
n_classes = 102
batch_size = 128
data_loader, test_data_loader = get_dataloaders('flowers102', ['orig', 'hflip', 'crop_tl', 'crop_b1', 
                                                               'crop_tr', 'crop_br', 
                                                               'scale_1.04', 'scale_1.10'], 32)

In [5]:
# import matplotlib.pyplot as plt
# fig, axes = plt.subplots(2, 4, sharex=True, sharey=True)
# for i in range(n_augs):
#     img = x[i][1].cpu()
#     ax = axes[int(i/4), i%4]
#     ax.imshow(np.transpose(img, (1, 2, 0)))

# Original model accuracy

In [6]:
orig_model = Original(model, 0)
orig_model.cuda('cuda:0')
orig_model.eval()
test_acc1s = []
test_acc5s = []
for examples, target in tqdm(test_data_loader):
    examples = examples.cuda('cuda:0', non_blocking=True)
    target = target.cuda('cuda:0', non_blocking=True)    
    output = orig_model(examples)
    acc1, acc5 = accuracy(output, target, topk=(1, 5))
    test_acc1s.append(acc1.item())
    test_acc5s.append(acc5.item())
print(np.mean(test_acc1s), np.mean(test_acc5s))

HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))


88.14432989690722 97.1971649484536


# Evaluating Standard TTA model

In [7]:
stta_model = StandardTTA(model)
stta_model.cuda('cuda:0')
test_acc1s = []
test_acc5s = []
for examples, target in tqdm(test_data_loader):
    examples = examples.cuda('cuda:0', non_blocking=True)
    target = target.cuda('cuda:0', non_blocking=True)    
    output = stta_model(examples)
    acc1, acc5 = accuracy(output, target, topk=(1, 5))
    test_acc1s.append(acc1.item())
    test_acc5s.append(acc5.item())
print(np.mean(test_acc1s), np.mean(test_acc5s))

HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))


88.2409793814433 97.06829896907216


In [5]:
from utils.aug_utils import invert_aug_list

aug_list = np.load('../' + 'flowers102' + '/' + 'five_crop_hflip_scale' + '/aug_list.npy')
aug_order = np.load('../' +'flowers102' + '/' + 'five_crop_hflip_scale' + '/aug_order.npy')
aug_names = invert_aug_list(aug_list, aug_order)

In [15]:
np.where(np.sum(aug_list,axis=1) == 0)[0]

array([12])

# Evaluating... all models together 

In [17]:
def eval_agg_model(agg_model, dataloader):
    agg_model.eval()
    agg_model.cuda('cuda:0')
    model.cuda('cuda:0')
    test_acc1s = []
    test_acc5s = []
    outputs = []
    targets = []
    for examples, target in dataloader:
        examples = examples.cuda('cuda:0', non_blocking=True)
        target = target.cuda('cuda:0', non_blocking=True)    
        output = agg_model(examples)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        test_acc1s.append(acc1.item())
        test_acc5s.append(acc5.item())
        outputs.append(output.cpu())
        targets.append(target.cpu())
    return np.mean(test_acc1s), np.mean(test_acc5s)


def train_agg_model(model, dataloader, n_augs, n_classes, agg_model_name, n_features, orig_idx):
    if agg_model_name == 'image_s':
        agg_model = ImageS(model, n_classes, n_features, orig_idx)
    elif agg_model_name == 'image_w':
        agg_model = ImageW(model, n_augs, n_classes, n_features, orig_idx)

    agg_model.cuda('cuda:0')
    criterion = torch.nn.CrossEntropyLoss()
    criterion.cuda('cuda:0')
    optimizer = torch.optim.SGD(agg_model.parameters(), lr=.01, momentum=.9, weight_decay=1e-4)

    losses = []
    acc1s = []
    acc5s = []
    epochs = 100
    for epoch in tqdm(range(epochs)):
        epoch_loss = []
        for examples, target in tqdm(dataloader):
            examples = examples.cuda('cuda:0', non_blocking=True)
            target = target.cuda('cuda:0', non_blocking=True)
            output = agg_model(examples)
            loss = criterion(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.item())
            acc1s.append(acc1.item())
            acc5s.append(acc5.item())
        losses.append(np.mean(epoch_loss))
    return agg_model, losses, acc1s, acc5s

In [None]:
# Original model
orig_idx = 0
results = []
# orig_model = Original(model, 0)
# orig_acc1, orig_acc5 = eval_agg_model(orig_model, test_data_loader)
# results.append({'agg_model': 'orig',  'acc1': orig_acc1, 'acc5': orig_acc5})
# # stta model
# stta_model = StandardTTA(model)
# stta_acc1, stta_acc5 = eval_agg_model(stta_model, test_data_loader)
# results.append({'agg_model': 'stta', 'acc1': stta_acc1, 'acc5': stta_acc5})

n_features = 25088
# image to s 
s_model, losses, acc1s, acc5s = train_agg_model(model, data_loader, n_augs, n_classes, 
                                                'image_s', n_features, orig_idx)
acc1, acc5, s_outputs, s_targets = eval_agg_model(s_model, test_data_loader)
results.append({'agg_model': 'image_s', 'acc1': acc1, 'acc5': acc5})

# image to w 
w_model, losses, acc1s, acc5s = train_agg_model(model, data_loader, n_augs, n_classes, 
                                                'image_w', n_features, orig_idx)
acc1, acc5, w_outputs, w_targets = eval_agg_model(w_model, test_data_loader)
results.append({'agg_model': 'image_w',  'acc1': acc1, 'acc5': acc5})

# image to s and w 
# ws_model, losses, acc1s, acc5s = train_agg_model(model, data_loader, n_augs, 'image_ws', n_features)
# acc1, acc5, ws_outputs, ws_targets = eval_agg_model(ws_model, test_data_loader)
# results.append({'agg_model': 'image_ws', 'acc1': acc1, 'acc5': acc5})

results_df = pd.DataFrame(results)
results_df.to_csv('flowers102_preliminary_results')

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))