In [3]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [19]:
class mask_learning(nn.Module):
    def __init__(self):
        super().__init__()
        self.logits1 = nn.Parameter(torch.randn(3))
        self.logits2 = nn.Parameter(torch.randn(3))
        self.lin1 = nn.Linear(3, 2)

    def forward(self, x, temp):
        mask1 = F.gumbel_softmax(self.logits1, tau=temp, hard=True)
        mask2 = F.gumbel_softmax(self.logits2, tau=temp, hard=True)
        mask = mask1 + mask2
        mask = torch.clamp(mask, min=0, max=1)
        y_pred = self.lin1(mask * x)

        return y_pred, [self.logits1, self.logits2]
    
    def softmax(self):
        return F.softmax(self.logits1), F.softmax(self.logits2)


In [13]:
x = torch.randn(1000, 3)
y = x[:, 1:]

In [16]:
model = mask_learning(temp=1)
print("initial distribution")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)
optimizer = torch.optim.Adam(model.parameters())

loss_fn = nn.MSELoss()


initial distribution
logits1 tensor([ 0.1910, -0.5794, -1.2352])
logits2 tensor([-1.0291, -1.2941,  2.6254])
lin1.weight tensor([[-0.3464,  0.3120, -0.4000],
        [ 0.0294, -0.3264, -0.0787]])
lin1.bias tensor([-0.5049, -0.4018])


In [17]:
# Using MAE loss, distribution converges to target
for epoch in range(10001):
    optimizer.zero_grad()
    y_pred, logits = model(x)
    # loss = kl_loss(out, ground_truth)
    loss = loss_fn(y_pred, y)
    if epoch % 1000 == 0:
        with torch.no_grad():
            print(model.softmax())
            print(float(loss))
    loss.backward()
    optimizer.step()
    

  return F.softmax(self.logits1), F.softmax(self.logits2)


(tensor([0.5872, 0.2717, 0.1411]), tensor([0.0247, 0.0190, 0.9563]))
1.332318902015686
(tensor([0.2416, 0.6363, 0.1220]), tensor([0.0148, 0.0291, 0.9560]))
0.06507055461406708
(tensor([0.1183, 0.7964, 0.0853]), tensor([0.0106, 0.0310, 0.9584]))
0.45338794589042664
(tensor([0.0710, 0.8758, 0.0532]), tensor([0.0090, 0.0297, 0.9614]))
4.3489377276273444e-05
(tensor([0.0483, 0.9126, 0.0391]), tensor([0.0076, 0.0278, 0.9646]))
1.7782284089662426e-07
(tensor([0.0348, 0.9362, 0.0291]), tensor([0.0067, 0.0244, 0.9690]))
6.974802545300918e-06
(tensor([0.0256, 0.9500, 0.0243]), tensor([0.0055, 0.0201, 0.9745]))
1.86763300007442e-05
(tensor([0.0197, 0.9607, 0.0196]), tensor([0.0047, 0.0177, 0.9776]))
9.355631647167684e-08
(tensor([0.0161, 0.9673, 0.0166]), tensor([0.0039, 0.0140, 0.9820]))
7.748096322757192e-06
(tensor([0.0127, 0.9737, 0.0136]), tensor([0.0033, 0.0117, 0.9850]))
7.504349014197942e-06
(tensor([0.0104, 0.9779, 0.0117]), tensor([0.0031, 0.0105, 0.9864]))
9.683490134193562e-07


In [21]:
# Temperature annealing

model2 = mask_learning()
print("initial distribution")
for name, param in model2.named_parameters():
    if param.requires_grad:
        print(name, param.data)
optimizer2 = torch.optim.Adam(model2.parameters())

temp = 10
max_epoch = 10000

def temp_schedule(epoch):
    return 10 * (1 - epoch / max_epoch) + 1e-5

for epoch in range(max_epoch):
    
    optimizer2.zero_grad()
    y_pred, logits = model2(x, temp_schedule(epoch))
    # loss = kl_loss(out, ground_truth)
    loss = loss_fn(y_pred, y)
    if epoch % 1000 == 0:
        with torch.no_grad():
            print(model2.softmax())
            print(float(loss))
    loss.backward()
    optimizer2.step()
    

initial distribution
logits1 tensor([0.5194, 0.3145, 0.9577])
logits2 tensor([ 1.0598,  1.1816, -0.1711])
lin1.weight tensor([[-0.1794, -0.1443, -0.1944],
        [ 0.2465, -0.1393,  0.3290]])
lin1.bias tensor([-0.4425,  0.1930])
(tensor([0.2972, 0.2421, 0.4607]), tensor([0.4130, 0.4664, 0.1206]))
1.2145395278930664


  return F.softmax(self.logits1), F.softmax(self.logits2)


(tensor([0.0897, 0.1742, 0.7361]), tensor([0.1811, 0.5337, 0.2852]))
0.12098657339811325
(tensor([0.0320, 0.2499, 0.7181]), tensor([0.0543, 0.7009, 0.2448]))
0.005308471620082855
(tensor([0.0133, 0.2443, 0.7424]), tensor([0.0229, 0.7195, 0.2577]))
0.00020985625451430678
(tensor([0.0061, 0.2397, 0.7542]), tensor([0.0104, 0.7277, 0.2620]))
0.00015841623826418072
(tensor([0.0030, 0.2373, 0.7597]), tensor([0.0050, 0.7383, 0.2567]))
0.0001733555836835876
(tensor([0.0017, 0.2462, 0.7522]), tensor([0.0026, 0.7602, 0.2372]))
0.00017947357264347374
(tensor([0.0009, 0.2253, 0.7737]), tensor([0.0015, 0.7550, 0.2435]))
0.5191536545753479
(tensor([6.3644e-04, 2.2100e-01, 7.7836e-01]), tensor([0.0010, 0.7833, 0.2157]))
5.352523294277489e-05
(tensor([5.0899e-04, 1.8282e-01, 8.1667e-01]), tensor([7.8806e-04, 7.9873e-01, 2.0048e-01]))
5.673144187312573e-05


In [47]:
# n_input  == n_mask
class mask_learning_case_2(nn.Module):
    def __init__(self, n_mask = 3, n_output = 2):
        super().__init__()
        self.n_mask = n_mask
        self.n_output = n_output
        self.logits = nn.Parameter(torch.randn(self.n_mask, 2))
        self.lin1 = nn.Linear(self.n_mask, self.n_output)

    def forward(self, x, temp):
        mask = F.gumbel_softmax(self.logits, tau=temp, hard=True) #samples 1 for each row
        mask = mask[:, 0] # 0th idx is "selected" and 1th idx is "not selectd"
        # print(self.logits.shape)
        # print(mask.shape, x.shape)

        y_pred = self.lin1(mask.reshape(-1, ) * x)

        return y_pred, [self.logits]
    
    def softmax(self):
        return F.softmax(self.logits, dim = 1)


In [48]:
# Temperature annealing

model3 = mask_learning_case_2()
print("initial distribution")
for name, param in model3.named_parameters():
    if param.requires_grad:
        print(name, param.data)
optimizer3 = torch.optim.Adam(model3.parameters())

temp = 10
max_epoch = 10000


for epoch in range(max_epoch):
    
    optimizer3.zero_grad()
    y_pred, logits = model3(x, temp_schedule(epoch))
    # loss = kl_loss(out, ground_truth)
    loss = loss_fn(y_pred, y)
    if epoch % 1000 == 0:
        with torch.no_grad():
            print(model3.softmax())
            print(float(loss))
    loss.backward()
    optimizer3.step()
    

initial distribution
logits tensor([[-0.3636, -0.8735],
        [-0.3875,  0.7121],
        [-1.0806, -0.2381]])
lin1.weight tensor([[-0.3462, -0.2087, -0.2830],
        [-0.0008,  0.1352,  0.3754]])
lin1.bias tensor([0.3326, 0.5095])
tensor([[0.6248, 0.3752],
        [0.2498, 0.7502],
        [0.3010, 0.6990]])
1.2047051191329956
tensor([[0.5403, 0.4597],
        [0.2996, 0.7004],
        [0.7321, 0.2679]])
0.2965875566005707
tensor([[0.5423, 0.4577],
        [0.8787, 0.1213],
        [0.8682, 0.1318]])
0.012684086337685585
tensor([[0.5424, 0.4576],
        [0.9431, 0.0569],
        [0.9167, 0.0833]])
1.99364385480294e-05
tensor([[0.5424, 0.4576],
        [0.9606, 0.0394],
        [0.9483, 0.0517]])
2.014121309912298e-05
tensor([[0.5422, 0.4578],
        [0.9722, 0.0278],
        [0.9654, 0.0346]])
2.79874602711061e-05
tensor([[0.5415, 0.4585],
        [0.9769, 0.0231],
        [0.9741, 0.0259]])
2.3776821933552128e-07
tensor([[0.5403, 0.4597],
        [0.9819, 0.0181],
        [0.981

In [49]:
model3.lin1.weight

Parameter containing:
tensor([[6.5783e-07, 1.0000e+00, 1.4398e-07],
        [2.8003e-09, 2.0142e-09, 1.0000e+00]], requires_grad=True)