In [125]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [126]:
class Auxiliary(nn.Module):
    def __init__(self):
        super().__init__()
        
        kernel_size = 3
        out_channels1 = 32
        in_channels1 = 1
        out_channels2 = 64
        in_channels2 = 32
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv1_1 = nn.Conv2d(in_channels1, out_channels1, kernel_size)
        self.conv1_2 = nn.Conv2d(in_channels2, out_channels2, kernel_size)
        
        self.conv2_1 = nn.Conv2d(in_channels1, out_channels1, kernel_size)
        self.conv2_2 = nn.Conv2d(in_channels2, out_channels2, kernel_size)
        
        in_features1 = 256
        out_features1 = 120
        out_features2 = 10
        
        self.fc1_1 = nn.Linear(in_features1, out_features1)
        self.fc1_2 = nn.Linear(out_features1, out_features2)

        self.fc2_1 = nn.Linear(in_features1, out_features1)
        self.fc2_2 = nn.Linear(out_features1, out_features2)

        
        self.fc3 = nn.Linear(2 * out_features2, 1)
        
    def forward(self, x):
        x1 = x[:,:1,:]
        x2 = x[:,1:2,:]
        
        x1 = self.pool(F.relu(self.conv1_1(x1)))
        x1 = self.pool(F.relu(self.conv1_2(x1)))
        x1 = torch.flatten(x1, 1)
        x1 = F.relu(self.fc1_1(x1))
        x1 = F.relu(self.fc1_2(x1))
        
        x2 = self.pool(F.relu(self.conv2_1(x2)))
        x2 = self.pool(F.relu(self.conv2_2(x2)))
        x2 = torch.flatten(x2, 1)
        x2 = F.relu(self.fc2_1(x2))
        x2 = F.relu(self.fc2_2(x2))
        
        x = self.fc3(torch.cat((x1, x2), dim=1))
        
        return x, x1, x2

In [127]:
model = Auxiliary()

In [128]:
from dataloading import load_data

trainloader, testloader = load_data()

In [129]:
learning_rate = .001

criterion = nn.BCEWithLogitsLoss()
class_criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
#optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [132]:
epochs = 30

losses = []

for epoch in range(epochs):
    running_loss = 0.0
    epoch_losses = []
    
    for i, data in enumerate(trainloader, 0):
        inputs, target, classes = data
        
        outputs, x1, x2 = model(inputs)
        outputs = outputs.squeeze()
        
        loss = criterion(outputs, target.float())
        
        epoch_losses.append(loss.item())
        
        class_loss1 = class_criterion(x1, classes[:,0])
        class_loss2 = class_criterion(x2, classes[:,1])
        
        loss += class_loss1 + class_loss2
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        
    
    epoch_loss = torch.mean(torch.tensor(epoch_losses))
    print(f'epoch: {epoch}, loss: {epoch_loss}')
    losses.append(epoch_loss)  

epoch: 0, loss: 0.6235746145248413
epoch: 1, loss: 0.6248899102210999
epoch: 2, loss: 0.6175503134727478
epoch: 3, loss: 0.6174079179763794
epoch: 4, loss: 0.6148945093154907
epoch: 5, loss: 0.6070784330368042
epoch: 6, loss: 0.6035206317901611
epoch: 7, loss: 0.6073626279830933
epoch: 8, loss: 0.5984597206115723
epoch: 9, loss: 0.5994621515274048
epoch: 10, loss: 0.5949071049690247
epoch: 11, loss: 0.5884535908699036
epoch: 12, loss: 0.5888424515724182
epoch: 13, loss: 0.5880634784698486
epoch: 14, loss: 0.5843179821968079
epoch: 15, loss: 0.5839698314666748
epoch: 16, loss: 0.580315887928009
epoch: 17, loss: 0.5808995962142944
epoch: 18, loss: 0.5784263610839844
epoch: 19, loss: 0.5759315490722656
epoch: 20, loss: 0.5730262994766235
epoch: 21, loss: 0.578160285949707
epoch: 22, loss: 0.5752938389778137
epoch: 23, loss: 0.5673414468765259
epoch: 24, loss: 0.5658857226371765
epoch: 25, loss: 0.5645912885665894
epoch: 26, loss: 0.5601229071617126
epoch: 27, loss: 0.5627739429473877
epoc

In [131]:
from evaluate import evaluate_model

def model_fn(x):
    outputs, _, _ = model(x)
    
    return outputs
    

evaluate_model(model_fn, testloader)

Accuracy: 0.5559999942779541
