In [1]:
from torch import nn as tnn
from torch_geometric import nn as gnn
from typing import Union, List
from collections import OrderedDict
import torch, os, pickle

class BaseNetwork(tnn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()
        self._model_attrs = [None]

    def _save(self, path, file_name, attr=None, overwrite=True):
        model_path = os.path.join(path, file_name)
        if not overwrite and os.path.isfile(model_path):
            raise FileExistsError(model_path)
        
        if attr is None:
            state_dict = self.state_dict()
            model_param = getattr(self, '_model_param')
        elif isinstance(attr, str) and hasattr(self, attr):
            state_dict = getattr(self, attr).state_dict()
            model_param = getattr(getattr(self, attr), '_model_param')
        else:
            raise AttributeError(attr)
        
        state_dict = OrderedDict({k:v.cpu().numpy() for k, v in state_dict.items()})
        with open(model_path,'wb') as f:
            pickle.dump({'model_param':model_param, 'state_dict':state_dict}, f)
        
    def _load(self, path, file_name, attr=None, requires_grad=False):
        model_path = os.path.join(path, file_name)
        
        with open(model_path,'rb') as f:
            obj = pickle.load(f)
        model_param = obj['model_param']
        model_state_dict = OrderedDict({k:torch.from_numpy(v) for k,v in obj['state_dict'].items()})
        
        if attr is None:
            self.__init__(**model_param)
            self.load_state_dict(model_state_dict)
            self.requires_grad_(requires_grad=requires_grad)
        elif isinstance(attr, str) and hasattr(self, attr):
            cls = getattr(self, attr)
            cls.__init__(**model_param)
            cls.load_state_dict(model_state_dict)
            cls.requires_grad_(requires_grad=requires_grad)
            setattr(self, attr, cls)
        return self

    def save(self, path, pfx, sfx='model', overwrite=True):
        for attr in self._model_attrs:
            file_name = f'{pfx}.{sfx}' if attr is None else f'{pfx}_{attr}.{sfx}'
            self._save(path, file_name, attr, overwrite)
        
    def load_module(self, path, pfx, attr, sfx='model', requires_grad=False):
        if not hasattr(self, attr):
            raise AttributeError(attr, self._model_attrs)
        self = self._load(path, file_name=f'{pfx}_{attr}.{sfx}', attr=attr, requires_grad=requires_grad)
        return self
        
    def load(self, path, pfx, sfx='model', requires_grad=False):
        for attr in self._model_attrs:
            file_name = f'{pfx}.{sfx}' if attr is None else f'{pfx}_{attr}.{sfx}'
            self = self._load(path, file_name, attr, requires_grad)
        return self

class DNNBlock(BaseNetwork):
    def __init__(self, 
                 input_dim:int, 
                 output_dim:int, 
                 hidden_dim:int = 32,
                 hidden_layers:int = 2,
                 batch_norm:bool = True, 
                 dropout:float = 0,
                 activation:str = 'LeakyReLU',
                 **kwargs): 
        super(DNNBlock, self).__init__()
        self._model_param = {
                 'input_dim':input_dim,
                 'output_dim':output_dim,
                 'hidden_dim':hidden_dim,
                 'hidden_layers':hidden_layers,
                 'batch_norm':batch_norm,
                 'dropout':dropout,
                 'activation':activation
        }
        
        self.embed_layer = tnn.Linear(input_dim, hidden_dim)
        
        self.hidden_layer = tnn.ModuleList()
        for _ in range(hidden_layers):
            layer = [tnn.Linear(hidden_dim, hidden_dim)]
            if batch_norm:
                layer.append(tnn.BatchNorm1d(hidden_dim))
            if dropout > 0:
                layer.append(tnn.Dropout(dropout))
            layer.append(eval(f'tnn.{activation}()'))
            self.hidden_layer.append(tnn.Sequential(*layer))
        
        self.output_layer = tnn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        h = self.embed_layer(x)
        for hidden_layer in self.hidden_layer:
            h = hidden_layer(h)
        out = self.output_layer(h)
        return out

class AE(BaseNetwork):
    def __init__(self, 
                 input_dim:int, 
                 latent_dim:int, 
                 hidden_dim:int = 32,
                 hidden_layers:int = 2,
                 batch_norm:bool = True, 
                 dropout:float = 0,
                 activation:str = 'LeakyReLU',
                 **kwargs): 
        super(AE, self).__init__()
        self._model_attrs = ['encoder','decoder']
        self.encoder = DNNBlock(input_dim, latent_dim, hidden_dim, hidden_layers,
                                batch_norm, dropout, activation, **kwargs)
        self.decoder = DNNBlock(latent_dim, input_dim, hidden_dim, hidden_layers,
                                batch_norm, dropout, activation, **kwargs)
    def forward(self, x):
        l = self.encoder(x)
        y = self.decoder(l)
        return l, y

class CVAE(BaseNetwork):
    def __init__(self, 
                 input_dim:int, 
                 latent_dim:int, 
                 condition_vector_dim:int, 
                 hidden_dim:int = 32,
                 hidden_layers:int = 2,
                 batch_norm:bool = True, 
                 dropout:float = 0,
                 activation:str = 'LeakyReLU',
                 **kwargs): 
        super(CVAE, self).__init__()


In [18]:
a, b = torch.rand((2,8,4))
na = a / torch.sqrt(torch.sum(torch.square(a), dim=-1, keepdim=True))


In [17]:
x = torch.sum(a * b, 1, keepdim=True) 
y = (
    torch.sqrt(torch.sum(torch.square(a), dim=-1, keepdim=True)) * \
    torch.sqrt(torch.sum(torch.square(b), dim=-1, keepdim=True))
)
x/y


tensor([[0.9451],
        [0.9343],
        [0.9404],
        [0.9912],
        [0.7958],
        [0.7436],
        [0.8241],
        [0.9621]])

In [38]:
p.shape

torch.Size([25, 4, 5, 16])

In [40]:
a, b = torch.chunk(p, 2, dim=-1)

In [43]:
a.shape, b.shape
torch.concat([a, b], -1) - p

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  