In [1]:
!pip install einops



In [9]:
import torch
import torch.nn as nn
from einops import rearrange
import math

In [10]:
def unit_vector(a, dim=None):
    '''
    Returns the unit vector with the direction of a
    :param dim:
    :param a:
    :return:
    '''
    return a / a.norm(dim=dim).unsqueeze(-1)

def random_on_unit_sphere(size, device='cpu'):
    # We use the method in https://stats.stackexchange.com/questions/7977/how-to-generate-uniformly-distributed-points-on-the-surface-of-the-3-d-unit-sphe
    # to produce vectors on the surface of a unit sphere

    x = torch.randn(size)
    l = torch.sqrt(torch.sum(torch.pow(x, 2), dim=-1)).unsqueeze(1)
    x = (x / l).to(device)

    return x

In [11]:
N_samples = 64
near = 2.
far = 6.
N_rays = 1024
EPOCHS = 50000
LR = 0.01
print(f'Batch_size: {N_rays * N_samples}')

Batch_size: 65536


In [12]:
device = torch.device('cuda')

In [13]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [32]:
class CurveModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.pos_encoder = PositionalEncoding(20)
        self.fc_1 = nn.Linear(6, 256)
        self.fc_enc = nn.Linear(20, 256)
        self.relu = nn.ReLU()
        self.fc_final = nn.Linear(256, 3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, p):
        x = self.fc_1(x)
        p = self.pos_encoder(p)
        p = self.fc_enc(p)
        x = self.relu(x + p)
        x = self.fc_final(x)
        
        return x

In [33]:
model = CurveModel()
model.train()
model.to(device)

CurveModel(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (fc_1): Linear(in_features=6, out_features=256, bias=True)
  (fc_enc): Linear(in_features=20, out_features=256, bias=True)
  (relu): ReLU()
  (fc_final): Linear(in_features=256, out_features=3, bias=True)
  (sigmoid): Sigmoid()
)

In [34]:
opt = torch.optim.SGD(model.parameters(), lr=LR)
criterion = nn.MSELoss()
lmbda = lambda epoch: 0.98
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(opt, lr_lambda=lmbda)

In [35]:
loss_history = []

In [None]:
for i in range(EPOCHS):
    rays_o = random_on_unit_sphere((N_rays, 3))
    rays_d = -1 * torch.rand((N_rays, 3)) + 1 # This is distributed -1 to 1
    rays_d = unit_vector(rays_d)
    z_vals = torch.rand((N_rays, N_samples, 1))

    target = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :]  # [N_rays, N_samples, 3]
    target = target.to(device)
    
    input_batch = torch.cat((rays_o, rays_d), dim=-1).unsqueeze(1)
    input_batch = input_batch.expand(N_rays, N_samples, 6)
    #input_batch = torch.cat((input_batch, z_vals.unsqueeze(-1)), dim=-1)
    z_vals = z_vals.to(device)
    input_batch = input_batch.to(device)
    
    output = model(input_batch, z_vals)
    
    loss = criterion(output, target)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    if i % 500 == 0:
        #scheduler.step()
        print(f'Epoch: {i}. Loss: {loss.item()}')
        loss_history.append(loss.item())
        print(output[1,0].detach().cpu().numpy(), target[1,0].detach().cpu().numpy())


Epoch: 0. Loss: 0.0026099744718521833
[-0.54483694 -0.450928    0.6316081 ] [-0.52307534 -0.5470368   0.6533295 ]
Epoch: 500. Loss: 0.002454501809552312
[-0.16613536  0.7168212  -0.48061556] [-0.1913549  0.7369966 -0.6513281]
Epoch: 1000. Loss: 0.002320326864719391
[ 0.28074166 -0.5798911  -0.6983054 ] [ 0.31694487 -0.604796   -0.7234049 ]
Epoch: 1500. Loss: 0.0021373755298554897
[ 0.00872508 -0.97128135  0.28719175] [-5.7892432e-04 -9.4649714e-01  3.1966606e-01]


In [None]:
output[2,0]

In [40]:
target[2,0]

tensor([ 0.7021, -0.7014,  0.1878])

In [17]:
torch.save(model.state_dict(), './straight_model.pth')