In [120]:
import pandas as pd
import torch.nn as nn
import tsl
from tsl.nn.blocks.encoders import TemporalConvNet, SpatioTemporalConvNet, Transformer, SpatioTemporalTransformerLayer
from tsl.nn.blocks.encoders.recurrent import RNN, GraphConvRNN#, MultiRNN

from tsl.nn.blocks.decoders import GCNDecoder

from torch.nn.utils.parametrizations import spectral_norm


In [2]:
def apply_spectral_norm(m):
    for _, module in m.named_children():
        if isinstance(module, nn.Linear):
            spectral_norm(module)
        elif isinstance(module, nn.LSTM):
            for p in module._flat_weights_names:
                if 'weight' in p:
                    spectral_norm(module, name=p)
    return m
    

In [74]:
mlp_block = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
)

print(mlp_block)

Sequential(
  (0): Linear(in_features=10, out_features=20, bias=True)
  (1): ReLU()
  (2): Linear(in_features=20, out_features=1, bias=True)
)


In [75]:
for name, module in mlp_block.named_modules():
    if isinstance(module, nn.Linear):
        spectral_norm(module)
        

Sequential(
  (0): ParametrizedLinear(
    in_features=10, out_features=20, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _SpectralNorm()
      )
    )
  )
  (1): ReLU()
  (2): ParametrizedLinear(
    in_features=20, out_features=1, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _SpectralNorm()
      )
    )
  )
)


In [132]:
def apply_spectral_norm(m):

    class_trans_layer = tsl.nn.blocks.encoders.transformer.Transformer
    decoder = tsl.nn.blocks.decoders.GCNDecoder

    if isinstance(m, class_trans_layer) or isinstance(m, decoder):
        for name, module in m.named_modules():
            if isinstance(module, nn.Linear):
                spectral_norm(module)

    else:

        for _, module in m.named_children():
            
            if isinstance(module, nn.Linear):
                spectral_norm(module)

            elif isinstance(module, nn.LSTM) or isinstance(module, nn.GRU):
                print(module.__class__)
                for p in module.state_dict().keys():
                    if 'weight' in p:
                        spectral_norm(module, name=p)
    return m
    
    
mlp_block = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
)



trans_block = Transformer(
    input_size=10,
    hidden_size=20,
    ff_size=30,
    n_layers=3
    )

rnn_block = RNN(
    input_size = 10, 
    hidden_size = 20, 
    cell='lstm',
    n_layers=3
    )

gcn_block = GCNDecoder(
    input_size = 10, 
    hidden_size = 20, 
    n_layers=3,
    output_size=1
    )



In [117]:
print(mlp_block)
print('-------------------- ')
apply_spectral_norm(mlp_block)
print(mlp_block)

Sequential(
  (0): Linear(in_features=10, out_features=20, bias=True)
  (1): ReLU()
  (2): Linear(in_features=20, out_features=1, bias=True)
)
-------------------- 
Sequential(
  (0): ParametrizedLinear(
    in_features=10, out_features=20, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _SpectralNorm()
      )
    )
  )
  (1): ReLU()
  (2): ParametrizedLinear(
    in_features=20, out_features=1, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _SpectralNorm()
      )
    )
  )
)


In [118]:
print(rnn_block)
print('-------------------- ')
apply_spectral_norm(rnn_block)

RNN(
  (rnn): LSTM(10, 20, num_layers=3)
)
-------------------- 
<class 'torch.nn.modules.rnn.LSTM'>


RNN(
  (rnn): ParametrizedLSTM(
    10, 20, num_layers=3
    (parametrizations): ModuleDict(
      (weight_ih_l0): ParametrizationList(
        (0): _SpectralNorm()
      )
      (weight_hh_l0): ParametrizationList(
        (0): _SpectralNorm()
      )
      (weight_ih_l1): ParametrizationList(
        (0): _SpectralNorm()
      )
      (weight_hh_l1): ParametrizationList(
        (0): _SpectralNorm()
      )
      (weight_ih_l2): ParametrizationList(
        (0): _SpectralNorm()
      )
      (weight_hh_l2): ParametrizationList(
        (0): _SpectralNorm()
      )
    )
  )
)

In [119]:
print('-------------------- ')
apply_spectral_norm(trans_block)

-------------------- 


Transformer(
  (net): Sequential(
    (0): TransformerLayer(
      (att): MultiHeadAttention(
        (out_proj): ParametrizedNonDynamicallyQuantizableLinear(
          in_features=20, out_features=20, bias=True
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): _SpectralNorm()
            )
          )
        )
        (q_proj): Linear(10, 20, bias=True)
      )
      (skip_conn): ParametrizedLinear(
        in_features=10, out_features=20, bias=True
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): _SpectralNorm()
          )
        )
      )
      (norm1): LayerNorm(10)
      (mlp): Sequential(
        (0): LayerNorm(20)
        (1): ParametrizedLinear(
          in_features=20, out_features=30, bias=True
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): _SpectralNorm()
            )
          )
        )
        (2): ELU(alpha=

In [131]:
gcn_block
apply_spectral_norm(gcn_block)

GCNDecoder(
  (convs): ModuleList(
    (0): GraphConv(10, 20)
    (1): GraphConv(20, 20)
    (2): GraphConv(20, 20)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (readout): MLPDecoder(
    (readout): MLP(
      (mlp): Sequential(
        (0): Dense(
          (affinity): ParametrizedLinear(
            in_features=20, out_features=20, bias=True
            (parametrizations): ModuleDict(
              (weight): ParametrizationList(
                (0): _SpectralNorm()
              )
            )
          )
          (activation): ReLU()
          (dropout): Identity()
        )
      )
      (readout): ParametrizedLinear(
        in_features=20, out_features=1, bias=True
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): _SpectralNorm()
          )
        )
      )
    )
    (rearrange): Rearrange('b n (h f) -> b h n f', f=1, h=1)
  )
)