In [5]:
%load_ext autoreload
%autoreload 2

import sys
import os
module_path = os.path.abspath(os.path.join(os.pardir))
if module_path not in sys.path:
    sys.path.append(module_path)

In [6]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.sklearns import R2Score

import pandas as pd
import numpy as np
import joblib
from pathlib import Path

from sklearn.preprocessing import StandardScaler

# Custom
from project.film_model import LinearBlock, FiLMGenerator

## FiLM Structures

In [11]:
class FiLMNetwork(pl.LightningModule):
    def __init__(self, inputs_sz, conds_sz, learning_rate=1e-2, metric=R2Score()):
        super().__init__()
        self.save_hyperparameters()
        self.metric = metric
        self.inputs_emb = LinearBlock(in_sz=inputs_sz, layers=[512,256,128,64], out_sz=32, ps=None, use_bn=True, bn_final=True)
        self.conds_emb = LinearBlock(in_sz=conds_sz, layers=[], out_sz=32, ps=None, use_bn=True, bn_final=False)
        self.film_1 = FiLMGenerator(in_sz=self.conds_emb.out_sz, layers=[], out_sz=32, ps=None, use_bn=False, bn_final=False)
        self.block_1 = LinearBlock(in_sz=self.film_1.out_sz, layers=[16], out_sz=16, ps=None, use_bn=True, bn_final=True)
        self.film_2 = FiLMGenerator(in_sz=self.conds_emb.out_sz, layers=[], out_sz=16, ps=None, use_bn=False, bn_final=False)
        self.block_2 = LinearBlock(in_sz=self.film_2.out_sz, layers=[8], out_sz=1, ps=None, use_bn=True, bn_final=False)
    
    def _forward(self, batch, batch_idx):
        inputs, conds, y = batch
        input_emb = self.inputs_emb(inputs)
        conds_emb = self.conds_emb(conds)
        gamma_1, beta_1 = self.film_1(conds_emb)
        x = input_emb * gamma_1 + beta_1
        x = self.block_1(x)
        gamma_2, beta_2 = self.film_2(conds_emb)
        x = x * gamma_2 + beta_2
        y_hat = self.block_2(x)
        loss = F.mse_loss(y_hat, y)
        return loss
    
    def forward(self, conds):
        return self.conds_emb(conds)

    def training_step(self, batch, batch_idx):
        loss = self._forward(batch, batch_idx)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss)
        return result
    
    def validation_step(self, batch, batch_idx):
        loss = self._forward(batch, batch_idx)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('val_loss', loss, on_step=True)
        result.log('val_r2', self.metric(y_hat, y), on_step=True)
        return result

    def test_step(self, batch, batch_idx):
        loss = self._forward(batch, batch_idx)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('test_loss', loss, on_step=True)
        result.log('test_r2', self.metric(y_hat, y), on_step=True)
        return result

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

In [12]:
bar = FiLMNetwork(978, 513)

In [13]:
bar

FiLMNetwork(
  (metric): R2Score()
  (inputs_emb): LinearBlock(
    (block): Sequential(
      (0): Linear(in_features=978, out_features=512, bias=True)
      (1): ReLU(inplace=True)
      (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Linear(in_features=512, out_features=256, bias=True)
      (4): ReLU(inplace=True)
      (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): Linear(in_features=256, out_features=128, bias=True)
      (7): ReLU(inplace=True)
      (8): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): Linear(in_features=128, out_features=64, bias=True)
      (10): ReLU(inplace=True)
      (11): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): Linear(in_features=64, out_features=32, bias=True)
      (13): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  

## Concat Structures

In [1]:
32 * 16

512

In [2]:
512 / 64

8.0

In [None]:
class ConcatNetwork(pl.LightningModule):
    def __init__(self, inputs_sz, conds_sz, learning_rate=1e-2, metric=R2Score()):
        super().__init__()
        self.save_hyperparameters()
        self.metric = metric
        self.inputs_emb = LinearBlock(in_sz=inputs_sz, layers=[512,256,128,64], out_sz=32, ps=None, use_bn=True, bn_final=True)
        self.conds_emb = LinearBlock(in_sz=conds_sz, layers=[], out_sz=32, ps=None, use_bn=True, bn_final=False)
        self.block_1 = LinearBlock(in_sz=self.inputs_emb.out_sz + self.conds_emb.out_sz, layers=[16], out_sz=16, ps=None, use_bn=True, bn_final=True)
        self.block_2 = LinearBlock(in_sz=self.block_1.out_sz, layers=[8], out_sz=1, ps=None, use_bn=True, bn_final=False)
    
    def _forward(self, batch, batch_idx):
        inputs, conds, y = batch
        input_emb = self.inputs_emb(inputs)
        conds_emb = self.conds_emb(conds)
        x = torch.cat([input_emb, conds_emb], dim=1)
        x = self.block_1(x)
        y_hat = self.block_2(x)
        loss = F.mse_loss(y_hat, y)
        return loss
    
    def forward(self, conds):
        return self.conds_emb(conds)
    
    def training_step(self, batch, batch_idx):
        loss = self._forward(batch, batch_idx)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss)
        return result
    
    def validation_step(self, batch, batch_idx):
        loss = self._forward(batch, batch_idx)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('val_loss', loss, on_step=True)
        result.log('val_r2', self.metric(y_hat, y), on_step=True)
    
    def test_step(self, batch, batch_idx):
        loss = self._forward(batch, batch_idx)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('test_loss', loss, on_step=True)
        result.log('test_r2', self.metric(y_hat, y), on_step=True)
        return result
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)