In [60]:
import torch
import torch.fx as fx
from torch.fx import symbolic_trace
from torch import nn
from torchvision import models
from torchvision import transforms
from PIL import Image

class ATL(nn.Module):
    def __init__(self, train, Nclasses, Nlayer=3, Pmax=0.4):
        # call constructor from superclass
        super().__init__()
        self.resnet = models.resnet50(pretrained=True)
        layers = self.get_layers(train)
        fm_indicies, resnet_out_size = self.get_featuremaps_idicies(train)
        self.resnet = self._transform(self.resnet, Nlayer, layers, fm_indicies)
        for param in self.resnet.parameters():
            param.requires_grad = False
        self.fcl = nn.Linear(resnet_out_size, Nclasses)
        self.softmax = torch.nn.Softmax(-1)
        
    def get_layers(self, train):
        # Musi zwracać coś w tym stylu, tj iterowalny obiekt z N conv layers
        return ['layer2_0_conv3', 'layer2_1_conv3']
    
    def get_featuremaps_idicies(self, train):
        # Musi zwracać obiekt który ma N iterowalnych rzeczy, z których każda ma ileś indeksów feature map. 
        # W sumie ilość indeksów musi być równa ilości klas * Nfeature z 2.3 w artykule
        return [[1],[1]], 1568
    
    def _transform(self, m: torch.nn.Module, n, layers, idx) -> torch.nn.Module:
        gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
        graph=gm.graph
        final_nodes=[]
        last_node=None

        for node in graph.nodes:
            if node.name in layers:
                final_nodes.append(node)
            if not last_node and len(final_nodes)==n:
                last_node = node
            if node.name == 'output':
                out_node = node

        i=0
        nodes_to_output=[]
        for i in range(n):
            with graph.inserting_after(last_node):
            # Insert a new `call_function` node calling `torch.relu`
                new_node = graph.call_function(torch.tensor,
                                               args=(idx[i],),
                                              kwargs={"dtype":torch.int32})
                last_node = new_node

            with graph.inserting_after(last_node):
                new_node = graph.call_function(torch.index_select,
                                              args=(final_nodes[i], 1, last_node))
                last_node = new_node

            with graph.inserting_after(last_node):
                new_node = graph.call_function(torch.flatten,
                                              args=(last_node,1))
                nodes_to_output.append(new_node)
                last_node = new_node
        with graph.inserting_after(last_node):
                new_node = graph.call_function(torch.cat,
                                              args=(nodes_to_output,1))
        out_node.args=(new_node,)
        graph.eliminate_dead_code()
        graph.lint() 
        gm.recompile()

        return gm
        
        
    
    def forward(self, x):
        x = self.resnet(x)
        x = self.fcl(x)
        x = self.softmax(x)
        return x

# testy
transform = transforms.Compose([            #[1]
             transforms.Resize(256),                    #[2]
             transforms.CenterCrop(224),                #[3]
             transforms.ToTensor(),                     #[4]
             transforms.Normalize(                      #[5]
             mean=[0.485, 0.456, 0.406],                #[6]
             std=[0.229, 0.224, 0.225]                  #[7]
             )])
img = Image.open("dog.jpg")
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)


m = ATL('chuj', 5, 2)
m.eval()
out=m(batch_t)
out.size()
#print(symbolic_trace(m.resnet).code)

torch.Size([1, 5])