# Using Networks with Multiple Inputs and Outputs

## Links Used
 - [How to construct a network with two inputs in PyTorch](https://stackoverflow.com/questions/51700729/how-to-construct-a-network-with-two-inputs-in-pytorch)

## Importing the Libraries

In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

## Creating a Pseudo Model

In [30]:
class TwoNet(nn.Module):
    
    def __init__(self):
        super(TwoNet, self).__init__()
        self.f1 = nn.Linear(3, 2)
        self.f2 = nn.Linear(3, 2)
        
    def forward(self, input1, input2):
        out_1 = self.f1(input_1)
        out_2 = self.f2(input_2)
        out = out_1 + out_2
        return out

In [31]:
model = TwoNet()
model.eval()

TwoNet(
  (f1): Linear(in_features=3, out_features=2, bias=True)
  (f2): Linear(in_features=3, out_features=2, bias=True)
)

## Checking If the Model Works

In [32]:
for i in model.parameters():
    print(i)

Parameter containing:
tensor([[-0.1770,  0.4861, -0.1690],
        [ 0.4799, -0.5607, -0.1139]], requires_grad=True)
Parameter containing:
tensor([-0.3397, -0.3146], requires_grad=True)
Parameter containing:
tensor([[-0.1165,  0.3833, -0.0081],
        [-0.0513, -0.4427,  0.0206]], requires_grad=True)
Parameter containing:
tensor([-0.0471, -0.4016], requires_grad=True)


In [33]:
input_1 = torch.randn(2, 3)
input_2 = torch.randn(2, 3)

In [34]:
out = model(input_1, input_2)
print(out.size())
print(out)

torch.Size([2, 2])
tensor([[-0.5524, -1.3399],
        [ 1.2091, -2.7553]], grad_fn=<AddBackward0>)


## Checking if the Gradients Backpropogate

In [35]:
actual_output = torch.randn((2, 2))

In [36]:
loss = torch.sum((actual_output-out)*(actual_output-out))
print(loss)

tensor(13.1771, grad_fn=<SumBackward0>)


In [37]:
loss.backward()

In [43]:
print('F1 Weight')
print('F1 Grad: {}'.format(model.f1.weight.grad))
print('F2 Grad: {}'.format(model.f2.weight.grad))

F1 Grad: tensor([[  1.1681,  -2.3881,  -4.3994],
        [  4.5854, -14.3296, -11.7378]])
F2 Grad: tensor([[ 0.3781, -0.3503,  5.1552],
        [ 0.3790, -3.3202, 11.1682]])
