In [1]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

class MyModelA(pl.LightningModule):
    def __init__(self, hidden_dim = 10):
        super(MyModelA, self).__init__()
        self.fc1 = torch.nn.Linear(hidden_dim, 2)
        self.save_hyperparameters()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = 1e-3)
        return optimizer
        
    def forward(self, x):
        x = self.fc1(x)
        return x

    def training_step(self, batch, batch_idx):
        x,y = batch
        return F.mse_loss(self.forward(x), y)
    
class MyModelB(pl.LightningModule):
    def __init__(self, hidden_dim = 10):
        super(MyModelB, self).__init__()
        self.fc1 = torch.nn.Linear(hidden_dim, 2)
        self.save_hyperparameters()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = 1e-3)
        return optimizer
      
    def forward(self, x):
        x = self.fc1(x)
        return x

    def training_step(self, batch, batch_idx):
        x,y = batch
        return F.mse_loss(self.forward(x), y)

class MyEnsemble(pl.LightningModule):
    def __init__(self, 
                modelA_hparams, modelB_hparams, 
                modelA_params = None, modelB_params = None):
        super(MyEnsemble, self).__init__()
        self.modelA = MyModelA(**modelA_hparams)
        self.modelB = MyModelB(**modelA_hparams)

        if modelA_params:
            self.modelA.load_state_dict({k: v["value"].reshape(v["shape"])
                                        for k, v in modelA_params.items()})
        if modelB_params:
            self.modelB.load_state_dict({k: v["value"].reshape(v["shape"])
                                        for k, v in modelB_params.items()})

        self.modelA.freeze()
        self.modelB.freeze()
        self.classifier = torch.nn.Linear(4, 2)

        self.save_hyperparameters(ignore=["modelA_params", "modelB_params"])

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = 1e-3)
        return optimizer
        
    def forward(self, x):
        x1 = self.modelA(x)
        x2 = self.modelB(x)
        x = torch.cat((x1, x2), dim=1)
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        return F.mse_loss(self.forward(x), y)

dl = DataLoader(TensorDataset(torch.randn(1000, 10), 
                            torch.randn(1000, 2)), 
                batch_size = 10)

modelA = MyModelA(10)
modelB = MyModelB(10)

# pretrained modelA and modelB
trainerA = pl.Trainer(gpus = 0, max_epochs = 5, progress_bar_refresh_rate = 50)
trainerA.fit(modelA, dl)
trainerB = pl.Trainer(gpus = 0, max_epochs = 5, progress_bar_refresh_rate = 50)
trainerB.fit(modelB, dl)

# Reshape parameters/weights such that it is 1D
modelA_params = {k: {"shape": v.shape,"value": torch.flatten(v)} 
                for k, v in modelA.state_dict().items()}
modelB_params = {k: {"shape": v.shape,"value": torch.flatten(v)} 
                for k, v in modelB.state_dict().items()}
modelA_hparams = modelA.hparams
modelB_hparams = modelB.hparams

# modelA and modelB contain pretrained weights
model = MyEnsemble(modelA_hparams, modelB_hparams, 
                modelA_params, modelB_params)

trainer = pl.Trainer(gpus = 0, max_epochs = 5, progress_bar_refresh_rate = 50)
trainer.fit(model, dl)


  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name | Type   | Params
--------------------------------
0 | fc1  | Linear | 22    
--------------------------------
22        Trainable params
0         Non-trainable params
22        Total params
0.000     Total estimated model params size (MB)
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Epoch 4: 100%|██████████| 100/100 [00:00<00:00, 456.49it/s, loss=0.948, v_num=0]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name | Type   | Params
--------------------------------
0 | fc1  | Linear | 22    
--------------------------------
22        Trainable params
0         Non-trainable params
22        Total params
0.000     Total estimated model params size (MB)



Epoch 4: 100%|██████████| 100/100 [00:00<00:00, 284.12it/s, loss=0.948, v_num=1]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs






  | Name       | Type     | Params
----------------------------------------
0 | modelA     | MyModelA | 22    
1 | modelB     | MyModelB | 22    
2 | classifier | Linear   | 10    
----------------------------------------
10        Trainable params
44        Non-trainable params
54        Total params
0.000     Total estimated model params size (MB)


Epoch 4: 100%|██████████| 100/100 [00:00<00:00, 348.43it/s, loss=0.943, v_num=2]
