In [13]:
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import resnet34 as resnet
from torchvision import transforms
import zarr
from tqdm.notebook import tqdm
import pickle

In [2]:
resize = transforms.Resize(1024)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()
composed_transform = transforms.Compose([
    #transforms.ToPILImage(),
    #resize,
    to_tensor, normalize])

In [3]:
resnet_model = resnet(pretrained=True)
resnet_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x
    
resnet_model.layer4 = Identity()
resnet_model.avgpool = Identity()
resnet_model.fc = Identity()

In [5]:
layer1 = resnet_model._modules.get("layer1") # 64 channels
layer2 = resnet_model._modules.get("layer2") # 128 channels
layer3 = resnet_model._modules.get("layer3") # 256 channels

In [6]:
layer1_gap = None
layer2_gap = None
layer2_gap = None

def gap1(m, i, o):
    global layer1_gap
    layer1_gap = torch.mean(o.data, (2,3))
h1 = layer1.register_forward_hook(gap1)     

def gap2(m, i, o):
    global layer2_gap
    layer2_gap = torch.mean(o.data, (2,3))
h2 = layer2.register_forward_hook(gap2)     

def gap3(m, i, o):
    global layer3_gap
    layer3_gap = torch.mean(o.data, (2,3))
h3 = layer3.register_forward_hook(gap3)     

In [7]:
# h1.remove()     
# h2.remove()     
# h3.remove()     



In [8]:
idr0017 = zarr.open('data/idr0017.zarr', "r+")
images = idr0017["images"]
#resnet34_2048 = idr0017.create_dataset('resnet34_2048_1', shape=(32928, 64+128+256), chunks=False)
#resnet34_1024 = idr0017["resnet34_1024"]

In [9]:
images.shape[0]


32928

In [10]:
resnet34_512 = np.zeros((32928, 3+64+128+256))

In [11]:
for i in tqdm(range(images.shape[0])):
    t_img = normalize(torch.Tensor(
        images[i, :, 768:768+512, 768:768+512])).unsqueeze(0)
    resnet_model(t_img)
    resnet34_512[i, :] = torch.cat([
        torch.mean(t_img, (2,3)),
        layer1_gap,
        layer2_gap,
        layer3_gap],
        axis=1).squeeze().numpy()

HBox(children=(FloatProgress(value=0.0, max=32928.0), HTML(value='')))




In [12]:
np.save('resnet34_512.npy', resnet34_512)

In [15]:
with open("resnet34_512.pkl", "wb") as f:
    pickle.dump(resnet34_512, f)