In [111]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [112]:
class Auxiliary(nn.Module):
    def __init__(self):
        super().__init__()
        
        kernel_size = 3
        out_channels1 = 32
        in_channels1 = 1
        out_channels2 = 64
        in_channels2 = 32
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv1_1 = nn.Conv2d(in_channels1, out_channels1, kernel_size)
        self.conv1_2 = nn.Conv2d(in_channels2, out_channels2, kernel_size)
        
        self.conv2_1 = nn.Conv2d(in_channels1, out_channels1, kernel_size)
        self.conv2_2 = nn.Conv2d(in_channels2, out_channels2, kernel_size)
        
        in_features1 = 256
        out_features1 = 120
        out_features2 = 10
        
        self.fc1_1 = nn.Linear(in_features1, out_features1)
        self.fc1_2 = nn.Linear(out_features1, out_features2)

        self.fc2_1 = nn.Linear(in_features1, out_features1)
        self.fc2_2 = nn.Linear(out_features1, out_features2)

        
        self.fc3 = nn.Linear(2 * out_features2, 1)
        
    def forward(self, x):
        x1 = x[:,:1,:]
        x2 = x[:,1:2,:]
        
        x1 = self.pool(F.relu(self.conv1_1(x1)))
        x1 = self.pool(F.relu(self.conv1_2(x1)))
        x1 = torch.flatten(x1, 1)
        x1 = F.relu(self.fc1_1(x1))
        x1 = F.relu(self.fc1_2(x1))
        
        x2 = self.pool(F.relu(self.conv2_1(x2)))
        x2 = self.pool(F.relu(self.conv2_2(x2)))
        x2 = torch.flatten(x2, 1)
        x2 = F.relu(self.fc2_1(x2))
        x2 = F.relu(self.fc2_2(x2))
        
        x = self.fc3(torch.cat((x1, x2), dim=1))
        
        return x, x1, x2

In [113]:
model = Auxiliary()

In [114]:
from dataloading import load_data

trainloader, testloader = load_data()

In [115]:
learning_rate = .001

criterion = nn.BCEWithLogitsLoss()
class_criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
#optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [116]:
epochs = 20

losses = []

for epoch in range(epochs):
    running_loss = 0.0
    epoch_losses = []
    
    for i, data in enumerate(trainloader, 0):
        inputs, target, classes = data
        
        outputs, x1, x2 = model(inputs)
        outputs = outputs.squeeze()
        
        loss = criterion(outputs, target.float())
        
        epoch_losses.append(loss.item())
        
        class_loss1 = class_criterion(x1, classes[:,0])
        class_loss2 = class_criterion(x2, classes[:,1])
        
        loss += class_loss1 + class_loss2
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        
    
    epoch_loss = torch.mean(torch.tensor(epoch_losses))
    print(f'epoch: {epoch}, loss: {epoch_loss}')
    losses.append(epoch_loss)  

epoch: 0, loss: 0.8248533606529236
epoch: 1, loss: 0.6929522752761841
epoch: 2, loss: 0.6838962435722351
epoch: 3, loss: 0.681434154510498
epoch: 4, loss: 0.6716891527175903
epoch: 5, loss: 0.6625599265098572
epoch: 6, loss: 0.6620429754257202
epoch: 7, loss: 0.6541255116462708
epoch: 8, loss: 0.6503430008888245
epoch: 9, loss: 0.6267626285552979
epoch: 10, loss: 0.6119005084037781
epoch: 11, loss: 0.6117603182792664
epoch: 12, loss: 0.5827203392982483
epoch: 13, loss: 0.5704416632652283
epoch: 14, loss: 0.5634921193122864
epoch: 15, loss: 0.5546033382415771
epoch: 16, loss: 0.5490549802780151
epoch: 17, loss: 0.5436999797821045
epoch: 18, loss: 0.5415211319923401
epoch: 19, loss: 0.5350039601325989


In [117]:
from evaluate import evaluate_model

def model_fn(x):
    outputs, _, _ = model(x)
    
    return outputs
    

evaluate_model(model_fn, testloader)

Accuracy: 0.6620000004768372
