In [1]:
import torch
from torch import nn, optim
import sys
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set_style('white')
sys.path.insert(0, '../')
#from tta_agg_models import TTARegression, TTAPartialRegression, GPS
from utils.gpu_utils import restrict_GPU_pytorch
import numpy as np
from utils.aug_utils import invert_aug_list
SMALL_SIZE = 14
MEDIUM_SIZE = 18
BIGGER_SIZE = 20

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
restrict_GPU_pytorch('0')
class TTAPartialRegression(nn.Module):
    def __init__(self, n_augs, n_classes, temp_scale=1, initialization='even',coeffs=[]):
        super().__init__()
        # To make "a" and "b" real parameters of the model, we need to wrap them with nn.Parameter
        self.coeffs = nn.Parameter(torch.randn((n_augs,1 ), requires_grad=True, dtype=torch.float))
        self.temperature = temp_scale
        if len(coeffs):
            self.coeffs = nn.Parameter(torch.Tensor(coeffs), requires_grad=True)
        else:
            if initialization == 'even':
                self.coeffs.data.fill_(1.0/n_augs)
            elif initialization== 'original':
                self.coeffs.data[0,:].fill_(1)
                self.coeffs.data[1,:].fill_(0)

    def forward(self, x):
        # Computes the outputs / predictions
        x = x/self.temperature
        mult = torch.matmul(x.transpose(1, 2), self.coeffs / torch.sum(self.coeffs, axis=0))
        return mult.squeeze()

class TTARegression(nn.Module):
    def __init__(self, n_augs, n_classes, temp_scale=1, initialization='even'):
        super().__init__()
        
        if initialization == 'even':
            self.coeffs = nn.Parameter(torch.randn((n_augs, n_classes), requires_grad=True, dtype=torch.float))
            self.coeffs.data.fill_(1.0/n_augs) 
        else:
            coeffs = torch.cat([torch.Tensor(initialization) for i in range(n_classes)], axis=1)
            self.coeffs = nn.Parameter(coeffs, requires_grad = True)

        self.temperature = temp_scale
    
    def forward(self, x):
        # Computes the outputs / predictions
        x = x/self.temperature
        mult = self.coeffs * x
        return mult.sum(axis=1)
dataset, n_classes, model_name = 'stl10', 10, 'stl10_cnn'
#dataset, n_classes, model_name = 'cifar100', 100, 'cifar100_cnn'
policy = 'five_crop_hflip_scale'
n_augs = 30
agg_models_dir = '../' + dataset + '/' + policy + '/agg_models'

aug_name = 'combo'
aug_model_path = agg_models_dir + '/'+model_name+'/'+aug_name + '/partial_lr.pth'
aug_model = TTAPartialRegression(n_augs,n_classes,1,'even')
aug_model.load_state_dict(torch.load(aug_model_path))
aug_model.eval()
aug_coeffs = aug_model.coeffs.detach().cpu().numpy()

class_model_path = agg_models_dir + '/'+model_name+'/'+aug_name + '/full_lr.pth'
class_model = TTARegression(n_augs,n_classes,1,'even')
class_model.load_state_dict(torch.load(class_model_path))
class_model.eval()
class_coeffs = class_model.coeffs.detach().cpu().numpy()

Using GPU:0


In [3]:
aug_coeffs, class_coeffs

(array([[0.078957  ],
        [0.        ],
        [0.        ],
        [0.078957  ],
        [0.        ],
        [0.        ],
        [0.078957  ],
        [0.        ],
        [0.        ],
        [0.078957  ],
        [0.        ],
        [0.        ],
        [0.078957  ],
        [0.        ],
        [0.        ],
        [0.04132787],
        [0.00155485],
        [0.0460134 ],
        [0.04132787],
        [0.00155485],
        [0.0460134 ],
        [0.04132787],
        [0.00155485],
        [0.0460134 ],
        [0.04132787],
        [0.00155485],
        [0.0460134 ],
        [0.04132787],
        [0.00155485],
        [0.0460134 ]], dtype=float32),
 array([[5.87668642e-02, 3.77280079e-02, 5.17027862e-02, 2.57664509e-02,
         4.55461666e-02, 2.34682858e-02, 3.94460186e-02, 3.06818392e-02,
         6.33720076e-03, 3.29748504e-02],
        [5.43353474e-03, 3.78886750e-03, 6.74395543e-03, 9.94405895e-03,
         0.00000000e+00, 0.00000000e+00, 2.58740336e-02, 0.000