In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [35]:
class Shared(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(168, 1)
        
    def conv_forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        return x
        
    def forward(self, x):
        x1 = x[:,:1,:]
        x2 = x[:,1:2,:]
        
        x1 = self.conv_forward(x1)
        x2 = self.conv_forward(x2)
        
        x = self.fc3(torch.cat((x1, x2), dim=1))
        
        return x

In [36]:
shared = Shared()

In [37]:
from dataloading import load_data

In [38]:
trainloader, testloader = load_data()

In [39]:
learning_rate = .001

criterion = nn.BCEWithLogitsLoss()
# optimizer = optim.SGD(baseline.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(shared.parameters(), lr=learning_rate)

In [40]:
epochs = 5

losses = []

for epoch in range(epochs):
    running_loss = 0.0
    epoch_losses = []
    
    for i, data in enumerate(trainloader, 0):
        inputs, target, classes = data
        
        optimizer.zero_grad()
        
        outputs = shared(inputs)
        outputs = outputs.squeeze()
        
        loss = criterion(outputs, target.float())
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        epoch_losses.append(loss.item())
    
    epoch_loss = torch.mean(torch.tensor(epoch_losses))
    print(f'epoch: {epoch}, loss: {epoch_loss}')
    losses.append(epoch_loss)  

epoch: 0, loss: 0.6707589626312256
epoch: 1, loss: 0.3997456431388855
epoch: 2, loss: 0.31657156348228455
epoch: 3, loss: 0.27291736006736755
epoch: 4, loss: 0.259965181350708


In [41]:
from evaluate import evaluate_model

evaluate_model(shared, testloader)

Accuracy: 0.8429999947547913
