In [11]:
import torch

In [142]:
class CrossStitchBlock(torch.nn.Module):
    """
    """
    def __init__(self,
                 num_tasks: int = 2,
                 num_subspaces: int = 1):
        super(CrossStitchBlock, self).__init__()
        
        self.num_tasks = num_tasks
        self.num_subspaces = num_subspaces
        
        # initialize using random uniform distribution as suggested in Section 5.1
        self.cross_stitch_kernel = torch.FloatTensor(
                                                self.num_tasks * self.num_subspaces,
                                                self.num_tasks * self.num_subspaces).uniform_(
            0.0, 1.0
        )
        # normalize, so that each row will be convex linear combination
        normalizer = torch.sum(self.cross_stitch_kernel, 0, keepdim=True)
        self.cross_stitch_kernel = self.cross_stitch_kernel / normalizer
    

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        print("inputs", inputs)
        print()        

        # concatenate nth element of each input task
        x = torch.cat([torch.unsqueeze(i, -1) for i in inputs], -1)
        print("x", x, x.size())
        print()
        
        print(f"the cross stitch kernel after normalization: \n {self.cross_stitch_kernel}")
        print()

        # multiply every element of the input with the cross-stitch kernel
        stitched_output = torch.matmul(x, self.cross_stitch_kernel)
        print(f"stitched_output:\n{stitched_output} {stitched_output.size()}")

        n, batch, seq, d = x.size()
        # index need to be of the same dimension size as source tensor
        index = torch.zeros(n, batch, seq, 1, dtype=torch.int64)

        # split result into tensors corresponding to specific tasks and return
        # to task-specific lists
        outputs = [
            torch.flatten(torch.gather(e, dim=-1, index=index), start_dim=2)
            for e in torch.split(stitched_output, 1, -1)
        ]

        assert len(outputs) == self.num_tasks
        return outputs


In [144]:
cross_stitch = CrossStitchBlock()

inputs = [torch.randn(1, 2, 2) for _ in range(2)]
output = cross_stitch(inputs)


for i, task_i in enumerate(["A", "B"]):
    for j, task_j in enumerate(["A", "B"]):
        print(f"{task_i} {task_j} {cross_stitch.cross_stitch_kernel[i][j]}")

print(cross_stitch.cross_stitch_kernel)

inputs [tensor([[[-0.3865, -0.4061],
         [-0.7158,  0.4503]]]), tensor([[[-0.5347,  0.0838],
         [-0.6797, -0.9894]]])]

x tensor([[[[-0.3865, -0.5347],
          [-0.4061,  0.0838]],

         [[-0.7158, -0.6797],
          [ 0.4503, -0.9894]]]]) torch.Size([1, 2, 2, 2])

the cross stitch kernel after normalization: 
 tensor([[0.5233, 0.6085],
        [0.4767, 0.3915]])

stitched_output:
tensor([[[[-0.4572, -0.4445],
          [-0.1726, -0.2143]],

         [[-0.6986, -0.7017],
          [-0.2359, -0.1134]]]]) torch.Size([1, 2, 2, 2])
A A 0.5233339667320251
A B 0.608464241027832
B A 0.47666603326797485
B B 0.39153578877449036
tensor([[0.5233, 0.6085],
        [0.4767, 0.3915]])


In [54]:
cross_stitch_kernel = torch.FloatTensor(
                                        2 * 2,
                                        2 * 2).uniform_(
    0.0, 1.0
)
print(cross_stitch_kernel)

tensor([[0.3007, 0.3849, 0.8875, 0.9102],
        [0.0386, 0.0050, 0.3250, 0.7348],
        [0.8492, 0.1911, 0.7196, 0.6900],
        [0.5646, 0.5089, 0.7480, 0.5455]])


In [97]:
inputs = [torch.randn(2, 2, 4) for _ in range(2)]
print(inputs)

[tensor([[[ 0.3692,  0.3376, -0.1420,  0.1280],
         [ 0.7821,  1.6744,  1.9704, -1.8366]],

        [[-0.1301,  1.4170, -0.1446, -0.6735],
         [ 0.3292,  0.7654, -0.4625, -0.0124]]]), tensor([[[ 1.8255, -0.0959, -0.1649,  0.3290],
         [-0.1835,  0.4722,  0.4786,  1.2735]],

        [[ 0.3063, -1.4406,  0.7297, -1.0241],
         [ 0.6807, -0.9482, -0.3396, -0.2354]]])]


In [99]:
o = []

for i in inputs:
    s1, s2 = i.split(2, dim= -1) # split in half on last dim
    x = torch.cat([s1, s2], 0)
    o.append(x)