In [None]:
import torch
import e3cnn.nn as enn
import e3cnn.gspaces as gspaces

gspace = gspaces.rot3dOnR3()

in_type  = enn.FieldType(gspace, [gspace.irrep(0)])
out_type = enn.FieldType(gspace, [gspace.irrep(0), gspace.irrep(1), gspace.irrep(2)])

model = enn.R3Conv(in_type, out_type, kernel_size=5)

In [None]:
cube = gspaces.octaOnR3()

In [None]:
import numpy as np
def kernel_so3(L: int) -> np.ndarray:
    dims = [ 2 * l + 1 for l in range(L+1)]
    V = np.concatenate([np.eye(d).flatten() * np.sqrt(d) for d in dims])
    return V

def kernel_sphere(gspace: gspaces.GSpace3D, L: int) -> np.ndarray:
    sphere = gspace.fibergroup.homspace((False, -1))
    identity = gspace.fibergroup.identity
    return np.concatenate([
        sphere.basis(identity, (l,), (0,)).flatten() * np.sqrt(2*l+1)
        for l in range(L+1)
    ])

In [None]:
dims = [ 2 * l + 1 for l in range(1+1)]
V = np.concatenate([np.eye(d).flatten() for d in dims])

In [None]:
V

In [None]:
kernel = kernel_sphere(gspace, 1)

grid = gspace.fibergroup.sphere_grid('tetra')

rho = gspace.fibergroup.bl_quotient_representation(1, (False,-1))

In [None]:
kernel = kernel_so3(1)

grid = gspace.fibergroup.grid('tetra')

rho = gspace.fibergroup.bl_regular_representation(1)

In [None]:
A = np.stack([ kernel @ rho(g).T for g in grid ])
# A /= np.sqrt(len(A))

In [None]:
A

In [None]:
A /= np.sqrt(len(A))

In [None]:
A

In [None]:
np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)

In [None]:
(A @ A.T)

In [None]:
k @ gspace.fibergroup.bl_quotient_representation(2, (False,-1))(grid[-5]).T.round()

In [None]:
cube.fibergroup.cube_vertices_representation

In [None]:
[x.shape for name, x in list(model.basisexpansion.named_buffers())]

In [None]:
model.basisexpansion.dimension()

In [None]:
for info in model.basisexpansion.get_basis_info():
    print(f"{info['in_irrep'][0]} -> {info['out_irrep'][0]} at radius {info['radius']} shape: {info['shape']}")

In [None]:
[(name, list(x.shape)) for name, x in list(model.named_parameters())]

In [None]:
import sys
sys.path.append('..')
from data_utils.dataset import AsocaClassificationDataset, AsocaSegmentationDataset

In [None]:
ds_seg = AsocaSegmentationDataset(ds_path='../dataset/processed/')

In [None]:
ds = AsocaClassificationDataset(ds_path='../dataset/classification//')

In [None]:
len(ds)

In [None]:
ds[-1][0].shape

In [None]:
from data_utils.datamodule import AsocaClassificationDataModule

In [None]:
adm = AsocaClassificationDataModule(data_dir='../dataset/classification/')

In [None]:
dl = adm.train_dataloader()

In [None]:
next(iter(dl))

In [None]:
adm.val_dataloader().dataset.file_ids

In [None]:
def get_req(batch_size, channels):
    n_params = sum([
        batch_size * channels * 32**3,
        batch_size * channels * 32**3,
        batch_size * channels * 14**3,
        batch_size * channels * 14**3,
        batch_size * channels * 5**3,
        batch_size * channels * 5**3,
        512*2
    ])
    x = n_params * 2 # forward + backward
    x = x * 32 # 32 bit floats
    x = x / 8  # in bytes
    x = x / 1024**3 # in GB
    return x, n_params

In [None]:
get_req(32, 480)

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
from train import AsocaClassificationConfig, AsocaSegmentationConfig

In [None]:
config = AsocaClassificationConfig((68,68,68))

In [None]:
config_seg = AsocaSegmentationConfig((68,68,68), patch_stride=(60,60,60))

In [None]:
config_seg

In [None]:
config

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def kernel_so3(L: int) -> np.ndarray:
    dims = [ 2 * l + 1 for l in range(L+1)]
    V = np.concatenate([np.eye(d).flatten() * np.sqrt(d) for d in dims])
    return V

def kernel_sphere(gspace: gspaces.GSpace3D, L: int) -> np.ndarray:
    sphere = gspace.fibergroup.homspace((False, -1))
    identity = gspace.fibergroup.identity

    return np.concatenate([
        sphere.basis(identity, (l,), (0,)).flatten() * np.sqrt(2*l+1)
        for l in range(L+1)
    ])

In [None]:
gspace

In [None]:
max_freq = 1

In [None]:
sph_kernel = kernel_sphere(gspace, max_freq)
reg_kernel = kernel_so3(max_freq)

In [None]:
sph_grid = gspace.fibergroup.sphere_grid(type='thomson_cube', N=2)
reg_grid = gspace.fibergroup.grid(type='thomson_cube', N=2)

In [None]:
sph_rho = gspace.fibergroup.bl_quotient_representation(max_freq, (False,-1))
reg_rho = gspace.fibergroup.bl_regular_representation(max_freq)

In [None]:
reg_grid.__len__()

In [None]:
reg_grid

In [None]:
plt.imshow(reg_rho(reg_grid[1]))

In [None]:
sph_rho(sph_grid[0])