In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [77]:
class Auxiliary(nn.Module):
    def __init__(self):
        super().__init__()
        
        kernel_size = 4
        out_channels1 = 6
        in_channels1 = 1
        out_channels2 = 16
        in_channels2 = 6
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv1_1 = nn.Conv2d(in_channels1, out_channels1, kernel_size)
        self.conv1_2 = nn.Conv2d(in_channels2, out_channels2, kernel_size)
        
        self.conv2_1 = nn.Conv2d(in_channels1, out_channels1, kernel_size)
        self.conv2_2 = nn.Conv2d(in_channels2, out_channels2, kernel_size)
        
        in_features1 = 16
        out_features1 = 120
        out_features2 = 10
        
        self.fc1_1 = nn.Linear(in_features1, out_features1)
        self.fc1_2 = nn.Linear(out_features1, out_features2)

        self.fc2_1 = nn.Linear(in_features1, out_features1)
        self.fc2_2 = nn.Linear(out_features1, out_features2)

        
        self.fc3 = nn.Linear(2 * out_features2, 1)
        
    def forward(self, x):
        x1 = x[:,:1,:]
        x2 = x[:,1:2,:]
        
        x1 = self.pool(F.relu(self.conv1_1(x1)))
        x1 = self.pool(F.relu(self.conv1_2(x1)))
        x1 = torch.flatten(x1, 1)
        x1 = F.relu(self.fc1_1(x1))
        x1 = F.relu(self.fc1_2(x1))
        
        x2 = self.pool(F.relu(self.conv2_1(x2)))
        x2 = self.pool(F.relu(self.conv2_2(x2)))
        x2 = torch.flatten(x2, 1)
        x2 = F.relu(self.fc2_1(x2))
        x2 = F.relu(self.fc2_2(x2))
        
        x = self.fc3(torch.cat((x1, x2), dim=1))
        
        return x, x1, x2

In [78]:
model = Auxiliary()

In [79]:
from dataloading import load_data

trainloader, testloader = load_data()

In [80]:
learning_rate = .001

criterion = nn.BCEWithLogitsLoss()
class_criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
#optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [81]:
epochs = 20

losses = []

for epoch in range(epochs):
    running_loss = 0.0
    epoch_losses = []
    
    for i, data in enumerate(trainloader, 0):
        inputs, target, classes = data
        
        optimizer.zero_grad()
        
        outputs, x1, x2 = model(inputs)
        outputs = outputs.squeeze()
        
        loss = criterion(outputs, target.float())
        
        class_loss1 = class_criterion(x1, classes[:,0])
        class_loss2 = class_criterion(x2, classes[:,1])
        
        loss += class_loss1 + class_loss2
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        epoch_losses.append(loss.item())
    
    epoch_loss = torch.mean(torch.tensor(epoch_losses))
    print(f'epoch: {epoch}, loss: {epoch_loss}')
    losses.append(epoch_loss)  

epoch: 0, loss: 5.936705589294434
epoch: 1, loss: 5.288024425506592
epoch: 2, loss: 5.293978214263916
epoch: 3, loss: 5.293136119842529
epoch: 4, loss: 5.292878150939941
epoch: 5, loss: 5.292753219604492
epoch: 6, loss: 5.292789936065674
epoch: 7, loss: 5.292542457580566
epoch: 8, loss: 5.291450023651123
epoch: 9, loss: 5.291220188140869
epoch: 10, loss: 5.291980266571045
epoch: 11, loss: 5.2918381690979
epoch: 12, loss: 5.290717124938965
epoch: 13, loss: 5.291734218597412
epoch: 14, loss: 5.2904839515686035
epoch: 15, loss: 5.290371417999268
epoch: 16, loss: 5.2902913093566895
epoch: 17, loss: 5.290802478790283
epoch: 18, loss: 5.290775299072266
epoch: 19, loss: 5.290101051330566


In [82]:
from evaluate import evaluate_model

def model_fn(x):
    outputs, _, _ = model(x)
    
    return outputs
    

evaluate_model(model_fn, testloader)

Accuracy: 0.5529999732971191
