In [4]:
import torch
import pickle
import torch.nn as nn
from backbones import ResNet18Enc, ResNet18Dec

class MixedModel(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        
        self.wave_encoder = ResNet18Enc(z_dim=z_dim)
        self.time_encoder = ResNet18Enc(z_dim=z_dim)

        self.wave_decoder = ResNet18Dec(z_dim=z_dim)
        self.time_decoder = ResNet18Dec(z_dim=z_dim)
        
    def forward(self, wave, time):
        e_wave, e_time = self.wave_encoder(wave), self.time_encoder(time)
        d_wave, d_time = self.wave_decoder(e_wave), self.time_decoder(e_time)

        return e_wave, e_time, d_wave, d_time

In [16]:
loaded_model = torch.jit.load("traced_joint_model.pt")
loaded_model.eval()

RecursiveScriptModule(
  original_name=MixedModel
  (wave_encoder): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(
      original_name=ResNet18Enc
      (conv1): RecursiveScriptModule(original_name=Conv1d)
      (bn1): RecursiveScriptModule(original_name=BatchNorm1d)
      (layer1): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=BasicBlockEnc
          (conv1): RecursiveScriptModule(original_name=Conv1d)
          (bn1): RecursiveScriptModule(original_name=BatchNorm1d)
          (conv2): RecursiveScriptModule(original_name=Conv1d)
          (bn2): RecursiveScriptModule(original_name=BatchNorm1d)
          (shortcut): RecursiveScriptModule(original_name=Sequential)
        )
        (1): RecursiveScriptModule(
          original_name=BasicBlockEnc
          (conv1): RecursiveScriptModule(original_name=Conv1d)
          (bn1): RecursiveScriptModule(original_name=BatchNorm1d)
        

In [21]:
x,y,z,w = loaded_model(*torch.ones(2, 8, 1, 64).unbind(0))

In [22]:
wave, time = torch.randn(2, 8, 1, 64).unbind(0)

with open("joint_model.pkl", "rb") as f:
    model = pickle.load(f)

In [23]:
model.eval()
xx, yy, zz, ww = model(*torch.ones(2, 8, 1, 64).unbind(0))

In [24]:
xx - x

tensor([[ -68.8670,   23.9643,  -45.3882,  -10.9773, -101.8772],
        [ -68.8670,   23.9643,  -45.3882,  -10.9773, -101.8772],
        [ -68.8670,   23.9643,  -45.3882,  -10.9773, -101.8772],
        [ -68.8670,   23.9643,  -45.3882,  -10.9773, -101.8772],
        [ -68.8670,   23.9643,  -45.3882,  -10.9773, -101.8772],
        [ -68.8670,   23.9643,  -45.3882,  -10.9773, -101.8772],
        [ -68.8670,   23.9643,  -45.3882,  -10.9773, -101.8772],
        [ -68.8670,   23.9643,  -45.3882,  -10.9773, -101.8772]],
       grad_fn=<SubBackward0>)

In [8]:
wave, time = torch.randn(2, 8, 1, 64).unbind(0)
traced_model = torch.jit.trace(model, (wave, time))
traced_model.save("traced_joint_model.pt")