In [46]:

import torch
import torch.nn as nn
import torch.nn.functional as F

n_attrs = 8

num = [6, 9, 8, 6, 5, 5, 10, 5]



In [44]:
class MultiTaskMLP(nn.Module):
    def __init__(self, feature_size, attr_nums):
        super(MultiTaskMLP, self).__init__()
        self.feature_size = feature_size
        self.attr_nums = attr_nums

        # Shared layers
        self.shared_fc1 = nn.Linear(self.feature_size, 512)
        self.shared_fc2 = nn.Linear(512, 256)

        # Task-specific output layers
        self.output_layers = nn.ModuleList([nn.Linear(256, val + 1) for val in self.attr_nums])

        self.sigmoid = nn.Sigmoid()

    def forward(self, x, a):
        # Shared layers
        x = F.relu(self.shared_fc1(x))
        x = F.relu(self.shared_fc2(x))

        # Task-specific output
        output = []
        for i in range(x.size(0)):
            task_out = F.softmax(self.output_layers[a[i]](x[i]), dim=0)
            output.append(task_out)
        # output = torch.stack(output)
        return output
    

class ValueCrossEntropy(nn.Module):
    def __init__(self):
        super(ValueCrossEntropy, self).__init__()
    
    def forward(self, y_pred, value):
        loss = 0.
        
        for _o, _v in zip(y_pred, value):
            _loss = F.cross_entropy(_o.unsqueeze(0), torch.tensor([_v], dtype=torch.long))
            print(_loss)
            loss += _loss

        loss /= len(value)

        return loss


mmlp = MultiTaskMLP(feature_size=1024, attr_nums=num)
vce = ValueCrossEntropy()

features = torch.randn(4, 1024)
a = [3, 2, 4, 2]
v = [1, 2, 0, 1]

out = mmlp(features, a)

out


[tensor([0.1481, 0.1437, 0.1637, 0.1235, 0.1505, 0.1375, 0.1331],
        grad_fn=<SoftmaxBackward0>),
 tensor([0.1138, 0.1242, 0.1218, 0.0947, 0.1084, 0.1214, 0.1031, 0.1014, 0.1112],
        grad_fn=<SoftmaxBackward0>),
 tensor([0.1523, 0.1832, 0.1560, 0.2086, 0.1755, 0.1244],
        grad_fn=<SoftmaxBackward0>),
 tensor([0.1171, 0.1472, 0.1190, 0.0989, 0.1022, 0.1124, 0.0898, 0.1096, 0.1037],
        grad_fn=<SoftmaxBackward0>)]

In [45]:
loss = vce(out, v)
loss

tensor(1.9451, grad_fn=<NllLossBackward0>)
tensor(2.1866, grad_fn=<NllLossBackward0>)
tensor(1.8065, grad_fn=<NllLossBackward0>)
tensor(2.1612, grad_fn=<NllLossBackward0>)


tensor(2.0249, grad_fn=<DivBackward0>)

In [52]:
class ValueEmbedding(nn.Module):
    def __init__(self,attr_nums, embed_size):
        super(ValueEmbedding, self).__init__()
        self.attr_nums = attr_nums # [6, 9, 8, 6, 5, 5, 10, 5]
        self.n_attrs = len(self.attr_nums)

        self.value_embedding  = nn.ModuleList(
            [torch.nn.Embedding(value, embed_size) for value in self.attr_nums]
        )

    def forward(self, a, v):
        out = []
        for one_a, one_v in zip(a, v):
            print(one_a, one_v)
            one_out = self.value_embedding[one_a](torch.tensor(one_v, dtype=torch.long))
            out.append(one_out)
            print(one_out.shape)
        return torch.stack(out)
    
vEmbed = ValueEmbedding(num, 512)
a = [3, 2, 4, 2]
v = [1, 2, 0, 1]
out = vEmbed(a, v)
out.shape


3 1
torch.Size([512])
2 2
torch.Size([512])
4 0
torch.Size([512])
2 1
torch.Size([512])


torch.Size([4, 512])