# Learning PyTorch for Multimodality

In [50]:
import torch

In [51]:
device = torch.device('cpu')

## Basic Model of Sequential Modules

Example from [here](https://github.com/jcjohnson/pytorch-examples#pytorch-optim)

In [27]:
N, D_in, Z_in, H, H2, D_out = 64, 400, 20, 300, 200, 10

x = torch.randn(N, D_in, device=device)
z = torch.randn(N, Z_in, device=device)
y = torch.randn(N, D_out, device=device)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
).to(device)

loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(50):
    y_pred = model(x)
    
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    optimizer.zero_grad()
    
    loss.backward()
    
    optimizer.step()

0 640.2095947265625
1 627.4595336914062
2 614.9490966796875
3 602.6520385742188
4 590.566650390625
5 578.705078125
6 567.0449829101562
7 555.6082153320312
8 544.3845825195312
9 533.3677368164062
10 522.5365600585938
11 511.8955078125
12 501.4302062988281
13 491.1579284667969
14 481.06689453125
15 471.17181396484375
16 461.476806640625
17 451.9712219238281
18 442.61285400390625
19 433.43414306640625
20 424.41656494140625
21 415.5590515136719
22 406.8478088378906
23 398.2962341308594
24 389.88116455078125
25 381.60052490234375
26 373.45965576171875
27 365.4594421386719
28 357.6005859375
29 349.86651611328125
30 342.2503356933594
31 334.7592468261719
32 327.3841857910156
33 320.13360595703125
34 312.9990234375
35 305.9814147949219
36 299.0978698730469
37 292.3316345214844
38 285.6947021484375
39 279.1699523925781
40 272.75714111328125
41 266.4468688964844
42 260.23211669921875
43 254.1072540283203
44 248.07635498046875
45 242.1517791748047
46 236.3173828125
47 230.57749938964844
48 224.92

In [28]:
model.state_dict().keys()

odict_keys(['0.weight', '0.bias', '2.weight', '2.bias'])

In [34]:
model.state_dict().get('0.bias').shape

torch.Size([300])

## Expanding to "Multi-Modal" with Custom Class

So rather than stringing together modules like above, we have to create a custom class for our model. Mainly, this is because we are concatenating multiple inputs.

This still hasn't been tested on real data, and I don't know all the "gotchas" of these modules yet, but this works as expected thus far, and creates weight matrices as expected.

In [5]:
class Multimodal(torch.nn.Module):
    def __init__(self, D_in, H, Z_in, H2, D_out):
        super(Multimodal, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.multi = torch.nn.Linear(H + Z_in, H2)
        self.linear2 = torch.nn.Linear(H2, D_out)
        
    def forward(self, x, z):
        h1 = self.linear1(x)
        h1_z = torch.cat([h1, z], dim=1)
        h2 = self.multi(h1_z)
        out = self.linear2(h2)
        return out

In [7]:
N, D_in, Z_in, H, H2, D_out = 64, 400, 20, 300, 200, 10

x = torch.randn(N, D_in, device=device)
z = torch.randn(N, Z_in, device=device)
y = torch.randn(N, D_out, device=device)

model = Multimodal(D_in, H, Z_in, H2, D_out)

loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for t in range(50):
    y_pred = model(x, z)
    
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    optimizer.zero_grad()
    
    loss.backward()
    
    optimizer.step()

0 656.3359985351562
1 640.2750244140625
2 624.4983520507812
3 609.0003662109375
4 593.7750244140625
5 578.8157958984375
6 564.115478515625
7 549.6661987304688
8 535.4600830078125
9 521.4890747070312
10 507.74517822265625
11 494.2205810546875
12 480.9076843261719
13 467.7993469238281
14 454.88873291015625
15 442.1695251464844
16 429.6359558105469
17 417.2827453613281
18 405.105224609375
19 393.09930419921875
20 381.2615661621094
21 369.58917236328125
22 358.0799560546875
23 346.7322998046875
24 335.5452575683594
25 324.5184326171875
26 313.6520080566406
27 302.9466552734375
28 292.403564453125
29 282.0243225097656
30 271.8110046386719
31 261.7658996582031
32 251.8917236328125
33 242.19137573242188
34 232.66799926757812
35 223.32484436035156
36 214.16529846191406
37 205.19277954101562
38 196.4107208251953
39 187.8224639892578
40 179.43133544921875
41 171.240478515625
42 163.2528839111328
43 155.47132873535156
44 147.8983917236328
45 140.536376953125
46 133.3872833251953
47 126.4528579711

For example, the "multi" weights have input weights of length 320 because we take 300 from the output of `linear1` + 20 from the auxillary input (e.g. z) and then we map this to the output of 200

In [8]:
model.state_dict().keys()

odict_keys(['linear1.weight', 'linear1.bias', 'multi.weight', 'multi.bias', 'linear2.weight', 'linear2.bias'])

In [9]:
model.state_dict().get('multi.weight').shape

torch.Size([200, 320])

In [10]:
print(model)

Multimodal(
  (linear1): Linear(in_features=400, out_features=300, bias=True)
  (multi): Linear(in_features=320, out_features=200, bias=True)
  (linear2): Linear(in_features=200, out_features=10, bias=True)
)


## Add Multimodality to FastAI

In [52]:
from fastai import *
from fastai.text import *

Bring in some sample data using fastai imdb

In [53]:
path = untar_data(URLs.IMDB_SAMPLE)
data_lm = TextLMDataBunch.from_csv(path)

for x, y in list(data_lm.train_dl): # just testing with one batch
    x, y
    
z = torch.randn(x.shape[0], x.shape[1], 10,
                device=device, requires_grad=True) # making up 10 "audio features"
z.shape

torch.Size([69, 64, 10])

In [54]:
x.shape

torch.Size([69, 64])

First, let's make sure we can run the same model as fastai manually...

In [55]:
vocab_sz = 20000
emb_sz = 400
n_hid = 1150
n_layers = 3
pad_token= 1
qrnn = False
bidir = False

dps = np.array([0.25, 0.1, 0.2, 0.02, 0.15])

hidden_p = dps[4]
input_p = dps[0]
embed_p = dps[3]
weight_p = dps[2]

tie_weights = True
output_p = dps[1]
bias = True

audio_sz = 10

# Create a full AWD-LSTM.
rnn_enc = RNNCore(vocab_sz=vocab_sz,
                  emb_sz=emb_sz,
                  n_hid=n_hid,
                  n_layers=n_layers,
                  pad_token=pad_token,
                  qrnn=qrnn,
                  bidir=bidir,
                  hidden_p=hidden_p,
                  input_p=input_p,
                  embed_p=embed_p,
                  weight_p=weight_p)

enc = rnn_enc.encoder if tie_weights else None
model = SequentialRNN(rnn_enc, LinearDecoder(vocab_sz, emb_sz, output_p, tie_encoder=enc, bias=bias))

In [56]:
out = model(x)

In [18]:
model.train()

loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

learning_rate = 1e-4

# opt_params = multimodal_rnn.parameters()

# this is a hack for now... not sure if this messes up the graph somewhere by doing this
opt_params = [par for par in model.parameters() if par.is_leaf]
optimizer = torch.optim.Adam(opt_params, lr=learning_rate)

y_pred = model(x)

for t in range(5):
    y_pred = model(x)[0]
    
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    optimizer.zero_grad()
    
    loss.backward()
    
    optimizer.step()

0 9.899279594421387
1 9.893121719360352
2 9.886876106262207
3 9.880671501159668
4 9.872952461242676


Next, try porting this into the fastai learner object

In [61]:
model.reset()
learn = RNNLearner(data_lm, model)
learn.fit(1)

epoch,train_loss,valid_loss,accuracy
,,,


KeyboardInterrupt: 

In [19]:
# class MultiModalRNN(RNNCore):
#     def __init__(self, audio_sz, **kwargs):
#         super(MultiModalRNN, self).__init__(**kwargs)
#         self.rnns = None
#         self.audio_sz = audio_sz
#         self.multimode = [nn.LSTM(emb_sz + audio_sz if l == 0 else n_hid,
#                                   (n_hid if l != n_layers - 1 else emb_sz + audio_sz)//self.ndir,
#                                   1, bidirectional=bidir) for l in range(n_layers)]
#         self.multimode = [WeightDropout(rnn, weight_p) for rnn in self.multimode]
#         self.multimode = torch.nn.ModuleList(self.multimode)
        
#     def forward(self, input:LongTensor, input_audio:Tensor)->Tuple[Tensor,Tensor]:
#         sl,bs = input.size()
#         if bs!=self.bs:
#             self.bs=bs
#             self.reset()
#         raw_output = self.input_dp(self.encoder_dp(input))
#         raw_output = torch.cat([raw_output, input_audio], dim=2)
#         new_hidden,raw_outputs,outputs = [],[],[]
#         for l, (rnn,hid_dp) in enumerate(zip(self.multimode, self.hidden_dps)):
#             raw_output, new_h = rnn(raw_output, self.hidden[l])
#             new_hidden.append(new_h)
#             raw_outputs.append(raw_output)
#             if l != self.n_layers - 1: raw_output = hid_dp(raw_output)
#             outputs.append(raw_output)
#         self.hidden = to_detach(new_hidden)
#         return raw_outputs, outputs
    
#     def _one_hidden(self, l:int)->Tensor:
#         "Return one hidden state."
#         nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz + self.audio_sz)//self.ndir
#         return self.weights.new(self.ndir, self.bs, nh).zero_()

#     def reset(self):
#         "Reset the hidden states."
#         [r.reset() for r in self.multimode if hasattr(r, 'reset')]
#         self.weights = next(self.parameters()).data
#         if self.qrnn: self.hidden = [self._one_hidden(l) for l in range(self.n_layers)]
#         else: self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)]
    
# multimodal_rnn = MultiModalRNN(audio_sz=audio_sz,
#                          vocab_sz=vocab_sz,
#                          emb_sz=emb_sz,
#                          n_hid=n_hid,
#                          n_layers=n_layers,
#                          pad_token=pad_token,
#                          qrnn=qrnn,
#                          bidir=bidir,
#                          hidden_p=hidden_p,
#                          input_p=input_p,
#                          embed_p=embed_p,
#                          weight_p=weight_p)

# enc = multimodal_rnn.encoder if tie_weights else None
# model = SequentialRNN(multimodal_rnn,
#                       LinearDecoder(vocab_sz,
#                                     emb_sz + audio_sz,
#                                     output_p,
#                                     tie_encoder=enc,
#                                     bias=bias)).to(device)
# model

In [20]:
# out = multimodal_rnn(x, z)

In [21]:
# out[0][2].shape

Thus far, I'm able to get multimodal_rnn to work, but it doesnt work when used with SequentialRNN. Pretty sure this is because `forward` is not registered properly with SequentialRNN.

Below, I attempt to include the decoder directly into the custom module

In [49]:
class MultiModalRNN(RNNCore):
    def __init__(self, audio_sz, output_p, bias, **kwargs):
        super(MultiModalRNN, self).__init__(**kwargs)
        self.rnns = None
        self.audio_sz = audio_sz
        self.multimode = [nn.LSTM(emb_sz + audio_sz if l == 0 else n_hid,
                                  (n_hid if l != n_layers - 1 else emb_sz + audio_sz)//self.ndir,
                                  1, bidirectional=bidir) for l in range(n_layers)]
        self.multimode = [WeightDropout(rnn, weight_p) for rnn in self.multimode]
        self.multimode = torch.nn.ModuleList(self.multimode)
        
        self.multidecoder = LinearDecoder(vocab_sz,
                                          emb_sz + audio_sz,
                                          output_p,
                                          tie_encoder=None,
                                          bias=bias)
        
    def forward(self, input:LongTensor, input_audio:Tensor)->Tuple[Tensor,Tensor,Tensor]:
        sl,bs = input.size()
        if bs!=self.bs:
            self.bs=bs
            self.reset()
        raw_output = self.input_dp(self.encoder_dp(input))
        raw_output = torch.cat([raw_output, input_audio], dim=2)
        new_hidden,raw_outputs,outputs = [],[],[]
        for l, (rnn,hid_dp) in enumerate(zip(self.multimode, self.hidden_dps)):
            raw_output, new_h = rnn(raw_output, self.hidden[l])
            new_hidden.append(new_h)
            raw_outputs.append(raw_output)
            if l != self.n_layers - 1: raw_output = hid_dp(raw_output)
            outputs.append(raw_output)
        self.hidden = to_detach(new_hidden)
        
        output = self.multidecoder.output_dp(outputs[-1])
        decoded = self.multidecoder.decoder(output.view(output.size(0)*output.size(1),
                                                        output.size(2)))
        
        return decoded, raw_outputs, outputs
    
    def _one_hidden(self, l:int)->Tensor:
        "Return one hidden state."
        nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz + self.audio_sz)//self.ndir
        return self.weights.new(self.ndir, self.bs, nh).zero_()

    def reset(self):
        "Reset the hidden states."
        [r.reset() for r in self.multimode if hasattr(r, 'reset')]
        self.weights = next(self.parameters()).data
        if self.qrnn: self.hidden = [self._one_hidden(l) for l in range(self.n_layers)]
        else: self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)]
    
multimodal_rnn = MultiModalRNN(audio_sz=audio_sz,
                              vocab_sz=vocab_sz,
                              emb_sz=emb_sz,
                              n_hid=n_hid,
                              n_layers=n_layers,
                              pad_token=pad_token,
                              qrnn=qrnn,
                              bidir=bidir,
                              hidden_p=hidden_p,
                              input_p=input_p,
                              embed_p=embed_p,
                              weight_p=weight_p,
                              output_p=output_p,
                              bias=bias).to(device)

multimodal_rnn

MultiModalRNN(
  (encoder): Embedding(20000, 400, padding_idx=1)
  (encoder_dp): EmbeddingDropout(
    (emb): Embedding(20000, 400, padding_idx=1)
  )
  (rnns): None
  (input_dp): RNNDropout()
  (hidden_dps): ModuleList(
    (0): RNNDropout()
    (1): RNNDropout()
    (2): RNNDropout()
  )
  (multimode): ModuleList(
    (0): WeightDropout(
      (module): LSTM(410, 1150)
    )
    (1): WeightDropout(
      (module): LSTM(1150, 1150)
    )
    (2): WeightDropout(
      (module): LSTM(1150, 410)
    )
  )
  (multidecoder): LinearDecoder(
    (decoder): Linear(in_features=410, out_features=20000, bias=True)
    (output_dp): RNNDropout()
  )
)

In [44]:
out = multimodal_rnn(x, z)

In [45]:
out[0].shape

torch.Size([4480, 20000])

Now, we should be able to train the data here

In [48]:
multimodal_rnn.train()
multimode

loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

learning_rate = 1e-4

# opt_params = multimodal_rnn.parameters()

# this is a hack for now... not sure if this messes up the graph somewhere by doing this
opt_params = [par for par in multimodal_rnn.parameters() if par.is_leaf]
optimizer = torch.optim.Adam(opt_params, lr=learning_rate)

y_pred = multimodal_rnn(x, z)

for t in range(5):
    y_pred = multimodal_rnn(x, z)[0]
    
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    optimizer.zero_grad()
    
    loss.backward()
    
    optimizer.step()

0 9.90432357788086
1 9.89766788482666
2 9.891155242919922
3 9.88436222076416
4 9.87718677520752


So we get an error for those that are non-leaf nodes (tensors). Doing some investigating shows those that have a `grad_fn` are the culprits. These are most likely the tensors that are dropouts. Need to look at fastai and see how they handle these tensors...

For now I used the code above to subset to `is_leaf == False`. I'm not sure if this "is leaf" is a symptom of something that's not working right or if this is, in fact, what we should do.

In the original `RNNCore` class I'm a bit confused by what is happening for `self.rnns`

In [282]:
xx = list(multimodal_rnn.parameters())[0]
xx

Parameter containing:
tensor([[ 0.0392, -0.0167,  0.0536,  ..., -0.0786,  0.0768, -0.0682],
        [ 0.0976,  0.0382,  0.0584,  ...,  0.0168, -0.0589,  0.0689],
        [-0.0717, -0.0809, -0.0149,  ...,  0.0354, -0.0986, -0.0407],
        ...,
        [-0.0886,  0.0420, -0.0151,  ..., -0.0790,  0.0397,  0.0285],
        [ 0.0292,  0.0304,  0.0135,  ...,  0.0153, -0.0956, -0.0052],
        [-0.0429, -0.0725, -0.0067,  ...,  0.0879,  0.0801,  0.0359]],
       requires_grad=True)

In [281]:
for idx, thing in enumerate(multimodal_rnn.parameters()):
    print(idx, thing.shape, thing.is_leaf, thing.requires_grad)

0 torch.Size([20000, 400]) True True
1 torch.Size([4600, 1150]) True True
2 torch.Size([4600, 410]) True True
3 torch.Size([4600, 1150]) False True
4 torch.Size([4600]) True True
5 torch.Size([4600]) True True
6 torch.Size([4600, 1150]) True True
7 torch.Size([4600, 1150]) True True
8 torch.Size([4600, 1150]) False True
9 torch.Size([4600]) True True
10 torch.Size([4600]) True True
11 torch.Size([1640, 410]) True True
12 torch.Size([1640, 1150]) True True
13 torch.Size([1640, 410]) False True
14 torch.Size([1640]) True True
15 torch.Size([1640]) True True
16 torch.Size([20000, 410]) True True
17 torch.Size([20000]) True True


In [304]:
xx = list(multimodal_rnn.parameters())[13]
xx

tensor([[ 0.0504, -0.0437, -0.0204,  ...,  0.0562,  0.0344, -0.0210],
        [ 0.0370,  0.0000, -0.0000,  ..., -0.0072,  0.0238,  0.0026],
        [-0.0064,  0.0480, -0.0527,  ..., -0.0084,  0.0398, -0.0314],
        ...,
        [ 0.0469, -0.0468, -0.0339,  ...,  0.0566, -0.0047,  0.0132],
        [ 0.0168,  0.0000, -0.0549,  ...,  0.0170,  0.0000,  0.0153],
        [-0.0044,  0.0383, -0.0571,  ...,  0.0059, -0.0031,  0.0284]],
       grad_fn=<MulBackward0>)

In [300]:
multimodal_rnn.multimode.state_dict()['2.module.weight_hh_l0'].shape

torch.Size([1640, 410])

In [294]:
multimodal_rnn.state_dict()
{k:v.shape for (k,v) in multimodal_rnn.state_dict().items()}

{'encoder.weight': torch.Size([20000, 400]),
 'encoder_dp.emb.weight': torch.Size([20000, 400]),
 'multidecoder.decoder.bias': torch.Size([20000]),
 'multidecoder.decoder.weight': torch.Size([20000, 410]),
 'multimode.0.module.bias_hh_l0': torch.Size([4600]),
 'multimode.0.module.bias_ih_l0': torch.Size([4600]),
 'multimode.0.module.weight_hh_l0': torch.Size([4600, 1150]),
 'multimode.0.module.weight_ih_l0': torch.Size([4600, 410]),
 'multimode.0.weight_hh_l0_raw': torch.Size([4600, 1150]),
 'multimode.1.module.bias_hh_l0': torch.Size([4600]),
 'multimode.1.module.bias_ih_l0': torch.Size([4600]),
 'multimode.1.module.weight_hh_l0': torch.Size([4600, 1150]),
 'multimode.1.module.weight_ih_l0': torch.Size([4600, 1150]),
 'multimode.1.weight_hh_l0_raw': torch.Size([4600, 1150]),
 'multimode.2.module.bias_hh_l0': torch.Size([1640]),
 'multimode.2.module.bias_ih_l0': torch.Size([1640]),
 'multimode.2.module.weight_hh_l0': torch.Size([1640, 410]),
 'multimode.2.module.weight_ih_l0': torch.Si

In [306]:
multimodal_rnn.state_dict()['multimode.2.module.weight_hh_l0']

tensor([[ 0.0504, -0.0437, -0.0204,  ...,  0.0562,  0.0344, -0.0210],
        [ 0.0370,  0.0000, -0.0000,  ..., -0.0072,  0.0238,  0.0026],
        [-0.0064,  0.0480, -0.0527,  ..., -0.0084,  0.0398, -0.0314],
        ...,
        [ 0.0469, -0.0468, -0.0339,  ...,  0.0566, -0.0047,  0.0132],
        [ 0.0168,  0.0000, -0.0549,  ...,  0.0170,  0.0000,  0.0153],
        [-0.0044,  0.0383, -0.0571,  ...,  0.0059, -0.0031,  0.0284]])

In [None]:
rnn()