In [None]:
import matplotlib.pyplot as plt
import os
import sys

sys.path.append('..')
os.environ.update(dict(CUDA_VISIBLE_DEVICES='3'))

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm import tqdm

from torch.distributions import MultivariateNormal
from models.Res import ResNet, resnet50

from analysis import *

def hook(module, args, output):
    module._output = output.data
    return output

class ResNetWrapper(nn.Module):
    def __init__(self, model):
        super(ResNetWrapper, self).__init__()
        self.model = model
        self.model.requires_grad_(False)
        self.image = torch.autograd.Variable(torch.randn(1, 3, 224, 224).cuda(), requires_grad=True)
    
    def forward(self):
        return self.model(self.image)

In [None]:
model = resnet50(pretrained=True).cuda()
model = ResNetWrapper(model)
model.eval()

model.model.layer3.register_forward_hook(hook)

In [None]:
L3 = torch.load('/ssd1/tta/inc/inc_all_resnet50_bn_INC0-5_00.pth')['features'][2]

In [None]:
L3 = L3.reshape(-1, 1024, 14, 14).cuda()
L3.shape

In [None]:
model.requires_grad_(False)
model = model.cuda()
optimizer = torch.optim.Adam((model.image,), lr=1)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.999)

ft = L3[1].unsqueeze(0)
criterion = nn.MSELoss()
progress = tqdm(range(1000))
for _ in progress:
    model()
    loss = criterion(ft, model.model.layer3._output)
    progress.set_postfix_str(f'{loss.item():.4f}')
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()


In [None]:
import torchvision.transforms as T
transform = T.Compose([ 
    T.Normalize(mean = [ 0., 0., 0. ], std = [ 1/0.229, 1/0.224, 1/0.225 ]), 
    T.Normalize(mean = [ -0.485, -0.456, -0.406 ], std=[ 1., 1., 1. ]),
    # T.ToPILImage()
    ])
plt.imshow(transform(model.image[0].detach().cpu()))

In [None]:
transform(model.image[0].detach().cpu())