In [22]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

#Batching is done for large dataset to divide into smaller batches
#Using batches is called Stochastic Gradient Descent
#Here input and target have couple of additional copies to simulate a large dataset

# Input (temp, rainfall, humidity)
inputs = np.array([[73, 67, 43], [91, 88, 64], [87, 134, 58], [102, 43, 37], [69, 96, 70], [73, 67, 43], [91, 88, 64], [87, 134, 58], [102, 43, 37], [69, 96, 70], [73, 67, 43], [91, 88, 64], [87, 134, 58], [102, 43, 37], [69, 96, 70]], dtype='float32')
# Targets (apples, oranges)
targets = np.array([[56, 70], [81, 101], [119, 133], [22, 37], [103, 119], 
                    [56, 70], [81, 101], [119, 133], [22, 37], [103, 119], 
                    [56, 70], [81, 101], [119, 133], [22, 37], [103, 119]], dtype='float32')

inputs=torch.from_numpy(inputs)
targets=torch.from_numpy(targets)

#Define dataset for batching
train_ds =TensorDataset(inputs,targets)

#Define data loader
batch_size = 5
train_dl = DataLoader(train_ds,batch_size,shuffle=True)
#next(iter(train_dl))

#Define model
model = nn.Linear(3,2)
#print(model.weight)
#print(model.bias)


#Define optimizer
opt =torch.optim.SGD(model.parameters(),lr=1e-5)

#Loss function
loss_fn = F.mse_loss
loss = loss_fn(model(inputs),targets)
print("Initially loss is:", loss)

#Create a utility function to train the model
def fit(num_epochs, model, loss_fn, opt):
    for epoch in range(num_epochs):
        for xb,yb in train_dl:
            #Generate predictions
            pred = model(xb)
            loss = loss_fn(pred,yb)
            #Perform GD
            loss.backward()
            opt.step()
            opt.zero_grad()
    print('Training loss: ',loss_fn(model(inputs),targets))
    
fit(100,model,loss_fn,opt)
preds = model(inputs)
print(preds)
print(targets)

Initially loss is: tensor(12698.3594, grad_fn=<MseLossBackward>)
Training loss:  tensor(40.0097, grad_fn=<MseLossBackward>)
tensor([[ 58.9914,  72.0505],
        [ 84.1545,  97.9393],
        [111.9324, 136.7225],
        [ 31.0113,  46.9652],
        [ 99.5264, 108.4528],
        [ 58.9914,  72.0505],
        [ 84.1545,  97.9393],
        [111.9324, 136.7225],
        [ 31.0113,  46.9652],
        [ 99.5264, 108.4528],
        [ 58.9914,  72.0505],
        [ 84.1545,  97.9393],
        [111.9324, 136.7225],
        [ 31.0113,  46.9652],
        [ 99.5264, 108.4528]], grad_fn=<AddmmBackward>)
tensor([[ 56.,  70.],
        [ 81., 101.],
        [119., 133.],
        [ 22.,  37.],
        [103., 119.],
        [ 56.,  70.],
        [ 81., 101.],
        [119., 133.],
        [ 22.,  37.],
        [103., 119.],
        [ 56.,  70.],
        [ 81., 101.],
        [119., 133.],
        [ 22.,  37.],
        [103., 119.]])


In [23]:
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(3,3)
        self.act1 = nn.ReLU()
        self.linear2 = nn.Linear(3,2)
    def forward(self,x):
        x = self.linear1(x)
        x = self.act1(x)
        x = self.linear2(x)
        return x
    
model = SimpleNN()
opt =torch.optim.SGD(model.parameters(),lr=1e-5)
loss_fn = F.mse_loss
fit(100,model,loss_fn,opt)
preds = model(inputs)
print(preds)
print(targets)

Training loss:  tensor(34.1900, grad_fn=<MseLossBackward>)
tensor([[ 59.0778,  69.5779],
        [ 80.9869,  95.3601],
        [120.0251, 141.2996],
        [ 30.8310,  36.3375],
        [ 93.9535, 110.6189],
        [ 59.0778,  69.5779],
        [ 80.9869,  95.3601],
        [120.0251, 141.2996],
        [ 30.8310,  36.3375],
        [ 93.9535, 110.6189],
        [ 59.0778,  69.5779],
        [ 80.9869,  95.3601],
        [120.0251, 141.2996],
        [ 30.8310,  36.3375],
        [ 93.9535, 110.6189]], grad_fn=<AddmmBackward>)
tensor([[ 56.,  70.],
        [ 81., 101.],
        [119., 133.],
        [ 22.,  37.],
        [103., 119.],
        [ 56.,  70.],
        [ 81., 101.],
        [119., 133.],
        [ 22.,  37.],
        [103., 119.],
        [ 56.,  70.],
        [ 81., 101.],
        [119., 133.],
        [ 22.,  37.],
        [103., 119.]])
