In [6]:
import torch
from torch import nn
import torch.fx
from torch.fx.node import Node

from typing import Dict
from torchvision.models import inception_v3, Inception_V3_Weights
from torchvision import transforms, datasets

In [7]:
weights = Inception_V3_Weights.IMAGENET1K_V1
model = inception_v3(weights=weights)
traced_graph = torch.fx.symbolic_trace(model)
x = torch.rand(3, 3, 299, 299)    

In [8]:
class GraphInterperterWithGamma(nn.Module): 
    """
    класс, но с гаммами для нод
    """
    def __init__(self, mod):
        super(GraphInterperterWithGamma, self).__init__()
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())
        gammas = []
        self.gammas_name = {}
        i = 0
        for node in self.graph.nodes:
            if node.op == 'call_module':
                gammas.append(1.0)
                self.gammas_name[str(node)] = i# перевод в str тут для удобства. в реалньых методах это не нужно
                i+=1                        # да и вообще, тут по идее должен быть тензор/параметр
        self.gammas = nn.Parameter(torch.as_tensor(gammas), requires_grad = False)

    def forward(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr
        

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) * self.gammas[self.gammas_name[str(node)]]
            if node.op == 'output':
                return result
                        
            env[node.name] = result
        
        return result
        

In [9]:
gamma_graph = GraphInterperterWithGamma(traced_graph).to('cpu')
for i in range(len(gamma_graph.gammas)):
    gamma_graph.gammas[i] = 1 #0

In [10]:
gamma_graph.gammas_name

{'conv2d_1a_3x3_conv': 0,
 'conv2d_1a_3x3_bn': 1,
 'conv2d_2a_3x3_conv': 2,
 'conv2d_2a_3x3_bn': 3,
 'conv2d_2b_3x3_conv': 4,
 'conv2d_2b_3x3_bn': 5,
 'maxpool1': 6,
 'conv2d_3b_1x1_conv': 7,
 'conv2d_3b_1x1_bn': 8,
 'conv2d_4a_3x3_conv': 9,
 'conv2d_4a_3x3_bn': 10,
 'maxpool2': 11,
 'mixed_5b_branch1x1_conv': 12,
 'mixed_5b_branch1x1_bn': 13,
 'mixed_5b_branch5x5_1_conv': 14,
 'mixed_5b_branch5x5_1_bn': 15,
 'mixed_5b_branch5x5_2_conv': 16,
 'mixed_5b_branch5x5_2_bn': 17,
 'mixed_5b_branch3x3dbl_1_conv': 18,
 'mixed_5b_branch3x3dbl_1_bn': 19,
 'mixed_5b_branch3x3dbl_2_conv': 20,
 'mixed_5b_branch3x3dbl_2_bn': 21,
 'mixed_5b_branch3x3dbl_3_conv': 22,
 'mixed_5b_branch3x3dbl_3_bn': 23,
 'mixed_5b_branch_pool_conv': 24,
 'mixed_5b_branch_pool_bn': 25,
 'mixed_5c_branch1x1_conv': 26,
 'mixed_5c_branch1x1_bn': 27,
 'mixed_5c_branch5x5_1_conv': 28,
 'mixed_5c_branch5x5_1_bn': 29,
 'mixed_5c_branch5x5_2_conv': 30,
 'mixed_5c_branch5x5_2_bn': 31,
 'mixed_5c_branch3x3dbl_1_conv': 32,
 'mixed_5

In [11]:
gamma_graph(x)

InceptionOutputs(logits=tensor([[ 0.2396,  0.4675,  0.6938,  ...,  0.5704,  1.0240,  0.3124],
        [-0.9802, -0.3262, -1.2678,  ..., -0.2163, -0.5276,  0.8134],
        [ 0.6286, -0.3925,  0.7923,  ..., -0.7243, -0.2120, -0.9829]],
       grad_fn=<MulBackward0>), aux_logits=tensor([[-1.7601,  0.7197,  0.0937,  ..., -1.5120,  1.0705,  0.1076],
        [ 1.0559,  0.6736, -1.2077,  ...,  0.0332,  1.0180,  0.7502],
        [ 0.3578, -1.5966,  1.2425,  ...,  0.7635, -1.6698, -0.5169]],
       grad_fn=<MulBackward0>))

In [12]:
# Load data
WORK_DIR = './data'
BATCH_SIZE = 12
NUM_EPOCHS = 5

preprocess = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


dataset = datasets.FakeData(24, (3, 400, 400), 2, transform=preprocess)

dataset_loader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size= BATCH_SIZE,
                                             shuffle=False)

In [22]:
gamma_graph = GraphInterperterWithGamma(traced_graph).to('cpu')
device = "cuda:0"
optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.01,
        momentum=0.0001)
criterion = torch.nn.CrossEntropyLoss()
gamma_graph = gamma_graph.to(device)
step = 0
for epoch in range(NUM_EPOCHS):
    gamma_graph.train()
    running_loss = 0.0
    total = 0
    correct = 0
    with torch.no_grad():
        gamma_graph.gammas.copy_(torch.bernoulli(gamma_graph.gammas))
    
    for images, labels in dataset_loader:
          step += 1
          images = images.to(device)
          labels = labels.to(device)
          print(images.shape)
          outputs, aux_outputs = gamma_graph(images)
          loss1 = criterion(outputs, labels)
          loss2 = criterion(aux_outputs, labels)
          loss = loss1 + 0.4*loss2
          running_loss =+ loss.item() * images.size(0)
        
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          print("epoch: ", epoch)
          print(f"Step [{step * BATCH_SIZE}/{NUM_EPOCHS * len(dataset)}], "
            f"Loss: {loss.item():.8f}.")
          print("Running Loss=",running_loss)
    
          # equal prediction and acc
          _, predicted = torch.max(outputs.data, 1)
          # val_loader total
          total += labels.size(0)
          # add correct
          correct += (predicted == labels).sum().item()

          print(f"Acc: {correct / total:.4f}.")
     
  

torch.Size([12, 3, 299, 299])
epoch:  0
Step [12/120], Loss: 11.06143951.
Running Loss= 132.73727416992188
Acc: 0.0000.
torch.Size([12, 3, 299, 299])
epoch:  0
Step [24/120], Loss: 9.68844223.
Running Loss= 116.26130676269531
Acc: 0.0417.
torch.Size([12, 3, 299, 299])
epoch:  1
Step [36/120], Loss: 6.87118244.
Running Loss= 82.45418930053711
Acc: 0.1667.
torch.Size([12, 3, 299, 299])
epoch:  1
Step [48/120], Loss: 4.87286329.
Running Loss= 58.4743595123291
Acc: 0.4167.
torch.Size([12, 3, 299, 299])
epoch:  2
Step [60/120], Loss: 3.70966291.
Running Loss= 44.51595497131348
Acc: 1.0000.
torch.Size([12, 3, 299, 299])
epoch:  2
Step [72/120], Loss: 1.97249496.
Running Loss= 23.669939517974854
Acc: 0.9583.
torch.Size([12, 3, 299, 299])
epoch:  3
Step [84/120], Loss: 1.12656462.
Running Loss= 13.518775463104248
Acc: 1.0000.
torch.Size([12, 3, 299, 299])
epoch:  3
Step [96/120], Loss: 0.53211695.
Running Loss= 6.385403394699097
Acc: 1.0000.
torch.Size([12, 3, 299, 299])
epoch:  4
Step [108/12