In [1]:
!pip install einops



In [2]:
import torch
import torch.nn as nn
from einops import rearrange

In [3]:
def unit_vector(a, dim=None):
    '''
    Returns the unit vector with the direction of a
    :param dim:
    :param a:
    :return:
    '''
    return a / a.norm(dim=dim).unsqueeze(-1)

def random_on_unit_sphere(size, device='cpu'):
    # We use the method in https://stats.stackexchange.com/questions/7977/how-to-generate-uniformly-distributed-points-on-the-surface-of-the-3-d-unit-sphe
    # to produce vectors on the surface of a unit sphere

    x = torch.randn(size)
    l = torch.sqrt(torch.sum(torch.pow(x, 2), dim=-1)).unsqueeze(1)
    x = (x / l).to(device)

    return x

In [4]:
N_samples = 64
near = 2.
far = 6.
N_rays = 1024
EPOCHS = 50000
LR = 0.01
print(f'Batch_size: {N_rays * N_samples}')

Batch_size: 65536


In [5]:
device = torch.device('cuda')

In [6]:
class CurveModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(7, 128)
        self.relu = nn.LeakyReLU()
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 3)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [7]:
model = CurveModel()
model.train()
model.to(device)

CurveModel(
  (fc1): Linear(in_features=7, out_features=128, bias=True)
  (relu): LeakyReLU(negative_slope=0.01)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=3, bias=True)
)

In [8]:
opt = torch.optim.SGD(model.parameters(), lr=LR)
criterion = nn.MSELoss()
lmbda = lambda epoch: 0.98
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(opt, lr_lambda=lmbda)

In [9]:
loss_history = []

In [10]:
for i in range(EPOCHS):
    rays_o = random_on_unit_sphere((N_rays, 3)) * 4
    rays_d = -1 * torch.rand((N_rays, 3)) + 1 # This is distributed -1 to 1
    rays_d = unit_vector(rays_d)

    t_vals = torch.linspace(0., 1., steps=N_samples)
    z_vals = near * (1. - t_vals) + far * t_vals
    z_vals = z_vals.expand([N_rays, N_samples])
    target = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]  # [N_rays, N_samples, 3]
    target = target.to(device)
    
    input_batch = torch.cat((rays_o, rays_d), dim=-1).unsqueeze(1)
    input_batch = input_batch.expand(N_rays, N_samples, 6)
    input_batch = torch.cat((input_batch, z_vals.unsqueeze(-1)), dim=-1)
    input_batch = input_batch.to(device)
    
    output = model(input_batch)
    
    loss = criterion(output, target)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    if i % 500 == 0:
        #scheduler.step()
        print(f'Epoch: {i}. Loss: {loss.item()}')
        loss_history.append(loss.item())


Epoch: 0. Loss: 5.369784355163574
Epoch: 500. Loss: 0.011314112693071365
Epoch: 1000. Loss: 0.005781370215117931
Epoch: 1500. Loss: 0.004164324142038822
Epoch: 2000. Loss: 0.003451303346082568
Epoch: 2500. Loss: 0.003086695447564125
Epoch: 3000. Loss: 0.002840508706867695
Epoch: 3500. Loss: 0.002643626183271408
Epoch: 4000. Loss: 0.002531772945076227
Epoch: 4500. Loss: 0.0025178538635373116
Epoch: 5000. Loss: 0.002326932270079851
Epoch: 5500. Loss: 0.00232268706895411
Epoch: 6000. Loss: 0.0021695788018405437
Epoch: 6500. Loss: 0.0022157086059451103
Epoch: 7000. Loss: 0.002162887481972575
Epoch: 7500. Loss: 0.002142961136996746
Epoch: 8000. Loss: 0.0021078824065625668
Epoch: 8500. Loss: 0.0020528368186205626
Epoch: 9000. Loss: 0.0020586387254297733
Epoch: 9500. Loss: 0.002013945020735264
Epoch: 10000. Loss: 0.0020280135795474052
Epoch: 10500. Loss: 0.0019237319938838482
Epoch: 11000. Loss: 0.0019508841214701533
Epoch: 11500. Loss: 0.0019506517564877868
Epoch: 12000. Loss: 0.001910810358

In [15]:
output[1,0]

tensor([ 1.5020, -2.7223,  2.5142], device='cuda:0', grad_fn=<SelectBackward>)

In [16]:
target[1,0]

tensor([ 1.4858, -2.7177,  2.5315], device='cuda:0')

In [17]:
torch.save(model.state_dict(), './straight_model.pth')