In [None]:
import torch

from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from collidium import *

In [None]:
circles_proximity_queries_dataset_generator(8000, 'train.csv')
circles_proximity_queries_dataset_generator(2000, 'test.csv')
random_figure_generator('train.csv')

In [None]:
training_data = ProximityQueriesDataset('train.csv')
test_data = ProximityQueriesDataset('test.csv')
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"

In [None]:
learning_rate = 1e-3
weight_decay = 1e-5
batch_size = 64
epochs = 10
model = ShallowNet(input_size=6, hidden_size=32, output_size=1).to(device)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    model.train()
    total_loss, total_acc = 0, 0

    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)
        logits = model(X)

        loss = loss_fn(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * X.size(0)
        y_hat = (torch.sigmoid(logits) > 0.5).float()
        total_acc += (y_hat == y).sum().item()

    avg_loss = total_loss / len(dataloader.dataset)
    avg_acc = total_acc / len(dataloader.dataset)

    return avg_loss, avg_acc

def test_loop(dataloader, model, loss_fn):
    model.eval()
    total_loss, total_acc = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            logits = model(X)

            loss = loss_fn(logits, y)
            total_loss += loss.item() * X.size(0)
            y_hat = (torch.sigmoid(logits) > 0.5).float()
            total_acc += (y_hat == y).sum().item()

        avg_loss = total_loss / len(dataloader.dataset)
        avg_acc = total_acc / len(dataloader.dataset)

        return avg_loss, avg_acc

In [None]:
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(1, epochs+1):
    train_loss, train_acc = train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loss, test_acc = test_loop(test_dataloader, model, loss_fn)
    train_loss.append(train_loss)
    train_acc.append(train_acc)
    test_loss.append(test_loss)
    test_acc.append(test_acc)


In [None]:
torch.save(model, 'proximity_queries_model.pth')