In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [129]:
TEMPERATURE = 0.01

class Model_1(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(5, 20)
        self.layer_1 = nn.Linear(20, 40)
        self.layer_2 = nn.Linear(40, 3)
    
    def forward(self, x, probabilistic = False):
        if probabilistic:
            embeddings = torch.matmul(x, self.embedding.weight)
        else:
            embeddings = self.embedding(x)
            
        return self.layer_2(self.layer_1(embeddings))
            
class Model_2(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(3, 10)
        self.layer_1 = nn.Linear(10, 20)
        self.layer_2 = nn.Linear(20, 1)
    
    def forward(self, x, probabilistic = False):
        if probabilistic:
            embeddings = torch.matmul(x, self.embedding.weight)
        else:
            embeddings = self.embedding(x)
            
        return self.layer_2(self.layer_1(embeddings))

class GumbelConnection(nn.Module):
    def __init__(self, model_1, model_2, temperature=TEMPERATURE):
        super().__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.temperature = temperature

    def forward(self, x):
        logits = self.model_1(x)
        gumbel_probs = F.gumbel_softmax(logits, tau=self.temperature, dim=2)
        return self.model_2(gumbel_probs, probabilistic = True)

In [130]:
import torch

# Set random seed for reproducibility
torch.manual_seed(42)

# Create batch
samples = torch.randint(0, 5, (10, 10))


In [131]:
model_1 = Model_1()
model_2 = Model_2()
model = GumbelConnection(model_1, model_2)

In [141]:
model(samples)

tensor([[[-0.0817],
         [ 0.6894],
         [-0.2650],
         [-0.2650],
         [ 0.6887],
         [-0.2650],
         [-0.2650],
         [ 0.6894],
         [-0.0818],
         [ 0.6894]],

        [[ 0.6894],
         [-0.0817],
         [-0.2650],
         [-0.0817],
         [-0.0817],
         [-0.2650],
         [-0.0817],
         [-0.2650],
         [-0.0817],
         [ 0.6894]],

        [[ 0.6894],
         [-0.2650],
         [-0.2650],
         [-0.2650],
         [ 0.6856],
         [-0.2650],
         [ 0.6894],
         [-0.2650],
         [-0.2650],
         [-0.2650]],

        [[ 0.6894],
         [-0.0817],
         [-0.0817],
         [-0.0817],
         [-0.2650],
         [-0.0816],
         [-0.2650],
         [ 0.6894],
         [-0.0817],
         [ 0.6894]],

        [[ 0.6894],
         [-0.0817],
         [-0.0817],
         [ 0.6894],
         [-0.2650],
         [-0.2650],
         [-0.0817],
         [ 0.6894],
         [ 0.6894],
         [-0