In [97]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [98]:
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu') 

In [99]:
class BaseLine(nn.Module):
    def __init__(self):
        super(BaseLine, self).__init__()
        
        self.fc1 = nn.Linear(392, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)
        
    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

In [100]:
baseLine = BaseLine()

In [101]:
from dlc_practical_prologue import generate_pair_sets

In [102]:
train_input, train_target, train_classes, test_input, test_target, test_classes = generate_pair_sets(1000)

In [110]:
batch_size = 32

trainset = torch.utils.data.TensorDataset(train_input, train_target, train_classes)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torch.utils.data.TensorDataset(test_input, test_target, test_classes)
testloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)

In [104]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(baseline.parameters(), lr=0.001, momentum=0.9)

In [113]:
epochs = 20

In [115]:
losses = []

for epoch in range(epochs):
    running_loss = 0.0
    epoch_losses = []
    
    for i, data in enumerate(trainloader, 0):
        inputs, target, classes = data
        
        optimizer.zero_grad()
        
        outputs = baseLine(inputs)
        outputs = outputs.squeeze()
        
        loss = criterion(outputs, target.float())
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        epoch_losses.append(loss.item())
    
    epoch_loss = torch.mean(torch.tensor(epoch_losses))
    print(f'epoch: {epoch}, loss: {epoch_loss}')
    losses.append(epoch_loss)  

epoch: 0, loss: 2.3640100955963135
epoch: 1, loss: 2.313304901123047
epoch: 2, loss: 2.3007967472076416
epoch: 3, loss: 2.3176662921905518
epoch: 4, loss: 2.3303298950195312
epoch: 5, loss: 2.2992441654205322
epoch: 6, loss: 2.340909957885742
epoch: 7, loss: 2.3045778274536133
epoch: 8, loss: 2.3672242164611816
epoch: 9, loss: 2.31997013092041
epoch: 10, loss: 2.329522132873535
epoch: 11, loss: 2.3496270179748535
epoch: 12, loss: 2.3083980083465576
epoch: 13, loss: 2.325230360031128
epoch: 14, loss: 2.315294027328491
epoch: 15, loss: 2.304314374923706
epoch: 16, loss: 2.338766098022461
epoch: 17, loss: 2.3266441822052
epoch: 18, loss: 2.369011878967285
epoch: 19, loss: 2.324741840362549


In [133]:
correct = 0
total = 0
t = 0

with torch.no_grad():
    for data in testloader:
        inputs, targets, classes = data
        
        outputs = baseLine(inputs)
        
        predictions = torch.round(torch.sigmoid(outputs))
                
        correct += (predictions.squeeze() == targets.squeeze()).sum()
        total += predictions.size(0)

                
acc = correct / total
print(f'Accuracy: {acc}')

Accuracy: 0.4779999852180481
