In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import argoverse

import sys
import os
sys.path.append('./')
sys.path.append('../')

from EquiCtsConv import *
from EquiLinear import *

In [24]:
torch.utils.backcompat

TypeError: get_enabled() missing 1 required positional argument: 'self'

In [2]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

In [3]:
class VehicleEncoder(nn.Module):
    def __init__(self, 
                 num_radii = 3, 
                 num_theta = 16, 
                 reg_dim = 8,
                 radius_scale = 40,
                 timestep = 0.1,
                 in_channel = 19, 
                 layer_channels = [8, 16, 16]
                 ):
        super(VehicleEncoder, self).__init__()
        
        # init parameters
        
        self.num_radii = num_radii
        self.num_theta = num_theta
        self.reg_dim = reg_dim
        self.radius_scale = radius_scale
        self.timestep = timestep
        self.layer_channels = layer_channels
        
        self.in_channel = in_channel
        self.activation = F.relu
        # self.relu_shift = torch.nn.parameter.Parameter(torch.tensor(0.2))
        # relu_shift = torch.tensor(0.2)
        # self.register_buffer('relu_shift', relu_shift)
        
        # create continuous convolution and fully-connected layers
        
        convs = []
        denses = []
        # c_in, c_out, radius, num_radii, num_theta
        self.conv_vehicle = EquiCtsConv2dRho1ToReg(in_channels = self.in_channel, 
                                                 out_channels = self.layer_channels[0],
                                                 num_radii = self.num_radii, 
                                                 num_theta = self.num_theta,
                                                 radius = self.radius_scale, 
                                                 k = self.reg_dim)

        self.dense_vehicle = nn.Sequential(
            EquiLinearRho1ToReg(self.reg_dim), 
            EquiLinearRegToReg(self.in_channel, self.layer_channels[0], self.reg_dim)
        )
        
        # concat conv_obstacle, conv_fluid, dense_fluid
        in_ch = 2 * self.layer_channels[0] 
        for i in range(1, len(self.layer_channels)):
            out_ch = self.layer_channels[i]
            dense = EquiLinearRegToReg(in_ch, out_ch, self.reg_dim)
            denses.append(dense)
            conv = EquiCtsConv2dRegToReg(in_channels = in_ch, 
                                         out_channels = out_ch,
                                         num_radii = self.num_radii, 
                                         num_theta = self.num_theta,
                                         radius = self.radius_scale, 
                                         k = self.reg_dim)
            convs.append(conv)
            in_ch = self.layer_channels[i]
        
        self.convs = nn.ModuleList(convs)
        self.denses = nn.ModuleList(denses)
        

    def encode(self, p, vehicle_feats, car_mask):

        output_conv_vehicle = self.conv_vehicle(p, p, vehicle_feats, car_mask)
        output_dense_vehicle = self.dense_vehicle(vehicle_feats)
        
        output = torch.cat((output_conv_vehicle, output_dense_vehicle), -2)
        
        for conv, dense in zip(self.convs, self.denses):
            # pass input features to conv and fully-connected layers
            # mags = (torch.sum(output**2,axis=-1) + 1e-6).unsqueeze(-1)
            # in_feats = output/mags * self.activation(mags - self.relu_shift)
            in_feats = self.activation(output)
            output_conv = conv(p, p, in_feats, car_mask)
            output_dense = dense(in_feats)
            
            if output_dense.shape[-2] == output.shape[-2]:
                output = output_conv + output_dense + output
            else:
                output = output_conv + output_dense

        output = self.activation(output)
        
        return output
    
    def forward(self, inputs):
        """ inputs: 8 elems tuple
        p0_enc, v0_enc, p0, v0, feats, box, box_feats
        v0_enc: [batch, num_part, timestamps, 2]
        """
        p0_enc, v0_enc, p0, v0, car_mask = inputs
            
        feats = torch.cat((v0.unsqueeze(-2), v0_enc), -2)

        hidden_feature = self.encode(p0, feats, car_mask)

        return hidden_feature

In [4]:
class MapEncoder(nn.Module):
    def __init__(self, 
                 num_radii = 3, 
                 num_theta = 16, 
                 reg_dim = 8,
                 radius_scale = 40,
                 hidden_size: int = 8
                 ):
        super(MapEncoder, self).__init__()
        
        # init parameters
        
        self.num_radii = num_radii
        self.num_theta = num_theta
        self.reg_dim = reg_dim
        self.radius_scale = radius_scale
        self.hidden_size = hidden_size
        
        self.activation = F.relu
        # self.relu_shift = torch.nn.parameter.Parameter(torch.tensor(0.2))
        # relu_shift = torch.tensor(0.2)
        # self.register_buffer('relu_shift', relu_shift)
        
        # create continuous convolution and fully-connected layers
        
        # c_in, c_out, radius, num_radii, num_theta
        self.conv_map = EquiCtsConv2dRho1ToReg(in_channels = 1, 
                                               out_channels = self.hidden_size,
                                               num_radii = self.num_radii, 
                                               num_theta = self.num_theta,
                                               radius = self.radius_scale, 
                                               k = self.reg_dim)

    def forward(self, p, map_p, map_feat, map_mask):
        
        output = self.conv_map(map_p, p, map_feat.unsqueeze(-2), map_mask)
        
        output = self.activation(output)

        return output

In [5]:
class ModeDecoder(nn.Module):
    def __init__(self, vehicle_hidden=16, map_hidden=8, reg_dim=8, modes=6):
        super(ModeDecoder, self).__init__()
        in_channel = vehicle_hidden + map_hidden
        self.modes = modes
        if modes == 1:
            self.mode_decoder = lambda x: None
        else:
            self.mode_decoder = EquiLinearRegToReg(in_channel, modes, reg_dim)
        
    def forward(self, feat):
        """
        feat: shape (batch, num_vehicles, v+m, reg_dim)
        
        return: shape (batch, modes)
        """
        if self.modes == 1:
            return self.mode_decoder(feat)
        else:
            # mode_pred, _ = self.mode_decoder(feat).norm(dim=-1).topk(k=1, dim=1)
            # mode_pred = mode_pred.squeeze(1)
            mode_pred = self.mode_decoder(feat).norm(dim=-1).permute(0,2,1)
            mode_pred = F.max_pool1d(mode_pred, mode_pred.shape[-1]).squeeze(-1)
            return F.softmax(mode_pred, -1)

In [6]:
class TrajectoryDecoder(nn.Module):
    def __init__(self, 
                 num_radii = 3, 
                 num_theta = 16, 
                 reg_dim = 8,
                 radius_scale = 40,
                 timestep = 0.1,
                 vehicle_hidden = 16, 
                 map_hidden = 8,
                 layer_channels = [8, 8, 3], 
                 predict_window = 30, 
                 map_encoder = None
                 ):
        super(TrajectoryDecoder, self).__init__()
        
        # init parameters
        
        self.num_radii = num_radii
        self.num_theta = num_theta
        self.reg_dim = reg_dim
        self.radius_scale = radius_scale
        self.timestep = timestep
        self.predict_window = predict_window
        self.vehicle_hidden = vehicle_hidden
        self.map_hidden = map_hidden
        self.layer_channels = layer_channels
        
        self.in_channel = vehicle_hidden + map_hidden
        self.activation = F.relu
        # self.relu_shift = torch.nn.parameter.Parameter(torch.tensor(0.2))
        # relu_shift = torch.tensor(0.2)
        # self.register_buffer('relu_shift', relu_shift)
        
        # create continuous convolution and fully-connected layers
        
        convs = []
        denses = []
        # c_in, c_out, radius, num_radii, num_theta
        self.conv_vehicle = EquiCtsConv2dRegToReg(in_channels = self.in_channel, 
                                                   out_channels = self.layer_channels[0],
                                                   num_radii = self.num_radii, 
                                                   num_theta = self.num_theta,
                                                   radius = self.radius_scale, 
                                                   k = self.reg_dim)
        
        if map_encoder:
            self.map_encoder = map_encoder
        else:
            self.map_encoder = MapEncoder(num_radii = num_radii, 
                                          num_theta = num_theta, 
                                          reg_dim = reg_dim,
                                          radius_scale = radius_scale,
                                          hidden_size = map_hidden)

        self.dense_vehicle = EquiLinearRegToReg(self.in_channel, self.layer_channels[0], self.reg_dim)
        
        # concat conv_obstacle, conv_fluid, dense_fluid
        in_ch = self.layer_channels[0] 
        for i in range(1, len(self.layer_channels)):
            out_ch = self.layer_channels[i]
            dense = EquiLinearRegToReg(in_ch, out_ch, self.reg_dim)
            denses.append(dense)
            conv = EquiCtsConv2dRegToReg(in_channels = in_ch, 
                                         out_channels = out_ch,
                                         num_radii = self.num_radii, 
                                         num_theta = self.num_theta,
                                         radius = self.radius_scale, 
                                         k = self.reg_dim)
            convs.append(conv)
            in_ch = self.layer_channels[i]
        
        self.convs = nn.ModuleList(convs)
        self.denses = nn.ModuleList(denses)
        
        self.dense_back = EquiLinearRegToReg(self.layer_channels[-1], vehicle_hidden, self.reg_dim)
        self.dense_reg2rho1 = EquiLinearRegToRho1(self.reg_dim)
        
    def decode(self, p, feat, map_p, map_feat, car_mask, map_mask):
        output_conv_vehicle = self.conv_vehicle(p, p, feat, car_mask)
        output_dense_vehicle = self.dense_vehicle(feat)
        
        output = output_conv_vehicle + output_dense_vehicle
        
        for conv, dense in zip(self.convs, self.denses):
            # pass input features to conv and fully-connected layers
            # mags = (torch.sum(output**2,axis=-1) + 1e-6).unsqueeze(-1)
            # in_feats = output/mags * self.activation(mags - self.relu_shift)
            in_feats = self.activation(output)
            output_conv = conv(p, p, in_feats, car_mask)
            output_dense = dense(in_feats)
            
            if output_dense.shape[-2] == output.shape[-2]:
                output = output_conv + output_dense + output
            else:
                output = output_conv + output_dense

        output = self.activation(output)

        return output
    
    def forward(self, p, feat, map_p, map_feat, car_mask, map_mask):
        outputs = []
        
        output = self.decode(p, feat, map_p, map_feat, car_mask, map_mask)
        delta_p_dist = self.dense_reg2rho1(output)
        pred = delta_p_dist.clone()
        pred[...,0,:] = pred[...,0,:] + p
        outputs.append(pred)
        
        for t in range(1, self.predict_window):
            p = p + delta_p_dist[...,0,:]
            back_feat = torch.tanh(self.dense_back(output))
            
            feat = feat[...,:self.vehicle_hidden,:] * back_feat
            encode_map = self.map_encoder(p, map_p, map_feat, map_mask)
            feat = torch.cat([feat, encode_map], dim=-2)
            
        
            output = self.decode(p, feat, map_p, map_feat, car_mask, map_mask)
            delta_p_dist = self.dense_reg2rho1(output)
            pred = delta_p_dist.clone()
            pred[...,0,:] = pred[...,0,:] + p
            outputs.append(pred)
            
        return outputs
    
    def reset_predict_window(self, window):
        self.predict_window = window

In [19]:
class MultiModePECCO(nn.Module):
    def __init__(self, 
                 num_radii = 3, 
                 num_theta = 16, 
                 reg_dim = 8,
                 radius_scale = 40,
                 timestep = 0.1,
                 in_channel = 19,
                 map_hidden = 8, 
                 encoder_channels = [8, 16, 16],
                 decoder_channels = [8, 3], 
                 predict_window = 30, 
                 modes = 6):
        super(MultiModePECCO, self).__init__()
        
        self.modes = modes
        
        self.vehicle_encoder = VehicleEncoder(num_radii = num_radii, 
                                              num_theta = num_theta, 
                                              reg_dim = reg_dim,
                                              radius_scale = radius_scale,
                                              timestep = timestep,
                                              in_channel = in_channel, 
                                              layer_channels = encoder_channels)
        
        self.map_encoder = MapEncoder(num_radii = num_radii, 
                                      num_theta = num_theta, 
                                      reg_dim = reg_dim,
                                      radius_scale = radius_scale,
                                      hidden_size = map_hidden)
        
        self.mode_decoder = ModeDecoder(vehicle_hidden=encoder_channels[-1], 
                                        map_hidden=map_hidden, 
                                        reg_dim=reg_dim, modes=modes)
        
        self.traj_decoder = []
        for m in range(modes):
            traj_decoder_m = TrajectoryDecoder(num_radii = num_radii, 
                                               num_theta = num_theta, 
                                               reg_dim = reg_dim,
                                               radius_scale = radius_scale,
                                               timestep = timestep,
                                               vehicle_hidden = encoder_channels[-1], 
                                               map_hidden = map_hidden,
                                               layer_channels = decoder_channels, 
                                               predict_window = predict_window, 
                                               map_encoder = self.map_encoder)
            self.traj_decoder.append(traj_decoder_m)
            
        self.traj_decoder = nn.ModuleList(self.traj_decoder)
        
    def forward(self, inputs):
        p_enc, v_enc, p, v, map_p, map_feat, car_mask, map_mask = inputs
        
        vehicle_hidden = self.vehicle_encoder((p_enc, v_enc, p, v, car_mask))
        map_hidden = self.map_encoder(p, map_p, map_feat, map_mask)
        
        feat = torch.cat([vehicle_hidden, map_hidden], dim=-2)
        
        traj_preds = []
        for m in range(self.modes):
            traj_pred = self.traj_decoder[m](p, feat, map_p, map_feat, car_mask, map_mask)
            traj_preds.append(traj_pred)
            
        mode_pred = self.mode_decoder(feat)
        
        return traj_preds, mode_pred
    
    def reset_predict_window(self, window):
        for m in range(self.modes):
            self.traj_decoder[m].reset_predict_window(window)

In [8]:
model = MultiModePECCO(predict_window=10).to(device)

In [9]:
def quadratic_func(x, M):
    part1 = torch.einsum('...x,...xy->...y', x, M)
    return torch.einsum('...x,...x->...', part1, x)

def calc_sigma(M):
    M1 = torch.tanh(M)
    sigma = torch.einsum('...xy,...xz->...yz', M1, M1)
    return torch.matrix_exp(sigma)

def nll_loss(pred, gt, mask):
    mu = pred[...,0,:]
    # sigma = torch.einsum('...xy,...xz->...yz', pred[mask>0][...,1:,:], pred[mask>0][...,1:,:])
    sigma = calc_sigma(pred[...,1:,:])
    nll = quadratic_func(gt - mu, sigma.inverse()) + torch.log(sigma.det())
    return nll * mask

def nll_loss_per_sample(preds, data):
    loss = 0
    for i, pred in enumerate(preds):
        loss = loss + nll_loss(pred, data['pos'+str(i+1)], data['car_mask'][...,0])
    return loss / (i + 1)

def nll_loss_multimodes(preds, data, modes_pred, noise=0.0):
    """NLL loss multimodes for training.
    Args:
        pred is a list (with N modes) of predictions
        data is ground truth    
        noise is optional
    """
    modes = len(preds)
    log_lik = []   
    with torch.no_grad():
        for pred in preds:
            nll = nll_loss_per_sample(pred, data)
            log_lik.append(-nll.unsqueeze(-1))
        log_lik = torch.cat(log_lik, -1)
  
    priors = modes_pred.detach().unsqueeze(1)
    print(log_lik.shape, priors.shape)
    log_posterior_unnorm = log_lik + torch.log(priors)
    log_posterior_unnorm += torch.randn(*log_posterior_unnorm.shape).to(log_lik.device) * noise
    log_posterior = log_posterior_unnorm - torch.logsumexp(log_posterior_unnorm, axis=-1).unsqueeze(-1)
    post_pr = torch.exp(log_posterior)

    loss = 0.0
    for m, pred in enumerate(preds):
        nll_k = nll_loss_per_sample(pred, data) * post_pr[...,m] 
        nll_k = nll_k.sum(-1) / data['car_mask'][...,0].sum(-1)
        loss += nll_k.sum()

    kl_loss = torch.nn.KLDivLoss(reduction='batchmean')
    loss += kl_loss(torch.log(modes_pred.unsqueeze(1)), post_pr) 
    return loss 

def nonsingular_loss(multi_preds, epsilon=0.01):
    loss = 0.
    for i, preds in enumerate(multi_preds):
        for j, pred in enumerate(preds):
            loss = loss + torch.relu(epsilon - pred[...,1:,:].det()).mean()
    return loss / (i + j + 2)

In [10]:
from datasets.argoverse_lane_loader import read_pkl_data
import tqdm

In [11]:
datapath = '/home/leo/particle/argoverse/argoverse_forecasting/train/lane_data/'

In [12]:
dataset = read_pkl_data(datapath, 2, max_car=30)

In [13]:
batch = next(dataset)

In [14]:
convert_keys = (['pos' + str(i) for i in range(31)] + 
                ['vel' + str(i) for i in range(31)] + 
                ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])
for data in tqdm.tqdm(dataset):
    for k in convert_keys:
        if np.isnan(data[k]).any():
            break

2741it [00:03, 723.76it/s]


KeyboardInterrupt: 

In [14]:
batch_size = len(batch['pos0'])

batch_tensor = {}
convert_keys = (['pos' + str(i) for i in range(31)] + 
                ['vel' + str(i) for i in range(31)] + 
                ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])

for k in convert_keys:
    batch_tensor[k] = torch.tensor(np.stack(batch[k]), dtype=torch.float32, device=device)

for k in ['car_mask', 'lane_mask']:
    batch_tensor[k] = torch.tensor(np.stack(batch[k]), dtype=torch.float32, device=device).unsqueeze(-1)

for k in ['track_id' + str(i) for i in range(31)] + ['city', 'agent_id', 'scene_idx']:
    batch_tensor[k] = np.stack(batch[k])

batch_tensor['car_mask'] = batch_tensor['car_mask'].squeeze(-1)
batch_tensor['agent_id'] = batch_tensor['agent_id'][:,np.newaxis]

In [15]:
batch_tensor['car_mask'].shape

torch.Size([2, 30, 1])

In [16]:
# p_enc, v_enc, p, v, map_p, map_feat, car_mask, map_mask
inputs = ([
    batch_tensor['pos_2s'], batch_tensor['vel_2s'], 
    batch_tensor['pos0'], batch_tensor['vel0'], 
    batch_tensor['lane'], batch_tensor['lane_norm'], 
    batch_tensor['car_mask'], batch_tensor['lane_mask']
])

In [17]:
traj_preds, mode_pred = model(inputs)

tensor(9.2967, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4486, device='cuda:0', grad_fn=<MaxBackward1>)


In [26]:
traj_preds[0][9]

tensor([[[[ 3.2643e+03,  1.9516e+03],
          [-5.7332e-01, -2.6134e+00],
          [-9.7433e-01, -7.8999e-01]],

         [[ 3.2783e+03,  1.9783e+03],
          [-1.5892e-01, -1.7729e-01],
          [ 4.2915e-06,  3.8850e-02]],

         [[ 3.2376e+03,  1.9290e+03],
          [-4.9331e+00, -1.0109e+01],
          [-2.5961e+00, -1.7283e+00]],

         [[ 3.2768e+03,  1.9761e+03],
          [-1.5572e-01, -1.7676e-01],
          [-5.9481e-02,  3.0278e-02]],

         [[ 3.2186e+03,  1.9395e+03],
          [ 8.1726e-01, -1.9085e+00],
          [-1.4666e+00,  1.1846e+00]],

         [[ 3.2678e+03,  1.9213e+03],
          [ 3.9941e+00,  4.2584e+00],
          [ 7.6987e-01,  5.9982e-01]],

         [[ 0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+0

In [18]:
loss = nll_loss_multimodes(traj_preds, batch_tensor, mode_pred)
loss

torch.Size([2, 30, 6]) torch.Size([2, 1, 6])


tensor(20.6208, device='cuda:0', grad_fn=<AddBackward0>)

In [45]:
loss.backward()

In [46]:
traj_pred0 = traj_preds[0]
gt = p_gt['pos1']

In [53]:
nll_loss(traj_pred0[0], gt, car_mask).shape

torch.Size([2, 30])

In [22]:
with torch.no_grad():
    model.reset_predict_window(30)
    traj_preds_eval, mode_pred_eval = model(inputs)

In [25]:
traj_pred_one_mode = traj_preds_eval[0]

In [51]:
batch_tensor.keys()

dict_keys(['pos0', 'pos1', 'pos2', 'pos3', 'pos4', 'pos5', 'pos6', 'pos7', 'pos8', 'pos9', 'pos10', 'vel0', 'vel1', 'vel2', 'vel3', 'vel4', 'vel5', 'vel6', 'vel7', 'vel8', 'vel9', 'vel10', 'pos_2s', 'vel_2s', 'lane', 'lane_norm', 'car_mask', 'lane_mask', 'track_id0', 'track_id1', 'track_id2', 'track_id3', 'track_id4', 'track_id5', 'track_id6', 'track_id7', 'track_id8', 'track_id9', 'track_id10', 'track_id11', 'track_id12', 'track_id13', 'track_id14', 'track_id15', 'track_id16', 'track_id17', 'track_id18', 'track_id19', 'track_id20', 'track_id21', 'track_id22', 'track_id23', 'track_id24', 'track_id25', 'track_id26', 'track_id27', 'track_id28', 'track_id29', 'track_id30', 'city', 'agent_id', 'scene_idx'])

In [58]:
def get_de(traj_pred_one_mode, batch_tensor):
    des = []
    for t, pred in enumerate(traj_pred_one_mode):
        pred_agent, gt_agent = get_agent(pred, batch_tensor['pos'+str(t+1)], 
                               batch_tensor['track_id0'], batch_tensor['track_id'+str(t+1)], 
                               batch_tensor['agent_id'])
        des.append(torch.norm(pred_agent[...,0,:] - gt_agent, dim=-1))
    return torch.stack(des, -1)

def get_de_multi_modes(traj_pred, batch_tensor):
    des = []
    for m, preds in enumerate(traj_pred):
        des.append(get_de(preds, batch_tensor))
    return torch.stack(des, -1)

In [60]:
get_de_multi_modes(traj_preds_eval, batch_tensor)

torch.Size([2, 30, 6])

In [46]:
def get_agent(pr, gt, pr_id, gt_id, agent_id, device='cpu'):
    pr_agent = pr[pr_id == agent_id,:]
    gt_agent = gt[gt_id == agent_id,:]
    
    return pr_agent, gt_agent

In [8]:
batch_size = 2

p_enc = torch.rand(batch_size, 30, 18, 2).to(device)
v_enc = torch.rand(batch_size, 30, 18, 2).to(device)
p = torch.rand(batch_size, 30, 2).to(device)
v = torch.rand(batch_size, 30, 2).to(device)
map_p = torch.rand(batch_size, 650, 2).to(device)
map_feat = torch.rand(batch_size, 650, 2).to(device)
car_mask = torch.tensor([1.] * 30).to(device)
map_mask = torch.tensor([1.] * 650).to(device)

p_gt = {'pos{}'.format(i + 1): torch.rand(batch_size, 30, 2).to(device) for i in range(30)}

In [10]:
inputs = (p_enc, v_enc, p, v, map_p, map_feat, car_mask, map_mask)
traj_preds, mode_pred = model(inputs)

In [13]:
vehicle_encoder = VehicleEncoder().to(device)
map_encoder = MapEncoder().to(device)

traj_decoder = TrajectoryDecoder(predict_window=30, map_encoder=map_encoder).to(device)
mode_decoder = ModeDecoder(modes=6).to(device)

In [17]:
vehicle_hidden = model.vehicle_encoder((p_enc, v_enc, p, v, car_mask))

In [16]:
encode_map = map_encoder(p, map_p, map_feat, map_mask)
feat = torch.cat([vehicle_hidden, encode_map], dim=-2)

In [17]:
traj_pred = traj_decoder(p, feat, map_p, map_feat, car_mask, map_mask)

In [24]:
mode_decoder(feat)

tensor([[0.1376, 0.1893, 0.1734, 0.1731, 0.1294, 0.1972],
        [0.1357, 0.1906, 0.1737, 0.1734, 0.1272, 0.1994],
        [0.1367, 0.1899, 0.1734, 0.1732, 0.1285, 0.1983],
        [0.1365, 0.1900, 0.1736, 0.1732, 0.1284, 0.1983]], device='cuda:0',
       grad_fn=<SoftmaxBackward>)

In [27]:
torch.tensor(
    [[[1],[1],[1],[2],[2],[2],[3],[3],[3]]]
).reshape(3,1,3,1)

tensor([[[[1],
          [1],
          [1]]],


        [[[2],
          [2],
          [2]]],


        [[[3],
          [3],
          [3]]]])

In [54]:
torch.tensor([[-3.0457e+07, -2.2553e+07],
         [-2.2553e+07, -1.6700e+07]]).det()

tensor(-5.8916e+09)