In [6]:
import torch
import numpy as np
from src.sumformer import *
from src.data_representation import Batch


In [2]:
"""
Batched version of the sumformer
"""
from typing import Literal
import torch
from torch import nn, Tensor
import torch.nn.functional as F

# from ml_lib.models.layers import MLP, Repeat, ResidualShortcut
from src.basic import MLP
from src.combinators import ResidualShortcut, Repeat
from torch_geometric.nn.aggr import Aggregation, SumAggregation
from torch_geometric.nn.resolver import aggregation_resolver

from src.data_representation import Batch

import functools as ft
import itertools as it # pyright: ignore
from inspect import signature


class ResidualShortcut(nn.Module):
    """Residual shortcut as used in ResNet.

    A module that adds the input to the output of another module.
    So if inner module is f, the output is x + f(x).
    
    This is useful to implement residual blocks as they were 
    originally used in resnet (and are used in a lot of modern architectures)
    """
    inner_module: nn.Module
    def __init__(self, inner_module):
        super().__init__()
        self.inner_module = inner_module
        print(self.inner_module)
    
    def forward(self, x):
        y = self.inner_module(x)
        return x + y

class Sequential(nn.Module):
    """Sequential implementation that allows for more than one input and output."""
    sub_modules: nn.ModuleList
    def __init__(self, *modules: nn.Module):
        super().__init__()
        self.sub_modules = nn.ModuleList(modules) 
        print(self.sub_modules)

    def forward(self, *args, **kwargs):
        output = self.sub_modules[0](*args, **kwargs)
        for module in self.sub_modules[1:]:
            prev_output = output
            match prev_output, signature(module.forward).parameters.items():
                case x, [(_, p),] if p.kind in [Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD]:
                    #only one parameter, and it's positional
                    output = module(x)
                case {**kwargs}, _:
                    # the dict is interpreted as kwargs
                    output = module(**kwargs)
                case t, _ if hasattr(t, "_asdict"):
                    # the namedtuple is interpreted as kwargs
                    output = module(**prev_output._asdict())
                case (tuple(args), {**kwargs}), _:
                    output = module(*args, **kwargs)
                case tuple(args), _  :
                    output = module(*args)
                case x, _:
                    output = module(x)
        return output

class GlobalEmbedding(nn.Module):

    input_dim: int
    embed_dim: int

    mlp: MLP
    r"""The MLP that changes the input features to be summed (\phi in the paper)"""
    activation: nn.Module
    r"""The last activation after that MLP"""
    aggregation: nn.Module
    r"""The aggregation function (sum or mean. resolved using torch_geometric.nn.resolver.aggregation_resolver, so the choices are the same as in torch_geometric.nn.aggr.Multi)"""

    def __init__(self, input_dim, embed_dim, hidden_dim=256, n_layers = 3, aggregation:str = "mean", aggregation_args={}):
        super().__init__()
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.mlp = MLP(input_dim, *[hidden_dim]*n_layers, embed_dim, batchnorm=False, activation=nn.LeakyReLU)
        self.activation = nn.LeakyReLU()
        if "multi" in aggregation.lower(): 
            aggregation_args["mode"]= "proj"
            aggregation_args["mode_kwargs"] = {"in_channels": embed_dim, "out_channels": embed_dim, **aggregation_args.get("mode_kwargs", {})}
        self.aggregation = aggregation_resolver(aggregation, **aggregation_args)

    def forward(self, x: Batch):
        node_embeddings = self.activation(self.mlp(x.data)) #n_nodes_total, key_dim
        return self.aggregation(node_embeddings, ptr=x.ptr) #batch_size, key_dim


class SumformerInnerBlock(nn.Module):
    """
    Here we implement the sumformer "attention" block (in quotes, because it is not really attention)
    It is permutation-equivariant
    and almost equivalent to a 2-step MPNN on a disconnected graph with a single witness node.

    We implement the MLP-sumformer (not the polynomial sumformer). Why?
        1. Simpler.
        2. They do say that polynomial tends to train better at the beginning, but the MLP catches up, 
            and it’s on synthetic functions which may perform very differently from real data 
            (and gives an advantage to the polynomial sumformer, which has fewer parameters).

    """

    input_dim: int
    """dimension of the input features"""

    key_dim: int
    """Dimesion of the aggregate sigma"""

    hidden_dim: int
    """Dimension of the hidden layers of the MLPs"""

    aggreg_linear: nn.Linear

    psi: MLP

    def __init__(self, input_dim, hidden_dim=512, key_dim = 256 , output_dim=3, aggregation:str = "mean", aggregation_args={}, 
                 node_embed_n_layers=3, output_n_layers=3):
        super().__init__()
        self.input_dim = input_dim
        self.key_dim = key_dim
        self.hidden_dim = hidden_dim
        self.global_embedding = GlobalEmbedding(
                input_dim=input_dim, embed_dim=key_dim, 
                hidden_dim=hidden_dim, n_layers=node_embed_n_layers, 
                aggregation=aggregation, aggregation_args=aggregation_args
        ) 

        self.input_linear = nn.Linear(input_dim, hidden_dim)
        self.aggreg_linear = nn.Linear(key_dim, hidden_dim)
        self.psi = MLP(hidden_dim, *[hidden_dim]*output_n_layers, 10, 
                          batchnorm=False, activation=nn.LeakyReLU)

    def forward(self, x: Tensor|Batch):
        """This is a faster, equivalent formulation of the sumformer attention block.
        See my notes for the derivation (that i’ll transcribe to here at some point)

        Caution! This approximation may not be exact (but should still be universal)
        if the aggregation is not linear (ie sum or average).
        """
        if isinstance(x, Tensor): x = Batch.from_unbatched(x)
        assert isinstance(x, Batch)
        assert x.n_features == self.input_dim
        sigma = self.global_embedding(x)

        sigma_hiddendim = self.aggreg_linear(sigma) #batch_size, hidden_dim
        x_hiddendim = self.input_linear(x.data) #n_nodes_total, hidden_dim
        
        psi_input = x_hiddendim + sigma_hiddendim[x.batch, :] #n_nodes_total, hidden_dim
        psi_input = F.leaky_relu(psi_input) #n_nodes_total, hidden_dim

        return Batch.from_other(self.psi(psi_input), x) #n_nodes_total, input_dim

class SumformerBlock(nn.Sequential):
    """
    Inner SumformerBlock, with a residual connection and a layer norm.
    """
    
    def __init__(self, *block_args, **block_kwargs):
        super().__init__()
        block = SumformerInnerBlock(*block_args, **block_kwargs)
        residual_block = ResidualShortcut(block)
        self.add_module("residual_block", residual_block)
        self.add_module("norm", nn.LayerNorm(block.input_dim))

class Sumformer(Repeat):
    def __init__(self, num_blocks: int, *block_args, **block_kwargs):
        make_block = lambda: SumformerBlock(*block_args, **block_kwargs)
        super().__init__(num_blocks, make_block)

In [7]:
dataset = [torch.tensor([[1, 0, 0], 
            [0, 0, 1]], dtype=torch.float), 
           torch.tensor([[1, 1, 1], 
            [1, 2, 3]], dtype=torch.float)]
print(torch.norm(dataset[0], dim=1))
print(torch.div(dataset[0], torch.tensor([2, 3]).unsqueeze(-1)))
print (F.softmax(dataset[0], dim=1))
#print(torch.div(dataset[0],torch.norm(dataset[0], dim=1)))

batch = Batch.from_list(dataset, order=1)

tensor([1., 1.])
tensor([[0.5000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.3333]])
tensor([[0.5761, 0.2119, 0.2119],
        [0.2119, 0.2119, 0.5761]])


In [8]:
print(batch)
batch.normalize()
print(batch)


Batch(data=tensor([[1., 0., 0.],
        [0., 0., 1.],
        [1., 1., 1.],
        [1., 2., 3.]]), order=1, indicator=BatchIndicator(n_nodes=tensor([2, 2]), n_edges=None, ptr1=None, ptr2=None, batch1=None, batch2=None, diagonal=None, transpose_indices=None))
Batch(data=tensor([[1., 0., 0.],
        [0., 0., 1.],
        [1., 1., 1.],
        [1., 2., 3.]]), order=1, indicator=BatchIndicator(n_nodes=tensor([2, 2]), n_edges=None, ptr1=None, ptr2=None, batch1=None, batch2=None, diagonal=None, transpose_indices=None))


In [9]:
model = Sumformer(num_blocks=1, input_dim=3, hidden_dim=20, key_dim=3)
linear = nn.Linear(in_features = 3, out_features = 5)
model.to('cuda:0')
linear.to('cuda:0')

Linear(in_features=3, out_features=5, bias=True)

In [11]:
import torch.nn.functional as F
batch = batch.to('cuda:0')
out = model(batch)
probs = F.softmax(linear(out), dim=1)
print(probs)


Batch(data=tensor([[0.0446, 0.1156, 0.2075, 0.5637, 0.0686],
        [0.2569, 0.1081, 0.2178, 0.1546, 0.2626],
        [0.0607, 0.1271, 0.2437, 0.4806, 0.0879],
        [0.4420, 0.0538, 0.0986, 0.0485, 0.3570]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>), order=1, indicator=BatchIndicator(n_nodes=tensor([2, 2], device='cuda:0'), n_edges=None, ptr1=tensor([0, 2, 4], device='cuda:0'), ptr2=None, batch1=tensor([0, 0, 1, 1], device='cuda:0'), batch2=None, diagonal=None, transpose_indices=None))


In [18]:
print(batch)
batch.indicator
start = 0
for num in batch.n_nodes:
    end = start + num
    ptset = batch.data[start:end]
    print(ptset.shape)
    b_probs = probs.data[start:end]
    
    print(torch.mm(b_probs.T, ptset).shape)

Batch(data=tensor([[1., 0., 0.],
        [0., 0., 1.],
        [1., 1., 1.],
        [1., 2., 3.]]), order=1, indicator=BatchIndicator(n_nodes=tensor([2, 2]), n_edges=None, ptr1=tensor([0, 2, 4]), ptr2=None, batch1=tensor([0, 0, 1, 1]), batch2=None, diagonal=None, transpose_indices=None))
torch.Size([2, 3])
torch.Size([5, 3])
torch.Size([2, 3])
torch.Size([5, 3])


In [5]:
import numpy as np


array = np.load('/data/riley/for_sam/50_uniform_points.npy')
print(array[0][:, :-1])

[[0.58607938 0.1528506 ]
 [0.95847123 0.31168894]
 [0.836457   0.9830779 ]
 [0.32845917 0.45555548]
 [0.83689396 0.35559389]
 [0.3697229  0.86049048]
 [0.8512059  0.22971335]
 [0.92234859 0.0066988 ]
 [0.72984807 0.89840314]
 [0.37502894 0.35430839]
 [0.76617551 0.19779512]
 [0.78968389 0.16864712]
 [0.72171767 0.37874082]
 [0.70169063 0.55751477]
 [0.51761923 0.65631835]
 [0.15445026 0.87729761]
 [0.42434019 0.86491709]
 [0.66051889 0.17870329]
 [0.20330657 0.7728875 ]
 [0.14385089 0.44179211]
 [0.07265356 0.56502268]
 [0.35568253 0.03984401]
 [0.51125517 0.55209259]
 [0.87119226 0.42437854]
 [0.61632771 0.58700833]
 [0.25267784 0.07193735]
 [0.74810602 0.17122641]
 [0.30043845 0.60265021]
 [0.32310421 0.13110836]
 [0.05549359 0.60915485]
 [0.83762695 0.42213959]
 [0.52746938 0.88113292]
 [0.56956314 0.17724093]
 [0.98072008 0.8517689 ]
 [0.42008371 0.40821963]
 [0.70841029 0.88022145]
 [0.82401381 0.52359692]
 [0.65264608 0.27955114]
 [0.76541357 0.77453099]
 [0.72382187 0.7132233 ]


In [18]:
import yaml

model_params = {}
widths = [16, 32, 48, 64, 80, 96]
for w in widths:
    mname = f'depth-{2}-ed-{16}-hd-{w}-od-{w}'
    model_params[mname] = {'depth': 3,
                           'embedding_dim': 16,
                           'hidden_dim': w,
                            'output_dim': w,
                            'input_dim': 3
                           }

file=open("model-configs/change-output-dim-1-in-dim-3.yml","w")
yaml.dump(model_params,file)

In [13]:
import numpy as np
from scipy.spatial import ConvexHull

array = np.load('/data/riley/for_sam/ply_data_train0_CH.npy')

downsampled_samples = []
for i in range(array.shape[0]):
    pt_set = array[i]
    idx = np.random.choice(np.arange(array.shape[1]), size=50)
    new_pt_set = pt_set[idx, :3]
    chull = ConvexHull(new_pt_set)
    verts = chull.vertices

    chull_col = np.zeros(50)
    chull_col[verts] = 1.0
    chull_col = np.expand_dims(chull_col, axis=1)
    sample = np.concatenate((new_pt_set, chull_col), axis=1)
    downsampled_samples.append(sample)



In [15]:
np.save('/data/sam/coreset/data/50_mnet.npy', downsampled_samples)

KeyboardInterrupt: 