In [1]:
%matplotlib tk
# %matplotlib widget
import matplotlib as mpl
import matplotlib.pyplot as plt

import numpy as np

from itertools import cycle

import ipywidgets as widgets
from IPython.display import display

plt.style.use(['dark_background'])

In [46]:
class Line:
    def __init__(self, w1, w2, b):
        self.w1 = w1
        self.w2 = w2
        self.b = b
        
    def __str__(self):
        return f'{self.w1}*x {"+ " + str(self.w2) if self.w2 >= 0 else str(self.w2).replace("-", "- ")}*y {"+ " + str(self.b) if self.b >= 0 else str(self.b).replace("-", "- ")}'
    
    # calculate the linear combination of a point
    def linear_comb(self, x, y):
        return self.w1 * x + self.w2 * y + self.b
    
    # get the value of 'y' if we only know 'x'
    def calculate(self, x):
        return -((self.w1 * x + self.b) / self.w2) if self.w2 != 0 else self.w1 * x + self.b

In [47]:
# The input params of the line that will be modified so that it correctly classifies the bad point
LINE_W1, LINE_W2, LINE_B = 3, 4, -10

# The point that shouldn't be here
BAD_POINT = (-3, -3)
POSITIVE_COLOR, NEGATIVE_COLOR = 'b', 'r' # blue for positive, red for negative
MARKER = 'X'

LEARNING_RATE = 0.1

PLOT_PAUSE = 0.1

In [48]:
# Create figure
fig, axs = plt.subplots()
axs.set_xlabel('X')
axs.set_ylabel('y')
axs.set_xlim(-50, 50)
axs.set_ylim(-50, 50)
    
# Create the line equation
line = Line(LINE_W1, LINE_W2, LINE_B)
print('Input linear equation: ' + str(line))

# Plot the bad point and set its marker
bp_colors = cycle((POSITIVE_COLOR, NEGATIVE_COLOR))
bp_marker = next(bp_colors)
bp_marker = bp_marker if line.linear_comb(BAD_POINT[0], BAD_POINT[1]) < 0 else next(bp_colors)
bad_point_plt = axs.plot(BAD_POINT[0], BAD_POINT[1], bp_marker + MARKER)

# Plot the initial line
line_coords = np.array([-50, 50])
axs.plot(line_coords, line.calculate(line_coords), color = 'cyan') # original line

# Change the properties of the line so that the bad point is under it
bad_point_lr = [BAD_POINT[0] * LEARNING_RATE, BAD_POINT[1] * LEARNING_RATE] # the coords of the bad point adjusted by learning rate
koeff = (-1 if bp_marker == NEGATIVE_COLOR else 1) 
while(True):
    # Update the line
    line.w1 += (bad_point_lr[0] * koeff)
    line.w2 += (bad_point_lr[1] * koeff)
    line.b += (LEARNING_RATE * koeff)
    # move the correction line
    corrected_line = axs.plot(line_coords, line.calculate(line_coords), color = 'green', alpha=0.4)[0]
    plt.pause(PLOT_PAUSE)
    print(str(koeff * line.linear_comb(BAD_POINT[0], BAD_POINT[1])) + ' >= 0? ', line)
    if(koeff * line.linear_comb(BAD_POINT[0], BAD_POINT[1]) >= 0): # Found the right equation, break
        corrected_line.set_alpha(1)
        break
    corrected_line.remove()

print('New linear equation: ' + str(line))

Input linear equation: 3*x + 4*y - 10
-29.1 >= 0?  2.7*x + 3.7*y - 9.9
-27.200000000000003 >= 0?  2.4000000000000004*x + 3.4000000000000004*y - 9.8
-25.300000000000004 >= 0?  2.1000000000000005*x + 3.1000000000000005*y - 9.700000000000001
-23.400000000000006 >= 0?  1.8000000000000005*x + 2.8000000000000007*y - 9.600000000000001
-21.500000000000007 >= 0?  1.5000000000000004*x + 2.500000000000001*y - 9.500000000000002
-19.60000000000001 >= 0?  1.2000000000000004*x + 2.200000000000001*y - 9.400000000000002
-17.700000000000006 >= 0?  0.9000000000000004*x + 1.900000000000001*y - 9.300000000000002
-15.800000000000006 >= 0?  0.6000000000000003*x + 1.600000000000001*y - 9.200000000000003
-13.900000000000007 >= 0?  0.30000000000000027*x + 1.300000000000001*y - 9.100000000000003
-12.000000000000007 >= 0?  2.220446049250313e-16*x + 1.0000000000000009*y - 9.000000000000004
-10.100000000000007 >= 0?  -0.2999999999999998*x + 0.7000000000000008*y - 8.900000000000004
-8.200000000000006 >= 0?  -0.59999