# Learning PyTorch for Multimodality

In [1]:
import torch

In [4]:
device = torch.device('cpu')

## Basic Model of Sequential Modules

Example from [here](https://github.com/jcjohnson/pytorch-examples#pytorch-optim)

In [52]:
N, D_in, Z_in, H, H2, D_out = 64, 400, 20, 300, 200, 10

x = torch.randn(N, D_in, device=device)
z = torch.randn(N, Z_in, device=device)
y = torch.randn(N, D_out, device=device)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
).to(device)

loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(50):
    y_pred = model(x)
    
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    optimizer.zero_grad()
    
    loss.backward()
    
    optimizer.step()

0 707.7402954101562
1 694.0547485351562
2 680.6409301757812
3 667.4664306640625
4 654.5130615234375
5 641.800537109375
6 629.3379516601562
7 617.1375732421875
8 605.156494140625
9 593.393798828125
10 581.8489990234375
11 570.5316162109375
12 559.4170532226562
13 548.4935302734375
14 537.7697143554688
15 527.2457885742188
16 516.91748046875
17 506.77734375
18 496.7988586425781
19 486.9962158203125
20 477.3673095703125
21 467.897216796875
22 458.6047058105469
23 449.4712219238281
24 440.4930114746094
25 431.6700744628906
26 422.9800720214844
27 414.42974853515625
28 406.0163879394531
29 397.7267761230469
30 389.5657653808594
31 381.53240966796875
32 373.6302185058594
33 365.85687255859375
34 358.19830322265625
35 350.66436767578125
36 343.2438659667969
37 335.9500427246094
38 328.7704772949219
39 321.715576171875
40 314.7672119140625
41 307.9325866699219
42 301.1885681152344
43 294.5320129394531
44 287.96746826171875
45 281.50494384765625
46 275.1405944824219
47 268.890869140625
48 262.7

In [53]:
model.state_dict().keys()

odict_keys(['0.weight', '0.bias', '2.weight', '2.bias'])

## Expanding to "Multi-Modal" with Custom Class

So rather than stringing together modules like above, we have to create a custom class for our model. Mainly, this is because we are concatenating multiple inputs.

This still hasn't been tested on real data, and I don't know all the "gotchas" of these modules yet, but this works as expected thus far, and creates weight matrices as expected.

In [43]:
class Multimodal(torch.nn.Module):
    def __init__(self, D_in, H, Z_in, H2, D_out):
        super(Multimodal, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.multi = torch.nn.Linear(H + Z_in, H2)
        self.linear2 = torch.nn.Linear(H2, D_out)
        
    def forward(self, x, z):
        h1 = self.linear1(x)
        h1_z = torch.cat([h1, z], dim=1)
        h2 = self.multi(h1_z)
        out = self.linear2(h2)
        return out

In [55]:
N, D_in, Z_in, H, H2, D_out = 64, 400, 20, 300, 200, 10

x = torch.randn(N, D_in, device=device)
z = torch.randn(N, Z_in, device=device)
y = torch.randn(N, D_out, device=device)

model = Multimodal(D_in, H, Z_in, H2, D_out)

loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for t in range(50):
    y_pred = model(x, z)
    
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    optimizer.zero_grad()
    
    loss.backward()
    
    optimizer.step()

0 686.3792114257812
1 669.9608764648438
2 653.8302612304688
3 637.9819946289062
4 622.4102783203125
5 607.1083984375
6 592.0684814453125
7 577.2822265625
8 562.7410278320312
9 548.4363403320312
10 534.3598022460938
11 520.5030517578125
12 506.8583068847656
13 493.4178466796875
14 480.1744384765625
15 467.1211853027344
16 454.2517395019531
17 441.5601806640625
18 429.0411682128906
19 416.68994140625
20 404.5023193359375
21 392.4747314453125
22 380.604248046875
23 368.8885498046875
24 357.3259582519531
25 345.9154357910156
26 334.65643310546875
27 323.54913330078125
28 312.59417724609375
29 301.7928161621094
30 291.1467590332031
31 280.6581726074219
32 270.3296813964844
33 260.16436767578125
34 250.16555786132812
35 240.3369598388672
36 230.6824951171875
37 221.20635986328125
38 211.912841796875
39 202.80638122558594
40 193.8914337158203
41 185.17250061035156
42 176.65402221679688
43 168.3402862548828
44 160.23550415039062
45 152.34361267089844
46 144.66835021972656
47 137.21310424804688

For example, the "multi" weights have input weights of length 320 because we take 300 from the output of `linear1` + 20 from the auxillary input (e.g. z) and then we map this to the output of 200

In [56]:
model.state_dict().keys()

odict_keys(['linear1.weight', 'linear1.bias', 'multi.weight', 'multi.bias', 'linear2.weight', 'linear2.bias'])

In [57]:
model.state_dict().get('multi.weight').shape

torch.Size([200, 320])

In [59]:
print(model)

Multimodal(
  (linear1): Linear(in_features=400, out_features=300, bias=True)
  (multi): Linear(in_features=320, out_features=200, bias=True)
  (linear2): Linear(in_features=200, out_features=10, bias=True)
)
