In [1]:
import torch
import torch.fx as fx
import numpy as np
from torch.fx import symbolic_trace
from torch import nn
import torchvision
from torchvision import models
from torchvision import transforms
from PIL import Image
from scipy.stats import ttest_ind
import torch.optim as optim

class ATL(nn.Module):
    def __init__(self, train, Nclasses, Nlayer=3, Pmax=0.4):
        # call constructor from superclass
        super().__init__()
        self.resnet = models.resnet50(pretrained=True)
        layers = self.get_layers(train)
        
        fm_indicies, resnet_out_size = self.get_featuremaps_idicies(train,layers)        
        self.resnet = self._transform(self.resnet, Nlayer, layers, fm_indicies)
        
        for param in self.resnet.parameters():
            param.requires_grad = False
        self.relu = nn.ReLU()
        self.fcl = nn.Linear(resnet_out_size, Nclasses)
        self.softmax = torch.nn.Softmax(-1)
        
    def get_layers(self, train):
        # Musi zwracać coś w tym stylu, tj iterowalny obiekt z N conv layers
        
        return [3,9,20] # najlepiej indeksy warstw konwolucyjnych 
    
    def get_featuremaps_idicies(self, train, layers, N_feature  = 3):
        
        # obecnie bez pilnowania finalnej liczyby wybranych feature map
        # bez pilnownaia zeby z każdej klasy było n_feature map 
        # bez obliczania progu p value dla każdej klasy w warstwie (ustalona sztywna wartość 0.05)
        
        images, labels = trainset 
        labels_set = set(np.asarray(labels))
                
        conv_layers,model_weights = self.get_resnet_conv_layers()
    
        choose_fm = []
        choose_fm_output_len=0
        # iteruje po warstwach konwolucyjnuch 
        for conv_ind in layers:
            choose_fm_curent_conv_layer = []
            
            # iteruje po mapach             
            for fm_ind in range(len(model_weights[conv_ind])):                
                LAV_vec = []
                
                
                for image in images:
                    out_fm = self.get_feature_map_outputs(image)
                    
                    LAV_vec.append(self.LAV(out_fm[conv_ind][fm_ind]))
                 
                for _class in labels_set:
                    LAV_vec_curent_class = []
                    LAV_vec_other_class = []
                    for i,fm in enumerate(LAV_vec):
                        if labels[i]==_class:
                            LAV_vec_curent_class.append(fm)
                        else:     
                            LAV_vec_other_class.append(fm)
                    
                    # zwraca nan przy LAV_vec_current_class długości 1 !!
                    t_stat, p = ttest_ind(LAV_vec_curent_class, LAV_vec_other_class)
                    #print("t_stat, p", t_stat, p )
                    
                    # TODO pl = p_max * Rl/Rmax
                    p_treshold = 0.05 # testowo
                    
                    if p < p_treshold:
                        choose_fm_curent_conv_layer.append(fm_ind)
                        
                        map_dim = out_fm[conv_ind][fm_ind]
                        #print(map_dim)
                        
                        choose_fm_output_len+= map_dim[0].size()[0] **2
                        print(map_dim[0].size()[0] **2)
                        
                        print((conv_ind,fm_ind))
            choose_fm.append(choose_fm_curent_conv_layer)
        
        print (choose_fm)
        print (choose_fm_output_len)
        
        # Musi zwracać obiekt który ma N iterowalnych rzeczy, z których każda ma ileś indeksów feature map. 
        # W sumie ilość indeksów musi być równa ilości klas * Nfeature z 2.3 w artykule
        #return [[1],[1]], 1568
        
        return choose_fm, choose_fm_output_len  # na razie zwraca coś takiego: [[30, 94], [8, 100, 155, 187, 188], [49, 51, 90]]
        

    
    def LAV(self,featureMap):
        return featureMap.max().detach().numpy().item(0)
    
    def _transform(self, m: torch.nn.Module, n, layers, idx) -> torch.nn.Module:
        gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
        graph=gm.graph
        blocks_in_layers = [1,3,4,6,3]
        idx_to_layer_name = ['conv1'] + [f"layer{i}_{j}_conv{k}" for i in range(1,5) for j in range(blocks_in_layers[i]) for k in range(1,4)]
        layer_names = [idx_to_layer_name[idx] for idx in layers]
        final_nodes=[]
        last_node=None

        for node in graph.nodes:
            if node.name in layer_names:
                final_nodes.append(node)
            if not last_node and len(final_nodes)==n:
                last_node = node
            if node.name == 'output':
                out_node = node

        i=0
        nodes_to_output=[]
        for i in range(n):
            with graph.inserting_after(last_node):
            # Insert a new `call_function` node calling `torch.relu`
                new_node = graph.call_function(torch.tensor,
                                               args=(idx[i],),
                                              kwargs={"dtype":torch.int32})
                last_node = new_node

            with graph.inserting_after(last_node):
                new_node = graph.call_function(torch.index_select,
                                              args=(final_nodes[i], 1, last_node))
                last_node = new_node

            with graph.inserting_after(last_node):
                new_node = graph.call_function(torch.flatten,
                                              args=(last_node,1))
                nodes_to_output.append(new_node)
                last_node = new_node
        with graph.inserting_after(last_node):
                new_node = graph.call_function(torch.cat,
                                              args=(nodes_to_output,1))
        out_node.args=(new_node,)
        graph.eliminate_dead_code()
        graph.lint() 
        gm.recompile()

        return gm
        
        
    
    def forward(self, x):
        x = self.resnet(x)
        x = self.relu(x)
        x = self.fcl(x)
        x = self.softmax(x)
        return x
    
    
    def get_resnet_conv_layers(self):
        model = self.resnet
        # we will save the conv layer weights in this list
        model_weights =[]

        #we will save the 49 conv layers in this list
        conv_layers = []

        # get all the model children as list
        model_children = list(model.children())

        #counter to keep count of the conv layers
        counter = 0
        #append all the conv layers and their respective wights to the list

        for i in range(len(model_children)):

            if type(model_children[i]) == nn.Conv2d:
                counter+=1
                model_weights.append(model_children[i].weight)
                conv_layers.append(model_children[i])

            elif type(model_children[i]) == nn.Sequential:
                for j in range(len(model_children[i])):
                    for child in model_children[i][j].children():
                        if type(child) == nn.Conv2d:
                            counter+=1
                            model_weights.append(child.weight)
                            conv_layers.append(child)

        #print(f"Total convolution layers: {counter}")
        #print("conv_layers")
        return (conv_layers,model_weights)
    
    def get_feature_map_outputs(self,image):
        conv_layers, _ = self.get_resnet_conv_layers()
        outputs = []
        names = []

        for layer in conv_layers[0:]:
            image = layer(image)
            outputs.append(image)
            names.append(str(layer))
            
        #print(len(outputs))

        # print feature_maps
        #for feature_map in outputs:
            #print(feature_map.shape)
            
        return outputs;   


# testy
transform = transforms.Compose([            #[1]
             transforms.Resize(256),                    #[2]
             transforms.CenterCrop(224),                #[3]
             transforms.ToTensor(),                     #[4]
             transforms.Normalize(                      #[5]
             mean=[0.485, 0.456, 0.406],                #[6]
             std=[0.229, 0.224, 0.225]                  #[7]
             )])


#img = Image.open("dog.jpg")
#img = Image.open('img/cat.jpg')
#img_t = transform(img)
#batch_t = torch.unsqueeze(img_t, 0)



# CIFAR10 
import torchvision
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=19,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


torch.manual_seed(0)
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
trainset = images, labels
#print(trainset)


m = ATL(trainset, 10, 2)
m.eval()







#optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, *, maximize=False)





# out=m(batch_t)
# out.size()

#print(symbolic_trace(m.resnet).code)

Files already downloaded and verified
Files already downloaded and verified


  return _methods._var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  ret = ret.dtype.type(ret / rcount)


256
(3, 3)
256
(3, 5)
256
(3, 5)
256
(3, 6)
256
(3, 7)
256
(3, 8)
256
(3, 8)
256
(3, 9)
256
(3, 10)
256
(3, 10)
256
(3, 12)
256
(3, 17)
256
(3, 18)
256
(3, 24)
256
(3, 29)
256
(3, 34)
256
(3, 35)
256
(3, 43)
256
(3, 43)
256
(3, 45)
256
(3, 46)
256
(3, 49)
256
(3, 51)
256
(3, 52)
256
(3, 54)
256
(3, 55)
256
(3, 59)
256
(3, 59)
256
(3, 60)
256
(3, 63)
256
(3, 64)
256
(3, 67)
256
(3, 70)
256
(3, 73)
256
(3, 73)
256
(3, 74)
256
(3, 83)
256
(3, 87)
256
(3, 87)
256
(3, 89)
256
(3, 92)
256
(3, 93)
256
(3, 93)
256
(3, 94)
256
(3, 94)
256
(3, 98)
256
(3, 98)
256
(3, 100)
256
(3, 101)
256
(3, 102)
256
(3, 105)
256
(3, 107)
256
(3, 109)
256
(3, 109)
256
(3, 110)
256
(3, 110)
256
(3, 118)
256
(3, 120)
256
(3, 123)
256
(3, 124)
256
(3, 125)
256
(3, 125)
256
(3, 126)
256
(3, 128)
256
(3, 129)
256
(3, 129)
256
(3, 130)
256
(3, 131)
256
(3, 132)
256
(3, 134)
256
(3, 136)
256
(3, 139)
256
(3, 141)
256
(3, 145)
256
(3, 145)
256
(3, 147)
256
(3, 148)
256
(3, 152)
256
(3, 155)
256
(3, 155)
256
(3, 157)
25

ATL(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Module(
      (0): Module(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Module(
          (0): Conv2d(64, 2

In [2]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(m.parameters(), lr=0.001, momentum=0.9)

In [8]:
img_cat = Image.open('img/cat.jpg')
img_cat_t = transform(img_cat)
batch_cat_t = torch.unsqueeze(img_cat_t, 0)

In [12]:




for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        
        inputs, labels = trainset #data
        #print(labels)
    

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        #inputs = transform(inputs)
        
        inputs = batch_cat_t
        labels = torch.tensor([0])
        
        
        outputs = m(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x18624000 and 53696x10)

In [None]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        
        inputs, labels = trainset #data
    

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        inputs = transform(inputs)
        outputs = m(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

In [4]:
images.size()

torch.Size([19, 3, 32, 32])

In [11]:
labels = torch.tensor([0])
labels

tensor([0])