In [None]:
import torch
import torch.nn as nn

In [None]:
class BaseModel(nn.Module):

  def __init__(self, device):

    super(BaseModel, self).__init__()
    self.device = device

  def forward(self, X):

    return None

  def train_model(self, criterion, optimizer, dataloader):

    self.train()

    all_labels = torch.empty(0).to(self.device)
    all_preds = torch.empty(0).to(self.device)

    running_loss = 0.0

    for i, instance in enumerate(dataloader):

      image_input, tab_input, labels = instance[0].to(torch.float32).to(self.device), instance[1].to(torch.float32).to(self.device).reshape(-1, 1), instance[2].to(torch.float32).to(self.device)

      optimizer.zero_grad()

      outputs = self(image_input, tab_input).to(torch.float32)

      loss = criterion(outputs.view(-1, 1), labels.view(-1, 1))
      loss.backward()

      optimizer.step()

      running_loss += loss.item()

      all_labels = torch.cat((all_labels, labels), dim = 0)
      all_preds = torch.cat((all_preds, outputs), dim = 0)

      if (i + 1) % 10 == 0:
        print(f'Batch {i + 1} Loss = {loss}')

    return all_labels.detach().cpu().numpy(), all_preds.detach().cpu().numpy(), running_loss / len(dataloader)

  def evaluate_model(self, criterion, dataloader):

    self.eval()

    all_labels = torch.empty(0).to(device)
    all_preds = torch.empty(0).to(device)

    running_loss = 0.0

    with torch.no_grad():

      for i, instance in enumerate(dataloader):

        image_input, tab_input, labels = instance[0].to(torch.float32).to(device), instance[1].to(torch.float32).to(device).reshape(-1, 1), instance[2].to(torch.float32).to(device)

        outputs = self(image_input, tab_input).to(torch.float32).view(-1)

        loss = criterion(outputs, labels.view(-1))
        running_loss += loss.item()

        all_labels = torch.cat((all_labels, labels), dim = 0)
        all_preds = torch.cat((all_preds, outputs), dim = 0)

    return all_labels.cpu().numpy(), all_preds.cpu().numpy(), running_loss / len(dataloader)