In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

import torch
from matplotlib import pyplot as plt
%matplotlib qt5
from IPython import display


# sample: y = 0.447 * x + 1.241
def sample_generator(n_sample=200):
    sample_x = (torch.randn(n_sample, 1)) * 3
    sample_y = 0.447 * sample_x + (1 + torch.randn(n_sample, 1)) * 1.241
    return sample_x, sample_y


if __name__ == '__main__':
    n_sample = 100
    x, y  = sample_generator(n_sample)
    w = torch.randn((1, 1), requires_grad=True)
    b = torch.randn((1, 1), requires_grad=True)

    fig = plt.figure("Linear Regression (Dynamic)")
    plt.ion()
    
    n_epoch = 1000
    i_epoch = 0
    lr = 1e-5
    loss_hist = torch.zeros(n_epoch)
    while i_epoch < n_epoch:
        y_pred = x@w + b
        loss = (y - y_pred).pow(2).sum()
        loss_hist = loss_hist.index_fill(0, torch.tensor(i_epoch), loss.data.item())
        loss.backward()

        with torch.no_grad():
            w += - lr * w.grad
            b += - lr * b.grad
            w.grad.zero_()
            b.grad.zero_()

        if i_epoch % 5 == 0 or i_epoch == n_epoch - 1:
            display.clear_output(wait=True)
            plt.clf()

            # visualization axes
            ax0 = fig.add_subplot(211)
            ax0.set_title("epoch [%d/%d], lr = %.5f" % (i_epoch + 1, n_epoch, lr), size=10)
            ax0.set_xlim(-12., 12.)
            ax0.set_ylim(-8., 8.)

            ax0.scatter(x.numpy(), y.numpy(), marker='.', color='b')

            x_ori = torch.linspace(-10, 10, 100)
            y_ori = x_ori * 0.447 + 1.241
            ax0.plot(x_ori, y_ori, color='b')

            x_model = torch.linspace(-10, 10, 100)
            y_model = x_model * w.data.item() + b.data.item()
            ax0.plot(x_model, y_model, color='r')
            
            # loss axes
            ax1 = fig.add_subplot(212)
            ax1.set_title("mse_loss = %.6f" % loss, size=10)
            ax1.set_xlim(0, n_epoch)
            ax1.set_ylim(0, 1000)

            x_loss = torch.arange(0, i_epoch)
            y_loss = loss_hist.numpy()[0:i_epoch]
            ax1.plot(x_loss, y_loss)

            fig.tight_layout(pad=1.0, h_pad=1.0)
            plt.pause(0.02)
        
        i_epoch += 1

    plt.ioff()
    plt.show()
