In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import pytorch_lightning as pl

import pandas as pd
import numpy as np
import joblib
from pathlib import Path

from sklearn.preprocessing import StandardScaler

In [46]:
class FiLMNetwork(pl.LightningModule):
    def __init__(self, inputs_sz, conds_sz):
        super().__init__()
        self.save_hyperparameters()
        self.inputs_emb = self.generate_layers(in_sz=inputs_sz, layers=[512], out_sz=32, ps=None, use_bn=True, bn_final=True)
        self.conds_emb = self.generate_layers(in_sz=conds_sz, layers=[], out_sz=32, ps=None, use_bn=True, bn_final=False)
        self.film_1 = self.film_generator(in_sz=32, layers=[], out_sz=32, ps=None, use_bn=False, bn_final=False)
        self.block_1 = self.generate_layers(in_sz=32, layers=[16], out_sz=16, ps=None, use_bn=True, bn_final=True)
        self.film_2 = self.film_generator(in_sz=16, layers=[], out_sz=16, ps=None, use_bn=False, bn_final=False)
        self.block_2 = self.generate_layers(in_sz=16, layers=[8], out_sz=1, ps=None, use_bn=True, bn_final=False)
    
    def film_generator(self, *args, **kwargs):
        gamma = self.generate_layers(*args, **kwargs)
        beta = self.generate_layers(*args, **kwargs)
        return gamma, beta
    
    def generate_layers(self, in_sz, layers, out_sz, ps, use_bn, bn_final):
        if ps is None: ps = [0]*len(layers) 
        else: ps = ps*len(layers)
        sizes = self.get_sizes(in_sz, layers, out_sz)
        actns = [nn.ReLU(inplace=True) for _ in range(len(sizes)-2)] + [None]
        layers = []
        for i,(n_in,n_out,dp,act) in enumerate(zip(sizes[:-1],sizes[1:],[0.]+ps,actns)):
            layers += self.bn_drop_lin(n_in, n_out, bn=use_bn and i!=0, p=dp, actn=act)
        if bn_final: layers.append(nn.BatchNorm1d(sizes[-1]))
        block = nn.Sequential(*layers)
        return block
    
    def get_sizes(self, in_sz, layers, out_sz):
        return [in_sz] + layers + [out_sz]
    
    def bn_drop_lin(self, n_in:int, n_out:int, bn:bool=True, p:float=0., actn:nn.Module=None):
        "`n_in`->bn->dropout->linear(`n_in`,`n_out`)->`actn`"
        layers = [nn.BatchNorm1d(n_in)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        layers.append(nn.Linear(n_in, n_out))
        if actn is not None: layers.append(actn)
        return layers
    
    def forward(self, inputs, conds):
        return self.conds_emb(conds)

    def training_step(self, batch, batch_idx):
        inputs, conds, y = batch
        input_emb = self.input_emb(inputs)
        conds_emb = self.conds_emb(conds)
        gamma_1, beta_1 = self.film_1(conds_emb)
        x = input_emb * gamma_1 + beta_1
        x = self.block_1(x)
        gamma_2, beta_2 = self.film_2(conds_emb)
        x = x * gamma_2 + beta_2
        y_hat = self.block_2(x)
        loss = F.mse_loss(y_hat, y)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss)
        return result
    
    def validation_step(self, batch, batch_idx):
        inputs, conds, y = batch
        input_emb = self.input_emb(inputs)
        conds_emb = self.conds_emb(conds)
        gamma_1, beta_1 = self.film_1(conds_emb)
        x = input_emb * gamma_1 + beta_1
        x = self.block_1(x)
        gamma_2, beta_2 = self.film_2(conds_emb)
        x = x * gamma_2 + beta_2
        y_hat = self.block_2(x)
        loss = F.mse_loss(y_hat, y)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('val_loss', loss)
        result.log('val_acc', accuracy(y_hat, y))
        return result

    def test_step(self, batch, batch_idx):
        inputs, conds, y = batch
        input_emb = self.input_emb(inputs)
        conds_emb = self.conds_emb(conds)
        gamma_1, beta_1 = self.film_1(conds_emb)
        x = input_emb * gamma_1 + beta_1
        x = self.block_1(x)
        gamma_2, beta_2 = self.film_2(conds_emb)
        x = x * gamma_2 + beta_2
        y_hat = self.block_2(x)
        loss = F.mse_loss(y_hat, y)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('test_loss', loss)
        result.log('test_acc', accuracy(y_hat, y))
        return result

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    

In [44]:
bar = FiLMNetwork(978, 513)

In [45]:
bar

FiLMNetwork(
  (inputs_emb): Sequential(
    (0): Linear(in_features=978, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=512, out_features=32, bias=True)
    (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conds_emb): Sequential(
    (0): Linear(in_features=513, out_features=32, bias=True)
  )
  (block_1): Sequential(
    (0): Linear(in_features=32, out_features=16, bias=True)
    (1): ReLU(inplace=True)
    (2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=16, out_features=16, bias=True)
    (4): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block_2): Sequential(
    (0): Linear(in_features=16, out_features=8, bias=True)
    (1): ReLU(inplace=True)
    (2): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, trac