In [2]:
import timm
import torch
from torch import nn as nn
from torchvision import models
from collections import OrderedDict

In [6]:
# Get weights of the model and add nn.LogSoftmax(dim=1) to the end
model_name = 'resnet18'
model = models.resnet18(pretrained=True)
model = nn.Sequential(model, nn.LogSoftmax(dim=1))

## Convert the last FC layer to a 1x1 convolution
Now, let us derive a ResNet18 model where we convert the last FC layer to a 1x1 convolution and skip the GAP layer
- Feature embedding `512`
- Number of classes: `1000`

In [24]:
res18_model = models.resnet18(pretrained=True)
fc = res18_model.fc.state_dict()

in_ch = 512
out_ch = fc["weight"].size(0)
finalConv = nn.Conv2d(in_ch, out_ch, 1, 1)

### get the weights from the fc layer
finalConv.load_state_dict({"weight":fc["weight"].view(out_ch, in_ch, 1, 1), "bias":fc["bias"]})
res18_conv = nn.Sequential(*list(res18_model.children())[:-2]+[finalConv])

res18_model.eval()
res18_conv.eval()

print()




In [18]:
print(len(fc))

fc['weight'].shape

2


torch.Size([1000, 512])

## RetFound
- patch size: `16x16`
- without overlapping

In [3]:
# Load pre-trained model from hugging face
model = timm.create_model("hf_hub:bitfount/RETFound_MAE", pretrained=True)

model.to('cuda')
model.eval()
print()

  return self.fget.__get__(instance, owner)()





In [19]:
ts_img = torch.randn((1, 3, 224, 224), device='cuda')
out = model(ts_img)

In [20]:
ll = list(model.named_children()) # generator
print(len(ll))
ll[:3]

9


[('patch_embed',
  PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )),
 ('pos_drop', Dropout(p=0.0, inplace=False)),
 ('patch_drop', Identity())]

In [21]:
model.blocks[0]

Block(
  (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
  (attn): Attention(
    (qkv): Linear(in_features=1024, out_features=3072, bias=True)
    (q_norm): Identity()
    (k_norm): Identity()
    (attn_drop): Dropout(p=0.0, inplace=False)
    (proj): Linear(in_features=1024, out_features=1024, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (ls1): Identity()
  (drop_path1): Identity()
  (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
  (mlp): Mlp(
    (fc1): Linear(in_features=1024, out_features=4096, bias=True)
    (act): GELU(approximate='none')
    (drop1): Dropout(p=0.0, inplace=False)
    (norm): Identity()
    (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    (drop2): Dropout(p=0.0, inplace=False)
  )
  (ls2): Identity()
  (drop_path2): Identity()
)

In [22]:
model.blocks[0].attn.qkv

Linear(in_features=1024, out_features=3072, bias=True)

In [24]:
in_ch = model.blocks[0].attn.qkv.in_features
out_ch = model.blocks[0].attn.qkv.out_features

qkv = model.blocks[0].attn.qkv.state_dict()

conv_qkv = nn.Conv2d(in_ch, out_ch, 1, 1)

conv_qkv.load_state_dict({"weight":qkv["weight"].view(out_ch, in_ch, 1, 1), "bias":qkv["bias"]}) # init

model.blocks[0].attn.qkv = conv_qkv

print(in_ch, out_ch)

1024 3072


## Explanable RetFound

In [196]:
# Load pre-trained model from hugging face
model = timm.create_model("hf_hub:bitfount/RETFound_MAE", pretrained=True)

In [197]:
layers = {}

for name, layer in model.named_children():
    layers[name] = layer

layers_before = {}
layer_name_before = ['patch_embed', 'pos_drop', 'patch_drop', 'norm_pre']
for name in layer_name_before:
    layers_before[name] = layers[name]

layers_after = {}
layer_name_after= ['norm', 'fc_norm', 'head_drop']
for name in layer_name_after:
    layers_after[name] = layers[name]


#classification head
n_classes = 5
in_ftrs = model.head.in_features
classifier_ = nn.Conv2d(in_channels=in_ftrs, out_channels=n_classes, kernel_size=1)
classifier = {'classifier': classifier_}

#layers_before = OrderedDict(layers_before)
#self_layers_before = nn.Sequential(layers_before)

#self_layers_before
#layers_before

In [198]:
layers.keys()

dict_keys(['patch_embed', 'pos_drop', 'patch_drop', 'norm_pre', 'blocks', 'norm', 'fc_norm', 'head_drop', 'head'])

In [199]:
main_block = layers['blocks']

for idx, (name, layer) in enumerate(main_block.named_children()):
    #print(f'{name}, \n {layer}')
    
    for name_, layer_ in layer.named_children():
        #print(f'Name: \t {name_}, \n {layer_}')
        
        if (name_ == 'attn') or (name_ == 'mlp'):
            
            for name__, layer__ in layer_.named_children():
                #print(f'Name 2 : \t {name__}, \n {layer__}')

                if isinstance(layer__, nn.Linear): 
                    in_ftrs = layer__.in_features
                    out_ftrs = layer__.out_features
                    
                    # Create a 1x1 convolutional layer
                    conv_layer = nn.Conv2d(in_channels=in_ftrs, out_channels=out_ftrs, kernel_size=1)

                    weight_bias = layer__.state_dict()
                    conv_layer.load_state_dict({"weight": weight_bias["weight"].view(out_ftrs, in_ftrs, 1, 1), "bias": weight_bias["bias"]})
                    #print(main_block[idx].attn.qkv)

                    # dynamic assignment
                    setattr(getattr(main_block[idx], name_), name__, conv_layer)
                    
                    #if name__ == 'qkv':
                    #    main_block[idx].attn.qkv = conv_layer

                    #if name__ == 'proj':
                    #    main_block[idx].name_.proj = conv_layer
                    
                    #print(f'nn.linear: {name__} \t {in_ftrs}, \t {out_ftrs}')
        #break
    #break

In [200]:
layers_before

{'patch_embed': PatchEmbed(
   (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
   (norm): Identity()
 ),
 'pos_drop': Dropout(p=0.0, inplace=False),
 'patch_drop': Identity(),
 'norm_pre': Identity()}

In [201]:
layers_block = {}

for name, layer in main_block.named_children():
    layers_block[name] = layer

layers_block_ = OrderedDict(layers_block)
layers_block_ = nn.Sequential(layers_block_)
layers_block_ = {'blocks': layers_block_}

In [220]:
#main_block

In [204]:
#layers_block_

In [202]:
merged_layers = OrderedDict(list(layers_before.items()) + list(layers_block_.items()) + list(layers_after.items()) + list(classifier.items())) 

new_model = nn.Sequential(merged_layers)

## Class

In [1]:
import timm
import torch
from torch import nn as nn
from torchvision import models
from collections import OrderedDict

In [44]:
class ExplainRetFound(nn.Module):
    def __init__(self, input_size, num_classes):
        super(ExplainRetFound, self).__init__()  
        self.size = input_size//16, input_size//16
        
        # Load pre-trained model from hugging face
        model = timm.create_model("hf_hub:bitfount/RETFound_MAE", pretrained=True)

        layers_before, layers_after = self.get_layer_before_after(model)
        main_block_layer = self.from_linear_to_conv_layers(model)
        
        in_ftrs = model.head.in_features
        cls_head = nn.Conv2d(in_channels=in_ftrs, out_channels=num_classes, kernel_size=1)

        self.norm = nn.GroupNorm(1, 1024, eps=1e-06) # default = eps=1e-05
        layers_after['norm'] = self.norm

        self.layers_before = nn.Sequential(OrderedDict(layers_before))
        self.main_block_layer = nn.Sequential(OrderedDict(main_block_layer)) 
        self.layers_after = nn.Sequential(OrderedDict(layers_after))
        self.classifier = cls_head
        self.avgpool = nn.AvgPool2d(kernel_size=self.size, stride=(1,1), padding=0) 

    def get_layer_before_after(self, model):
        layers = {}   # all layers
        for name, layer in model.named_children():
            layers[name] = layer
            
        layers_before = {}
        layer_name_before = ['patch_embed', 'pos_drop', 'patch_drop', 'norm_pre']
        for name in layer_name_before:
            layers_before[name] = layers[name]

        layers_after = {}
        layer_name_after= ['norm', 'fc_norm', 'head_drop']
        for name in layer_name_after:
            layers_after[name] = layers[name]

        return layers_before, layers_after

    def from_linear_to_conv_layers(self, model):
        blocks = model.blocks
        #print(blocks)

        for idx, (name, layer) in enumerate(blocks.named_children()):
            #print(f'idx \t {idx} \t {name} \n {layer}')
            for name_, layer_ in layer.named_children():
                if (name_ == 'attn') or (name_ == 'mlp'):
                    for name__, layer__ in layer_.named_children():
                        if isinstance(layer__, nn.Linear):
                            in_ftrs = layer__.in_features
                            out_ftrs = layer__.out_features
                            
                            # Create a 1x1 convolutional layer
                            conv_layer = nn.Conv2d(in_channels=in_ftrs, out_channels=out_ftrs, kernel_size=1)
                            weight_bias = layer__.state_dict()
                            conv_layer.load_state_dict({"weight": weight_bias["weight"].view(out_ftrs, in_ftrs, 1, 1), 
                                                        "bias": weight_bias["bias"]})

                            # dynamic assignment
                            setattr(getattr(blocks[idx], name_), name__, conv_layer)

        layers_block = {}
        for name, layer in blocks.named_children():
            layers_block[name] = layer
            
        return layers_block

    def forward(self, x):
        h, w = self.size
    
        # Initial processing
        x = self.layers_before(x)
        bs, hw, c = x.shape  # Assuming input shape is (batch_size, height * width, channels)
    
        # Reshape before block processing
        x = x.view(bs, h, w, c).permute(0, 3, 1, 2)  # (bs, c, h, w)
        
        # Loop through all 24 blocks
        for idx, block in enumerate(self.main_block_layer):
            x = self.norm(x)  # Normalization
            
            #print(f'idx: {idx}, \t {x.shape}')
            # Attention mechanism
            x = block.attn.qkv(x) 
            x = block.attn.q_norm(x)
            x = block.attn.k_norm(x)
            x = x.view(bs, 1024, 3, h, w).sum(dim=2)  # Sum over QKV
            x = block.attn.proj(x)
            x = block.attn.proj_drop(x)
    
            # Skip connection & dropout
            x = block.ls1(x)
            x = block.drop_path1(x)
    
            # Feedforward MLP block
            x = self.norm(x) 
            x = block.mlp.fc1(x)
            x = block.mlp.act(x)
            x = block.mlp.drop1(x)
            x = block.mlp.norm(x)
            x = block.mlp.fc2(x)
            x = block.mlp.drop2(x)
    
            # Skip connection & dropout
            x = block.ls2(x)
            x = block.drop_path2(x)

        x = self.layers_after(x)
        activation = self.classifier(x)
        out = self.avgpool(activation) 
        out = out.view(out.shape[0], -1)    # (bs, n_class)
        
        return out, activation

In [45]:
tmp_model = ExplainRetFound(input_size=224, num_classes=5)
tmp_model.to('cuda')
print()




In [41]:
#tmp_model.layers_before
#tmp_model.layers_after

In [42]:
ts_img = torch.randn((1, 3, 224, 224), device='cuda')

In [47]:
out, acts = tmp_model(ts_img)
out.shape, acts.shape

(torch.Size([1, 5]), torch.Size([1, 5, 14, 14]))

In [36]:
aaa = tmp_model.main_block_layer
#aaa

In [18]:
model = timm.create_model("hf_hub:bitfount/RETFound_MAE", pretrained=True)

In [30]:
3072/3

1024.0

In [114]:
#model

## test

In [10]:
ts_img = torch.randn((1, 3, 224, 224), device='cuda')

new_model.to('cuda')
new_model.eval()

print()




In [11]:
#out = new_model(ts_img)

In [108]:
#layers_block['0']

## Test debug

In [34]:
layers1 = {'0': layers_block['0']}
layers1_ = {'blocks': nn.Sequential(OrderedDict(layers1))}

In [14]:
test_layers = OrderedDict(list(layers_before.items()) ) #+ list(layers1_.items())
test_model = nn.Sequential(test_layers)
#test_model.blocks

In [16]:
test_model.to('cuda')
test_model.eval()
print()




In [24]:
test_model

Sequential(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
)

In [38]:
ts_img = torch.randn((1, 3, 224, 224), device='cuda')

out = test_model(ts_img)

out2 = out.view(1, 14, 14, 1024) 
out2 = out2.permute(0, 3, 1, 2)
out.shape, out2.shape

(torch.Size([1, 196, 1024]), torch.Size([1, 1024, 14, 14]))

In [39]:
#model.blocks[0]

In [21]:
14*14

196

In [41]:
self_conv = nn.Conv2d(1024, 3072, kernel_size=1)
self_conv.to('cuda')

Conv2d(1024, 3072, kernel_size=(1, 1), stride=(1, 1))

In [42]:
self_conv

Conv2d(1024, 3072, kernel_size=(1, 1), stride=(1, 1))

In [45]:
bb = self_conv(out2)
bb.shape

torch.Size([1, 3072, 14, 14])

In [None]:
def forward(self, x):
        h, w = self.size
        
        x = self.layers_before(x)
        bs, hw, c = x.shape # bs, c, h, w
        

        # block 0
        block0 = self.main_block_layer[0] 
        print(x.shape)
        x = x.reshape(bs, h, w, c)
        x = x.permute(0, 3, 1, 2)
        x = self.norm(x) #block0.norm1(x)
        x = block0.attn.qkv(x) #nn.Sequential( )
        x = block0.attn.q_norm(x)
        x = block0.attn.k_norm(x)
        x = x.view(bs, 1024, 3, h, w).sum(dim=2)  # Sum over the new dimension
        x = block0.attn.proj(x)
        x = block0.attn.proj_drop(x)
        
        x = block0.ls1(x)
        x = block0.drop_path1(x)
        print(x.shape)
        x = self.norm(x) #block0.norm2(x)

        x = block0.mlp.fc1(x)
        x = block0.mlp.act(x)
        x = block0.mlp.drop1(x)
        x = block0.mlp.norm(x)
        x = block0.mlp.fc2(x)
        x = block0.mlp.drop2(x)

        x = block0.ls2(x)
        x = block0.drop_path2drop_path2(x)
        return x