In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.shared.wavelet import (
    dwt, 
    iwt,
)

torch.set_printoptions(precision=4, sci_mode=False)

In [2]:
n = 32
entry = torch.randn(6, 3, n, n)
dwted =  dwt(dwt(entry))
iwted = iwt(iwt(dwted))
dwted.shape, (entry - iwted).abs().max()

(torch.Size([6, 48, 8, 8]), tensor(    0.0000))

In [3]:
from collections import OrderedDict

from src.shared.faces import make_cube_faces
from src.shared.sides import (    
    to_vertices,
    make_phi_theta,
    sphered_vertices,
    to_spherical,
)



class Coarse(nn.Module):
    def __init__(self, n, r=0.5):
        super(Coarse, self).__init__()        
        self.n = n
        phi, theta = make_phi_theta(n)
        self.register_buffer('phi', phi)
        self.register_buffer('theta', theta)
        self.register_buffer('faces', make_cube_faces(n))                
        self.register_buffer('sphere', to_spherical(sphered_vertices(n, r)))
        self.net = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, 32, 3, padding=1, padding_mode='reflect')),
            ('relu1', nn.LeakyReLU(0.1)),
            ('conv2', nn.Conv2d(32,3, 3, padding=1, padding_mode='reflect')),
            ('relu2', nn.Sigmoid())
        ]))
        self.radii = torch.nn.Parameter(torch.zeros(3, *phi.shape))    
   
    def get_ellipsoidal(self, radii):
        x = radii[:, 0, :, :] * torch.sin(self.theta) * torch.cos(self.phi)
        y = radii[:, 1, :, :] * torch.sin(self.theta) * torch.sin(self.phi)
        z = radii[:, 2, :, :] * torch.cos(self.theta) 
        return torch.stack((x, y, z), dim=1)   
    
    def forward(self):        
        radii = self.net(self.sphere) +0.5
        ellipsoidal = self.get_ellipsoidal(radii) + self.sphere
        vert = to_vertices(ellipsoidal)
        return vert, self.faces
n = 16
coarse = Coarse(n)

v, f = coarse()
print(v.shape, f.shape)


meshplot.plot(v.detach().numpy(), f.numpy())

torch.Size([1536, 3]) torch.Size([3068, 3])


NameError: name 'meshplot' is not defined

In [4]:
import meshplot

meshplot.plot(v.detach().numpy(), f.numpy())

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-0.566085…

<meshplot.Viewer.Viewer at 0x7fe27073e350>

tensor(1.5821, device='cuda:0', grad_fn=<MaxBackward1>)

In [95]:
t = torch.randn(4, 4).cuda()
t.requires_grad_(True)

optimizer = torch.optim.Adam([t], lr=0.01)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.01
    weight_decay: 0
)

In [96]:
def outside_bound(x, minima, maxima):
    mx = torch.tensor(maxima).cuda()
    mi = torch.tensor(minima).cuda()
    mx_mask =  torch.where(x > mx, 1., 0.).detach()
    mi_mask =  torch.where(x < mi, 1., 0.).detach()
    mask = mi_mask + mx_mask
    return (x * mask).abs().mean()

start = t.detach().clone()
print(t)
for step in range(1000):
    optimizer.zero_grad()
    loss = outside_bound(t, -0.1, 0.3)    
    loss.backward()
    print(step, loss.item())
    optimizer.step()    
t-start

tensor([[-0.7088, -0.7346,  0.8056, -0.2323],
        [ 0.7319, -0.3856, -0.7708, -1.6826],
        [ 1.6372,  0.3474, -0.2518,  0.5446],
        [-1.6594,  0.9039, -0.8347, -0.4844]], device='cuda:0',
       requires_grad=True)
0 0.7947365045547485
1 0.7847365140914917
2 0.7747365236282349
3 0.764736533164978
4 0.7547365427017212
5 0.7261481285095215
6 0.7167730927467346
7 0.7073980569839478
8 0.6980230808258057
9 0.6886481046676636
10 0.6792731285095215
11 0.6698981523513794
12 0.6605231761932373
13 0.6511481404304504
14 0.6360033750534058
15 0.6272534132003784
16 0.6127679347991943
17 0.6046428680419922
18 0.5965179204940796
19 0.588392972946167
20 0.5802679061889648
21 0.5721429586410522
22 0.5640179514884949
23 0.5558929443359375
24 0.5477679967880249
25 0.5212278366088867
26 0.5137278437614441
27 0.5062278509140015
28 0.49872785806655884
29 0.4852498173713684
30 0.4783748388290405
31 0.47149980068206787
32 0.46462482213974
33 0.4577498435974121
34 0.45087486505508423
35 0.4439998

875 0.0
876 0.0
877 0.0
878 0.0
879 0.0
880 0.0
881 0.0
882 0.0
883 0.0
884 0.0
885 0.0
886 0.0
887 0.0
888 0.0
889 0.0
890 0.0
891 0.0
892 0.0
893 0.0
894 0.0
895 0.0
896 0.0
897 0.0
898 0.0
899 0.0
900 0.0
901 0.0
902 0.0
903 0.0
904 0.0
905 0.0
906 0.0
907 0.0
908 0.0
909 0.0
910 0.0
911 0.0
912 0.0
913 0.0
914 0.0
915 0.0
916 0.0
917 0.0
918 0.0
919 0.0
920 0.0
921 0.0
922 0.0
923 0.0
924 0.0
925 0.0
926 0.0
927 0.0
928 0.0
929 0.0
930 0.0
931 0.0
932 0.0
933 0.0
934 0.0
935 0.0
936 0.0
937 0.0
938 0.0
939 0.0
940 0.0
941 0.0
942 0.0
943 0.0
944 0.0
945 0.0
946 0.0
947 0.0
948 0.0
949 0.0
950 0.0
951 0.0
952 0.0
953 0.0
954 0.0
955 0.0
956 0.0
957 0.0
958 0.0
959 0.0
960 0.0
961 0.0
962 0.0
963 0.0
964 0.0
965 0.0
966 0.0
967 0.0
968 0.0
969 0.0
970 0.0
971 0.0
972 0.0
973 0.0
974 0.0
975 0.0
976 0.0
977 0.0
978 0.0
979 0.0
980 0.0
981 0.0
982 0.0
983 0.0
984 0.0
985 0.0
986 0.0
987 0.0
988 0.0
989 0.0
990 0.0
991 0.0
992 0.0
993 0.0
994 0.0
995 0.0
996 0.0
997 0.0
998 0.0
999 0.0


tensor([[ 0.7071,  0.7368, -0.6082,  0.2399],
        [-0.5391,  0.3912,  0.7764,  1.6830],
        [-1.4335, -0.1346,  0.2608, -0.3516],
        [ 1.6530, -0.7071,  0.8360,  0.4898]], device='cuda:0',
       grad_fn=<SubBackward0>)

In [87]:
start.abs() - t.abs()

tensor([[0.4747, 0.8223, 0.4925, 0.4389],
        [0.1561, 0.4197, 0.0000, 0.7334],
        [0.5435, 0.5205, 0.5587, 0.8802],
        [0.9137, 0.8416, 0.8931, 0.5579]], device='cuda:0',
       grad_fn=<SubBackward0>)

In [88]:
t, start

(tensor([[-0.0495,  0.3893, -0.0524, -0.0438],
         [-0.0063, -0.0389,  0.0975, -0.1351],
         [ 0.2280, -0.0552, -0.0588,  0.8083],
         [-1.3316,  0.4914, -0.9688,  0.2328]], device='cuda:0',
        requires_grad=True), tensor([[-0.5243,  1.2116, -0.5449, -0.4827],
         [-0.1623, -0.4586,  0.0975, -0.8686],
         [ 0.7715, -0.5757, -0.6175,  1.6885],
         [-2.2453,  1.3330, -1.8619,  0.7907]], device='cuda:0'))