In [1]:
# Imports
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons, CheckButtons
from mpl_toolkits.mplot3d import Axes3D
import math
import os
import time
import itertools

import warnings
warnings.filterwarnings('ignore')

from matplotlib.patches import Circle, Wedge
from matplotlib.collections import PatchCollection

from scipy.spatial import ConvexHull

plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'monospace'
plt.rcParams['text.latex.preamble'] = r"\usepackage{amsmath}"

%matplotlib notebook

plt.rcParams['figure.figsize'] = [7, 7]

import torch
import torch.nn
import torch.utils.data
import torchvision

In [2]:
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = torch.nn.Linear(2, 8)
        self.fc2 = torch.nn.Linear(8, 8)
        self.fc3 = torch.nn.Linear(8, 8)
        self.fc4 = torch.nn.Linear(8, 2)
        
    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = self.fc4(x)
        
        return x

In [3]:
shifts = np.array(list(itertools.product([-1, 1], repeat=2)))

In [4]:
def get_dataloader(train=True, r1=0.5, batch_size=100, shuffle=False):
    
    np.random.seed(0)

    r1sqr = r1**2
    r2sqr = 2*r1sqr
    r3sqr = 3*r1sqr

    train_x = []
    train_y = []

    test_x = []
    test_y = []

    while len(train_x) < 100000:
        sample = 2*np.random.uniform(size=(1,2)) - 1
        if np.sum(np.square(sample)) < r1sqr:
            train_x.append(sample)
            train_y.append(np.array([[0, 1]]))
        elif r2sqr < np.sum(np.square(sample)) < r3sqr:
            train_x.append(sample)
            train_y.append(np.array([[1, 0]]))
      
    shifts = np.array(list(itertools.product([-1, 1], repeat=2)))
    which = np.random.randint(0, 4, (len(train_x), 1))
    for i in range(len(train_x)):
        train_x[i] += shifts[which[i], :]
        if which[i] in [1, 2]:
            train_y[i] = np.flip(train_y[i])

    test_data = []

    for i in range(201):
        for j in range(201):
            test_data.append([2*i/100 - 1, 2*j/100 - 1])

    test_data = np.array(test_data)

    for elem in test_data:
        if np.sum(np.square(elem)) < r1sqr:
            test_x.append(elem)
            test_y.append(np.array([[0, 1]]))
        elif r2sqr < np.sum(np.square(elem)) < r3sqr:
            test_x.append(elem)
            test_y.append(np.array([[1, 0]]))
            
    shifts = np.array(list(itertools.product([-1, 1], repeat=2)))
    which = np.random.randint(0, 4, (len(test_x)))
    for i in range(len(test_x)):
        test_x[i] += shifts[which[i], :]
        if which[i] in [1, 2]:
            test_y[i] = np.flip(test_y[i])
    
    if train:
        tensor_x = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in train_x])
        tensor_y = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in train_y])
    else:
        tensor_x = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in test_x])
        tensor_y = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in test_y])
        
    my_dataset = torch.utils.data.TensorDataset(tensor_x, tensor_y)
    return torch.utils.data.DataLoader(my_dataset, batch_size=batch_size, shuffle=shuffle)

In [5]:
train_loader = get_dataloader(train=True, batch_size=100, shuffle=True)
test_loader =  get_dataloader(train=False, batch_size=100, shuffle=True)

In [6]:
timeSeen = 0
accuracies = []

In [7]:
def train(network, optimizer, lossFunc, train_loader, epoch):
    global timeSeen
    global accuracies
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = lossFunc(output, target)
        loss.backward()
        optimizer.step()
        
        accuracies.append(test(network, test_loader))
        
        print(epoch, batch_idx, accuracies[-1])
    
        torch.save(network, 'trainedNets/network{:09d}.pt'.format(timeSeen))
        timeSeen += 1
        
def test(network, test_loader):
    network.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            pred = output.data.max(1, keepdim=True)[1]
            tar = target.data.max(2, keepdim=True)[1]
            correct += pred.eq(tar.data.view_as(pred)).sum()
    return correct.item()/len(test_loader.dataset)

In [8]:
acc = 0
while max(accuracies + [0]) < 0.98:
    criterion = torch.nn.MSELoss()
    network = SimpleNet()
    optimizer = torch.optim.SGD(network.parameters(), 0.1, momentum=0.9, weight_decay=1e-4)
    timeSeen = 0
    accuracies = []

    torch.save(network, 'trainedNets/network{:09d}.pt'.format(timeSeen))
    timeSeen += 1
    for i in range(4):
        train(network, optimizer, criterion, train_loader, i)
        #acc = test(network, test_loader)
        print(i, accuracies[-1])

0 0 0.5065067619290635
0 1 0.5065067619290635
0 2 0.4830313855575402
0 3 0.5154376116356213
0 4 0.4934932380709365
0 5 0.4934932380709365
0 6 0.4934932380709365
0 7 0.4934932380709365
0 8 0.4562388364378668
0 9 0.5325338096453177
0 10 0.5065067619290635
0 11 0.5065067619290635
0 12 0.5065067619290635
0 13 0.5065067619290635
0 14 0.5065067619290635
0 15 0.5065067619290635
0 16 0.5065067619290635
0 17 0.5266649655524368
0 18 0.4934932380709365
0 19 0.4934932380709365
0 20 0.4934932380709365
0 21 0.4934932380709365
0 22 0.4934932380709365
0 23 0.4934932380709365
0 24 0.4934932380709365
0 25 0.4934932380709365
0 26 0.4934932380709365
0 27 0.4934932380709365
0 28 0.5065067619290635
0 29 0.5065067619290635
0 30 0.5065067619290635
0 31 0.5065067619290635
0 32 0.5065067619290635
0 33 0.5065067619290635
0 34 0.5065067619290635
0 35 0.5065067619290635
0 36 0.5065067619290635
0 37 0.5154376116356213
0 38 0.4934932380709365
0 39 0.4934932380709365
0 40 0.4934932380709365
0 41 0.4934932380709365
0 

0 335 0.7777494258739475
0 336 0.7930594539423322
0 337 0.817810665986221
0 338 0.8193416687930595
0 339 0.8216381730033172
0 340 0.7994386323041592
0 341 0.7973972952283745
0 342 0.7943352896146977
0 343 0.8065833120694055
0 344 0.8165348303138555
0 345 0.7968869609594285
0 346 0.7833631028323552
0 347 0.7933146210768053
0 348 0.7971421280939015
0 349 0.8022454707833631
0 350 0.8063281449349324
0 351 0.8060729778004593
0 352 0.7887216126562898
0 353 0.763460066343455
0 354 0.79867313090074
0 355 0.8195968359275325
0 356 0.8099004848175555
0 357 0.8070936463383516
0 358 0.8014799693799438
0 359 0.8042868078591477
0 360 0.8323551926511865
0 361 0.840010206685379
0 362 0.8328655269201327
0 363 0.8282725184996172
0 364 0.8328655269201327
0 365 0.8392447052819597
0 366 0.8351620311303904
0 367 0.8384792038785405
0 368 0.8292931870375095
0 369 0.8185761673896402
0 370 0.8216381730033172
0 371 0.834141362592498
0 372 0.8474100535850982
0 373 0.849706557795356
0 374 0.8573615718295483
0 375 0

0 667 0.9132431742791528
0 668 0.8933401377902526
0 669 0.891553967848941
0 670 0.914008675682572
0 671 0.8828782852768563
0 672 0.8581270732329676
0 673 0.8645062515947946
0 674 0.888236795100791
0 675 0.9020158203623373
0 676 0.9119673386067875
0 677 0.9208981883133452
0 678 0.8997193161520796
0 679 0.9004848175554988
0 680 0.9183465169686145
0 681 0.913498341413626
0 682 0.907629497320745
0 683 0.9025261546312835
0 684 0.907119163051799
0 685 0.9122225057412605
0 686 0.9175810155651952
0 687 0.9091605001275835
0 688 0.9127328400102067
0 689 0.8833886195458025
0 690 0.8634855830569023
0 691 0.8634855830569023
0 692 0.8739474355702985
0 693 0.8869609594284256
0 694 0.891298800714468
0 695 0.8935953049247257
0 696 0.8895126307731565
0 697 0.9183465169686145
0 698 0.9214085225822914
0 699 0.907629497320745
0 700 0.9142638428170452
0 701 0.9060984945139066
0 702 0.8823679510079102
0 703 0.8823679510079102
0 704 0.8890022965042103
0 705 0.9078846644552182
0 706 0.8997193161520796
0 707 0.

0 998 0.945394233222761
0 999 0.9737177851492728
0 0.9737177851492728
1 0 0.9688696095942843
1 1 0.9510079101811687
1 2 0.951773411584588
1 3 0.965042102577188
1 4 0.9686144424598112
1 5 0.9704006124011227
1 6 0.9658076039806073
1 7 0.955090584332738
1 8 0.9635110997703495
1 9 0.9752487879561113
1 10 0.9709109466700689
1 11 0.9635110997703495
1 12 0.9709109466700689
1 13 0.9675937739219188
1 14 0.9624904312324573
1 15 0.9573870885429957
1 16 0.9637662669048227
1 17 0.9752487879561113
1 18 0.9726971166113805
1 19 0.9655524368461342
1 20 0.9571319214085225
1 21 0.9548354171982648
1 22 0.968359275325338
1 23 0.9798417963766267
1 24 0.9831589691247767
1 25 0.9737177851492728
1 26 0.9663179382495535
1 27 0.9645317683082419
1 28 0.9760142893595305
1 29 0.9729522837458535
1 30 0.9640214340392957
1 31 0.9635110997703495
1 32 0.9658076039806073
1 33 0.9752487879561113
1 34 0.980862464914519
1 35 0.9706557795355958
1 36 0.9635110997703495
1 37 0.9670834396529727
1 38 0.9783107935697882
1 39 0.98

1 332 0.9982138300586885
1 333 0.9977034957897423
1 334 0.993620821638173
1 335 0.9923449859658076
1 336 0.9943863230415922
1 337 0.9954069915794845
1 338 0.9977034957897423
1 339 0.99668282725185
1 340 0.99668282725185
1 341 0.996937994386323
1 342 0.9959173258484307
1 343 0.9951518244450115
1 344 0.9974483286552692
1 345 0.9974483286552692
1 346 0.9961724929829038
1 347 0.9948966573105384
1 348 0.9910691502934422
1 349 0.996937994386323
1 350 0.9992344985965808
1 351 0.9948966573105384
1 352 0.9992344985965808
1 353 0.9928553202347538
1 354 0.9915794845623883
1 355 0.996937994386323
1 356 0.99668282725185
1 357 0.9928553202347538
1 358 0.993620821638173
1 359 0.9974483286552692
1 360 0.9951518244450115
1 361 0.9918346516968615
1 362 0.9926001531002807
1 363 0.9931104873692268
1 364 0.9987241643276346
1 365 0.996937994386323
1 366 0.9941311559071192
1 367 0.9989793314621077
1 368 0.9994896657310538
1 369 0.9959173258484307
1 370 0.9918346516968615
1 371 0.9971931615207961
1 372 0.9943

1 666 0.9982138300586885
1 667 0.9984689971931615
1 668 0.9992344985965808
1 669 0.9982138300586885
1 670 0.99668282725185
1 671 0.9943863230415922
1 672 0.9984689971931615
1 673 0.999744832865527
1 674 0.999744832865527
1 675 0.9994896657310538
1 676 0.9977034957897423
1 677 0.9974483286552692
1 678 0.9971931615207961
1 679 0.9979586629242153
1 680 0.9982138300586885
1 681 0.9984689971931615
1 682 0.9977034957897423
1 683 0.9989793314621077
1 684 0.9989793314621077
1 685 0.9989793314621077
1 686 0.9992344985965808
1 687 0.9971931615207961
1 688 0.9961724929829038
1 689 0.9987241643276346
1 690 0.996937994386323
1 691 0.9961724929829038
1 692 0.9931104873692268
1 693 0.9920898188313345
1 694 0.9946414901760653
1 695 0.9971931615207961
1 696 0.9977034957897423
1 697 0.9977034957897423
1 698 0.9974483286552692
1 699 0.9971931615207961
1 700 0.9974483286552692
1 701 0.9987241643276346
1 702 0.9984689971931615
1 703 0.9982138300586885
1 704 0.9989793314621077
1 705 0.9992344985965808
1 706

2 2 0.999744832865527
2 3 0.9982138300586885
2 4 0.9984689971931615
2 5 0.9994896657310538
2 6 0.9994896657310538
2 7 0.999744832865527
2 8 1.0
2 9 0.9984689971931615
2 10 0.9979586629242153
2 11 0.999744832865527
2 12 0.9989793314621077
2 13 0.9987241643276346
2 14 0.9989793314621077
2 15 0.9987241643276346
2 16 0.9992344985965808
2 17 0.9989793314621077
2 18 0.9989793314621077
2 19 0.9982138300586885
2 20 0.9977034957897423
2 21 0.996937994386323
2 22 0.9987241643276346
2 23 0.9984689971931615
2 24 0.9971931615207961
2 25 0.9989793314621077
2 26 0.993620821638173
2 27 0.9920898188313345
2 28 0.9987241643276346
2 29 0.9928553202347538
2 30 0.9989793314621077
2 31 0.9964276601173769
2 32 0.9959173258484307
2 33 0.9987241643276346
2 34 0.9977034957897423
2 35 0.9959173258484307
2 36 0.9977034957897423
2 37 0.9951518244450115
2 38 0.9961724929829038
2 39 0.9984689971931615
2 40 0.9994896657310538
2 41 0.999744832865527
2 42 0.9984689971931615
2 43 1.0
2 44 0.999744832865527
2 45 1.0
2 46

2 347 0.9941311559071192
2 348 0.9974483286552692
2 349 0.9974483286552692
2 350 0.9977034957897423
2 351 0.9979586629242153
2 352 0.9982138300586885
2 353 0.99668282725185
2 354 0.9954069915794845
2 355 0.996937994386323
2 356 0.9971931615207961
2 357 0.9890278132176575
2 358 0.993620821638173
2 359 0.9977034957897423
2 360 0.9836693033937229
2 361 0.9933656545037
2 362 0.9964276601173769
2 363 0.9788211278387343
2 364 0.9831589691247767
2 365 0.987496810410819
2 366 0.9867313090073998
2 367 0.993875988772646
2 368 0.9836693033937229
2 369 0.9818831334524113
2 370 0.9854554733350345
2 371 0.9928553202347538
2 372 0.9982138300586885
2 373 0.9854554733350345
2 374 0.9706557795355958
2 375 0.9831589691247767
2 376 0.9974483286552692
2 377 0.990303648890023
2 378 0.9852003062005613
2 379 0.9877519775452922
2 380 0.9887726460831845
2 381 0.993875988772646
2 382 0.9915794845623883
2 383 0.990303648890023
2 384 0.9885174789487114
2 385 0.9946414901760653
2 386 0.9918346516968615
2 387 0.9954

2 681 0.9989793314621077
2 682 0.9992344985965808
2 683 0.9984689971931615
2 684 0.996937994386323
2 685 0.9974483286552692
2 686 0.9987241643276346
2 687 0.9982138300586885
2 688 0.9987241643276346
2 689 0.9984689971931615
2 690 0.9982138300586885
2 691 0.9989793314621077
2 692 0.9987241643276346
2 693 0.9992344985965808
2 694 0.9989793314621077
2 695 0.9987241643276346
2 696 0.9989793314621077
2 697 1.0
2 698 0.9994896657310538
2 699 0.9987241643276346
2 700 0.9987241643276346
2 701 0.9992344985965808
2 702 0.9989793314621077
2 703 0.9989793314621077
2 704 0.9987241643276346
2 705 0.9987241643276346
2 706 0.9987241643276346
2 707 0.9989793314621077
2 708 0.9989793314621077
2 709 0.9989793314621077
2 710 0.9992344985965808
2 711 0.9984689971931615
2 712 0.9979586629242153
2 713 0.9989793314621077
2 714 0.9994896657310538
2 715 0.9992344985965808
2 716 0.9994896657310538
2 717 0.9992344985965808
2 718 0.9987241643276346
2 719 0.996937994386323
2 720 0.996937994386323
2 721 0.9982138300

3 24 0.9954069915794845
3 25 0.986986476141873
3 26 0.9931104873692268
3 27 0.9982138300586885
3 28 0.9964276601173769
3 29 0.9954069915794845
3 30 0.9964276601173769
3 31 0.9989793314621077
3 32 0.9984689971931615
3 33 0.99668282725185
3 34 0.9915794845623883
3 35 0.9895381474866037
3 36 0.9895381474866037
3 37 0.9984689971931615
3 38 0.9979586629242153
3 39 0.983924470528196
3 40 0.9798417963766267
3 41 0.987496810410819
3 42 0.9971931615207961
3 43 0.9977034957897423
3 44 0.9982138300586885
3 45 0.9977034957897423
3 46 0.9954069915794845
3 47 0.9895381474866037
3 48 0.9892829803521307
3 49 0.9948966573105384
3 50 0.999744832865527
3 51 0.9992344985965808
3 52 0.9984689971931615
3 53 0.9956621587139577
3 54 0.9959173258484307
3 55 0.9974483286552692
3 56 0.9977034957897423
3 57 0.9979586629242153
3 58 0.9984689971931615
3 59 0.9987241643276346
3 60 0.9994896657310538
3 61 0.9989793314621077
3 62 0.9951518244450115
3 63 0.9989793314621077
3 64 0.9992344985965808
3 65 0.999234498596580

3 375 0.9956621587139577
3 376 0.9977034957897423
3 377 0.99668282725185
3 378 0.9943863230415922
3 379 0.990558816024496
3 380 0.9885174789487114
3 381 0.987241643276346
3 382 0.9877519775452922
3 383 0.990303648890023
3 384 0.99668282725185
3 385 0.9992344985965808
3 386 0.9992344985965808
3 387 0.9984689971931615
3 388 0.9994896657310538
3 389 0.9959173258484307
3 390 0.9920898188313345
3 391 0.9974483286552692
3 392 0.99668282725185
3 393 0.9971931615207961
3 394 0.993620821638173
3 395 0.9895381474866037
3 396 0.9867313090073998
3 397 0.9854554733350345
3 398 0.9910691502934422
3 399 0.9951518244450115
3 400 0.9979586629242153
3 401 0.9954069915794845
3 402 0.9941311559071192
3 403 0.9928553202347538
3 404 0.990558816024496
3 405 0.9987241643276346
3 406 0.9956621587139577
3 407 0.9946414901760653
3 408 0.9979586629242153
3 409 0.9982138300586885
3 410 0.9951518244450115
3 411 0.9977034957897423
3 412 0.9954069915794845
3 413 0.9877519775452922
3 414 0.9954069915794845
3 415 0.998

3 713 0.9989793314621077
3 714 0.9989793314621077
3 715 0.9989793314621077
3 716 0.9989793314621077
3 717 0.9984689971931615
3 718 0.9982138300586885
3 719 0.9987241643276346
3 720 0.9977034957897423
3 721 0.9987241643276346
3 722 0.9992344985965808
3 723 0.999744832865527
3 724 0.9994896657310538
3 725 0.9987241643276346
3 726 0.9959173258484307
3 727 0.9923449859658076
3 728 0.9910691502934422
3 729 0.9982138300586885
3 730 0.9994896657310538
3 731 0.9994896657310538
3 732 1.0
3 733 0.9994896657310538
3 734 0.9989793314621077
3 735 0.9994896657310538
3 736 0.9992344985965808
3 737 0.999744832865527
3 738 0.999744832865527
3 739 0.9994896657310538
3 740 0.9982138300586885
3 741 0.9992344985965808
3 742 0.9992344985965808
3 743 0.999744832865527
3 744 0.999744832865527
3 745 1.0
3 746 0.999744832865527
3 747 0.999744832865527
3 748 0.999744832865527
3 749 0.9994896657310538
3 750 0.9987241643276346
3 751 0.9987241643276346
3 752 0.9987241643276346
3 753 0.9984689971931615
3 754 0.99846

In [9]:
np.save('accuracies.npy', accuracies)
np.save('indices.npy', range(timeSeen-1))