In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [13]:
class mask_learning_case(nn.Module):
    def __init__(self, n_mask=10, n_output=2, gene_cnt=550):
        super().__init__()
        self.n_mask = n_mask
        self.n_output = n_output
        self.gene_cnt = gene_cnt
        self.logits = nn.Parameter(torch.randn(self.n_mask, self.gene_cnt))
        self.lin1 = nn.Linear(self.n_mask, self.n_output)

    def forward(self, x, temp):
        mask = F.gumbel_softmax(self.logits, tau=temp, hard=True)  # samples 1 for each row
        mask = mask[:, 0]  # 0th idx is "selected" and 1th idx is "not selectd"
        # print(self.logits.shape)
        # print(mask.shape, x.shape)

        y_pred = self.lin1(
            mask.reshape(
                -1,
            )
            * x
        )

        return y_pred, [self.logits]

    def softmax(self):
        return F.softmax(self.logits, dim=1)

In [14]:
# in this first case, we use each gumbel_softmask to select a unique feature (# gumbel-softmax = number of features to choose). concrete distribution chooses 1 out of m (where m is total # of features)
# in the second cae, we use a gumbel_softmask for each feature (# gumbel-softmax - total # of features): yes vs no. concrete distribution chooses 1 out of 2
class mask_learning(nn.Module):
    def __init__(self):
        super().__init__()
        self.logits1 = nn.Parameter(torch.randn(3))
        self.logits2 = nn.Parameter(torch.randn(3))
        self.lin1 = nn.Linear(3, 2)

    def forward(self, x, temp):
        mask1 = F.gumbel_softmax(self.logits1, tau=temp, hard=True)
        mask2 = F.gumbel_softmax(self.logits2, tau=temp, hard=True)
        mask = (
            mask1 + mask2
        )  # masks are used to select 2 indices out of 3; we want them to select the first and second.
        mask = torch.clamp(
            mask, min=0, max=1
        )  # clamping is done so that if both masks pick the same element, 2 becomes 1
        y_pred = self.lin1(
            mask * x
        )  # elementwise multiplication, mask is broadcasted for all rows of x. self.lin1 converts 3 dim into 2 dim. We want lin1 to just select the positive elements from mask * x.
        # but lin1 could just learn a function to map an "incorrect" mask*x to the correct y labels (since task is not complex either)
        return y_pred, [self.logits1, self.logits2]

    def softmax(self):
        return F.softmax(self.logits1), F.softmax(self.logits2)

In [15]:
x = torch.randn(1000, 3)
y = x[:, 1:]

In [16]:
model = mask_learning()
print("initial distribution")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)
optimizer = torch.optim.Adam(model.parameters())

loss_fn = nn.MSELoss()

initial distribution
logits1 tensor([-1.0162, -1.0681,  0.0481])
logits2 tensor([ 1.1409,  0.1846, -0.3736])
lin1.weight tensor([[ 0.1511,  0.1861,  0.5604],
        [ 0.3632, -0.0440, -0.1327]])
lin1.bias tensor([-0.4298, -0.0884])


In [17]:
# Using MAE loss, distribution converges to target
for epoch in range(101):
    optimizer.zero_grad()
    y_pred, logits = model(x, 1)
    # loss = kl_loss(out, ground_truth)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimizer.step()

In [18]:
# Temperature annealing

model2 = mask_learning()
print("initial distribution")
for name, param in model2.named_parameters():
    if param.requires_grad:
        print(name, param.data)
optimizer2 = torch.optim.Adam(model2.parameters())

temp = 10
max_epoch = 101


def temp_schedule(epoch):
    return 10 * (1 - epoch / max_epoch) + 1e-5


for epoch in range(max_epoch):
    optimizer2.zero_grad()
    y_pred, logits = model2(x, temp_schedule(epoch))
    # loss = kl_loss(out, ground_truth)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimizer2.step()

initial distribution
logits1 tensor([0.0624, 1.1732, 0.1909])
logits2 tensor([-0.4508,  0.3173, -1.1861])
lin1.weight tensor([[-0.3245,  0.4395, -0.5165],
        [-0.0170, -0.3961,  0.1580]])
lin1.bias tensor([-0.1179, -0.2119])


In [19]:
# n_input  == n_mask
f


class mask_learning_case_2(nn.Module):
    def __init__(self, n_mask=3, n_output=2):
        super().__init__()
        self.n_mask = n_mask
        self.n_output = n_output
        self.logits = nn.Parameter(torch.randn(self.n_mask, 2))
        self.lin1 = nn.Linear(self.n_mask, self.n_output)

    def forward(self, x, temp):
        mask = F.gumbel_softmax(self.logits, tau=temp, hard=True)  # samples 1 for each row
        mask = mask[:, 0]  # 0th idx is "selected" and 1th idx is "not selectd"
        # print(self.logits.shape)
        print(mask.shape, x.shape)

        y_pred = self.lin1(
            mask.reshape(
                -1,
            )
            * x
        )

        return y_pred, [self.logits]

    def softmax(self):
        return F.softmax(self.logits, dim=1)

In [20]:
# Temperature annealing

model3 = mask_learning_case_2()
print("initial distribution")
for name, param in model3.named_parameters():
    if param.requires_grad:
        print(name, param.data)
optimizer3 = torch.optim.Adam(model3.parameters())

temp = 10
max_epoch = 101


for epoch in range(max_epoch):
    optimizer3.zero_grad()
    y_pred, logits = model3(x, 0.1)
    # loss = kl_loss(out, ground_truth)
    loss = loss_fn(y_pred, y)
    # if epoch % 1000 == 0:
    #     with torch.no_grad():
    #         print(model3.softmax())
    #         print(float(loss))
    loss.backward()
    optimizer3.step()

initial distribution
logits tensor([[ 0.9418,  0.3991],
        [ 1.4392, -0.6753],
        [ 0.4566, -1.4735]])
lin1.weight tensor([[ 0.3486, -0.2031, -0.0548],
        [-0.0998,  0.4586, -0.2397]])
lin1.bias tensor([ 0.0207, -0.4857])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
torch.Size([3]) torch.Size([1000, 3])
tor