In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from collections import OrderedDict
import torch.nn.functional as F
from torch.optim import Adam
import time
from tqdm import tqdm

In [3]:
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(OrderedDict())
        self.features.add_module("conv1", nn.Conv2d(1, 2, kernel_size=2, stride=1, bias=False))
        self.features.add_module("conv2", nn.Conv2d(2, 2, kernel_size=2, stride=1, bias=False))   
        self.classifier = nn.Linear(2, 1, bias=False)
        
    def forward(self, inp):          
        # some code
        features = self.features(inp)
        features = features.view(1, -1)
        output = self.classifier(features)
        output = F.sigmoid(output)
        return output 

In [4]:
my_net = MyNet().cuda()

In [5]:
class ExtractFeatures(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, inp):
        kernel_size = (inp.size()[2], inp.size()[3])  
        self.pooled_features = F.avg_pool2d(inp, kernel_size)
        return inp
    
class MyNet2(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.features = nn.Sequential(OrderedDict())
        self.all_pooled_features = []
        self.attribute_weights = nn.Linear(4, 100000, bias=False)
        count = 0
        for c in net.children():
            if isinstance(c, nn.Sequential):
                for mod in c:
                    self.features.add_module(f"conv_{count}", mod)
                    pooled_features = ExtractFeatures()
                    self.features.add_module(f"global_pooled_features_{count}", pooled_features)
                    self.all_pooled_features.append(pooled_features)
                    count += 1
        
    def forward(self, inp):
    
        _ = self.features(inp)
        all_pooled_features = [mod.pooled_features for mod in self.all_pooled_features]
        all_pooled_features = torch.cat(all_pooled_features, dim=1).squeeze()
        all_pooled_features = all_pooled_features.unsqueeze(0)
        scores = self.attribute_weights(all_pooled_features)
        return scores

In [6]:
my_net2 = MyNet2(my_net).cuda()

In [7]:
optimizer = Adam(my_net2.attribute_weights.parameters())
criterion = nn.MSELoss(reduce=False)

In [9]:
inp = Variable(torch.rand(1, 1, 3, 3)).cuda()
target = Variable(torch.rand(1, 100000)).cuda()
weights = Variable(torch.rand(1, 100000)).cuda()
for i in tqdm(range(1000)):
    output = my_net2(inp)
    loss = criterion(output, target)
    loss = torch.mean(loss, dim=0).unsqueeze(0)
    torch.autograd.backward(loss, weights)
    optimizer.step()
    time.sleep(0.5)

  2%|▏         | 24/1000 [00:12<08:12,  1.98it/s]

KeyboardInterrupt: 