In [35]:
import json
import torch
import torch.nn as nn
from datasets import ImputationDataset
from torch.utils.data import DataLoader
from models import TransformerEncoderInputter
from datasets import find_padding_masks

In [84]:
transformer_model = TransformerEncoderInputter(feat_dim=35,
                                    max_len=40,
                                    d_model=64, 
                                    n_heads=8, 
                                    num_layers=1,
                                    dim_feedforward=256, 
                                    dropout=0.1, 
                                    freeze=False)
transformer_model.float()

# Load pretrained weights 
transformer_model.load_state_dict(torch.load('../models/inputting_unity_norm.pt'))

<All keys matched successfully>

In [85]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(64 * 35 * 40, 20)
        self.fc2 = nn.Linear(20, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        logits = self.fc2(x)
        return logits

In [86]:
cnn_model = CNNModel()
cnn_model

CNNModel(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (fc1): Linear(in_features=89600, out_features=20, bias=True)
  (fc2): Linear(in_features=20, out_features=2, bias=True)
)

In [87]:
class CombinedModel(nn.Module):
    def __init__(self, transformer_model, cnn_model):
        super(CombinedModel, self).__init__()
        self.transformer_model = transformer_model
        self.cnn_model = cnn_model

    def forward(self, x, padding_mask):
        transformer_output = self.transformer_model(x, padding_mask)
        transformed_output = transformer_output.unsqueeze(1)
        logits_output = self.cnn_model(transformed_output)
        return logits_output

In [88]:
main_model = CombinedModel(transformer_model, cnn_model)
main_model

CombinedModel(
  (transformer_model): TransformerEncoderInputter(
    (project_inp): Linear(in_features=35, out_features=64, bias=True)
    (pos_enc): LearnablePositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
          )
          (linear1): Linear(in_features=64, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=256, out_features=64, bias=True)
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (output_layer): Linear(in_

In [89]:
with open('../data/data_indices.json', 'r') as f: data_indices = json.load(f)
train_indices = data_indices['train_indices']
train_dataloader = DataLoader(ImputationDataset(train_indices, norm_type='unity', mean_mask_length=3, masking_ratio=0.15), batch_size=10, shuffle=True, drop_last=True)

In [90]:
model.eval()
x, _, _ = next(iter(train_dataloader))
padding_mask = find_padding_masks(x)
x = torch.nan_to_num(x) # replace nan with 0 (since needs to be processed by the model)
probabilities = main_model(x, padding_mask)
probabilities

torch.Size([10, 40, 35])
torch.Size([10, 1, 40, 35])


tensor([[ 0.1553, -0.0554],
        [ 0.1723, -0.0285],
        [ 0.1419, -0.0572],
        [ 0.1556, -0.0286],
        [ 0.1580, -0.0458],
        [ 0.1482, -0.0473],
        [ 0.1514, -0.0428],
        [ 0.1457, -0.0441],
        [ 0.1711, -0.0435],
        [ 0.1557, -0.0394]], grad_fn=<AddmmBackward0>)