In [15]:
### import libraries and load data ###

import torch
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim
from pathlib import Path
from datetime import datetime

# data path
df_path = (
    Path.cwd()
    / "data"
    / "raw"
    / "Jun22_2020"
    / "df3_fertsplit.csv"
)
df = pd.read_csv(df_path)[["yield", "NDVI_mean", "SAVI_mean"]]

# load tensor
D = torch.tensor(df.values, dtype=torch.float)

# extract features and transpose
x_dataset = D[:, 1:3].t()

# extract yield as y_dataset and transpose
y_dataset = D[:, 0].t()

# variable to remember input columns
n = 2

In [16]:
### model definition ###
A = torch.randn((1, n), requires_grad=True)
b = torch.randn(1, requires_grad=True)

# define pred model
def model(x_input):
    return A.mm(x_input) + b

# loss function definition
def loss(y_predicted, y_target):
    return ((y_predicted - y_target)**2).sum()

In [18]:
#### train the model ###

# setup the optimizer object, so it optimizes a and b
optimizer = optim.Adam([A, b], lr=0.1)

# main optimzation loop
for t in range(2000):
    optimizer.zero_grad()
    y_predicted = model(x_dataset)
    current_loss = loss(y_predicted, y_dataset)
    current_loss.backward()
    optimizer.step()
    print(f"t= {t}, loss= {current_loss}, A = {A.detach().numpy()}, b = {b.item()}")

t= 0, loss= 41905452.0, A = [[ 1.3849505 -0.8602897]], b = 0.6129608154296875
t= 1, loss= 41882228.0, A = [[ 1.4849498  -0.76029044]], b = 0.7129600644111633
t= 2, loss= 41859008.0, A = [[ 1.5849478  -0.66029245]], b = 0.8129581212997437
t= 3, loss= 41835808.0, A = [[ 1.684944  -0.5602962]], b = 0.9129544496536255
t= 4, loss= 41812600.0, A = [[ 1.7849381  -0.46030214]], b = 1.0129485130310059
t= 5, loss= 41789404.0, A = [[ 1.8849294  -0.36031082]], b = 1.1129399538040161
t= 6, loss= 41766216.0, A = [[ 1.9849175 -0.2603227]], b = 1.2129281759262085
t= 7, loss= 41743036.0, A = [[ 2.084902   -0.16033822]], b = 1.3129127025604248
t= 8, loss= 41719860.0, A = [[ 2.1848824  -0.06035789]], b = 1.4128931760787964
t= 9, loss= 41696700.0, A = [[2.2848582  0.03961784]], b = 1.512869119644165
t= 10, loss= 41673536.0, A = [[2.3848288  0.13958853]], b = 1.6128400564193726
t= 11, loss= 41650388.0, A = [[2.4847941  0.23955376]], b = 1.7128055095672607
t= 12, loss= 41627244.0, A = [[2.5847535 0.3395131]