In [1]:
import json
import torch
import torch.nn as nn
from datasets import ImputationDataset
from torch.utils.data import DataLoader
from models import TransformerEncoder
from datasets import find_padding_masks

In [2]:
transformer_model = TransformerEncoder(feat_dim=35,
                                    max_len=40,
                                    d_model=64, 
                                    n_heads=8, 
                                    num_layers=1,
                                    dim_feedforward=256, 
                                    dropout=0.1, 
                                    freeze=False)
transformer_model.float()

# Load pretrained weights 
transformer_model.load_state_dict(torch.load('../models/inputting_unity_norm.pt'))

<All keys matched successfully>

In [3]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(64 * 35 * 40, 20)
        self.fc2 = nn.Linear(20, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        logits = self.fc2(x)
        return logits

In [4]:
cnn_model = CNNModel()
cnn_model

CNNModel(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (fc1): Linear(in_features=89600, out_features=20, bias=True)
  (fc2): Linear(in_features=20, out_features=2, bias=True)
)

In [5]:
class CombinedModel(nn.Module):
    def __init__(self, transformer_model, cnn_model):
        super(CombinedModel, self).__init__()
        self.transformer_model = transformer_model
        self.cnn_model = cnn_model

    def forward(self, x, padding_mask):
        transformer_output = self.transformer_model(x, padding_mask)
        transformed_output = transformer_output.unsqueeze(1)
        logits_output = self.cnn_model(transformed_output)
        return logits_output

In [6]:
main_model = CombinedModel(transformer_model, cnn_model)
main_model

CombinedModel(
  (transformer_model): TransformerEncoder(
    (project_inp): Linear(in_features=35, out_features=64, bias=True)
    (pos_enc): LearnablePositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
          )
          (linear1): Linear(in_features=64, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=256, out_features=64, bias=True)
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (output_layer): Linear(in_features

In [7]:
with open('../data/data_indices.json', 'r') as f: data_indices = json.load(f)
train_indices = data_indices['train_indices']
train_dataloader = DataLoader(ImputationDataset(train_indices, norm_type='unity', mean_mask_length=3, masking_ratio=0.15), batch_size=10, shuffle=True, drop_last=True)

In [8]:
main_model.eval()
x, _, _ = next(iter(train_dataloader))
padding_mask = find_padding_masks(x)
x = torch.nan_to_num(x) # replace nan with 0 (since needs to be processed by the model)
probabilities = main_model(x, padding_mask)
probabilities

tensor([[0.2484, 0.2286],
        [0.2265, 0.1931],
        [0.1994, 0.1556],
        [0.2204, 0.1553],
        [0.2171, 0.1671],
        [0.2285, 0.2240],
        [0.2230, 0.1962],
        [0.2366, 0.2181],
        [0.2256, 0.1924],
        [0.2259, 0.2230]], grad_fn=<AddmmBackward0>)

In [11]:
def train_and_validate_classifier(model, train_loader, test_loader, n_epoch)

    best_test_loss = 1e20
    running_batch_loss_train = []
    running_batch_loss_test = []
    save_path = '../models/classification_unity_norm.pt'

    for epoch in range(n_epoch):
        print(f"Epoch: {epoch + 1}")

        # Train loop 
        for x, _, y in train_loader:
            padding_mask = find_padding_masks(x)
            x = torch.nan_to_num(x).to(device)
            y_hat = main_model(x, padding_mask).to(device)
            optimizer.zero_grad()
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            running_batch_loss_train.append(loss.item())

        # Test loop
        for x, _, y in test_loader:
            padding_mask = find_padding_masks(x)
            x = torch.nan_to_num(x).to(device)
            y_hat = main_model(x, padding_mask).to(device)
            loss = criterion(y_hat, y)
            running_batch_loss_test.append(loss.item())

        # Save model if test loss is lower than best test loss
        if running_batch_loss_test[-1] < best_test_loss:
            best_test_loss = running_batch_loss_test[-1]
            torch.save(main_model.state_dict(), save_path)
            print(f"Saved model at epoch {epoch + 1}")

        # Print loss
        print(f"Train loss: {running_batch_loss_train[-1]}")
        print(f"Test loss: {running_batch_loss_test[-1]}")

    return running_batch_loss_train, running_batch_loss_test

n_epoch = 100
train_loader = train_dataloader
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(main_model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
main_model.to(device)

train_loss, test_loss = train_and_validate_classifier(main_model, train_loader, train_loader, n_epoch)

Epoch: 1
tensor([1, 1, 1, 1, 0, 1, 1, 1, 0, 1])
