In [1]:
# Imports
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons, CheckButtons
from mpl_toolkits.mplot3d import Axes3D
import math
import os
import time
import itertools

import warnings
warnings.filterwarnings('ignore')

from matplotlib.patches import Circle, Wedge
from matplotlib.collections import PatchCollection

from scipy.spatial import ConvexHull

plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'monospace'
plt.rcParams['text.latex.preamble'] = r"\usepackage{amsmath}"

%matplotlib notebook

plt.rcParams['figure.figsize'] = [7, 7]

import torch
import torch.nn
import torch.utils.data
import torchvision

In [2]:
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = torch.nn.Linear(2, 3)
        self.fc4 = torch.nn.Linear(3, 2)
        
    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc4(x)
        
        return x

In [3]:
def get_dataloader(train=True, r1=0.5, batch_size=100, shuffle=False):
    
    np.random.seed(0)

    r1sqr = r1**2
    r2sqr = 2*r1sqr
    r3sqr = 3*r1sqr

    train_x = []
    train_y = []

    test_x = []
    test_y = []

    while len(train_x) < 100000:
        sample = 2*np.random.uniform(size=(1,2)) - 1
        if np.sum(np.square(sample)) < r1sqr:
            train_x.append(sample)
            train_y.append(np.array([[0, 1]]))
        elif r2sqr < np.sum(np.square(sample)) < r3sqr:
            train_x.append(sample)
            train_y.append(np.array([[1, 0]]))

    test_data = []

    for i in range(101):
        for j in range(101):
            test_data.append([2*i/100 - 1, 2*j/100 - 1])

    test_data = np.array(test_data)

    for elem in test_data:
        if np.sum(np.square(elem)) < r1sqr:
            test_x.append(elem)
            test_y.append(np.array([[0, 1]]))
        elif r2sqr < np.sum(np.square(elem)) < r3sqr:
            test_x.append(elem)
            test_y.append(np.array([[1, 0]]))
    
    if train:
        tensor_x = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in train_x])
        tensor_y = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in train_y])
    else:
        tensor_x = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in test_x])
        tensor_y = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in test_y])
        
    my_dataset = torch.utils.data.TensorDataset(tensor_x, tensor_y)
    return torch.utils.data.DataLoader(my_dataset, batch_size=batch_size, shuffle=shuffle)

In [4]:
train_loader = get_dataloader(train=True, batch_size=100, shuffle=True)
test_loader =  get_dataloader(train=False, batch_size=100, shuffle=True)

In [5]:
timeSeen = 0
accuracies = []

In [6]:
def train(network, optimizer, lossFunc, train_loader, epoch):
    global timeSeen
    global accuracies
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = lossFunc(output, target)
        loss.backward()
        optimizer.step()
        
        accuracies.append(test(network, test_loader))
        
        print(epoch, batch_idx, accuracies[-1])
    
        torch.save(network, 'trainedNetsCircle/network{:09d}.pt'.format(timeSeen))
        timeSeen += 1
        
def test(network, test_loader):
    network.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            pred = output.data.max(1, keepdim=True)[1]
            tar = target.data.max(2, keepdim=True)[1]
            correct += pred.eq(tar.data.view_as(pred)).sum()
    return correct.item()/len(test_loader.dataset)

In [7]:
acc = 0
while max(accuracies + [0]) < 0.98:
    criterion = torch.nn.MSELoss()
    network = SimpleNet()
    optimizer = torch.optim.SGD(network.parameters(), 0.1, momentum=0.9, weight_decay=1e-4)
    timeSeen = 0
    accuracies = []

    torch.save(network, 'trainedNetsCircle/network{:09d}.pt'.format(timeSeen))
    timeSeen += 1
    for i in range(1):
        train(network, optimizer, criterion, train_loader, i)
        #acc = test(network, test_loader)
        print(i, accuracies[-1])

0 0 0.594284256187803
0 1 0.626180147996938
0 2 0.5787190609849452
0 3 0.46746619035468234
0 4 0.38581270732329676
0 5 0.3689716764480735
0 6 0.3763715233477928
0 7 0.38428170451645827
0 8 0.4092880836948201
0 9 0.4940035723398826
0 10 0.613166624138811
0 11 0.718550650676193
0 12 0.6014289359530492
0 13 0.6057667772390916
0 14 0.632814493493238
0 15 0.6560347027302883
0 16 0.6585863740750192
0 17 0.6404695075274305
0 18 0.6180147996937995
0 19 0.6432763460066343
0 20 0.6481245215616228
0 21 0.6381730033171727
0 22 0.6330696606277112
0 23 0.6126562898698648
0 24 0.6139321255422302
0 25 0.6157182954835417
0 26 0.632559326358765
0 27 0.6282214850727227
0 28 0.6346006634345497
0 29 0.639193671855065
0 30 0.6629242153610615
0 31 0.6639448838989538
0 32 0.7042612911457005
0 33 0.7203368206175045
0 34 0.7333503444756315
0 35 0.744322531257974
0 36 0.747894871140597
0 37 0.744322531257974
0 38 0.7430466955856085
0 39 0.7420260270477163
0 40 0.7448328655269202
0 41 0.7550395509058433
0 42 0.76

0 578 1.0
0 579 1.0
0 580 1.0
0 581 1.0
0 582 1.0
0 583 0.999744832865527
0 584 1.0
0 585 1.0
0 586 1.0
0 587 1.0
0 588 1.0
0 589 1.0
0 590 1.0
0 591 1.0
0 592 1.0
0 593 1.0
0 594 1.0
0 595 1.0
0 596 1.0
0 597 1.0
0 598 1.0
0 599 1.0
0 600 1.0
0 601 1.0
0 602 1.0
0 603 1.0
0 604 1.0
0 605 1.0
0 606 1.0
0 607 1.0
0 608 1.0
0 609 1.0
0 610 1.0
0 611 1.0
0 612 1.0
0 613 1.0
0 614 1.0
0 615 1.0
0 616 1.0
0 617 1.0
0 618 1.0
0 619 1.0
0 620 1.0
0 621 1.0
0 622 1.0
0 623 1.0
0 624 1.0
0 625 1.0
0 626 1.0
0 627 1.0
0 628 1.0
0 629 1.0
0 630 1.0
0 631 1.0
0 632 1.0
0 633 1.0
0 634 1.0
0 635 1.0
0 636 1.0
0 637 1.0
0 638 1.0
0 639 1.0
0 640 1.0
0 641 1.0
0 642 1.0
0 643 1.0
0 644 1.0
0 645 1.0
0 646 1.0
0 647 1.0
0 648 1.0
0 649 1.0
0 650 1.0
0 651 1.0
0 652 1.0
0 653 1.0
0 654 1.0
0 655 1.0
0 656 1.0
0 657 0.9977034957897423
0 658 0.996937994386323
0 659 1.0
0 660 1.0
0 661 1.0
0 662 1.0
0 663 1.0
0 664 1.0
0 665 1.0
0 666 1.0
0 667 1.0
0 668 1.0
0 669 1.0
0 670 1.0
0 671 1.0
0 672 1.0
0 673 1

In [8]:
np.save('accuraciesCircle.npy', accuracies)
np.save('indicesCircle.npy', range(timeSeen-1))