In [2]:
import sys  
sys.path.insert(0, '../')

In [29]:
import torch
import torch.nn as nn
import libs.other_helpers as u
from einops import rearrange

In [68]:
N_samples = 16
near = 2.
far = 6.
N_rays = 8
EPOCHS = 200
LR = 0.001
print(f'Batch_size: {N_rays * N_samples}')

Batch_size: 128


In [56]:
device = torch.device('cuda')

In [57]:
class CurveModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(7, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 3)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [58]:
model = CurveModel()
model.train()
model.to(device)

In [69]:
opt = torch.optim.SGD(model.parameters(), lr=LR)
criterion = nn.MSELoss()

In [70]:
for i in range(EPOCHS):
    rays_o = u.random_on_unit_sphere((N_rays, 3)) * 4
    rays_d = -1 * torch.rand((N_rays, 3)) + 1 # This is distributed -1 to 1
    rays_d = u.unit_vector(rays_d)

    t_vals = torch.linspace(0., 1., steps=N_samples)
    z_vals = near * (1. - t_vals) + far * t_vals
    z_vals = z_vals.expand([N_rays, N_samples])
    target = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]  # [N_rays, N_samples, 3]
    target = target.to(device)
    
    input_batch = torch.cat((rays_o, rays_d), dim=-1).unsqueeze(1)
    input_batch = input_batch.expand(N_rays, N_samples, 6)
    input_batch = torch.cat((input_batch, z_vals.unsqueeze(-1)), dim=-1)
    input_batch = input_batch.to(device)
    
    output = model(input_batch)
    
    loss = criterion(output, target)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    if i % 10 == 0:
        print(f'Epoch: {i}. Loss: {loss.item()}')


Epoch: 0. Loss: 0.14436160027980804
Epoch: 10. Loss: 0.21754953265190125
Epoch: 20. Loss: 0.21686062216758728
Epoch: 30. Loss: 0.09655234217643738
Epoch: 40. Loss: 0.18292856216430664
Epoch: 50. Loss: 0.17143933475017548
Epoch: 60. Loss: 0.18324628472328186
Epoch: 70. Loss: 0.22196708619594574
Epoch: 80. Loss: 0.18775534629821777
Epoch: 90. Loss: 0.1800617277622223
Epoch: 100. Loss: 0.209504172205925
Epoch: 110. Loss: 0.17791514098644257
Epoch: 120. Loss: 0.17417508363723755
Epoch: 130. Loss: 0.18488049507141113
Epoch: 140. Loss: 0.14253106713294983
Epoch: 150. Loss: 0.1533195972442627
Epoch: 160. Loss: 0.20403504371643066
Epoch: 170. Loss: 0.14890572428703308
Epoch: 180. Loss: 0.19020722806453705
Epoch: 190. Loss: 0.14024381339550018


In [71]:
output[0,0]

tensor([3.5599, 2.3184, 0.9406], device='cuda:0', grad_fn=<SelectBackward>)

In [72]:
target[0,0]

tensor([3.6514, 2.3680, 0.7674], device='cuda:0')

In [12]:
rays_o.shape, rays_d.shape, z_vals.shape

(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 4]))

In [30]:
input_batch.shape

torch.Size([2, 4, 7])

In [33]:
rearrange(input_batch, 'r s d -> (r s) d').shape

torch.Size([8, 7])

In [40]:
model(input_batch).shape

torch.Size([2, 4, 3])