In [None]:
import matplotlib.pyplot as plt
import os
import sys

sys.path.append('..')
os.environ.update(dict(CUDA_VISIBLE_DEVICES='3'))

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm import tqdm

from torch.distributions import MultivariateNormal
from models.Res import ResNet, resnet50

from analysis import *

In [None]:
class ResNetWrapper(nn.Module):
    def __init__(self, net: ResNet, i_block=0):
        super().__init__()
        self.net = net
        self.i_block = i_block
    
    def forward(self, x, return_feature=False, return_feature_only=False):
        if self.i_block <= 0:
            x = self.net.conv1(x)
            x = self.net.bn1(x)
            x = self.net.relu(x)
            x = self.net.maxpool(x)

        if self.i_block <= 1:
            x = self.net.layer1(x)
            
        if self.i_block <= 2:
            x = self.net.layer2(x)

        if self.i_block <= 3:
            x = self.net.layer3(x)
        # if return_feature:
        #     feature = x
        if self.i_block <= 4:
            x = self.net.layer4(x)
            x = self.net.avgpool(x)
            x = x.reshape(x.size(0), -1)
            
        if return_feature:
            feature = x
        x = self.net.fc(x)

        if return_feature:
            if return_feature_only:
                return feature
            else:
                return x, feature

        else:
            return x

In [None]:
L1, L2, L3, L4, FT = torch.load('/ssd1/tta/imagenet_val_resnet50_distributions.pth')
#mean, ncov, ninv, V, L

In [None]:
L1['mean'].shape

In [None]:
d = MultivariateNormal(L4['mean'], L4['cov'])

In [None]:
model = resnet50(pretrained=True)
model = ResNetWrapper(model, 5)
model.eval()

In [None]:
x = model(d.sample_n(3).float().view(3, 2048))
x.softmax(1).argmax(1)

In [None]:
ddd = torch.load('/ssd1/tta/inc/inc_all_resnet50_bn_INC0-5_00.pth')
f0 = ddd['features'][0][0]
f1 = ddd['features'][1][0]
f2 = ddd['features'][2][0]
f3 = ddd['features'][3][0]

In [None]:
ddd.keys()

In [None]:
f3.size()

In [None]:
with torch.no_grad():
    w1 = model(f3)
    w2 = model(f3.mean((-1, -2)).view(256, 2048, 1, 1))

In [None]:
w1.softmax(1).argmax(1)[:10]

In [None]:
w2.softmax(1).argmax(1)[:10]